patches.py 10.1 KB
Newer Older
1
2
3
4
5
6
7
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

"""SGLang-specific patches for GPU Memory Service integration.

- patch_torch_memory_saver: Routes to GMS hybrid implementation
- patch_model_runner: Fixes memory accounting with pre-loaded weights
8
- patch_static_state_for_gms: No-ops named-buffer export/import (GMS preserves them)
9
10
11
12
"""

from __future__ import annotations

13
import inspect
14
import logging
15
from contextlib import contextmanager
16
17
18
19
20
21
22
23
from typing import Optional

import torch

logger = logging.getLogger(__name__)

_torch_memory_saver_patched = False
_model_runner_patched = False
24
_static_state_patched = False
25
26
27
28
29
30
31
32
33
34
35
36
37


def patch_torch_memory_saver() -> None:
    """Patch torch_memory_saver to use GPU Memory Service implementation.

    This function is idempotent - calling it multiple times has no effect.
    This patch is only applied when GMSModelLoader is imported (load_format="gms").
    """
    global _torch_memory_saver_patched
    if _torch_memory_saver_patched:
        return

    try:
38
        import torch_memory_saver
39
40
41
42
43
44
45
        import torch_memory_saver.entrypoint as entrypoint_module
    except ImportError:
        logger.debug("[GMS] torch_memory_saver not installed, skipping patch")
        return

    # Store reference to original method
    original_ensure_initialized = entrypoint_module.TorchMemorySaver._ensure_initialized
46
    original_configure_subprocess = torch_memory_saver.configure_subprocess
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68

    def patched_ensure_initialized(self):
        """Patched _ensure_initialized that uses GPU Memory Service implementation."""
        # Check if already initialized
        if self._impl is not None:
            logger.debug("[GMS] TorchMemorySaver already initialized, skipping")
            return

        # Check hook_mode - use GMS for None or explicit "gms"
        hook_mode = self._impl_ctor_kwargs.get("hook_mode")
        logger.info(f"[GMS] TorchMemorySaver initializing with hook_mode={hook_mode}")

        if hook_mode is None or hook_mode == "gms":
            # Use our GPU Memory Service implementation
            from gpu_memory_service.integrations.sglang.memory_saver import (
                GMSMemorySaverImpl,
            )
            from torch_memory_saver.entrypoint import _TorchMemorySaverImpl

            # Get device from torch.cuda.current_device() (already set by SGLang)
            device_index = torch.cuda.current_device()

69
            # Create underlying torch impl for non-GMS tags.
70
71
            torch_impl = _TorchMemorySaverImpl(hook_mode="torch")

72
73
74
            # Read lock mode set by setup_gms() (defaults to RW_OR_RO)
            from gpu_memory_service.integrations.sglang import _gms_lock_mode

75
76
77
            gms_impl = GMSMemorySaverImpl(
                torch_impl=torch_impl,
                device_index=device_index,
78
                mode=_gms_lock_mode,
79
80
81
82
83
            )

            # Set _impl directly (accessible via gms_impl property)
            self._impl = gms_impl
            logger.info(
84
                "[GMS] Using GMS mode (device=%d, mode=%s)",
85
86
87
88
89
90
91
92
93
94
95
                device_index,
                gms_impl.get_mode(),
            )
            del self._impl_ctor_kwargs
        else:
            # Fall back to original implementation
            logger.info("[GMS] Using default torch_memory_saver hook mode")
            original_ensure_initialized(self)

    entrypoint_module.TorchMemorySaver._ensure_initialized = patched_ensure_initialized

96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
    @contextmanager
    def patched_configure_subprocess():
        """Avoid LD_PRELOAD in GMS mode; keep upstream behavior otherwise."""
        singleton = torch_memory_saver.torch_memory_saver
        ctor_kwargs = getattr(singleton, "_impl_ctor_kwargs", None) or {}
        hook_mode = ctor_kwargs.get("hook_mode")

        if hook_mode is None or hook_mode == "gms":
            logger.info("[GMS] torch_memory_saver.configure_subprocess is a no-op")
            yield
            return

        with original_configure_subprocess():
            yield

    torch_memory_saver.configure_subprocess = patched_configure_subprocess

113
114
115
116
117
118
119
120
121
122
123
124
    # Add property to access GMS impl directly from the singleton
    from gpu_memory_service.integrations.sglang.memory_saver import GMSMemorySaverImpl

    @property
    def gms_impl(self) -> Optional[GMSMemorySaverImpl]:
        """Get the GMS impl if installed, None otherwise."""
        if isinstance(self._impl, GMSMemorySaverImpl):
            return self._impl
        return None

    entrypoint_module.TorchMemorySaver.gms_impl = gms_impl

125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
    # If the singleton was already initialized before this patch ran (e.g.,
    # due to import ordering in multiprocessing spawn), reset _impl so the
    # next call to _ensure_initialized goes through the patched version and
    # creates GMSMemorySaverImpl instead of the default _TorchMemorySaverImpl.
    import torch_memory_saver

    singleton = torch_memory_saver.torch_memory_saver
    if singleton._impl is not None:
        logger.debug(
            "[GMS] TorchMemorySaver singleton already initialized, "
            "resetting to force GMS re-init on next use"
        )
        singleton._impl = None
        # The original _ensure_initialized deletes _impl_ctor_kwargs after
        # creating _impl.  Restore it so the patched version can read it.
        if not hasattr(singleton, "_impl_ctor_kwargs"):
            singleton._impl_ctor_kwargs = {}

143
144
145
146
147
148
149
    _torch_memory_saver_patched = True
    logger.debug("[GMS] Patched torch_memory_saver")


def patch_model_runner() -> None:
    """Patch SGLang's ModelRunner to fix memory accounting with pre-loaded weights.

150
151
152
153
    SGLang 0.5.9 passes a startup free-memory snapshot as total_gpu_memory into
    init_memory_pool(). In GMS read mode, imported weights can already occupy GPU
    memory, so that snapshot is lower than physical device capacity and the KV cache
    overhead term is under-reserved.
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
    """
    global _model_runner_patched

    if _model_runner_patched:
        return

    try:
        from sglang.srt.model_executor.model_runner import ModelRunner
    except ImportError:
        logger.warning("[GMS] Could not import ModelRunner, skipping patch")
        return

    if hasattr(ModelRunner, "_gms_patched"):
        return

    original_init_memory_pool = ModelRunner.init_memory_pool
170
171
172
173
174
175
176
177
    memory_arg_name = next(
        (
            name
            for name in inspect.signature(original_init_memory_pool).parameters
            if name != "self"
        ),
        None,
    )
178
179

    def patched_init_memory_pool(self, *args, **kwargs):
180
181
182
183
184
185
186
187
        """Patch init_memory_pool for SGLang versions that use total_gpu_memory.

        SGLang's KV cache formula uses total_gpu_memory as the baseline:
        rest_memory = available - total*(1-mem_fraction).
        Replace that baseline with physical device capacity when GMS imported
        weights are already resident. Newer SGLang versions changed this API, so
        only rewrite the old total_gpu_memory parameter shape.
        """
188
189
190
191
192
193
        from gpu_memory_service.integrations.sglang.memory_saver import (
            get_gms_memory_saver_impl,
        )

        impl = get_gms_memory_saver_impl()
        if impl is not None and impl.get_imported_weights_bytes() > 0:
194
            total_memory_gib = torch.cuda.get_device_properties(
195
                torch.cuda.current_device()
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
            ).total_memory / (1 << 30)
            if memory_arg_name == "total_gpu_memory":
                if args:
                    old_value = args[0]
                    args = (total_memory_gib,) + args[1:]
                elif memory_arg_name in kwargs:
                    old_value = kwargs[memory_arg_name]
                    kwargs = dict(kwargs)
                    kwargs[memory_arg_name] = total_memory_gib
                else:
                    old_value = None
                logger.info(
                    "[GMS] Adjusted total_gpu_memory: %s -> %.2f GiB",
                    (
                        f"{old_value:.2f} GiB"
                        if isinstance(old_value, (int, float))
                        else "<missing>"
                    ),
                    total_memory_gib,
                )
            elif memory_arg_name is not None:
217
                logger.info(
218
219
                    "[GMS] Leaving %s unchanged in patched init_memory_pool",
                    memory_arg_name,
220
221
222
223
224
225
226
227
                )

        return original_init_memory_pool(self, *args, **kwargs)

    ModelRunner.init_memory_pool = patched_init_memory_pool
    ModelRunner._gms_patched = True
    _model_runner_patched = True
    logger.info("[GMS] Patched ModelRunner.init_memory_pool")
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273


def patch_static_state_for_gms() -> None:
    """No-op SGLang's _export/_import_static_state when using GMS.

    SGLang's release_memory_occupation clones every named buffer via
    buffer.detach().clone() through the default CUDA allocator, then restores
    them during resume_memory_occupation.
    This patch must run inside the scheduler child process (which uses
    multiprocessing spawn).  It is triggered by the GMSModelLoader import
    in model_loader.py, which executes at module level in the child.
    """
    import os

    global _static_state_patched
    logger.info(
        "[GMS] patch_static_state_for_gms called (pid=%d, already_patched=%s)",
        os.getpid(),
        _static_state_patched,
    )
    if _static_state_patched:
        return

    try:
        from sglang.srt.managers import scheduler_update_weights_mixin as _mixin

        def _export_noop(model):
            """NO-OP: GMS preserves buffers via VA-stable unmap/remap."""
            return dict(buffers=[])

        def _import_noop(model, static_params):
            """NO-OP: GMS preserves buffers via VA-stable unmap/remap."""
            pass

        _mixin._export_static_state = _export_noop
        _mixin._import_static_state = _import_noop
        _static_state_patched = True
        logger.info(
            "[GMS] Patched _export/_import_static_state -> no-op (pid=%d)",
            os.getpid(),
        )
    except Exception:
        logger.warning(
            "[GMS] Could not patch scheduler_update_weights_mixin: ",
            exc_info=True,
        )