# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 """vLLM-specific patches for GPU Memory Service integration. This module contains vLLM-specific patches that are applied when the GMSWorker module is imported: - MemorySnapshot.measure patch (adjusts free memory for read mode) Note: The torch.cuda.empty_cache patch is in integrations/common/patches.py """ from __future__ import annotations import logging from gpu_memory_service import get_gms_client_memory_manager from gpu_memory_service.common.types import GrantedLockType from gpu_memory_service.integrations.vllm.utils import is_shadow_mode logger = logging.getLogger(__name__) _memory_snapshot_patched = False _request_memory_patched = False _register_kv_caches_patched = False # ============================================================================= # Core GMS patch (always applied) # ============================================================================= def patch_memory_snapshot() -> None: """Add committed GMS bytes to MemorySnapshot.free_memory""" global _memory_snapshot_patched if _memory_snapshot_patched: return try: from vllm.utils.mem_utils import MemorySnapshot except ImportError: logger.debug("[GMS Patch] MemorySnapshot not available") return original_measure = MemorySnapshot.measure def patched_measure(self): original_measure(self) manager = get_gms_client_memory_manager() assert manager is not None, "GMS client is not initialized" if manager.granted_lock_type == GrantedLockType.RO: allocations = manager.list_handles() committed_bytes = sum(alloc.get("aligned_size", 0) for alloc in allocations) else: # NOTE: by design, we want to assume we have the whole GPU when writing # weights for the first time, so we don't make an adjustment. committed_bytes = 0 logger.info("[GMS] RW mode - skipping committed memory adjustment") original_free = self.free_memory self.free_memory += committed_bytes if committed_bytes > 0: logger.info( "[GMS Patch] Adjusted free_memory: %.2f GiB + %.2f GiB = %.2f GiB", original_free / (1 << 30), committed_bytes / (1 << 30), self.free_memory / (1 << 30), ) MemorySnapshot.measure = patched_measure _memory_snapshot_patched = True logger.info("[GMS Patch] Patched MemorySnapshot.measure") # ============================================================================= # Shadow mode patches # ============================================================================= def patch_request_memory() -> None: """Bypass free >= requested check (shadow shares GPU with active engine).""" global _request_memory_patched if _request_memory_patched: return try: from vllm.v1.worker import utils as worker_utils except ImportError: logger.debug("[GMS Patch] vllm.v1.worker.utils not available") return def patched_request_memory(init_snapshot, cache_config): requested_memory = int( init_snapshot.total_memory * cache_config.gpu_memory_utilization ) logger.info( "[GMS Patch] Shadow mode: bypassing memory check " "(requested=%.2f GiB, free=%.2f GiB)", requested_memory / (1 << 30), init_snapshot.free_memory / (1 << 30), ) return requested_memory worker_utils.request_memory = patched_request_memory _request_memory_patched = True logger.info("[GMS Patch] Patched request_memory for shadow mode") def patch_register_kv_caches() -> None: """Skip NixlConnector.register_kv_caches when kv_caches is empty.""" global _register_kv_caches_patched if _register_kv_caches_patched: return try: from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import ( NixlConnector, ) except ImportError: logger.debug("[GMS Patch] NixlConnector not available") return original_register = NixlConnector.register_kv_caches def patched_register_kv_caches(self, kv_caches): if not kv_caches: logger.info("[GMS Patch] Skipping KV cache registration (empty kv_caches)") return return original_register(self, kv_caches) NixlConnector.register_kv_caches = patched_register_kv_caches _register_kv_caches_patched = True logger.info("[GMS Patch] Patched NixlConnector.register_kv_caches") # ============================================================================= # Patch application helper # ============================================================================= def apply_shadow_mode_patches() -> None: """Apply shadow mode monkey-patches. No-ops if not in shadow mode.""" if not is_shadow_mode(): return patch_request_memory() patch_register_kv_caches() logger.info("[GMS Patch] Shadow mode patches applied")