utils.py 2.68 KB
Newer Older
1
2
3
4
5
6
7
8
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

"""Common utilities shared across GMS integrations."""

from __future__ import annotations

import logging
9
from dataclasses import replace
10
11
12
13
14
15
16
17
18
19
from typing import TYPE_CHECKING

import torch

if TYPE_CHECKING:
    from gpu_memory_service.client.memory_manager import GMSClientMemoryManager

logger = logging.getLogger(__name__)


20
21
22
23
24
25
26
27
28
29
30
31
32
def get_gms_lock_mode(extra_config: dict):
    """Resolve GMS lock mode from model_loader_extra_config.

    Returns RO if gms_read_only=True, otherwise RW_OR_RO (default).
    """
    from gpu_memory_service.common.types import RequestedLockType

    if extra_config.get("gms_read_only", False):
        logger.info("[GMS] gms_read_only=True, forcing RO mode")
        return RequestedLockType.RO
    return RequestedLockType.RW_OR_RO


33
34
35
36
37
38
39
40
41
42
43
44
45
46
def strip_gms_model_loader_config(load_config, load_format: str):
    """Copy a loader config with GMS-only keys removed for backend loaders."""
    extra_config = getattr(load_config, "model_loader_extra_config", {}) or {}
    return replace(
        load_config,
        load_format=load_format,
        model_loader_extra_config={
            key: value
            for key, value in extra_config.items()
            if not key.startswith("gms_")
        },
    )


47
48
49
50
51
52
53
54
55
56
57
58
59
def setup_meta_tensor_workaround() -> None:
    """Enable workaround for meta tensor operations like torch.nonzero()."""
    try:
        import torch.fx.experimental._config as fx_config

        fx_config.meta_nonzero_assume_all_nonzero = True
    except (ImportError, AttributeError):
        pass


def finalize_gms_write(
    allocator: "GMSClientMemoryManager", model: torch.nn.Module
) -> int:
60
    """Finalize GMS write mode: register tensors, commit, reconnect in read mode.
61

62
    Flow: register tensors -> sync -> unmap + commit -> connect(RO) -> remap
63
64
65
66
67
68
69
70
71

    Args:
        allocator: The GMS client memory manager in write mode.
        model: The loaded model with weights to register.

    Returns:
        Total bytes committed.
    """
    from gpu_memory_service.client.torch.module import register_module_tensors
72
    from gpu_memory_service.common.types import RequestedLockType
73
74
75
76

    register_module_tensors(allocator, model)
    total_bytes = allocator.total_bytes

77
    # Synchronize before commit — caller's writes must be visible
78
79
    torch.cuda.synchronize()

80
    allocator.commit()
81

82
    allocator.connect(RequestedLockType.RO)
83
    allocator.remap_all_vas()
84
85
86
87
88
89
90
91

    logger.info(
        "[GMS] Committed %.2f GiB, switched to read mode with %d mappings",
        total_bytes / (1 << 30),
        len(allocator._mappings),
    )

    return int(total_bytes)