utils.py 2.41 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

"""Common utilities shared across GMS integrations."""

from __future__ import annotations

import logging
from typing import TYPE_CHECKING

import torch

if TYPE_CHECKING:
    from gpu_memory_service.client.memory_manager import GMSClientMemoryManager

logger = logging.getLogger(__name__)


19
20
21
22
23
24
25
26
27
28
29
30
31
def get_gms_lock_mode(extra_config: dict):
    """Resolve GMS lock mode from model_loader_extra_config.

    Returns RO if gms_read_only=True, otherwise RW_OR_RO (default).
    """
    from gpu_memory_service.common.types import RequestedLockType

    if extra_config.get("gms_read_only", False):
        logger.info("[GMS] gms_read_only=True, forcing RO mode")
        return RequestedLockType.RO
    return RequestedLockType.RW_OR_RO


32
33
34
35
36
37
38
39
40
41
42
43
44
45
def setup_meta_tensor_workaround() -> None:
    """Enable workaround for meta tensor operations like torch.nonzero()."""
    try:
        import torch.fx.experimental._config as fx_config

        fx_config.meta_nonzero_assume_all_nonzero = True
    except (ImportError, AttributeError):
        pass


def finalize_gms_write(
    allocator: "GMSClientMemoryManager", model: torch.nn.Module
) -> int:
    """Finalize GMS write mode: register tensors, commit, switch to read.
46
47

    Flow: register tensors -> sync -> commit (server-only) -> disconnect -> connect(RO)
48
49
50
51
52
53
54
55
56
57
58
59

    Args:
        allocator: The GMS client memory manager in write mode.
        model: The loaded model with weights to register.

    Returns:
        Total bytes committed.

    Raises:
        RuntimeError: If commit fails.
    """
    from gpu_memory_service.client.torch.module import register_module_tensors
60
    from gpu_memory_service.common.types import RequestedLockType
61
62
63
64

    register_module_tensors(allocator, model)
    total_bytes = allocator.total_bytes

65
    # Synchronize before commit — caller's writes must be visible
66
67
68
69
70
    torch.cuda.synchronize()

    if not allocator.commit():
        raise RuntimeError("GMS commit failed")

71
72
73
    # commit() closed the RW socket; acquire RO for inference
    allocator.disconnect()  # no-op if commit already cleared _client, but safe
    allocator.connect(RequestedLockType.RO)
74
75
76
77
78
79
80
81

    logger.info(
        "[GMS] Committed %.2f GiB, switched to read mode with %d mappings",
        total_bytes / (1 << 30),
        len(allocator._mappings),
    )

    return int(total_bytes)