patches.py 5.05 KB
Newer Older
1
2
3
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

4
"""vLLM-specific patches for GPU Memory Service integration.
5

6
This module contains vLLM-specific patches that are applied when the GMSWorker
7
8
module is imported:
- MemorySnapshot.measure patch (adjusts free memory for read mode)
9
10

Note: The torch.cuda.empty_cache patch is in integrations/common/patches.py
11
12
13
14
15
16
17
18
"""

from __future__ import annotations

import logging

from gpu_memory_service import get_gms_client_memory_manager
from gpu_memory_service.common.types import GrantedLockType
19
from gpu_memory_service.integrations.vllm.utils import is_shadow_mode
20
21
22
23

logger = logging.getLogger(__name__)

_memory_snapshot_patched = False
24
25
26
27
28
29
30
_request_memory_patched = False
_register_kv_caches_patched = False


# =============================================================================
# Core GMS patch (always applied)
# =============================================================================
31
32
33


def patch_memory_snapshot() -> None:
34
    """Add committed GMS bytes to MemorySnapshot.free_memory"""
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
    global _memory_snapshot_patched

    if _memory_snapshot_patched:
        return

    try:
        from vllm.utils.mem_utils import MemorySnapshot
    except ImportError:
        logger.debug("[GMS Patch] MemorySnapshot not available")
        return

    original_measure = MemorySnapshot.measure

    def patched_measure(self):
        original_measure(self)

51
        manager = get_gms_client_memory_manager("weights")
52
53
        assert manager is not None, "GMS client is not initialized"

54
55
        if manager.granted_lock_type == GrantedLockType.RO:
            allocations = manager.list_handles()
56
            committed_bytes = sum(alloc.aligned_size for alloc in allocations)
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
        else:
            # NOTE: by design, we want to assume we have the whole GPU when writing
            # weights for the first time, so we don't make an adjustment.
            committed_bytes = 0
            logger.info("[GMS] RW mode - skipping committed memory adjustment")

        original_free = self.free_memory
        self.free_memory += committed_bytes

        if committed_bytes > 0:
            logger.info(
                "[GMS Patch] Adjusted free_memory: %.2f GiB + %.2f GiB = %.2f GiB",
                original_free / (1 << 30),
                committed_bytes / (1 << 30),
                self.free_memory / (1 << 30),
            )

    MemorySnapshot.measure = patched_measure
    _memory_snapshot_patched = True
    logger.info("[GMS Patch] Patched MemorySnapshot.measure")
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154


# =============================================================================
# Shadow mode patches
# =============================================================================


def patch_request_memory() -> None:
    """Bypass free >= requested check (shadow shares GPU with active engine)."""
    global _request_memory_patched

    if _request_memory_patched:
        return

    try:
        from vllm.v1.worker import utils as worker_utils
    except ImportError:
        logger.debug("[GMS Patch] vllm.v1.worker.utils not available")
        return

    def patched_request_memory(init_snapshot, cache_config):
        requested_memory = int(
            init_snapshot.total_memory * cache_config.gpu_memory_utilization
        )
        logger.info(
            "[GMS Patch] Shadow mode: bypassing memory check "
            "(requested=%.2f GiB, free=%.2f GiB)",
            requested_memory / (1 << 30),
            init_snapshot.free_memory / (1 << 30),
        )
        return requested_memory

    worker_utils.request_memory = patched_request_memory
    _request_memory_patched = True
    logger.info("[GMS Patch] Patched request_memory for shadow mode")


def patch_register_kv_caches() -> None:
    """Skip NixlConnector.register_kv_caches when kv_caches is empty."""
    global _register_kv_caches_patched

    if _register_kv_caches_patched:
        return

    try:
        from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
            NixlConnector,
        )
    except ImportError:
        logger.debug("[GMS Patch] NixlConnector not available")
        return

    original_register = NixlConnector.register_kv_caches

    def patched_register_kv_caches(self, kv_caches):
        if not kv_caches:
            logger.info("[GMS Patch] Skipping KV cache registration (empty kv_caches)")
            return
        return original_register(self, kv_caches)

    NixlConnector.register_kv_caches = patched_register_kv_caches
    _register_kv_caches_patched = True
    logger.info("[GMS Patch] Patched NixlConnector.register_kv_caches")


# =============================================================================
# Patch application helper
# =============================================================================


def apply_shadow_mode_patches() -> None:
    """Apply shadow mode monkey-patches. No-ops if not in shadow mode."""
    if not is_shadow_mode():
        return

    patch_request_memory()
    patch_register_kv_caches()
    logger.info("[GMS Patch] Shadow mode patches applied")