Unverified Commit 91a8c6af authored by Schwinn Saereesitthipitak's avatar Schwinn Saereesitthipitak Committed by GitHub
Browse files

fix: simplify GMS layout state and harden GPU-backed flows (#7006)


Signed-off-by: default avatarmohammedabdulwahhab <furkhan324@berkeley.edu>
Signed-off-by: default avatarHannah Zhang <hannahz@nvidia.com>
Co-authored-by: default avatarmohammedabdulwahhab <furkhan324@berkeley.edu>
Co-authored-by: default avatarhhzhang16 <54051230+hhzhang16@users.noreply.github.com>
parent dd7ceb4a
...@@ -14,11 +14,13 @@ Usage: ...@@ -14,11 +14,13 @@ Usage:
from __future__ import annotations from __future__ import annotations
import logging import logging
from dataclasses import replace
import torch import torch
from gpu_memory_service.integrations.common import patch_empty_cache from gpu_memory_service.integrations.common import patch_empty_cache
from gpu_memory_service.integrations.common.utils import setup_meta_tensor_workaround from gpu_memory_service.integrations.common.utils import (
setup_meta_tensor_workaround,
strip_gms_model_loader_config,
)
from gpu_memory_service.integrations.sglang.patches import ( from gpu_memory_service.integrations.sglang.patches import (
patch_model_runner, patch_model_runner,
patch_static_state_for_gms, patch_static_state_for_gms,
...@@ -50,7 +52,10 @@ class GMSModelLoader: ...@@ -50,7 +52,10 @@ class GMSModelLoader:
if self._default_loader is None: if self._default_loader is None:
from sglang.srt.model_loader.loader import DefaultModelLoader from sglang.srt.model_loader.loader import DefaultModelLoader
config = replace(self.load_config, load_format="auto") config = strip_gms_model_loader_config(
self.load_config,
load_format="auto",
)
self._default_loader = DefaultModelLoader(config) self._default_loader = DefaultModelLoader(config)
return self._default_loader return self._default_loader
...@@ -124,7 +129,10 @@ class GMSModelLoader: ...@@ -124,7 +129,10 @@ class GMSModelLoader:
with meta_device: with meta_device:
model = get_model( model = get_model(
model_config=model_config, model_config=model_config,
load_config=replace(self.load_config, load_format="dummy"), load_config=strip_gms_model_loader_config(
self.load_config,
load_format="dummy",
),
device_config=device_config, device_config=device_config,
) )
......
...@@ -10,11 +10,12 @@ ...@@ -10,11 +10,12 @@
from __future__ import annotations from __future__ import annotations
import inspect
import logging import logging
from contextlib import contextmanager
from typing import Optional from typing import Optional
import torch import torch
from gpu_memory_service.common.utils import get_socket_path
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -34,6 +35,7 @@ def patch_torch_memory_saver() -> None: ...@@ -34,6 +35,7 @@ def patch_torch_memory_saver() -> None:
return return
try: try:
import torch_memory_saver
import torch_memory_saver.entrypoint as entrypoint_module import torch_memory_saver.entrypoint as entrypoint_module
except ImportError: except ImportError:
logger.debug("[GMS] torch_memory_saver not installed, skipping patch") logger.debug("[GMS] torch_memory_saver not installed, skipping patch")
...@@ -41,6 +43,7 @@ def patch_torch_memory_saver() -> None: ...@@ -41,6 +43,7 @@ def patch_torch_memory_saver() -> None:
# Store reference to original method # Store reference to original method
original_ensure_initialized = entrypoint_module.TorchMemorySaver._ensure_initialized original_ensure_initialized = entrypoint_module.TorchMemorySaver._ensure_initialized
original_configure_subprocess = torch_memory_saver.configure_subprocess
def patched_ensure_initialized(self): def patched_ensure_initialized(self):
"""Patched _ensure_initialized that uses GPU Memory Service implementation.""" """Patched _ensure_initialized that uses GPU Memory Service implementation."""
...@@ -63,10 +66,7 @@ def patch_torch_memory_saver() -> None: ...@@ -63,10 +66,7 @@ def patch_torch_memory_saver() -> None:
# Get device from torch.cuda.current_device() (already set by SGLang) # Get device from torch.cuda.current_device() (already set by SGLang)
device_index = torch.cuda.current_device() device_index = torch.cuda.current_device()
# Resolve socket path from env or default # Create underlying torch impl for non-GMS tags.
socket_path = get_socket_path(device_index)
# Create underlying torch impl for non-weights tags (KV cache etc.)
torch_impl = _TorchMemorySaverImpl(hook_mode="torch") torch_impl = _TorchMemorySaverImpl(hook_mode="torch")
# Read lock mode set by setup_gms() (defaults to RW_OR_RO) # Read lock mode set by setup_gms() (defaults to RW_OR_RO)
...@@ -74,7 +74,6 @@ def patch_torch_memory_saver() -> None: ...@@ -74,7 +74,6 @@ def patch_torch_memory_saver() -> None:
gms_impl = GMSMemorySaverImpl( gms_impl = GMSMemorySaverImpl(
torch_impl=torch_impl, torch_impl=torch_impl,
socket_path=socket_path,
device_index=device_index, device_index=device_index,
mode=_gms_lock_mode, mode=_gms_lock_mode,
) )
...@@ -82,9 +81,8 @@ def patch_torch_memory_saver() -> None: ...@@ -82,9 +81,8 @@ def patch_torch_memory_saver() -> None:
# Set _impl directly (accessible via gms_impl property) # Set _impl directly (accessible via gms_impl property)
self._impl = gms_impl self._impl = gms_impl
logger.info( logger.info(
"[GMS] Using GMS mode (device=%d, socket=%s, mode=%s)", "[GMS] Using GMS mode (device=%d, mode=%s)",
device_index, device_index,
socket_path,
gms_impl.get_mode(), gms_impl.get_mode(),
) )
del self._impl_ctor_kwargs del self._impl_ctor_kwargs
...@@ -95,6 +93,23 @@ def patch_torch_memory_saver() -> None: ...@@ -95,6 +93,23 @@ def patch_torch_memory_saver() -> None:
entrypoint_module.TorchMemorySaver._ensure_initialized = patched_ensure_initialized entrypoint_module.TorchMemorySaver._ensure_initialized = patched_ensure_initialized
@contextmanager
def patched_configure_subprocess():
"""Avoid LD_PRELOAD in GMS mode; keep upstream behavior otherwise."""
singleton = torch_memory_saver.torch_memory_saver
ctor_kwargs = getattr(singleton, "_impl_ctor_kwargs", None) or {}
hook_mode = ctor_kwargs.get("hook_mode")
if hook_mode is None or hook_mode == "gms":
logger.info("[GMS] torch_memory_saver.configure_subprocess is a no-op")
yield
return
with original_configure_subprocess():
yield
torch_memory_saver.configure_subprocess = patched_configure_subprocess
# Add property to access GMS impl directly from the singleton # Add property to access GMS impl directly from the singleton
from gpu_memory_service.integrations.sglang.memory_saver import GMSMemorySaverImpl from gpu_memory_service.integrations.sglang.memory_saver import GMSMemorySaverImpl
...@@ -132,9 +147,10 @@ def patch_torch_memory_saver() -> None: ...@@ -132,9 +147,10 @@ def patch_torch_memory_saver() -> None:
def patch_model_runner() -> None: def patch_model_runner() -> None:
"""Patch SGLang's ModelRunner to fix memory accounting with pre-loaded weights. """Patch SGLang's ModelRunner to fix memory accounting with pre-loaded weights.
When weights are pre-loaded via GMS (import-only mode), SGLang's min_per_gpu_memory SGLang 0.5.9 passes a startup free-memory snapshot as total_gpu_memory into
captured before loading is lower than device total. This causes under-reservation init_memory_pool(). In GMS read mode, imported weights can already occupy GPU
of overhead memory in KV cache calculation. memory, so that snapshot is lower than physical device capacity and the KV cache
overhead term is under-reserved.
""" """
global _model_runner_patched global _model_runner_patched
...@@ -151,25 +167,56 @@ def patch_model_runner() -> None: ...@@ -151,25 +167,56 @@ def patch_model_runner() -> None:
return return
original_init_memory_pool = ModelRunner.init_memory_pool original_init_memory_pool = ModelRunner.init_memory_pool
memory_arg_name = next(
(
name
for name in inspect.signature(original_init_memory_pool).parameters
if name != "self"
),
None,
)
def patched_init_memory_pool(self, *args, **kwargs): def patched_init_memory_pool(self, *args, **kwargs):
"""Patched init_memory_pool that uses device total for overhead calculation.""" """Patch init_memory_pool for SGLang versions that use total_gpu_memory.
SGLang's KV cache formula uses total_gpu_memory as the baseline:
rest_memory = available - total*(1-mem_fraction).
Replace that baseline with physical device capacity when GMS imported
weights are already resident. Newer SGLang versions changed this API, so
only rewrite the old total_gpu_memory parameter shape.
"""
from gpu_memory_service.integrations.sglang.memory_saver import ( from gpu_memory_service.integrations.sglang.memory_saver import (
get_gms_memory_saver_impl, get_gms_memory_saver_impl,
) )
impl = get_gms_memory_saver_impl() impl = get_gms_memory_saver_impl()
if impl is not None and impl.get_imported_weights_bytes() > 0: if impl is not None and impl.get_imported_weights_bytes() > 0:
total_memory = torch.cuda.get_device_properties( total_memory_gib = torch.cuda.get_device_properties(
torch.cuda.current_device() torch.cuda.current_device()
).total_memory ).total_memory / (1 << 30)
if hasattr(self, "min_per_gpu_memory"): if memory_arg_name == "total_gpu_memory":
old_value = self.min_per_gpu_memory if args:
self.min_per_gpu_memory = total_memory old_value = args[0]
args = (total_memory_gib,) + args[1:]
elif memory_arg_name in kwargs:
old_value = kwargs[memory_arg_name]
kwargs = dict(kwargs)
kwargs[memory_arg_name] = total_memory_gib
else:
old_value = None
logger.info(
"[GMS] Adjusted total_gpu_memory: %s -> %.2f GiB",
(
f"{old_value:.2f} GiB"
if isinstance(old_value, (int, float))
else "<missing>"
),
total_memory_gib,
)
elif memory_arg_name is not None:
logger.info( logger.info(
"[GMS] Adjusted min_per_gpu_memory: %.2f GiB -> %.2f GiB", "[GMS] Leaving %s unchanged in patched init_memory_pool",
old_value / (1 << 30), memory_arg_name,
total_memory / (1 << 30),
) )
return original_init_memory_pool(self, *args, **kwargs) return original_init_memory_pool(self, *args, **kwargs)
......
...@@ -11,11 +11,11 @@ processes import from GMS metadata (RO). ...@@ -11,11 +11,11 @@ processes import from GMS metadata (RO).
from __future__ import annotations from __future__ import annotations
import logging import logging
from dataclasses import replace
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
import torch import torch
from gpu_memory_service import get_or_create_gms_client_memory_manager from gpu_memory_service import get_or_create_gms_client_memory_manager
from gpu_memory_service.client.torch.allocator import gms_use_mem_pool
from gpu_memory_service.client.torch.module import materialize_module_from_gms from gpu_memory_service.client.torch.module import materialize_module_from_gms
from gpu_memory_service.common.types import GrantedLockType from gpu_memory_service.common.types import GrantedLockType
from gpu_memory_service.common.utils import get_socket_path from gpu_memory_service.common.utils import get_socket_path
...@@ -23,6 +23,7 @@ from gpu_memory_service.integrations.common.utils import ( ...@@ -23,6 +23,7 @@ from gpu_memory_service.integrations.common.utils import (
finalize_gms_write, finalize_gms_write,
get_gms_lock_mode, get_gms_lock_mode,
setup_meta_tensor_workaround, setup_meta_tensor_workaround,
strip_gms_model_loader_config,
) )
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -49,23 +50,14 @@ def register_gms_loader(load_format: str = "gms") -> None: ...@@ -49,23 +50,14 @@ def register_gms_loader(load_format: str = "gms") -> None:
class GMSModelLoader(BaseModelLoader): class GMSModelLoader(BaseModelLoader):
"""vLLM model loader that loads weights via GPU Memory Service.""" """vLLM model loader that loads weights via GPU Memory Service."""
# Keys in model_loader_extra_config that are GMS-specific and should
# not be passed to the fallback DefaultModelLoader.
_GMS_EXTRA_KEYS = frozenset({"gms_read_only"})
def __init__(self, load_config): def __init__(self, load_config):
super().__init__(load_config) super().__init__(load_config)
# Strip GMS-specific keys before creating the fallback loader, # Strip GMS-specific keys before creating the fallback loader,
# otherwise DefaultModelLoader rejects unknown extra config. # otherwise DefaultModelLoader rejects unknown extra config.
extra = getattr(load_config, "model_loader_extra_config", None) or {}
clean_extra = {
k: v for k, v in extra.items() if k not in self._GMS_EXTRA_KEYS
}
self.default_loader = DefaultModelLoader( self.default_loader = DefaultModelLoader(
replace( strip_gms_model_loader_config(
load_config, load_config,
load_format="auto", load_format="auto",
model_loader_extra_config=clean_extra,
) )
) )
...@@ -79,8 +71,8 @@ def register_gms_loader(load_format: str = "gms") -> None: ...@@ -79,8 +71,8 @@ def register_gms_loader(load_format: str = "gms") -> None:
device = torch.cuda.current_device() device = torch.cuda.current_device()
extra = getattr(self.load_config, "model_loader_extra_config", {}) or {} extra = getattr(self.load_config, "model_loader_extra_config", {}) or {}
mode = get_gms_lock_mode(extra) mode = get_gms_lock_mode(extra)
gms_client, pool = get_or_create_gms_client_memory_manager( gms_client = get_or_create_gms_client_memory_manager(
get_socket_path(device), get_socket_path(device, "weights"),
device, device,
mode=mode, mode=mode,
tag="weights", tag="weights",
...@@ -91,7 +83,6 @@ def register_gms_loader(load_format: str = "gms") -> None: ...@@ -91,7 +83,6 @@ def register_gms_loader(load_format: str = "gms") -> None:
else: else:
return _load_write_mode( return _load_write_mode(
gms_client, gms_client,
pool,
vllm_config, vllm_config,
model_config, model_config,
self.default_loader, self.default_loader,
...@@ -130,7 +121,6 @@ def _load_read_mode( ...@@ -130,7 +121,6 @@ def _load_read_mode(
def _load_write_mode( def _load_write_mode(
gms_client: "GMSClientMemoryManager", gms_client: "GMSClientMemoryManager",
pool,
vllm_config, vllm_config,
model_config, model_config,
default_loader, default_loader,
...@@ -143,18 +133,15 @@ def _load_write_mode( ...@@ -143,18 +133,15 @@ def _load_write_mode(
""" """
global _last_imported_weights_bytes global _last_imported_weights_bytes
from torch.cuda.memory import use_mem_pool
from vllm.model_executor.model_loader.utils import ( from vllm.model_executor.model_loader.utils import (
initialize_model, initialize_model,
process_weights_after_loading, process_weights_after_loading,
) )
from vllm.utils.torch_utils import set_default_torch_dtype from vllm.utils.torch_utils import set_default_torch_dtype
gms_client.clear_all_handles()
# Allocate model tensors using GMS memory pool # Allocate model tensors using GMS memory pool
with set_default_torch_dtype(model_config.dtype): with set_default_torch_dtype(model_config.dtype):
with use_mem_pool(pool, device=target_device): with gms_use_mem_pool("weights", target_device):
with target_device: with target_device:
model = initialize_model( model = initialize_model(
vllm_config=vllm_config, model_config=model_config vllm_config=vllm_config, model_config=model_config
......
...@@ -11,9 +11,7 @@ They should only allocate on their cache when they are the active/leader engine. ...@@ -11,9 +11,7 @@ They should only allocate on their cache when they are the active/leader engine.
from __future__ import annotations from __future__ import annotations
import logging import logging
import time
import torch
from vllm.v1.worker.gpu_model_runner import GPUModelRunner from vllm.v1.worker.gpu_model_runner import GPUModelRunner
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -48,6 +46,10 @@ class GMSShadowModelRunner(GPUModelRunner): ...@@ -48,6 +46,10 @@ class GMSShadowModelRunner(GPUModelRunner):
logger.info( logger.info(
"[Shadow] Init phase: stored config, skipping KV cache allocation" "[Shadow] Init phase: stored config, skipping KV cache allocation"
) )
print(
"[Shadow] Init phase: stored config, skipping KV cache allocation",
flush=True,
)
return {} return {}
return super().initialize_kv_cache_tensors(kv_cache_config, kernel_block_sizes) return super().initialize_kv_cache_tensors(kv_cache_config, kernel_block_sizes)
...@@ -86,7 +88,9 @@ class GMSShadowModelRunner(GPUModelRunner): ...@@ -86,7 +88,9 @@ class GMSShadowModelRunner(GPUModelRunner):
"""Allocate KV cache on wake using config stored during shadow init. """Allocate KV cache on wake using config stored during shadow init.
Called by GMSWorker.wake_up() after shadow init phase is exited. Called by GMSWorker.wake_up() after shadow init phase is exited.
Waits up to 60s for GPU memory to be freed. GMS kv_cache RW lock acquisition serves as the memory barrier — the
dying engine's abort() releases the lock and frees memory before we
can connect.
""" """
assert hasattr( assert hasattr(
self, "_shadow_kv_cache_config" self, "_shadow_kv_cache_config"
...@@ -95,47 +99,7 @@ class GMSShadowModelRunner(GPUModelRunner): ...@@ -95,47 +99,7 @@ class GMSShadowModelRunner(GPUModelRunner):
self, "_shadow_kernel_block_sizes" self, "_shadow_kernel_block_sizes"
), "_shadow_kernel_block_sizes not set — was enter_shadow_init() called?" ), "_shadow_kernel_block_sizes not set — was enter_shadow_init() called?"
# OOM remediation during failover: wait for the dying engine to release memory.
# TODO: This will be replaced with a barrier in GMS when we manage kv cache there instead
config = self._shadow_kv_cache_config config = self._shadow_kv_cache_config
kv_cache_bytes = sum(t.size for t in config.kv_cache_tensors)
free_bytes, _ = torch.cuda.mem_get_info()
if free_bytes < kv_cache_bytes:
logger.info(
"[Shadow] Waiting for GPU memory (need %.2f GiB, free %.2f GiB)",
kv_cache_bytes / (1 << 30),
free_bytes / (1 << 30),
)
deadline = time.monotonic() + 60.0
last_log = time.monotonic()
while free_bytes < kv_cache_bytes:
if time.monotonic() > deadline:
raise RuntimeError(
f"Timed out waiting for GPU memory: "
f"need {kv_cache_bytes / (1 << 30):.2f} GiB, "
f"free {free_bytes / (1 << 30):.2f} GiB"
)
now = time.monotonic()
if now - last_log >= 5.0:
elapsed = now - (deadline - 60.0)
remaining = deadline - now
logger.info(
"[Shadow] Still waiting for GPU memory: "
"need %.2f GiB, free %.2f GiB "
"(%.0fs elapsed, %.0fs remaining)",
kv_cache_bytes / (1 << 30),
free_bytes / (1 << 30),
elapsed,
remaining,
)
last_log = now
time.sleep(0.5)
free_bytes = torch.cuda.mem_get_info()[0]
logger.info(
"[Shadow] GPU memory available (free %.2f GiB), proceeding",
free_bytes / (1 << 30),
)
logger.info("[Shadow] Allocating KV cache on wake") logger.info("[Shadow] Allocating KV cache on wake")
...@@ -163,10 +127,11 @@ class GMSShadowModelRunner(GPUModelRunner): ...@@ -163,10 +127,11 @@ class GMSShadowModelRunner(GPUModelRunner):
logger.debug("[Shadow] KV transfer group not available") logger.debug("[Shadow] KV transfer group not available")
total_bytes = sum(t.numel() * t.element_size() for t in kv_caches.values()) total_bytes = sum(t.numel() * t.element_size() for t in kv_caches.values())
logger.info( msg = "[Shadow] Allocated KV cache on wake: %.2f GiB (%d tensors)" % (
"[Shadow] Allocated KV cache on wake: %.2f GiB (%d tensors)",
total_bytes / (1 << 30), total_bytes / (1 << 30),
len(kv_caches), len(kv_caches),
) )
logger.info(msg)
print(msg, flush=True)
return kv_caches return kv_caches
...@@ -48,12 +48,12 @@ def patch_memory_snapshot() -> None: ...@@ -48,12 +48,12 @@ def patch_memory_snapshot() -> None:
def patched_measure(self): def patched_measure(self):
original_measure(self) original_measure(self)
manager = get_gms_client_memory_manager() manager = get_gms_client_memory_manager("weights")
assert manager is not None, "GMS client is not initialized" assert manager is not None, "GMS client is not initialized"
if manager.granted_lock_type == GrantedLockType.RO: if manager.granted_lock_type == GrantedLockType.RO:
allocations = manager.list_handles() allocations = manager.list_handles()
committed_bytes = sum(alloc.get("aligned_size", 0) for alloc in allocations) committed_bytes = sum(alloc.aligned_size for alloc in allocations)
else: else:
# NOTE: by design, we want to assume we have the whole GPU when writing # NOTE: by design, we want to assume we have the whole GPU when writing
# weights for the first time, so we don't make an adjustment. # weights for the first time, so we don't make an adjustment.
......
...@@ -23,6 +23,7 @@ from gpu_memory_service import ( ...@@ -23,6 +23,7 @@ from gpu_memory_service import (
get_or_create_gms_client_memory_manager, get_or_create_gms_client_memory_manager,
) )
from gpu_memory_service.client.memory_manager import StaleMemoryLayoutError from gpu_memory_service.client.memory_manager import StaleMemoryLayoutError
from gpu_memory_service.client.torch.allocator import gms_use_mem_pool
from gpu_memory_service.common.types import RequestedLockType from gpu_memory_service.common.types import RequestedLockType
from gpu_memory_service.common.utils import get_socket_path from gpu_memory_service.common.utils import get_socket_path
from gpu_memory_service.integrations.common import patch_empty_cache from gpu_memory_service.integrations.common import patch_empty_cache
...@@ -70,18 +71,16 @@ class GMSWorker(Worker): ...@@ -70,18 +71,16 @@ class GMSWorker(Worker):
# Establish weights GMS connection (so MemorySnapshot can query committed bytes). # Establish weights GMS connection (so MemorySnapshot can query committed bytes).
# Lock type is determined by model_loader_extra_config, set upstream by # Lock type is determined by model_loader_extra_config, set upstream by
# configure_gms_lock_mode() in main.py. # configure_gms_lock_mode() in main.py.
socket_path = get_socket_path(device)
extra = ( extra = (
getattr(self.vllm_config.load_config, "model_loader_extra_config", {}) or {} getattr(self.vllm_config.load_config, "model_loader_extra_config", {}) or {}
) )
mode = get_gms_lock_mode(extra) mode = get_gms_lock_mode(extra)
get_or_create_gms_client_memory_manager( get_or_create_gms_client_memory_manager(
socket_path, get_socket_path(device, "weights"),
device, device,
mode=mode, mode=mode,
tag="weights", tag="weights",
) )
# Parent will set device again (harmless) and do memory checks # Parent will set device again (harmless) and do memory checks
super().init_device() super().init_device()
...@@ -111,9 +110,9 @@ class GMSWorker(Worker): ...@@ -111,9 +110,9 @@ class GMSWorker(Worker):
torch.cuda.synchronize() torch.cuda.synchronize()
torch_peak = torch.cuda.max_memory_allocated() torch_peak = torch.cuda.max_memory_allocated()
# If weights are strictly loaded (RO), torch's memory accounting will miss them since we didn't go through the mempool # GMS weights mapped via cuMemMap are invisible to PyTorch's memory
# We therefore add in the memory of the weights into our accounting here # stats on RO engines. Add them explicitly. On RW engines, torch_peak
# This is not an issue on engines that write the weights and then downgrade to RO # already includes weights so skip to avoid double-counting.
weights_memory = int(getattr(self.model_runner, "model_memory_usage", 0)) weights_memory = int(getattr(self.model_runner, "model_memory_usage", 0))
if torch_peak < weights_memory: if torch_peak < weights_memory:
non_kv_cache_memory = torch_peak + weights_memory non_kv_cache_memory = torch_peak + weights_memory
...@@ -122,19 +121,62 @@ class GMSWorker(Worker): ...@@ -122,19 +121,62 @@ class GMSWorker(Worker):
projected_available = self.requested_memory - non_kv_cache_memory projected_available = self.requested_memory - non_kv_cache_memory
logger.info( msg = (
"[GMS] Shadow mode: projected available memory " "[GMS] Shadow mode: projected available memory "
"%.2f GiB (requested=%.2f GiB, non_kv=%.2f GiB, " "%.2f GiB (requested=%.2f GiB, non_kv=%.2f GiB, "
"torch_peak=%.2f GiB, weights=%.2f GiB)", "torch_peak=%.2f GiB, weights=%.2f GiB)"
projected_available / (1 << 30), % (
self.requested_memory / (1 << 30), projected_available / (1 << 30),
non_kv_cache_memory / (1 << 30), self.requested_memory / (1 << 30),
torch_peak / (1 << 30), non_kv_cache_memory / (1 << 30),
weights_memory / (1 << 30), torch_peak / (1 << 30),
weights_memory / (1 << 30),
)
) )
logger.info(msg)
print(msg, flush=True)
return int(projected_available) return int(projected_available)
def initialize_from_config(self, kv_cache_config) -> None:
"""Allocate KV cache with a dedicated RW-only GMS tag.
Also validates cudagraph mode for shadow mode compatibility.
"""
from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized
ensure_kv_transfer_initialized(self.vllm_config, kv_cache_config)
if is_shadow_mode():
# GMS client for kv cache is deferred to wake for shadow mode
# GMSShadowModelRunner.initialize_kv_cache intercepts and stores config without creating an allocation
self.model_runner.initialize_kv_cache(kv_cache_config)
elif self.vllm_config.model_config.enable_sleep_mode:
# Normal sleep/wake: create kv_cache GMS tag now for unmap/remap
device = self.local_rank
get_or_create_gms_client_memory_manager(
get_socket_path(device, "kv_cache"),
device,
mode=RequestedLockType.RW,
tag="kv_cache",
)
with gms_use_mem_pool("kv_cache", torch.device(f"cuda:{device}")):
self.model_runner.initialize_kv_cache(kv_cache_config)
else:
# No sleep mode: plain KV cache init
self.model_runner.initialize_kv_cache(kv_cache_config)
# Validate cudagraph mode for shadow mode compatibility
if is_shadow_mode():
from vllm.config import CUDAGraphMode
mode = self.model_runner.compilation_config.cudagraph_mode
if mode not in (CUDAGraphMode.PIECEWISE, CUDAGraphMode.NONE):
raise RuntimeError(
f"Shadow mode requires PIECEWISE cudagraph mode after resolution, "
f"but got {mode.name}. vLLM's config resolution overrode it."
)
def load_model(self, *args, **kwargs) -> None: def load_model(self, *args, **kwargs) -> None:
"""Load model with corrected memory accounting. """Load model with corrected memory accounting.
...@@ -166,54 +208,38 @@ class GMSWorker(Worker): ...@@ -166,54 +208,38 @@ class GMSWorker(Worker):
except Exception as e: except Exception as e:
logger.debug("[GMS] Could not correct memory accounting: %s", e) logger.debug("[GMS] Could not correct memory accounting: %s", e)
def initialize_from_config(self, kv_cache_config) -> None:
"""Initialize from config with post-init cudagraph mode assertion.
vLLM can try to upgrade the cudagraph mode in certain scenarios. We
assert that the final resolved mode is still compatible with shadow mode.
"""
super().initialize_from_config(kv_cache_config)
if is_shadow_mode():
from vllm.config import CUDAGraphMode
mode = self.model_runner.compilation_config.cudagraph_mode
if mode not in (CUDAGraphMode.PIECEWISE, CUDAGraphMode.NONE):
raise RuntimeError(
f"Shadow mode requires PIECEWISE cudagraph mode after resolution, "
f"but got {mode.name}. vLLM's config resolution overrode it."
)
def sleep(self, level: int = 1) -> None: def sleep(self, level: int = 1) -> None:
""" """
vLLM sleep implementation with GMS integration. vLLM sleep implementation with GMS integration.
NOTE: `level` is a no-op here: weights are only unmapped (but remain in GPU memory).
NOTE: We do NOT call super().sleep() because it tries to copy GPU buffers to CPU, NOTE: We do NOT call super().sleep() because it tries to copy GPU buffers to CPU,
which segfaults on already-unmapped GMS memory. which segfaults on already-unmapped GMS memory.
Handles two cases for KV cache: Handles two cases for KV cache:
1. Normal: KV cache was allocated, sleep via CuMemAllocator 1. Normal: KV cache was allocated via GMS, unmap + abort
2. Shadow: KV cache was skipped at startup, nothing to do 2. Shadow: KV cache was skipped at startup, manager has no allocations
(unmap_all_vas is a no-op, abort disconnects)
""" """
free_bytes_before = torch.cuda.mem_get_info()[0] free_bytes_before = torch.cuda.mem_get_info()[0]
# Unmap GMS weights: synchronize + unmap all VAs + disconnect # Unmap GMS weights: synchronize + unmap all VAs + disconnect
manager = get_gms_client_memory_manager() weights_manager = get_gms_client_memory_manager("weights")
assert manager is not None, "GMS client is not initialized" assert weights_manager is not None, "GMS weights client is not initialized"
assert not manager.is_unmapped, "GMS weights are already unmapped" assert not weights_manager.is_unmapped, "GMS weights are already unmapped"
manager.unmap_all_vas() weights_manager.unmap_all_vas()
manager.disconnect() weights_manager.abort()
# Sleep KV cache via CuMemAllocator (discard, no CPU backup) # Unmap GMS KV cache: unmap all VAs + disconnect
# If KV cache was never allocated (shadow engine mode), this is a no-op # In shadow mode, kv_cache manager is deferred to wake — nothing to unmap.
from vllm.device_allocator.cumem import CuMemAllocator kv_cache_manager = get_gms_client_memory_manager("kv_cache")
if kv_cache_manager is not None:
kv_caches = getattr(self.model_runner, "kv_caches", None) assert not kv_cache_manager.is_unmapped, "GMS KV cache is already unmapped"
if kv_caches: kv_cache_manager.unmap_all_vas()
allocator = CuMemAllocator.get_instance() kv_cache_manager.abort()
allocator.sleep(offload_tags=tuple())
else: else:
logger.info("[GMS] KV cache not allocated (shadow mode), skipping sleep") logger.info(
"[GMS] No kv_cache manager (shadow mode), skipping kv_cache sleep"
)
free_bytes_after, total = torch.cuda.mem_get_info() free_bytes_after, total = torch.cuda.mem_get_info()
freed_bytes = free_bytes_after - free_bytes_before freed_bytes = free_bytes_after - free_bytes_before
...@@ -228,7 +254,7 @@ class GMSWorker(Worker): ...@@ -228,7 +254,7 @@ class GMSWorker(Worker):
"""vLLM wake implementation with GMS integration. """vLLM wake implementation with GMS integration.
Handles two cases for KV cache: Handles two cases for KV cache:
1. Normal: KV cache was allocated at startup, reallocate via CuMemAllocator 1. Normal: KV cache was allocated at startup, reconnect + reallocate + remap
2. Shadow: KV cache was skipped at startup, allocate via allocate_kv_cache_on_wake() 2. Shadow: KV cache was skipped at startup, allocate via allocate_kv_cache_on_wake()
""" """
if ( if (
...@@ -241,16 +267,16 @@ class GMSWorker(Worker): ...@@ -241,16 +267,16 @@ class GMSWorker(Worker):
tags = ["weights", "kv_cache"] tags = ["weights", "kv_cache"]
if "weights" in tags: if "weights" in tags:
manager = get_gms_client_memory_manager() weights_manager = get_gms_client_memory_manager("weights")
assert manager is not None, "GMS client is not initialized" assert weights_manager is not None, "GMS weights client is not initialized"
assert manager.is_unmapped, "GMS weights are not unmapped" assert weights_manager.is_unmapped, "GMS weights are not unmapped"
# These errors are fatal and unrecoverable in a worker subprocess: # These errors are fatal and unrecoverable in a worker subprocess:
# the worker cannot serve requests without weights. sys.exit(1) # the worker cannot serve requests without weights. sys.exit(1)
# ensures clean termination so the orchestrator (K8s) can restart. # ensures clean termination so the orchestrator (K8s) can restart.
try: try:
manager.connect(RequestedLockType.RO, timeout_ms=30_000) weights_manager.connect(RequestedLockType.RO, timeout_ms=30_000)
manager.remap_all_vas() weights_manager.remap_all_vas()
except TimeoutError: except TimeoutError:
logger.error( logger.error(
"Fatal: timed out waiting for GMS RO lock during remap " "Fatal: timed out waiting for GMS RO lock during remap "
...@@ -270,15 +296,30 @@ class GMSWorker(Worker): ...@@ -270,15 +296,30 @@ class GMSWorker(Worker):
# Check if KV cache was skipped at startup (shadow engine mode) # Check if KV cache was skipped at startup (shadow engine mode)
kv_caches = getattr(self.model_runner, "kv_caches", None) kv_caches = getattr(self.model_runner, "kv_caches", None)
if not kv_caches: if not kv_caches:
# Shadow mode: create kv_cache manager now (deferred from init
# to avoid RW lock contention between concurrent engines).
logger.info("[GMS] KV cache not allocated - allocating on wake") logger.info("[GMS] KV cache not allocated - allocating on wake")
self.model_runner.allocate_kv_cache_on_wake() get_or_create_gms_client_memory_manager(
get_socket_path(self.local_rank, "kv_cache"),
self.local_rank,
mode=RequestedLockType.RW,
tag="kv_cache",
)
with gms_use_mem_pool(
"kv_cache", torch.device("cuda", self.local_rank)
):
self.model_runner.allocate_kv_cache_on_wake()
logger.info("[GMS] Successfully allocated KV cache on wake") logger.info("[GMS] Successfully allocated KV cache on wake")
else: else:
# Normal case: KV cache was allocated, reallocate via CuMemAllocator # Normal case: KV cache was allocated via GMS, reconnect + reallocate + remap
from vllm.device_allocator.cumem import CuMemAllocator kv_cache_manager = get_gms_client_memory_manager("kv_cache")
assert (
allocator = CuMemAllocator.get_instance() kv_cache_manager is not None
allocator.wake_up(tags=["kv_cache"]) ), "GMS KV cache client is not initialized"
assert kv_cache_manager.is_unmapped, "GMS KV cache is not unmapped"
kv_cache_manager.connect(RequestedLockType.RW)
kv_cache_manager.reallocate_all_handles(tag="kv_cache")
kv_cache_manager.remap_all_vas()
# Reinitialize FP8 KV scales if needed # Reinitialize FP8 KV scales if needed
if self.cache_config.cache_dtype.startswith("fp8") and hasattr( if self.cache_config.cache_dtype.startswith("fp8") and hasattr(
...@@ -287,12 +328,18 @@ class GMSWorker(Worker): ...@@ -287,12 +328,18 @@ class GMSWorker(Worker):
self.model_runner.init_fp8_kv_scales() self.model_runner.init_fp8_kv_scales()
def _maybe_get_memory_pool_context(self, tag: str): def _maybe_get_memory_pool_context(self, tag: str):
"""Skip CuMemAllocator for weights when using GMS. """Route tag-scoped runtime allocations to the right allocator.
GMS manages its own memory pool for weights, so we don't want vLLM's Weight tensors are allocated explicitly in the GMS model-loader path,
CuMemAllocator to interfere. not through vLLM's tagged runtime allocator hook. For `weights` we
therefore only suppress CuMemAllocator here so it does not interfere
with the loader-managed GMS allocations. `kv_cache` is the tag that
actually allocates through this hook, so it uses the dedicated GMS
mempool.
""" """
if tag == "weights": if tag == "weights":
logger.debug("[GMS] Skipping CuMemAllocator for weights") logger.debug("[GMS] Skipping CuMemAllocator for weights")
return nullcontext() return nullcontext()
if tag == "kv_cache":
return gms_use_mem_pool("kv_cache", torch.device("cuda", self.local_rank))
return super()._maybe_get_memory_pool_context(tag) return super()._maybe_get_memory_pool_context(tag)
...@@ -9,26 +9,33 @@ from gpu_memory_service.common.types import ( ...@@ -9,26 +9,33 @@ from gpu_memory_service.common.types import (
ServerState, ServerState,
StateSnapshot, StateSnapshot,
) )
from gpu_memory_service.server.handler import MetadataEntry, RequestHandler from gpu_memory_service.server.allocations import (
from gpu_memory_service.server.locking import Connection, GMSLocalFSM
from gpu_memory_service.server.memory_manager import (
AllocationInfo, AllocationInfo,
AllocationNotFoundError, AllocationNotFoundError,
GMSServerMemoryManager, GMSAllocationManager,
) )
from gpu_memory_service.server.gms import GMS, MetadataEntry
from gpu_memory_service.server.rpc import GMSRPCServer from gpu_memory_service.server.rpc import GMSRPCServer
from gpu_memory_service.server.session import (
Connection,
GMSSessionManager,
InvalidTransition,
OperationNotAllowed,
)
__all__ = [ __all__ = [
"GMSRPCServer", "GMSRPCServer",
"GMSServerMemoryManager", "GMS",
"GMSSessionManager",
"GMSAllocationManager",
"AllocationInfo", "AllocationInfo",
"AllocationNotFoundError", "AllocationNotFoundError",
"MetadataEntry", "MetadataEntry",
"Connection", "Connection",
"GrantedLockType", "GrantedLockType",
"RequestedLockType", "RequestedLockType",
"RequestHandler",
"ServerState", "ServerState",
"GMSLocalFSM",
"StateSnapshot", "StateSnapshot",
"InvalidTransition",
"OperationNotAllowed",
] ]
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Server-side CUDA allocation store."""
from __future__ import annotations
import asyncio
import logging
import time
from dataclasses import dataclass
from typing import Callable, Optional
from uuid import uuid4
from gpu_memory_service.common.cuda_utils import (
align_to_granularity,
cuda_ensure_initialized,
cumem_create_tolerate_oom,
cumem_export_to_shareable_handle,
cumem_get_allocation_granularity,
cumem_release,
)
logger = logging.getLogger(__name__)
@dataclass(frozen=True)
class AllocationInfo:
allocation_id: str
size: int
aligned_size: int
handle: int
tag: str
layout_slot: int
created_at: float
class AllocationNotFoundError(Exception):
"""Raised when an allocation_id doesn't exist."""
class GMSAllocationManager:
"""Server-side CUDA VMM allocation store."""
def __init__(
self,
device: int = 0,
*,
allocation_retry_interval: float = 0.5,
allocation_retry_timeout: Optional[float] = None,
):
if allocation_retry_interval <= 0:
raise ValueError(
f"allocation_retry_interval must be > 0, got {allocation_retry_interval}"
)
if allocation_retry_timeout is not None and allocation_retry_timeout <= 0:
raise ValueError(
f"allocation_retry_timeout must be > 0 when set, got {allocation_retry_timeout}"
)
self._device = device
self._allocations: dict[str, AllocationInfo] = {}
self._next_layout_slot = 0
cuda_ensure_initialized()
self._granularity = cumem_get_allocation_granularity(device)
self._allocation_retry_interval = allocation_retry_interval
self._allocation_retry_timeout = allocation_retry_timeout
logger.info(
"GMSAllocationManager initialized: device=%d, granularity=%d, "
"alloc_retry_interval=%.3f, alloc_retry_timeout=%s",
device,
self._granularity,
self._allocation_retry_interval,
(
f"{self._allocation_retry_timeout:.3f}"
if self._allocation_retry_timeout is not None
else "none"
),
)
@property
def device(self) -> int:
return self._device
@property
def allocation_count(self) -> int:
return len(self._allocations)
async def allocate(
self,
size: int,
tag: str = "default",
is_connected: Optional[Callable[[], bool]] = None,
on_oom: Optional[Callable[[], None]] = None,
) -> AllocationInfo:
if size <= 0:
raise ValueError(f"size must be > 0, got {size}")
aligned_size = align_to_granularity(size, self._granularity)
started_at = time.monotonic()
reported_oom = False
while True:
if is_connected is not None and not is_connected():
raise ConnectionAbortedError(
"RW client disconnected during allocation retry"
)
allocated, handle = cumem_create_tolerate_oom(aligned_size, self._device)
if allocated:
break
if on_oom is not None and not reported_oom:
on_oom()
reported_oom = True
if self._allocation_retry_timeout is not None:
waited = time.monotonic() - started_at
if waited >= self._allocation_retry_timeout:
raise TimeoutError(
"Timed out waiting for GPU memory: "
f"requested_size={size}, aligned_size={aligned_size}, "
f"tag={tag}, waited_sec={waited:.3f}"
)
logger.warning(
"cuMemCreate OOM for aligned_size=%d bytes, tag=%s; retrying in %.3fs",
aligned_size,
tag,
self._allocation_retry_interval,
)
await asyncio.sleep(self._allocation_retry_interval)
info = AllocationInfo(
allocation_id=str(uuid4()),
size=size,
aligned_size=aligned_size,
handle=int(handle),
tag=tag,
layout_slot=self._next_layout_slot,
created_at=time.time(),
)
self._next_layout_slot = info.layout_slot + 1
self._allocations[info.allocation_id] = info
logger.debug(
"Allocated %s: size=%d, aligned=%d, tag=%s, slot=%d",
info.allocation_id,
size,
aligned_size,
tag,
info.layout_slot,
)
return info
def export_allocation(self, allocation_id: str) -> int:
return cumem_export_to_shareable_handle(
self.get_allocation(allocation_id).handle
)
def free_allocation(self, allocation_id: str) -> bool:
info = self._allocations.get(allocation_id)
if info is None:
return False
cumem_release(info.handle)
self._allocations.pop(allocation_id, None)
logger.debug("Freed allocation: %s", allocation_id)
return True
def clear_all(self) -> int:
allocation_ids = list(self._allocations)
for allocation_id in allocation_ids:
info = self._allocations[allocation_id]
cumem_release(info.handle)
self._allocations.pop(allocation_id, None)
if allocation_ids:
logger.info("Cleared %d allocations", len(allocation_ids))
self._next_layout_slot = 0
return len(allocation_ids)
def get_allocation(self, allocation_id: str) -> AllocationInfo:
info = self._allocations.get(allocation_id)
if info is None:
raise AllocationNotFoundError(f"Unknown allocation: {allocation_id}")
return info
def list_allocations(self, tag: Optional[str] = None) -> list[AllocationInfo]:
allocations = list(self._allocations.values())
allocations.sort(key=lambda info: info.layout_slot)
if tag is None:
return allocations
return [info for info in allocations if info.tag == tag]
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Top-level server-side GMS service."""
from __future__ import annotations
import hashlib
import logging
from collections import deque
from dataclasses import dataclass
from typing import Callable, Optional
from gpu_memory_service.common.protocol.messages import (
AllocateRequest,
AllocateResponse,
CommitRequest,
CommitResponse,
ExportAllocationRequest,
ExportAllocationResponse,
FreeAllocationRequest,
FreeAllocationResponse,
GetAllocationRequest,
GetAllocationResponse,
GetAllocationStateRequest,
GetAllocationStateResponse,
GetEventHistoryResponse,
GetLockStateRequest,
GetLockStateResponse,
GetRuntimeStateResponse,
GetStateHashRequest,
GetStateHashResponse,
GMSRuntimeEvent,
ListAllocationsRequest,
ListAllocationsResponse,
MetadataDeleteRequest,
MetadataDeleteResponse,
MetadataGetRequest,
MetadataGetResponse,
MetadataListRequest,
MetadataListResponse,
MetadataPutRequest,
MetadataPutResponse,
)
from gpu_memory_service.common.types import (
GrantedLockType,
RequestedLockType,
ServerState,
StateEvent,
)
from .allocations import AllocationInfo, GMSAllocationManager
from .session import Connection, GMSSessionManager
logger = logging.getLogger(__name__)
@dataclass(frozen=True)
class MetadataEntry:
allocation_id: str
offset_bytes: int
value: bytes
class GMS:
"""Owns all non-transport server state."""
_MAX_EVENTS = 256
def __init__(
self,
device: int = 0,
*,
allocation_retry_interval: float = 0.5,
allocation_retry_timeout: Optional[float] = None,
):
self._allocations = GMSAllocationManager(
device,
allocation_retry_interval=allocation_retry_interval,
allocation_retry_timeout=allocation_retry_timeout,
)
self._sessions = GMSSessionManager()
self._events: deque[GMSRuntimeEvent] = deque(maxlen=self._MAX_EVENTS)
self._metadata: dict[str, MetadataEntry] = {}
self._memory_layout_hash = ""
logger.info("GMS initialized: device=%d", device)
@property
def state(self) -> ServerState:
return self._sessions.state
@property
def committed(self) -> bool:
return self._sessions.snapshot().committed
@property
def allocation_count(self) -> int:
return self._allocations.allocation_count
def is_ready(self) -> bool:
return self._sessions.snapshot().is_ready
def get_runtime_state(self) -> GetRuntimeStateResponse:
session = self._sessions.snapshot()
return GetRuntimeStateResponse(
state=session.state.name,
has_rw_session=session.has_rw_session,
ro_session_count=session.ro_session_count,
waiting_writers=session.waiting_writers,
committed=session.committed,
is_ready=session.is_ready,
allocation_count=self._allocations.allocation_count,
memory_layout_hash=self._memory_layout_hash,
)
def get_event_history(self) -> GetEventHistoryResponse:
return GetEventHistoryResponse(events=list(self._events))
def next_session_id(self) -> str:
return self._sessions.next_session_id()
async def acquire_lock(
self,
mode: RequestedLockType,
timeout_ms: int | None,
session_id: str,
) -> GrantedLockType | None:
return await self._sessions.acquire_lock(mode, timeout_ms, session_id)
async def cancel_connect(
self,
session_id: str,
mode: GrantedLockType | None,
) -> None:
await self._sessions.cancel_connect(session_id, mode)
def _validate_metadata_target(
self,
allocation: AllocationInfo,
offset_bytes: int,
) -> None:
if offset_bytes < 0:
raise ValueError(f"offset_bytes must be >= 0, got {offset_bytes}")
if offset_bytes >= allocation.aligned_size:
raise ValueError(
f"offset_bytes {offset_bytes} out of range for allocation {allocation.allocation_id} "
f"(aligned_size={allocation.aligned_size})"
)
def _drop_metadata_for_allocation(self, allocation_id: str) -> int:
keys_to_remove = [
key
for key, entry in self._metadata.items()
if entry.allocation_id == allocation_id
]
for key in keys_to_remove:
self._metadata.pop(key, None)
return len(keys_to_remove)
def _validate_metadata_integrity(
self,
allocations_by_id: dict[str, AllocationInfo],
) -> None:
for key, entry in self._metadata.items():
info = allocations_by_id.get(entry.allocation_id)
if info is None:
raise AssertionError(
f"Metadata key {key!r} references missing allocation "
f"{entry.allocation_id!r}"
)
if entry.offset_bytes < 0 or entry.offset_bytes >= info.aligned_size:
raise AssertionError(
f"Metadata key {key!r} has invalid offset {entry.offset_bytes} "
f"for allocation {entry.allocation_id!r} "
f"(aligned_size={info.aligned_size})"
)
def _compute_memory_layout_hash(self, allocations: list[AllocationInfo]) -> str:
h = hashlib.sha256()
allocation_slots_by_id: dict[str, int] = {}
for info in sorted(allocations, key=lambda info: info.layout_slot):
allocation_slots_by_id[info.allocation_id] = info.layout_slot
h.update(
f"{info.layout_slot}:{info.size}:{info.aligned_size}:{info.tag}".encode()
)
for key in sorted(self._metadata):
entry = self._metadata[key]
layout_slot = allocation_slots_by_id[entry.allocation_id]
h.update(f"{key}:{layout_slot}:{entry.offset_bytes}:".encode())
h.update(entry.value)
return h.hexdigest()
def _clear_layout_state(self) -> int:
self._metadata.clear()
self._memory_layout_hash = ""
return self._allocations.clear_all()
def on_connect(self, conn: Connection) -> None:
if conn.mode == GrantedLockType.RW:
had_committed_layout = self._sessions.snapshot().committed
cleared = self._clear_layout_state()
if had_committed_layout:
self._events.append(
GMSRuntimeEvent(
kind="allocations_cleared",
allocation_count=cleared,
)
)
self._sessions.on_connect(conn)
if conn.mode == GrantedLockType.RW:
self._events.append(GMSRuntimeEvent(kind="rw_connected"))
async def cleanup_connection(self, conn: Connection | None) -> None:
event = self._sessions.begin_cleanup(conn)
if event == StateEvent.RW_ABORT:
logger.warning("RW aborted; clearing active layout")
cleared = self._clear_layout_state()
self._events.append(GMSRuntimeEvent(kind="rw_aborted"))
self._events.append(
GMSRuntimeEvent(
kind="allocations_cleared",
allocation_count=cleared,
)
)
await self._sessions.finish_cleanup(conn)
async def handle_request(
self,
conn: Connection,
msg,
is_connected: Callable[[], bool],
) -> tuple[object, int, bool]:
msg_type = type(msg)
self._sessions.check_operation(msg_type, conn)
if msg_type is CommitRequest:
if self.state != ServerState.RW:
raise AssertionError("RW state is not active")
allocations = self._allocations.list_allocations()
allocations_by_id = {info.allocation_id: info for info in allocations}
self._validate_metadata_integrity(allocations_by_id)
self._memory_layout_hash = self._compute_memory_layout_hash(allocations)
logger.info(
"Committed layout with state hash: %s...",
self._memory_layout_hash[:16],
)
self._sessions.on_commit(conn)
self._events.append(GMSRuntimeEvent(kind="committed"))
return CommitResponse(success=True), -1, True
if msg_type is AllocateRequest:
if self.state != ServerState.RW:
raise AssertionError("RW state is not active")
info = await self._allocations.allocate(
size=msg.size,
tag=msg.tag,
is_connected=is_connected,
on_oom=lambda: self._events.append(
GMSRuntimeEvent(
kind="allocation_oom",
allocation_count=self._allocations.allocation_count,
)
),
)
return (
AllocateResponse(
allocation_id=info.allocation_id,
size=info.size,
aligned_size=info.aligned_size,
layout_slot=info.layout_slot,
),
-1,
False,
)
if msg_type is GetLockStateRequest:
snapshot = self._sessions.snapshot()
return (
GetLockStateResponse(
state=snapshot.state.name,
has_rw_session=snapshot.has_rw_session,
ro_session_count=snapshot.ro_session_count,
waiting_writers=snapshot.waiting_writers,
committed=snapshot.committed,
is_ready=snapshot.is_ready,
),
-1,
False,
)
if msg_type is GetAllocationStateRequest:
return (
GetAllocationStateResponse(
allocation_count=self._allocations.allocation_count
),
-1,
False,
)
if msg_type is ExportAllocationRequest:
info = self._allocations.get_allocation(msg.allocation_id)
fd = self._allocations.export_allocation(info.allocation_id)
return (
ExportAllocationResponse(
allocation_id=info.allocation_id,
size=info.size,
aligned_size=info.aligned_size,
tag=info.tag,
layout_slot=info.layout_slot,
),
fd,
False,
)
if msg_type is GetStateHashRequest:
return (
GetStateHashResponse(memory_layout_hash=self._memory_layout_hash),
-1,
False,
)
if msg_type is GetAllocationRequest:
info = self._allocations.get_allocation(msg.allocation_id)
return (
GetAllocationResponse(
allocation_id=info.allocation_id,
size=info.size,
aligned_size=info.aligned_size,
tag=info.tag,
layout_slot=info.layout_slot,
),
-1,
False,
)
if msg_type is ListAllocationsRequest:
return (
ListAllocationsResponse(
allocations=[
GetAllocationResponse(
allocation_id=info.allocation_id,
size=info.size,
aligned_size=info.aligned_size,
tag=info.tag,
layout_slot=info.layout_slot,
)
for info in self._allocations.list_allocations(msg.tag)
]
),
-1,
False,
)
if msg_type is FreeAllocationRequest:
success = self._allocations.free_allocation(msg.allocation_id)
if success:
self._drop_metadata_for_allocation(msg.allocation_id)
return (
FreeAllocationResponse(success=success),
-1,
False,
)
if msg_type is MetadataPutRequest:
allocation = self._allocations.get_allocation(msg.allocation_id)
self._validate_metadata_target(allocation, msg.offset_bytes)
self._metadata[msg.key] = MetadataEntry(
allocation_id=allocation.allocation_id,
offset_bytes=msg.offset_bytes,
value=msg.value,
)
return MetadataPutResponse(success=True), -1, False
if msg_type is MetadataGetRequest:
entry = self._metadata.get(msg.key)
if entry is None:
return MetadataGetResponse(found=False), -1, False
return (
MetadataGetResponse(
found=True,
allocation_id=entry.allocation_id,
offset_bytes=entry.offset_bytes,
value=entry.value,
),
-1,
False,
)
if msg_type is MetadataDeleteRequest:
return (
MetadataDeleteResponse(
deleted=self._metadata.pop(msg.key, None) is not None
),
-1,
False,
)
if msg_type is MetadataListRequest:
if not msg.prefix:
keys = sorted(self._metadata)
else:
keys = sorted(
key for key in self._metadata if key.startswith(msg.prefix)
)
return MetadataListResponse(keys=keys), -1, False
raise ValueError(f"Unknown request: {msg_type.__name__}")
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Request handlers for GPU Memory Service."""
import hashlib
import logging
from dataclasses import dataclass
from gpu_memory_service.common.protocol.messages import (
AllocateRequest,
AllocateResponse,
ClearAllResponse,
FreeRequest,
FreeResponse,
GetAllocationRequest,
GetAllocationResponse,
GetAllocationStateResponse,
GetLockStateResponse,
GetStateHashResponse,
ListAllocationsRequest,
ListAllocationsResponse,
MetadataDeleteRequest,
MetadataDeleteResponse,
MetadataGetRequest,
MetadataGetResponse,
MetadataListRequest,
MetadataListResponse,
MetadataPutRequest,
MetadataPutResponse,
)
from gpu_memory_service.common.types import derive_state
from .memory_manager import AllocationNotFoundError, GMSServerMemoryManager
logger = logging.getLogger(__name__)
@dataclass(frozen=True)
class MetadataEntry:
allocation_id: str
offset_bytes: int
value: bytes
class RequestHandler:
"""Handles allocation and metadata requests."""
def __init__(self, device: int = 0):
self._memory_manager = GMSServerMemoryManager(device)
self._metadata: dict[str, MetadataEntry] = {}
self._memory_layout_hash: str = (
"" # Hash of allocations + metadata, computed on commit
)
logger.info(f"RequestHandler initialized: device={device}")
@property
def granularity(self) -> int:
return self._memory_manager.granularity
def on_rw_abort(self) -> None:
"""Called when RW connection closes without commit."""
logger.warning("RW aborted; clearing allocations and metadata")
self._memory_manager.clear_all()
self._metadata.clear()
self._memory_layout_hash = ""
def on_commit(self) -> None:
"""Called when RW connection commits. Computes state hash."""
self._memory_layout_hash = self._compute_memory_layout_hash()
logger.info(f"Committed with state hash: {self._memory_layout_hash[:16]}...")
def _compute_memory_layout_hash(self) -> str:
"""Compute hash of current allocations + metadata."""
h = hashlib.sha256()
# Hash allocations (sorted by ID for determinism)
for info in sorted(
self._memory_manager.list_allocations(), key=lambda x: x.allocation_id
):
h.update(
f"{info.allocation_id}:{info.size}:{info.aligned_size}:{info.tag}".encode()
)
# Hash metadata (sorted by key for determinism)
for key in sorted(self._metadata.keys()):
entry = self._metadata[key]
h.update(f"{key}:{entry.allocation_id}:{entry.offset_bytes}:".encode())
h.update(entry.value)
return h.hexdigest()
def on_shutdown(self) -> None:
"""Called on server shutdown."""
if self._memory_manager.allocation_count > 0:
count = self._memory_manager.clear_all()
self._metadata.clear()
logger.info(f"Released {count} GPU allocations during shutdown")
# ==================== State Queries ====================
def handle_get_lock_state(
self,
has_rw: bool,
ro_count: int,
waiting_writers: int,
committed: bool,
) -> GetLockStateResponse:
"""Get lock/session state."""
state = derive_state(has_rw, ro_count, committed)
return GetLockStateResponse(
state=state.value,
has_rw_session=has_rw,
ro_session_count=ro_count,
waiting_writers=waiting_writers,
committed=committed,
is_ready=committed and not has_rw,
)
def handle_get_allocation_state(self) -> GetAllocationStateResponse:
"""Get allocation state."""
return GetAllocationStateResponse(
allocation_count=self._memory_manager.allocation_count,
total_bytes=self._memory_manager.total_bytes,
)
# ==================== Allocation Operations ====================
def handle_allocate(self, req: AllocateRequest) -> AllocateResponse:
"""Create physical memory allocation.
Requires RW connection (enforced by server).
"""
info = self._memory_manager.allocate(req.size, req.tag)
return AllocateResponse(
allocation_id=info.allocation_id,
size=info.size,
aligned_size=info.aligned_size,
)
def handle_export(self, allocation_id: str) -> tuple[GetAllocationResponse, int]:
"""Export allocation as POSIX FD.
Returns (response, fd). Caller must close fd after sending.
"""
fd = self._memory_manager.export_fd(allocation_id)
info = self._memory_manager.get_allocation(allocation_id)
response = GetAllocationResponse(
allocation_id=info.allocation_id,
size=info.size,
aligned_size=info.aligned_size,
tag=info.tag,
)
return response, fd
def handle_get_allocation(self, req: GetAllocationRequest) -> GetAllocationResponse:
"""Get allocation info."""
try:
info = self._memory_manager.get_allocation(req.allocation_id)
return GetAllocationResponse(
allocation_id=info.allocation_id,
size=info.size,
aligned_size=info.aligned_size,
tag=info.tag,
)
except AllocationNotFoundError:
raise ValueError(f"Unknown allocation: {req.allocation_id}") from None
def handle_list_allocations(
self, req: ListAllocationsRequest
) -> ListAllocationsResponse:
"""List all allocations."""
allocations = self._memory_manager.list_allocations(req.tag)
result = [
{
"allocation_id": info.allocation_id,
"size": info.size,
"aligned_size": info.aligned_size,
"tag": info.tag,
}
for info in allocations
]
return ListAllocationsResponse(allocations=result)
def handle_free(self, req: FreeRequest) -> FreeResponse:
"""Free single allocation.
Requires RW connection (enforced by server).
"""
success = self._memory_manager.free(req.allocation_id)
return FreeResponse(success=success)
def handle_clear_all(self) -> ClearAllResponse:
"""Clear all allocations and metadata.
Requires RW connection (enforced by server).
"""
count = self._memory_manager.clear_all()
self._metadata.clear()
return ClearAllResponse(cleared_count=count)
# ==================== Metadata Operations ====================
def handle_metadata_put(self, req: MetadataPutRequest) -> MetadataPutResponse:
self._metadata[req.key] = MetadataEntry(
req.allocation_id, req.offset_bytes, req.value
)
return MetadataPutResponse(success=True)
def handle_metadata_get(self, req: MetadataGetRequest) -> MetadataGetResponse:
entry = self._metadata.get(req.key)
if entry is None:
return MetadataGetResponse(found=False)
return MetadataGetResponse(
found=True,
allocation_id=entry.allocation_id,
offset_bytes=entry.offset_bytes,
value=entry.value,
)
def handle_metadata_delete(
self, req: MetadataDeleteRequest
) -> MetadataDeleteResponse:
return MetadataDeleteResponse(
deleted=self._metadata.pop(req.key, None) is not None
)
def handle_metadata_list(self, req: MetadataListRequest) -> MetadataListResponse:
keys = (
[k for k in self._metadata if k.startswith(req.prefix)]
if req.prefix
else list(self._metadata)
)
return MetadataListResponse(keys=sorted(keys))
def handle_get_memory_layout_hash(self) -> GetStateHashResponse:
return GetStateHashResponse(memory_layout_hash=self._memory_layout_hash)
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""CUDA VMM allocation manager - pure business logic, no threading/transport.
This module contains the GMSServerMemoryManager class which handles physical GPU memory
allocations via CUDA Virtual Memory Management (VMM) API. It creates shareable
memory without mapping it locally (no CUDA context needed on the server).
The GMSServerMemoryManager is NOT thread-safe. Callers must provide external
synchronization (e.g., via LockManager ensuring single-writer access).
"""
import logging
import time
from dataclasses import dataclass
from typing import Dict, List, Optional
from uuid import uuid4
from cuda.bindings import driver as cuda
from gpu_memory_service.common.cuda_vmm_utils import (
align_to_granularity,
check_cuda_result,
ensure_cuda_initialized,
get_allocation_granularity,
)
logger = logging.getLogger(__name__)
@dataclass
class AllocationInfo:
"""Information about a single GPU memory allocation.
Attributes:
allocation_id: Unique identifier for this allocation
size: Requested size in bytes
aligned_size: Actual size after alignment to VMM granularity
handle: CUmemGenericAllocationHandle value
tag: User-provided tag for grouping allocations
created_at: Timestamp when allocation was created
"""
allocation_id: str
size: int
aligned_size: int
handle: int
tag: str
created_at: float
class AllocationNotFoundError(Exception):
"""Raised when an allocation_id doesn't exist."""
pass
class GMSServerMemoryManager:
"""GPU Memory Service server-side memory manager.
Manages CUDA VMM physical memory allocations. This class handles the core
memory operations using CUDA Virtual Memory Management (VMM) API. It creates
physical allocations that can be exported as POSIX file descriptors for
sharing with other processes.
Key design points:
- No VA mapping: The memory manager never maps memory to virtual addresses,
so it doesn't create a CUDA context. This allows it to survive GPU
driver failures.
- NOT thread-safe: Callers must provide external synchronization.
The GMSLocalFSM's RW/RO semantics ensure single-writer access.
"""
def __init__(self, device: int = 0):
self._device = device
self._allocations: Dict[str, AllocationInfo] = {}
ensure_cuda_initialized()
self._granularity = get_allocation_granularity(device)
logger.info(
f"GMSServerMemoryManager initialized: device={device}, granularity={self._granularity}"
)
@property
def device(self) -> int:
return self._device
@property
def granularity(self) -> int:
return self._granularity
@property
def allocation_count(self) -> int:
return len(self._allocations)
@property
def total_bytes(self) -> int:
return sum(info.aligned_size for info in self._allocations.values())
def _get(self, allocation_id: str) -> AllocationInfo:
info = self._allocations.get(allocation_id)
if info is None:
raise AllocationNotFoundError(f"Unknown allocation: {allocation_id}")
return info
def _release(self, info: AllocationInfo) -> None:
(result,) = cuda.cuMemRelease(info.handle)
if result != cuda.CUresult.CUDA_SUCCESS:
logger.warning(f"cuMemRelease failed for {info.allocation_id}: {result}")
def allocate(self, size: int, tag: str = "default") -> AllocationInfo:
"""Create a physical memory allocation (no VA mapping).
Uses cuMemCreate to allocate physical GPU memory that can be exported
as a file descriptor for sharing with other processes.
Args:
size: Requested size in bytes (will be aligned up to granularity)
tag: Tag for grouping allocations (e.g., "weights", "kv_cache")
Returns:
AllocationInfo with allocation_id, aligned_size, handle
Raises:
RuntimeError: If CUDA allocation fails
"""
aligned_size = align_to_granularity(size, self._granularity)
prop = cuda.CUmemAllocationProp()
prop.type = cuda.CUmemAllocationType.CU_MEM_ALLOCATION_TYPE_PINNED
prop.location.type = cuda.CUmemLocationType.CU_MEM_LOCATION_TYPE_DEVICE
prop.location.id = self._device
prop.requestedHandleTypes = (
cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR
)
result, handle = cuda.cuMemCreate(aligned_size, prop, 0)
check_cuda_result(result, "cuMemCreate")
info = AllocationInfo(
allocation_id=str(uuid4()),
size=size,
aligned_size=aligned_size,
handle=int(handle),
tag=tag,
created_at=time.time(),
)
self._allocations[info.allocation_id] = info
logger.debug(
f"Allocated {info.allocation_id}: size={size}, aligned={aligned_size}, tag={tag}"
)
return info
def export_fd(self, allocation_id: str) -> int:
"""Export allocation as POSIX FD for SCM_RIGHTS transfer.
The returned file descriptor can be sent to another process via
Unix domain socket SCM_RIGHTS. The receiving process can then
import it using cuMemImportFromShareableHandle.
IMPORTANT: The caller MUST close the returned FD after sendmsg()
to avoid file descriptor leaks.
Args:
allocation_id: ID of allocation to export
Returns:
File descriptor (caller owns, must close after sending)
Raises:
AllocationNotFoundError: If allocation_id doesn't exist
RuntimeError: If CUDA export fails
"""
info = self._get(allocation_id)
result, fd = cuda.cuMemExportToShareableHandle(
info.handle,
cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR,
0,
)
check_cuda_result(result, "cuMemExportToShareableHandle")
return int(fd)
def free(self, allocation_id: str) -> bool:
"""Release physical memory for a single allocation.
Args:
allocation_id: ID of allocation to free
Returns:
True if allocation existed and was freed, False otherwise
"""
info = self._allocations.pop(allocation_id, None)
if info is None:
return False
self._release(info)
logger.debug(f"Freed allocation: {allocation_id}")
return True
def clear_all(self) -> int:
"""Release ALL allocations.
Used by loaders before loading a new model, or during cleanup
when a writer aborts without committing.
Returns:
Number of allocations cleared
"""
count = len(self._allocations)
for info in self._allocations.values():
self._release(info)
self._allocations.clear()
logger.info(f"Cleared {count} allocations")
return count
def get_allocation(self, allocation_id: str) -> AllocationInfo:
"""Get allocation info. Raises AllocationNotFoundError if not found."""
return self._get(allocation_id)
def list_allocations(self, tag: Optional[str] = None) -> List[AllocationInfo]:
"""List all allocations, optionally filtered by tag."""
if tag is None:
return list(self._allocations.values())
return [info for info in self._allocations.values() if info.tag == tag]
This diff is collapsed.
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""GPU Memory Service Shadow Engine Failover Test for SGLang."""
import pytest
from tests.utils.constants import FAULT_TOLERANCE_MODEL_NAME
from .utils.common import run_shadow_failover_test
from .utils.sglang import SGLangWithGMSProcess
@pytest.mark.sglang
@pytest.mark.e2e
@pytest.mark.gpu_1
@pytest.mark.fault_tolerance
@pytest.mark.nightly
@pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME)
@pytest.mark.timeout(600)
def test_gms_shadow_engine_failover(
request, runtime_services, gms_ports, predownload_models
):
ports = gms_ports
run_shadow_failover_test(
request,
ports,
make_shadow=lambda: SGLangWithGMSProcess(
request,
"shadow",
ports["shadow_system"],
ports["shadow_sglang"],
ports["frontend"],
),
make_primary=lambda: SGLangWithGMSProcess(
request,
"primary",
ports["primary_system"],
ports["primary_sglang"],
ports["frontend"],
),
)
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""GPU Memory Service Shadow Engine Failover Test for vLLM."""
import pytest
from tests.utils.constants import FAULT_TOLERANCE_MODEL_NAME
from .utils.common import run_shadow_failover_test
from .utils.vllm import VLLMWithGMSProcess
@pytest.mark.vllm
@pytest.mark.e2e
@pytest.mark.gpu_1
@pytest.mark.fault_tolerance
@pytest.mark.nightly
@pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME)
@pytest.mark.timeout(600)
def test_gms_shadow_engine_failover(
request, runtime_services, gms_ports, predownload_models
):
ports = gms_ports
run_shadow_failover_test(
request,
ports,
make_shadow=lambda: VLLMWithGMSProcess(
request,
"shadow",
ports["shadow_system"],
ports["shadow_kv_event"],
ports["shadow_nixl"],
ports["frontend"],
),
make_primary=lambda: VLLMWithGMSProcess(
request,
"primary",
ports["primary_system"],
ports["primary_kv_event"],
ports["primary_nixl"],
ports["frontend"],
),
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment