Unverified Commit 91a8c6af authored by Schwinn Saereesitthipitak's avatar Schwinn Saereesitthipitak Committed by GitHub
Browse files

fix: simplify GMS layout state and harden GPU-backed flows (#7006)


Signed-off-by: default avatarmohammedabdulwahhab <furkhan324@berkeley.edu>
Signed-off-by: default avatarHannah Zhang <hannahz@nvidia.com>
Co-authored-by: default avatarmohammedabdulwahhab <furkhan324@berkeley.edu>
Co-authored-by: default avatarhhzhang16 <54051230+hhzhang16@users.noreply.github.com>
parent dd7ceb4a
......@@ -14,11 +14,13 @@ Usage:
from __future__ import annotations
import logging
from dataclasses import replace
import torch
from gpu_memory_service.integrations.common import patch_empty_cache
from gpu_memory_service.integrations.common.utils import setup_meta_tensor_workaround
from gpu_memory_service.integrations.common.utils import (
setup_meta_tensor_workaround,
strip_gms_model_loader_config,
)
from gpu_memory_service.integrations.sglang.patches import (
patch_model_runner,
patch_static_state_for_gms,
......@@ -50,7 +52,10 @@ class GMSModelLoader:
if self._default_loader is None:
from sglang.srt.model_loader.loader import DefaultModelLoader
config = replace(self.load_config, load_format="auto")
config = strip_gms_model_loader_config(
self.load_config,
load_format="auto",
)
self._default_loader = DefaultModelLoader(config)
return self._default_loader
......@@ -124,7 +129,10 @@ class GMSModelLoader:
with meta_device:
model = get_model(
model_config=model_config,
load_config=replace(self.load_config, load_format="dummy"),
load_config=strip_gms_model_loader_config(
self.load_config,
load_format="dummy",
),
device_config=device_config,
)
......
......@@ -10,11 +10,12 @@
from __future__ import annotations
import inspect
import logging
from contextlib import contextmanager
from typing import Optional
import torch
from gpu_memory_service.common.utils import get_socket_path
logger = logging.getLogger(__name__)
......@@ -34,6 +35,7 @@ def patch_torch_memory_saver() -> None:
return
try:
import torch_memory_saver
import torch_memory_saver.entrypoint as entrypoint_module
except ImportError:
logger.debug("[GMS] torch_memory_saver not installed, skipping patch")
......@@ -41,6 +43,7 @@ def patch_torch_memory_saver() -> None:
# Store reference to original method
original_ensure_initialized = entrypoint_module.TorchMemorySaver._ensure_initialized
original_configure_subprocess = torch_memory_saver.configure_subprocess
def patched_ensure_initialized(self):
"""Patched _ensure_initialized that uses GPU Memory Service implementation."""
......@@ -63,10 +66,7 @@ def patch_torch_memory_saver() -> None:
# Get device from torch.cuda.current_device() (already set by SGLang)
device_index = torch.cuda.current_device()
# Resolve socket path from env or default
socket_path = get_socket_path(device_index)
# Create underlying torch impl for non-weights tags (KV cache etc.)
# Create underlying torch impl for non-GMS tags.
torch_impl = _TorchMemorySaverImpl(hook_mode="torch")
# Read lock mode set by setup_gms() (defaults to RW_OR_RO)
......@@ -74,7 +74,6 @@ def patch_torch_memory_saver() -> None:
gms_impl = GMSMemorySaverImpl(
torch_impl=torch_impl,
socket_path=socket_path,
device_index=device_index,
mode=_gms_lock_mode,
)
......@@ -82,9 +81,8 @@ def patch_torch_memory_saver() -> None:
# Set _impl directly (accessible via gms_impl property)
self._impl = gms_impl
logger.info(
"[GMS] Using GMS mode (device=%d, socket=%s, mode=%s)",
"[GMS] Using GMS mode (device=%d, mode=%s)",
device_index,
socket_path,
gms_impl.get_mode(),
)
del self._impl_ctor_kwargs
......@@ -95,6 +93,23 @@ def patch_torch_memory_saver() -> None:
entrypoint_module.TorchMemorySaver._ensure_initialized = patched_ensure_initialized
@contextmanager
def patched_configure_subprocess():
"""Avoid LD_PRELOAD in GMS mode; keep upstream behavior otherwise."""
singleton = torch_memory_saver.torch_memory_saver
ctor_kwargs = getattr(singleton, "_impl_ctor_kwargs", None) or {}
hook_mode = ctor_kwargs.get("hook_mode")
if hook_mode is None or hook_mode == "gms":
logger.info("[GMS] torch_memory_saver.configure_subprocess is a no-op")
yield
return
with original_configure_subprocess():
yield
torch_memory_saver.configure_subprocess = patched_configure_subprocess
# Add property to access GMS impl directly from the singleton
from gpu_memory_service.integrations.sglang.memory_saver import GMSMemorySaverImpl
......@@ -132,9 +147,10 @@ def patch_torch_memory_saver() -> None:
def patch_model_runner() -> None:
"""Patch SGLang's ModelRunner to fix memory accounting with pre-loaded weights.
When weights are pre-loaded via GMS (import-only mode), SGLang's min_per_gpu_memory
captured before loading is lower than device total. This causes under-reservation
of overhead memory in KV cache calculation.
SGLang 0.5.9 passes a startup free-memory snapshot as total_gpu_memory into
init_memory_pool(). In GMS read mode, imported weights can already occupy GPU
memory, so that snapshot is lower than physical device capacity and the KV cache
overhead term is under-reserved.
"""
global _model_runner_patched
......@@ -151,25 +167,56 @@ def patch_model_runner() -> None:
return
original_init_memory_pool = ModelRunner.init_memory_pool
memory_arg_name = next(
(
name
for name in inspect.signature(original_init_memory_pool).parameters
if name != "self"
),
None,
)
def patched_init_memory_pool(self, *args, **kwargs):
"""Patched init_memory_pool that uses device total for overhead calculation."""
"""Patch init_memory_pool for SGLang versions that use total_gpu_memory.
SGLang's KV cache formula uses total_gpu_memory as the baseline:
rest_memory = available - total*(1-mem_fraction).
Replace that baseline with physical device capacity when GMS imported
weights are already resident. Newer SGLang versions changed this API, so
only rewrite the old total_gpu_memory parameter shape.
"""
from gpu_memory_service.integrations.sglang.memory_saver import (
get_gms_memory_saver_impl,
)
impl = get_gms_memory_saver_impl()
if impl is not None and impl.get_imported_weights_bytes() > 0:
total_memory = torch.cuda.get_device_properties(
total_memory_gib = torch.cuda.get_device_properties(
torch.cuda.current_device()
).total_memory
if hasattr(self, "min_per_gpu_memory"):
old_value = self.min_per_gpu_memory
self.min_per_gpu_memory = total_memory
).total_memory / (1 << 30)
if memory_arg_name == "total_gpu_memory":
if args:
old_value = args[0]
args = (total_memory_gib,) + args[1:]
elif memory_arg_name in kwargs:
old_value = kwargs[memory_arg_name]
kwargs = dict(kwargs)
kwargs[memory_arg_name] = total_memory_gib
else:
old_value = None
logger.info(
"[GMS] Adjusted total_gpu_memory: %s -> %.2f GiB",
(
f"{old_value:.2f} GiB"
if isinstance(old_value, (int, float))
else "<missing>"
),
total_memory_gib,
)
elif memory_arg_name is not None:
logger.info(
"[GMS] Adjusted min_per_gpu_memory: %.2f GiB -> %.2f GiB",
old_value / (1 << 30),
total_memory / (1 << 30),
"[GMS] Leaving %s unchanged in patched init_memory_pool",
memory_arg_name,
)
return original_init_memory_pool(self, *args, **kwargs)
......
......@@ -11,11 +11,11 @@ processes import from GMS metadata (RO).
from __future__ import annotations
import logging
from dataclasses import replace
from typing import TYPE_CHECKING
import torch
from gpu_memory_service import get_or_create_gms_client_memory_manager
from gpu_memory_service.client.torch.allocator import gms_use_mem_pool
from gpu_memory_service.client.torch.module import materialize_module_from_gms
from gpu_memory_service.common.types import GrantedLockType
from gpu_memory_service.common.utils import get_socket_path
......@@ -23,6 +23,7 @@ from gpu_memory_service.integrations.common.utils import (
finalize_gms_write,
get_gms_lock_mode,
setup_meta_tensor_workaround,
strip_gms_model_loader_config,
)
if TYPE_CHECKING:
......@@ -49,23 +50,14 @@ def register_gms_loader(load_format: str = "gms") -> None:
class GMSModelLoader(BaseModelLoader):
"""vLLM model loader that loads weights via GPU Memory Service."""
# Keys in model_loader_extra_config that are GMS-specific and should
# not be passed to the fallback DefaultModelLoader.
_GMS_EXTRA_KEYS = frozenset({"gms_read_only"})
def __init__(self, load_config):
super().__init__(load_config)
# Strip GMS-specific keys before creating the fallback loader,
# otherwise DefaultModelLoader rejects unknown extra config.
extra = getattr(load_config, "model_loader_extra_config", None) or {}
clean_extra = {
k: v for k, v in extra.items() if k not in self._GMS_EXTRA_KEYS
}
self.default_loader = DefaultModelLoader(
replace(
strip_gms_model_loader_config(
load_config,
load_format="auto",
model_loader_extra_config=clean_extra,
)
)
......@@ -79,8 +71,8 @@ def register_gms_loader(load_format: str = "gms") -> None:
device = torch.cuda.current_device()
extra = getattr(self.load_config, "model_loader_extra_config", {}) or {}
mode = get_gms_lock_mode(extra)
gms_client, pool = get_or_create_gms_client_memory_manager(
get_socket_path(device),
gms_client = get_or_create_gms_client_memory_manager(
get_socket_path(device, "weights"),
device,
mode=mode,
tag="weights",
......@@ -91,7 +83,6 @@ def register_gms_loader(load_format: str = "gms") -> None:
else:
return _load_write_mode(
gms_client,
pool,
vllm_config,
model_config,
self.default_loader,
......@@ -130,7 +121,6 @@ def _load_read_mode(
def _load_write_mode(
gms_client: "GMSClientMemoryManager",
pool,
vllm_config,
model_config,
default_loader,
......@@ -143,18 +133,15 @@ def _load_write_mode(
"""
global _last_imported_weights_bytes
from torch.cuda.memory import use_mem_pool
from vllm.model_executor.model_loader.utils import (
initialize_model,
process_weights_after_loading,
)
from vllm.utils.torch_utils import set_default_torch_dtype
gms_client.clear_all_handles()
# Allocate model tensors using GMS memory pool
with set_default_torch_dtype(model_config.dtype):
with use_mem_pool(pool, device=target_device):
with gms_use_mem_pool("weights", target_device):
with target_device:
model = initialize_model(
vllm_config=vllm_config, model_config=model_config
......
......@@ -11,9 +11,7 @@ They should only allocate on their cache when they are the active/leader engine.
from __future__ import annotations
import logging
import time
import torch
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
logger = logging.getLogger(__name__)
......@@ -48,6 +46,10 @@ class GMSShadowModelRunner(GPUModelRunner):
logger.info(
"[Shadow] Init phase: stored config, skipping KV cache allocation"
)
print(
"[Shadow] Init phase: stored config, skipping KV cache allocation",
flush=True,
)
return {}
return super().initialize_kv_cache_tensors(kv_cache_config, kernel_block_sizes)
......@@ -86,7 +88,9 @@ class GMSShadowModelRunner(GPUModelRunner):
"""Allocate KV cache on wake using config stored during shadow init.
Called by GMSWorker.wake_up() after shadow init phase is exited.
Waits up to 60s for GPU memory to be freed.
GMS kv_cache RW lock acquisition serves as the memory barrier — the
dying engine's abort() releases the lock and frees memory before we
can connect.
"""
assert hasattr(
self, "_shadow_kv_cache_config"
......@@ -95,47 +99,7 @@ class GMSShadowModelRunner(GPUModelRunner):
self, "_shadow_kernel_block_sizes"
), "_shadow_kernel_block_sizes not set — was enter_shadow_init() called?"
# OOM remediation during failover: wait for the dying engine to release memory.
# TODO: This will be replaced with a barrier in GMS when we manage kv cache there instead
config = self._shadow_kv_cache_config
kv_cache_bytes = sum(t.size for t in config.kv_cache_tensors)
free_bytes, _ = torch.cuda.mem_get_info()
if free_bytes < kv_cache_bytes:
logger.info(
"[Shadow] Waiting for GPU memory (need %.2f GiB, free %.2f GiB)",
kv_cache_bytes / (1 << 30),
free_bytes / (1 << 30),
)
deadline = time.monotonic() + 60.0
last_log = time.monotonic()
while free_bytes < kv_cache_bytes:
if time.monotonic() > deadline:
raise RuntimeError(
f"Timed out waiting for GPU memory: "
f"need {kv_cache_bytes / (1 << 30):.2f} GiB, "
f"free {free_bytes / (1 << 30):.2f} GiB"
)
now = time.monotonic()
if now - last_log >= 5.0:
elapsed = now - (deadline - 60.0)
remaining = deadline - now
logger.info(
"[Shadow] Still waiting for GPU memory: "
"need %.2f GiB, free %.2f GiB "
"(%.0fs elapsed, %.0fs remaining)",
kv_cache_bytes / (1 << 30),
free_bytes / (1 << 30),
elapsed,
remaining,
)
last_log = now
time.sleep(0.5)
free_bytes = torch.cuda.mem_get_info()[0]
logger.info(
"[Shadow] GPU memory available (free %.2f GiB), proceeding",
free_bytes / (1 << 30),
)
logger.info("[Shadow] Allocating KV cache on wake")
......@@ -163,10 +127,11 @@ class GMSShadowModelRunner(GPUModelRunner):
logger.debug("[Shadow] KV transfer group not available")
total_bytes = sum(t.numel() * t.element_size() for t in kv_caches.values())
logger.info(
"[Shadow] Allocated KV cache on wake: %.2f GiB (%d tensors)",
msg = "[Shadow] Allocated KV cache on wake: %.2f GiB (%d tensors)" % (
total_bytes / (1 << 30),
len(kv_caches),
)
logger.info(msg)
print(msg, flush=True)
return kv_caches
......@@ -48,12 +48,12 @@ def patch_memory_snapshot() -> None:
def patched_measure(self):
original_measure(self)
manager = get_gms_client_memory_manager()
manager = get_gms_client_memory_manager("weights")
assert manager is not None, "GMS client is not initialized"
if manager.granted_lock_type == GrantedLockType.RO:
allocations = manager.list_handles()
committed_bytes = sum(alloc.get("aligned_size", 0) for alloc in allocations)
committed_bytes = sum(alloc.aligned_size for alloc in allocations)
else:
# NOTE: by design, we want to assume we have the whole GPU when writing
# weights for the first time, so we don't make an adjustment.
......
......@@ -23,6 +23,7 @@ from gpu_memory_service import (
get_or_create_gms_client_memory_manager,
)
from gpu_memory_service.client.memory_manager import StaleMemoryLayoutError
from gpu_memory_service.client.torch.allocator import gms_use_mem_pool
from gpu_memory_service.common.types import RequestedLockType
from gpu_memory_service.common.utils import get_socket_path
from gpu_memory_service.integrations.common import patch_empty_cache
......@@ -70,18 +71,16 @@ class GMSWorker(Worker):
# Establish weights GMS connection (so MemorySnapshot can query committed bytes).
# Lock type is determined by model_loader_extra_config, set upstream by
# configure_gms_lock_mode() in main.py.
socket_path = get_socket_path(device)
extra = (
getattr(self.vllm_config.load_config, "model_loader_extra_config", {}) or {}
)
mode = get_gms_lock_mode(extra)
get_or_create_gms_client_memory_manager(
socket_path,
get_socket_path(device, "weights"),
device,
mode=mode,
tag="weights",
)
# Parent will set device again (harmless) and do memory checks
super().init_device()
......@@ -111,9 +110,9 @@ class GMSWorker(Worker):
torch.cuda.synchronize()
torch_peak = torch.cuda.max_memory_allocated()
# If weights are strictly loaded (RO), torch's memory accounting will miss them since we didn't go through the mempool
# We therefore add in the memory of the weights into our accounting here
# This is not an issue on engines that write the weights and then downgrade to RO
# GMS weights mapped via cuMemMap are invisible to PyTorch's memory
# stats on RO engines. Add them explicitly. On RW engines, torch_peak
# already includes weights so skip to avoid double-counting.
weights_memory = int(getattr(self.model_runner, "model_memory_usage", 0))
if torch_peak < weights_memory:
non_kv_cache_memory = torch_peak + weights_memory
......@@ -122,19 +121,62 @@ class GMSWorker(Worker):
projected_available = self.requested_memory - non_kv_cache_memory
logger.info(
msg = (
"[GMS] Shadow mode: projected available memory "
"%.2f GiB (requested=%.2f GiB, non_kv=%.2f GiB, "
"torch_peak=%.2f GiB, weights=%.2f GiB)",
"torch_peak=%.2f GiB, weights=%.2f GiB)"
% (
projected_available / (1 << 30),
self.requested_memory / (1 << 30),
non_kv_cache_memory / (1 << 30),
torch_peak / (1 << 30),
weights_memory / (1 << 30),
)
)
logger.info(msg)
print(msg, flush=True)
return int(projected_available)
def initialize_from_config(self, kv_cache_config) -> None:
"""Allocate KV cache with a dedicated RW-only GMS tag.
Also validates cudagraph mode for shadow mode compatibility.
"""
from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized
ensure_kv_transfer_initialized(self.vllm_config, kv_cache_config)
if is_shadow_mode():
# GMS client for kv cache is deferred to wake for shadow mode
# GMSShadowModelRunner.initialize_kv_cache intercepts and stores config without creating an allocation
self.model_runner.initialize_kv_cache(kv_cache_config)
elif self.vllm_config.model_config.enable_sleep_mode:
# Normal sleep/wake: create kv_cache GMS tag now for unmap/remap
device = self.local_rank
get_or_create_gms_client_memory_manager(
get_socket_path(device, "kv_cache"),
device,
mode=RequestedLockType.RW,
tag="kv_cache",
)
with gms_use_mem_pool("kv_cache", torch.device(f"cuda:{device}")):
self.model_runner.initialize_kv_cache(kv_cache_config)
else:
# No sleep mode: plain KV cache init
self.model_runner.initialize_kv_cache(kv_cache_config)
# Validate cudagraph mode for shadow mode compatibility
if is_shadow_mode():
from vllm.config import CUDAGraphMode
mode = self.model_runner.compilation_config.cudagraph_mode
if mode not in (CUDAGraphMode.PIECEWISE, CUDAGraphMode.NONE):
raise RuntimeError(
f"Shadow mode requires PIECEWISE cudagraph mode after resolution, "
f"but got {mode.name}. vLLM's config resolution overrode it."
)
def load_model(self, *args, **kwargs) -> None:
"""Load model with corrected memory accounting.
......@@ -166,54 +208,38 @@ class GMSWorker(Worker):
except Exception as e:
logger.debug("[GMS] Could not correct memory accounting: %s", e)
def initialize_from_config(self, kv_cache_config) -> None:
"""Initialize from config with post-init cudagraph mode assertion.
vLLM can try to upgrade the cudagraph mode in certain scenarios. We
assert that the final resolved mode is still compatible with shadow mode.
"""
super().initialize_from_config(kv_cache_config)
if is_shadow_mode():
from vllm.config import CUDAGraphMode
mode = self.model_runner.compilation_config.cudagraph_mode
if mode not in (CUDAGraphMode.PIECEWISE, CUDAGraphMode.NONE):
raise RuntimeError(
f"Shadow mode requires PIECEWISE cudagraph mode after resolution, "
f"but got {mode.name}. vLLM's config resolution overrode it."
)
def sleep(self, level: int = 1) -> None:
"""
vLLM sleep implementation with GMS integration.
NOTE: `level` is a no-op here: weights are only unmapped (but remain in GPU memory).
NOTE: We do NOT call super().sleep() because it tries to copy GPU buffers to CPU,
which segfaults on already-unmapped GMS memory.
Handles two cases for KV cache:
1. Normal: KV cache was allocated, sleep via CuMemAllocator
2. Shadow: KV cache was skipped at startup, nothing to do
1. Normal: KV cache was allocated via GMS, unmap + abort
2. Shadow: KV cache was skipped at startup, manager has no allocations
(unmap_all_vas is a no-op, abort disconnects)
"""
free_bytes_before = torch.cuda.mem_get_info()[0]
# Unmap GMS weights: synchronize + unmap all VAs + disconnect
manager = get_gms_client_memory_manager()
assert manager is not None, "GMS client is not initialized"
assert not manager.is_unmapped, "GMS weights are already unmapped"
manager.unmap_all_vas()
manager.disconnect()
# Sleep KV cache via CuMemAllocator (discard, no CPU backup)
# If KV cache was never allocated (shadow engine mode), this is a no-op
from vllm.device_allocator.cumem import CuMemAllocator
kv_caches = getattr(self.model_runner, "kv_caches", None)
if kv_caches:
allocator = CuMemAllocator.get_instance()
allocator.sleep(offload_tags=tuple())
weights_manager = get_gms_client_memory_manager("weights")
assert weights_manager is not None, "GMS weights client is not initialized"
assert not weights_manager.is_unmapped, "GMS weights are already unmapped"
weights_manager.unmap_all_vas()
weights_manager.abort()
# Unmap GMS KV cache: unmap all VAs + disconnect
# In shadow mode, kv_cache manager is deferred to wake — nothing to unmap.
kv_cache_manager = get_gms_client_memory_manager("kv_cache")
if kv_cache_manager is not None:
assert not kv_cache_manager.is_unmapped, "GMS KV cache is already unmapped"
kv_cache_manager.unmap_all_vas()
kv_cache_manager.abort()
else:
logger.info("[GMS] KV cache not allocated (shadow mode), skipping sleep")
logger.info(
"[GMS] No kv_cache manager (shadow mode), skipping kv_cache sleep"
)
free_bytes_after, total = torch.cuda.mem_get_info()
freed_bytes = free_bytes_after - free_bytes_before
......@@ -228,7 +254,7 @@ class GMSWorker(Worker):
"""vLLM wake implementation with GMS integration.
Handles two cases for KV cache:
1. Normal: KV cache was allocated at startup, reallocate via CuMemAllocator
1. Normal: KV cache was allocated at startup, reconnect + reallocate + remap
2. Shadow: KV cache was skipped at startup, allocate via allocate_kv_cache_on_wake()
"""
if (
......@@ -241,16 +267,16 @@ class GMSWorker(Worker):
tags = ["weights", "kv_cache"]
if "weights" in tags:
manager = get_gms_client_memory_manager()
assert manager is not None, "GMS client is not initialized"
assert manager.is_unmapped, "GMS weights are not unmapped"
weights_manager = get_gms_client_memory_manager("weights")
assert weights_manager is not None, "GMS weights client is not initialized"
assert weights_manager.is_unmapped, "GMS weights are not unmapped"
# These errors are fatal and unrecoverable in a worker subprocess:
# the worker cannot serve requests without weights. sys.exit(1)
# ensures clean termination so the orchestrator (K8s) can restart.
try:
manager.connect(RequestedLockType.RO, timeout_ms=30_000)
manager.remap_all_vas()
weights_manager.connect(RequestedLockType.RO, timeout_ms=30_000)
weights_manager.remap_all_vas()
except TimeoutError:
logger.error(
"Fatal: timed out waiting for GMS RO lock during remap "
......@@ -270,15 +296,30 @@ class GMSWorker(Worker):
# Check if KV cache was skipped at startup (shadow engine mode)
kv_caches = getattr(self.model_runner, "kv_caches", None)
if not kv_caches:
# Shadow mode: create kv_cache manager now (deferred from init
# to avoid RW lock contention between concurrent engines).
logger.info("[GMS] KV cache not allocated - allocating on wake")
get_or_create_gms_client_memory_manager(
get_socket_path(self.local_rank, "kv_cache"),
self.local_rank,
mode=RequestedLockType.RW,
tag="kv_cache",
)
with gms_use_mem_pool(
"kv_cache", torch.device("cuda", self.local_rank)
):
self.model_runner.allocate_kv_cache_on_wake()
logger.info("[GMS] Successfully allocated KV cache on wake")
else:
# Normal case: KV cache was allocated, reallocate via CuMemAllocator
from vllm.device_allocator.cumem import CuMemAllocator
allocator = CuMemAllocator.get_instance()
allocator.wake_up(tags=["kv_cache"])
# Normal case: KV cache was allocated via GMS, reconnect + reallocate + remap
kv_cache_manager = get_gms_client_memory_manager("kv_cache")
assert (
kv_cache_manager is not None
), "GMS KV cache client is not initialized"
assert kv_cache_manager.is_unmapped, "GMS KV cache is not unmapped"
kv_cache_manager.connect(RequestedLockType.RW)
kv_cache_manager.reallocate_all_handles(tag="kv_cache")
kv_cache_manager.remap_all_vas()
# Reinitialize FP8 KV scales if needed
if self.cache_config.cache_dtype.startswith("fp8") and hasattr(
......@@ -287,12 +328,18 @@ class GMSWorker(Worker):
self.model_runner.init_fp8_kv_scales()
def _maybe_get_memory_pool_context(self, tag: str):
"""Skip CuMemAllocator for weights when using GMS.
GMS manages its own memory pool for weights, so we don't want vLLM's
CuMemAllocator to interfere.
"""Route tag-scoped runtime allocations to the right allocator.
Weight tensors are allocated explicitly in the GMS model-loader path,
not through vLLM's tagged runtime allocator hook. For `weights` we
therefore only suppress CuMemAllocator here so it does not interfere
with the loader-managed GMS allocations. `kv_cache` is the tag that
actually allocates through this hook, so it uses the dedicated GMS
mempool.
"""
if tag == "weights":
logger.debug("[GMS] Skipping CuMemAllocator for weights")
return nullcontext()
if tag == "kv_cache":
return gms_use_mem_pool("kv_cache", torch.device("cuda", self.local_rank))
return super()._maybe_get_memory_pool_context(tag)
......@@ -9,26 +9,33 @@ from gpu_memory_service.common.types import (
ServerState,
StateSnapshot,
)
from gpu_memory_service.server.handler import MetadataEntry, RequestHandler
from gpu_memory_service.server.locking import Connection, GMSLocalFSM
from gpu_memory_service.server.memory_manager import (
from gpu_memory_service.server.allocations import (
AllocationInfo,
AllocationNotFoundError,
GMSServerMemoryManager,
GMSAllocationManager,
)
from gpu_memory_service.server.gms import GMS, MetadataEntry
from gpu_memory_service.server.rpc import GMSRPCServer
from gpu_memory_service.server.session import (
Connection,
GMSSessionManager,
InvalidTransition,
OperationNotAllowed,
)
__all__ = [
"GMSRPCServer",
"GMSServerMemoryManager",
"GMS",
"GMSSessionManager",
"GMSAllocationManager",
"AllocationInfo",
"AllocationNotFoundError",
"MetadataEntry",
"Connection",
"GrantedLockType",
"RequestedLockType",
"RequestHandler",
"ServerState",
"GMSLocalFSM",
"StateSnapshot",
"InvalidTransition",
"OperationNotAllowed",
]
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Server-side CUDA allocation store."""
from __future__ import annotations
import asyncio
import logging
import time
from dataclasses import dataclass
from typing import Callable, Optional
from uuid import uuid4
from gpu_memory_service.common.cuda_utils import (
align_to_granularity,
cuda_ensure_initialized,
cumem_create_tolerate_oom,
cumem_export_to_shareable_handle,
cumem_get_allocation_granularity,
cumem_release,
)
logger = logging.getLogger(__name__)
@dataclass(frozen=True)
class AllocationInfo:
allocation_id: str
size: int
aligned_size: int
handle: int
tag: str
layout_slot: int
created_at: float
class AllocationNotFoundError(Exception):
"""Raised when an allocation_id doesn't exist."""
class GMSAllocationManager:
"""Server-side CUDA VMM allocation store."""
def __init__(
self,
device: int = 0,
*,
allocation_retry_interval: float = 0.5,
allocation_retry_timeout: Optional[float] = None,
):
if allocation_retry_interval <= 0:
raise ValueError(
f"allocation_retry_interval must be > 0, got {allocation_retry_interval}"
)
if allocation_retry_timeout is not None and allocation_retry_timeout <= 0:
raise ValueError(
f"allocation_retry_timeout must be > 0 when set, got {allocation_retry_timeout}"
)
self._device = device
self._allocations: dict[str, AllocationInfo] = {}
self._next_layout_slot = 0
cuda_ensure_initialized()
self._granularity = cumem_get_allocation_granularity(device)
self._allocation_retry_interval = allocation_retry_interval
self._allocation_retry_timeout = allocation_retry_timeout
logger.info(
"GMSAllocationManager initialized: device=%d, granularity=%d, "
"alloc_retry_interval=%.3f, alloc_retry_timeout=%s",
device,
self._granularity,
self._allocation_retry_interval,
(
f"{self._allocation_retry_timeout:.3f}"
if self._allocation_retry_timeout is not None
else "none"
),
)
@property
def device(self) -> int:
return self._device
@property
def allocation_count(self) -> int:
return len(self._allocations)
async def allocate(
self,
size: int,
tag: str = "default",
is_connected: Optional[Callable[[], bool]] = None,
on_oom: Optional[Callable[[], None]] = None,
) -> AllocationInfo:
if size <= 0:
raise ValueError(f"size must be > 0, got {size}")
aligned_size = align_to_granularity(size, self._granularity)
started_at = time.monotonic()
reported_oom = False
while True:
if is_connected is not None and not is_connected():
raise ConnectionAbortedError(
"RW client disconnected during allocation retry"
)
allocated, handle = cumem_create_tolerate_oom(aligned_size, self._device)
if allocated:
break
if on_oom is not None and not reported_oom:
on_oom()
reported_oom = True
if self._allocation_retry_timeout is not None:
waited = time.monotonic() - started_at
if waited >= self._allocation_retry_timeout:
raise TimeoutError(
"Timed out waiting for GPU memory: "
f"requested_size={size}, aligned_size={aligned_size}, "
f"tag={tag}, waited_sec={waited:.3f}"
)
logger.warning(
"cuMemCreate OOM for aligned_size=%d bytes, tag=%s; retrying in %.3fs",
aligned_size,
tag,
self._allocation_retry_interval,
)
await asyncio.sleep(self._allocation_retry_interval)
info = AllocationInfo(
allocation_id=str(uuid4()),
size=size,
aligned_size=aligned_size,
handle=int(handle),
tag=tag,
layout_slot=self._next_layout_slot,
created_at=time.time(),
)
self._next_layout_slot = info.layout_slot + 1
self._allocations[info.allocation_id] = info
logger.debug(
"Allocated %s: size=%d, aligned=%d, tag=%s, slot=%d",
info.allocation_id,
size,
aligned_size,
tag,
info.layout_slot,
)
return info
def export_allocation(self, allocation_id: str) -> int:
return cumem_export_to_shareable_handle(
self.get_allocation(allocation_id).handle
)
def free_allocation(self, allocation_id: str) -> bool:
info = self._allocations.get(allocation_id)
if info is None:
return False
cumem_release(info.handle)
self._allocations.pop(allocation_id, None)
logger.debug("Freed allocation: %s", allocation_id)
return True
def clear_all(self) -> int:
allocation_ids = list(self._allocations)
for allocation_id in allocation_ids:
info = self._allocations[allocation_id]
cumem_release(info.handle)
self._allocations.pop(allocation_id, None)
if allocation_ids:
logger.info("Cleared %d allocations", len(allocation_ids))
self._next_layout_slot = 0
return len(allocation_ids)
def get_allocation(self, allocation_id: str) -> AllocationInfo:
info = self._allocations.get(allocation_id)
if info is None:
raise AllocationNotFoundError(f"Unknown allocation: {allocation_id}")
return info
def list_allocations(self, tag: Optional[str] = None) -> list[AllocationInfo]:
allocations = list(self._allocations.values())
allocations.sort(key=lambda info: info.layout_slot)
if tag is None:
return allocations
return [info for info in allocations if info.tag == tag]
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Top-level server-side GMS service."""
from __future__ import annotations
import hashlib
import logging
from collections import deque
from dataclasses import dataclass
from typing import Callable, Optional
from gpu_memory_service.common.protocol.messages import (
AllocateRequest,
AllocateResponse,
CommitRequest,
CommitResponse,
ExportAllocationRequest,
ExportAllocationResponse,
FreeAllocationRequest,
FreeAllocationResponse,
GetAllocationRequest,
GetAllocationResponse,
GetAllocationStateRequest,
GetAllocationStateResponse,
GetEventHistoryResponse,
GetLockStateRequest,
GetLockStateResponse,
GetRuntimeStateResponse,
GetStateHashRequest,
GetStateHashResponse,
GMSRuntimeEvent,
ListAllocationsRequest,
ListAllocationsResponse,
MetadataDeleteRequest,
MetadataDeleteResponse,
MetadataGetRequest,
MetadataGetResponse,
MetadataListRequest,
MetadataListResponse,
MetadataPutRequest,
MetadataPutResponse,
)
from gpu_memory_service.common.types import (
GrantedLockType,
RequestedLockType,
ServerState,
StateEvent,
)
from .allocations import AllocationInfo, GMSAllocationManager
from .session import Connection, GMSSessionManager
logger = logging.getLogger(__name__)
@dataclass(frozen=True)
class MetadataEntry:
allocation_id: str
offset_bytes: int
value: bytes
class GMS:
"""Owns all non-transport server state."""
_MAX_EVENTS = 256
def __init__(
self,
device: int = 0,
*,
allocation_retry_interval: float = 0.5,
allocation_retry_timeout: Optional[float] = None,
):
self._allocations = GMSAllocationManager(
device,
allocation_retry_interval=allocation_retry_interval,
allocation_retry_timeout=allocation_retry_timeout,
)
self._sessions = GMSSessionManager()
self._events: deque[GMSRuntimeEvent] = deque(maxlen=self._MAX_EVENTS)
self._metadata: dict[str, MetadataEntry] = {}
self._memory_layout_hash = ""
logger.info("GMS initialized: device=%d", device)
@property
def state(self) -> ServerState:
return self._sessions.state
@property
def committed(self) -> bool:
return self._sessions.snapshot().committed
@property
def allocation_count(self) -> int:
return self._allocations.allocation_count
def is_ready(self) -> bool:
return self._sessions.snapshot().is_ready
def get_runtime_state(self) -> GetRuntimeStateResponse:
session = self._sessions.snapshot()
return GetRuntimeStateResponse(
state=session.state.name,
has_rw_session=session.has_rw_session,
ro_session_count=session.ro_session_count,
waiting_writers=session.waiting_writers,
committed=session.committed,
is_ready=session.is_ready,
allocation_count=self._allocations.allocation_count,
memory_layout_hash=self._memory_layout_hash,
)
def get_event_history(self) -> GetEventHistoryResponse:
return GetEventHistoryResponse(events=list(self._events))
def next_session_id(self) -> str:
return self._sessions.next_session_id()
async def acquire_lock(
self,
mode: RequestedLockType,
timeout_ms: int | None,
session_id: str,
) -> GrantedLockType | None:
return await self._sessions.acquire_lock(mode, timeout_ms, session_id)
async def cancel_connect(
self,
session_id: str,
mode: GrantedLockType | None,
) -> None:
await self._sessions.cancel_connect(session_id, mode)
def _validate_metadata_target(
self,
allocation: AllocationInfo,
offset_bytes: int,
) -> None:
if offset_bytes < 0:
raise ValueError(f"offset_bytes must be >= 0, got {offset_bytes}")
if offset_bytes >= allocation.aligned_size:
raise ValueError(
f"offset_bytes {offset_bytes} out of range for allocation {allocation.allocation_id} "
f"(aligned_size={allocation.aligned_size})"
)
def _drop_metadata_for_allocation(self, allocation_id: str) -> int:
keys_to_remove = [
key
for key, entry in self._metadata.items()
if entry.allocation_id == allocation_id
]
for key in keys_to_remove:
self._metadata.pop(key, None)
return len(keys_to_remove)
def _validate_metadata_integrity(
self,
allocations_by_id: dict[str, AllocationInfo],
) -> None:
for key, entry in self._metadata.items():
info = allocations_by_id.get(entry.allocation_id)
if info is None:
raise AssertionError(
f"Metadata key {key!r} references missing allocation "
f"{entry.allocation_id!r}"
)
if entry.offset_bytes < 0 or entry.offset_bytes >= info.aligned_size:
raise AssertionError(
f"Metadata key {key!r} has invalid offset {entry.offset_bytes} "
f"for allocation {entry.allocation_id!r} "
f"(aligned_size={info.aligned_size})"
)
def _compute_memory_layout_hash(self, allocations: list[AllocationInfo]) -> str:
h = hashlib.sha256()
allocation_slots_by_id: dict[str, int] = {}
for info in sorted(allocations, key=lambda info: info.layout_slot):
allocation_slots_by_id[info.allocation_id] = info.layout_slot
h.update(
f"{info.layout_slot}:{info.size}:{info.aligned_size}:{info.tag}".encode()
)
for key in sorted(self._metadata):
entry = self._metadata[key]
layout_slot = allocation_slots_by_id[entry.allocation_id]
h.update(f"{key}:{layout_slot}:{entry.offset_bytes}:".encode())
h.update(entry.value)
return h.hexdigest()
def _clear_layout_state(self) -> int:
self._metadata.clear()
self._memory_layout_hash = ""
return self._allocations.clear_all()
def on_connect(self, conn: Connection) -> None:
if conn.mode == GrantedLockType.RW:
had_committed_layout = self._sessions.snapshot().committed
cleared = self._clear_layout_state()
if had_committed_layout:
self._events.append(
GMSRuntimeEvent(
kind="allocations_cleared",
allocation_count=cleared,
)
)
self._sessions.on_connect(conn)
if conn.mode == GrantedLockType.RW:
self._events.append(GMSRuntimeEvent(kind="rw_connected"))
async def cleanup_connection(self, conn: Connection | None) -> None:
event = self._sessions.begin_cleanup(conn)
if event == StateEvent.RW_ABORT:
logger.warning("RW aborted; clearing active layout")
cleared = self._clear_layout_state()
self._events.append(GMSRuntimeEvent(kind="rw_aborted"))
self._events.append(
GMSRuntimeEvent(
kind="allocations_cleared",
allocation_count=cleared,
)
)
await self._sessions.finish_cleanup(conn)
async def handle_request(
self,
conn: Connection,
msg,
is_connected: Callable[[], bool],
) -> tuple[object, int, bool]:
msg_type = type(msg)
self._sessions.check_operation(msg_type, conn)
if msg_type is CommitRequest:
if self.state != ServerState.RW:
raise AssertionError("RW state is not active")
allocations = self._allocations.list_allocations()
allocations_by_id = {info.allocation_id: info for info in allocations}
self._validate_metadata_integrity(allocations_by_id)
self._memory_layout_hash = self._compute_memory_layout_hash(allocations)
logger.info(
"Committed layout with state hash: %s...",
self._memory_layout_hash[:16],
)
self._sessions.on_commit(conn)
self._events.append(GMSRuntimeEvent(kind="committed"))
return CommitResponse(success=True), -1, True
if msg_type is AllocateRequest:
if self.state != ServerState.RW:
raise AssertionError("RW state is not active")
info = await self._allocations.allocate(
size=msg.size,
tag=msg.tag,
is_connected=is_connected,
on_oom=lambda: self._events.append(
GMSRuntimeEvent(
kind="allocation_oom",
allocation_count=self._allocations.allocation_count,
)
),
)
return (
AllocateResponse(
allocation_id=info.allocation_id,
size=info.size,
aligned_size=info.aligned_size,
layout_slot=info.layout_slot,
),
-1,
False,
)
if msg_type is GetLockStateRequest:
snapshot = self._sessions.snapshot()
return (
GetLockStateResponse(
state=snapshot.state.name,
has_rw_session=snapshot.has_rw_session,
ro_session_count=snapshot.ro_session_count,
waiting_writers=snapshot.waiting_writers,
committed=snapshot.committed,
is_ready=snapshot.is_ready,
),
-1,
False,
)
if msg_type is GetAllocationStateRequest:
return (
GetAllocationStateResponse(
allocation_count=self._allocations.allocation_count
),
-1,
False,
)
if msg_type is ExportAllocationRequest:
info = self._allocations.get_allocation(msg.allocation_id)
fd = self._allocations.export_allocation(info.allocation_id)
return (
ExportAllocationResponse(
allocation_id=info.allocation_id,
size=info.size,
aligned_size=info.aligned_size,
tag=info.tag,
layout_slot=info.layout_slot,
),
fd,
False,
)
if msg_type is GetStateHashRequest:
return (
GetStateHashResponse(memory_layout_hash=self._memory_layout_hash),
-1,
False,
)
if msg_type is GetAllocationRequest:
info = self._allocations.get_allocation(msg.allocation_id)
return (
GetAllocationResponse(
allocation_id=info.allocation_id,
size=info.size,
aligned_size=info.aligned_size,
tag=info.tag,
layout_slot=info.layout_slot,
),
-1,
False,
)
if msg_type is ListAllocationsRequest:
return (
ListAllocationsResponse(
allocations=[
GetAllocationResponse(
allocation_id=info.allocation_id,
size=info.size,
aligned_size=info.aligned_size,
tag=info.tag,
layout_slot=info.layout_slot,
)
for info in self._allocations.list_allocations(msg.tag)
]
),
-1,
False,
)
if msg_type is FreeAllocationRequest:
success = self._allocations.free_allocation(msg.allocation_id)
if success:
self._drop_metadata_for_allocation(msg.allocation_id)
return (
FreeAllocationResponse(success=success),
-1,
False,
)
if msg_type is MetadataPutRequest:
allocation = self._allocations.get_allocation(msg.allocation_id)
self._validate_metadata_target(allocation, msg.offset_bytes)
self._metadata[msg.key] = MetadataEntry(
allocation_id=allocation.allocation_id,
offset_bytes=msg.offset_bytes,
value=msg.value,
)
return MetadataPutResponse(success=True), -1, False
if msg_type is MetadataGetRequest:
entry = self._metadata.get(msg.key)
if entry is None:
return MetadataGetResponse(found=False), -1, False
return (
MetadataGetResponse(
found=True,
allocation_id=entry.allocation_id,
offset_bytes=entry.offset_bytes,
value=entry.value,
),
-1,
False,
)
if msg_type is MetadataDeleteRequest:
return (
MetadataDeleteResponse(
deleted=self._metadata.pop(msg.key, None) is not None
),
-1,
False,
)
if msg_type is MetadataListRequest:
if not msg.prefix:
keys = sorted(self._metadata)
else:
keys = sorted(
key for key in self._metadata if key.startswith(msg.prefix)
)
return MetadataListResponse(keys=keys), -1, False
raise ValueError(f"Unknown request: {msg_type.__name__}")
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Request handlers for GPU Memory Service."""
import hashlib
import logging
from dataclasses import dataclass
from gpu_memory_service.common.protocol.messages import (
AllocateRequest,
AllocateResponse,
ClearAllResponse,
FreeRequest,
FreeResponse,
GetAllocationRequest,
GetAllocationResponse,
GetAllocationStateResponse,
GetLockStateResponse,
GetStateHashResponse,
ListAllocationsRequest,
ListAllocationsResponse,
MetadataDeleteRequest,
MetadataDeleteResponse,
MetadataGetRequest,
MetadataGetResponse,
MetadataListRequest,
MetadataListResponse,
MetadataPutRequest,
MetadataPutResponse,
)
from gpu_memory_service.common.types import derive_state
from .memory_manager import AllocationNotFoundError, GMSServerMemoryManager
logger = logging.getLogger(__name__)
@dataclass(frozen=True)
class MetadataEntry:
allocation_id: str
offset_bytes: int
value: bytes
class RequestHandler:
"""Handles allocation and metadata requests."""
def __init__(self, device: int = 0):
self._memory_manager = GMSServerMemoryManager(device)
self._metadata: dict[str, MetadataEntry] = {}
self._memory_layout_hash: str = (
"" # Hash of allocations + metadata, computed on commit
)
logger.info(f"RequestHandler initialized: device={device}")
@property
def granularity(self) -> int:
return self._memory_manager.granularity
def on_rw_abort(self) -> None:
"""Called when RW connection closes without commit."""
logger.warning("RW aborted; clearing allocations and metadata")
self._memory_manager.clear_all()
self._metadata.clear()
self._memory_layout_hash = ""
def on_commit(self) -> None:
"""Called when RW connection commits. Computes state hash."""
self._memory_layout_hash = self._compute_memory_layout_hash()
logger.info(f"Committed with state hash: {self._memory_layout_hash[:16]}...")
def _compute_memory_layout_hash(self) -> str:
"""Compute hash of current allocations + metadata."""
h = hashlib.sha256()
# Hash allocations (sorted by ID for determinism)
for info in sorted(
self._memory_manager.list_allocations(), key=lambda x: x.allocation_id
):
h.update(
f"{info.allocation_id}:{info.size}:{info.aligned_size}:{info.tag}".encode()
)
# Hash metadata (sorted by key for determinism)
for key in sorted(self._metadata.keys()):
entry = self._metadata[key]
h.update(f"{key}:{entry.allocation_id}:{entry.offset_bytes}:".encode())
h.update(entry.value)
return h.hexdigest()
def on_shutdown(self) -> None:
"""Called on server shutdown."""
if self._memory_manager.allocation_count > 0:
count = self._memory_manager.clear_all()
self._metadata.clear()
logger.info(f"Released {count} GPU allocations during shutdown")
# ==================== State Queries ====================
def handle_get_lock_state(
self,
has_rw: bool,
ro_count: int,
waiting_writers: int,
committed: bool,
) -> GetLockStateResponse:
"""Get lock/session state."""
state = derive_state(has_rw, ro_count, committed)
return GetLockStateResponse(
state=state.value,
has_rw_session=has_rw,
ro_session_count=ro_count,
waiting_writers=waiting_writers,
committed=committed,
is_ready=committed and not has_rw,
)
def handle_get_allocation_state(self) -> GetAllocationStateResponse:
"""Get allocation state."""
return GetAllocationStateResponse(
allocation_count=self._memory_manager.allocation_count,
total_bytes=self._memory_manager.total_bytes,
)
# ==================== Allocation Operations ====================
def handle_allocate(self, req: AllocateRequest) -> AllocateResponse:
"""Create physical memory allocation.
Requires RW connection (enforced by server).
"""
info = self._memory_manager.allocate(req.size, req.tag)
return AllocateResponse(
allocation_id=info.allocation_id,
size=info.size,
aligned_size=info.aligned_size,
)
def handle_export(self, allocation_id: str) -> tuple[GetAllocationResponse, int]:
"""Export allocation as POSIX FD.
Returns (response, fd). Caller must close fd after sending.
"""
fd = self._memory_manager.export_fd(allocation_id)
info = self._memory_manager.get_allocation(allocation_id)
response = GetAllocationResponse(
allocation_id=info.allocation_id,
size=info.size,
aligned_size=info.aligned_size,
tag=info.tag,
)
return response, fd
def handle_get_allocation(self, req: GetAllocationRequest) -> GetAllocationResponse:
"""Get allocation info."""
try:
info = self._memory_manager.get_allocation(req.allocation_id)
return GetAllocationResponse(
allocation_id=info.allocation_id,
size=info.size,
aligned_size=info.aligned_size,
tag=info.tag,
)
except AllocationNotFoundError:
raise ValueError(f"Unknown allocation: {req.allocation_id}") from None
def handle_list_allocations(
self, req: ListAllocationsRequest
) -> ListAllocationsResponse:
"""List all allocations."""
allocations = self._memory_manager.list_allocations(req.tag)
result = [
{
"allocation_id": info.allocation_id,
"size": info.size,
"aligned_size": info.aligned_size,
"tag": info.tag,
}
for info in allocations
]
return ListAllocationsResponse(allocations=result)
def handle_free(self, req: FreeRequest) -> FreeResponse:
"""Free single allocation.
Requires RW connection (enforced by server).
"""
success = self._memory_manager.free(req.allocation_id)
return FreeResponse(success=success)
def handle_clear_all(self) -> ClearAllResponse:
"""Clear all allocations and metadata.
Requires RW connection (enforced by server).
"""
count = self._memory_manager.clear_all()
self._metadata.clear()
return ClearAllResponse(cleared_count=count)
# ==================== Metadata Operations ====================
def handle_metadata_put(self, req: MetadataPutRequest) -> MetadataPutResponse:
self._metadata[req.key] = MetadataEntry(
req.allocation_id, req.offset_bytes, req.value
)
return MetadataPutResponse(success=True)
def handle_metadata_get(self, req: MetadataGetRequest) -> MetadataGetResponse:
entry = self._metadata.get(req.key)
if entry is None:
return MetadataGetResponse(found=False)
return MetadataGetResponse(
found=True,
allocation_id=entry.allocation_id,
offset_bytes=entry.offset_bytes,
value=entry.value,
)
def handle_metadata_delete(
self, req: MetadataDeleteRequest
) -> MetadataDeleteResponse:
return MetadataDeleteResponse(
deleted=self._metadata.pop(req.key, None) is not None
)
def handle_metadata_list(self, req: MetadataListRequest) -> MetadataListResponse:
keys = (
[k for k in self._metadata if k.startswith(req.prefix)]
if req.prefix
else list(self._metadata)
)
return MetadataListResponse(keys=sorted(keys))
def handle_get_memory_layout_hash(self) -> GetStateHashResponse:
return GetStateHashResponse(memory_layout_hash=self._memory_layout_hash)
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""CUDA VMM allocation manager - pure business logic, no threading/transport.
This module contains the GMSServerMemoryManager class which handles physical GPU memory
allocations via CUDA Virtual Memory Management (VMM) API. It creates shareable
memory without mapping it locally (no CUDA context needed on the server).
The GMSServerMemoryManager is NOT thread-safe. Callers must provide external
synchronization (e.g., via LockManager ensuring single-writer access).
"""
import logging
import time
from dataclasses import dataclass
from typing import Dict, List, Optional
from uuid import uuid4
from cuda.bindings import driver as cuda
from gpu_memory_service.common.cuda_vmm_utils import (
align_to_granularity,
check_cuda_result,
ensure_cuda_initialized,
get_allocation_granularity,
)
logger = logging.getLogger(__name__)
@dataclass
class AllocationInfo:
"""Information about a single GPU memory allocation.
Attributes:
allocation_id: Unique identifier for this allocation
size: Requested size in bytes
aligned_size: Actual size after alignment to VMM granularity
handle: CUmemGenericAllocationHandle value
tag: User-provided tag for grouping allocations
created_at: Timestamp when allocation was created
"""
allocation_id: str
size: int
aligned_size: int
handle: int
tag: str
created_at: float
class AllocationNotFoundError(Exception):
"""Raised when an allocation_id doesn't exist."""
pass
class GMSServerMemoryManager:
"""GPU Memory Service server-side memory manager.
Manages CUDA VMM physical memory allocations. This class handles the core
memory operations using CUDA Virtual Memory Management (VMM) API. It creates
physical allocations that can be exported as POSIX file descriptors for
sharing with other processes.
Key design points:
- No VA mapping: The memory manager never maps memory to virtual addresses,
so it doesn't create a CUDA context. This allows it to survive GPU
driver failures.
- NOT thread-safe: Callers must provide external synchronization.
The GMSLocalFSM's RW/RO semantics ensure single-writer access.
"""
def __init__(self, device: int = 0):
self._device = device
self._allocations: Dict[str, AllocationInfo] = {}
ensure_cuda_initialized()
self._granularity = get_allocation_granularity(device)
logger.info(
f"GMSServerMemoryManager initialized: device={device}, granularity={self._granularity}"
)
@property
def device(self) -> int:
return self._device
@property
def granularity(self) -> int:
return self._granularity
@property
def allocation_count(self) -> int:
return len(self._allocations)
@property
def total_bytes(self) -> int:
return sum(info.aligned_size for info in self._allocations.values())
def _get(self, allocation_id: str) -> AllocationInfo:
info = self._allocations.get(allocation_id)
if info is None:
raise AllocationNotFoundError(f"Unknown allocation: {allocation_id}")
return info
def _release(self, info: AllocationInfo) -> None:
(result,) = cuda.cuMemRelease(info.handle)
if result != cuda.CUresult.CUDA_SUCCESS:
logger.warning(f"cuMemRelease failed for {info.allocation_id}: {result}")
def allocate(self, size: int, tag: str = "default") -> AllocationInfo:
"""Create a physical memory allocation (no VA mapping).
Uses cuMemCreate to allocate physical GPU memory that can be exported
as a file descriptor for sharing with other processes.
Args:
size: Requested size in bytes (will be aligned up to granularity)
tag: Tag for grouping allocations (e.g., "weights", "kv_cache")
Returns:
AllocationInfo with allocation_id, aligned_size, handle
Raises:
RuntimeError: If CUDA allocation fails
"""
aligned_size = align_to_granularity(size, self._granularity)
prop = cuda.CUmemAllocationProp()
prop.type = cuda.CUmemAllocationType.CU_MEM_ALLOCATION_TYPE_PINNED
prop.location.type = cuda.CUmemLocationType.CU_MEM_LOCATION_TYPE_DEVICE
prop.location.id = self._device
prop.requestedHandleTypes = (
cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR
)
result, handle = cuda.cuMemCreate(aligned_size, prop, 0)
check_cuda_result(result, "cuMemCreate")
info = AllocationInfo(
allocation_id=str(uuid4()),
size=size,
aligned_size=aligned_size,
handle=int(handle),
tag=tag,
created_at=time.time(),
)
self._allocations[info.allocation_id] = info
logger.debug(
f"Allocated {info.allocation_id}: size={size}, aligned={aligned_size}, tag={tag}"
)
return info
def export_fd(self, allocation_id: str) -> int:
"""Export allocation as POSIX FD for SCM_RIGHTS transfer.
The returned file descriptor can be sent to another process via
Unix domain socket SCM_RIGHTS. The receiving process can then
import it using cuMemImportFromShareableHandle.
IMPORTANT: The caller MUST close the returned FD after sendmsg()
to avoid file descriptor leaks.
Args:
allocation_id: ID of allocation to export
Returns:
File descriptor (caller owns, must close after sending)
Raises:
AllocationNotFoundError: If allocation_id doesn't exist
RuntimeError: If CUDA export fails
"""
info = self._get(allocation_id)
result, fd = cuda.cuMemExportToShareableHandle(
info.handle,
cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR,
0,
)
check_cuda_result(result, "cuMemExportToShareableHandle")
return int(fd)
def free(self, allocation_id: str) -> bool:
"""Release physical memory for a single allocation.
Args:
allocation_id: ID of allocation to free
Returns:
True if allocation existed and was freed, False otherwise
"""
info = self._allocations.pop(allocation_id, None)
if info is None:
return False
self._release(info)
logger.debug(f"Freed allocation: {allocation_id}")
return True
def clear_all(self) -> int:
"""Release ALL allocations.
Used by loaders before loading a new model, or during cleanup
when a writer aborts without committing.
Returns:
Number of allocations cleared
"""
count = len(self._allocations)
for info in self._allocations.values():
self._release(info)
self._allocations.clear()
logger.info(f"Cleared {count} allocations")
return count
def get_allocation(self, allocation_id: str) -> AllocationInfo:
"""Get allocation info. Raises AllocationNotFoundError if not found."""
return self._get(allocation_id)
def list_allocations(self, tag: Optional[str] = None) -> List[AllocationInfo]:
"""List all allocations, optionally filtered by tag."""
if tag is None:
return list(self._allocations.values())
return [info for info in self._allocations.values() if info.tag == tag]
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Async Allocation RPC Server - Single-threaded event loop with explicit state machine.
State transitions are explicit and validated by the GMSLocalFSM class.
Operations are checked against state/mode permissions before operation.
State Machine (see locking.py for full diagram):
EMPTY: No connections, not committed
RW: Writer connected (exclusive)
COMMITTED: No connections, committed (weights valid)
RO: Reader(s) connected (shared)
"""
"""Async GMS RPC transport server."""
from __future__ import annotations
import asyncio
import logging
import os
from typing import ClassVar, Optional
import select
import socket
from typing import Optional
from gpu_memory_service.common.protocol.messages import (
AllocateRequest,
ClearAllRequest,
CommitRequest,
CommitResponse,
ErrorResponse,
ExportRequest,
FreeRequest,
GetAllocationRequest,
GetAllocationStateRequest,
GetLockStateRequest,
GetStateHashRequest,
GetEventHistoryRequest,
GetRuntimeStateRequest,
HandshakeRequest,
HandshakeResponse,
ListAllocationsRequest,
MetadataDeleteRequest,
MetadataGetRequest,
MetadataListRequest,
MetadataPutRequest,
)
from gpu_memory_service.common.protocol.wire import recv_message, send_message
from gpu_memory_service.common.types import (
GrantedLockType,
RequestedLockType,
ServerState,
StateEvent,
)
from gpu_memory_service.common.utils import fail
from .handler import RequestHandler
from .locking import Connection, GMSLocalFSM
from .allocations import AllocationNotFoundError
from .gms import GMS
from .session import Connection, InvalidTransition, OperationNotAllowed
logger = logging.getLogger(__name__)
class GMSRPCServer:
"""GPU Memory Service RPC Server.
def _is_connection_alive(conn: Connection) -> bool:
if conn.writer.is_closing():
return False
if conn.reader.at_eof() or conn.reader.exception() is not None:
return False
sock = conn.writer.get_extra_info("socket")
if sock is None:
return False
try:
fd = sock.fileno()
except OSError:
return False
if fd < 0:
return False
flags = select.POLLERR | select.POLLHUP | select.POLLNVAL
if hasattr(select, "POLLRDHUP"):
flags |= select.POLLRDHUP
poller = select.poll()
poller.register(fd, flags)
return not poller.poll(0)
Async single-threaded server using GMSLocalFSM for explicit state transitions
and operation validation. All state mutations happen through the state machine's
transition() method.
"""
class GMSRPCServer:
"""Unix-socket transport for the GPU Memory Service."""
def __init__(
self,
socket_path: str,
device: int = 0,
*,
allocation_retry_interval: float = 0.5,
allocation_retry_timeout: Optional[float] = None,
):
self.socket_path = socket_path
self.device = device
# Request handler (business logic)
self._handler = RequestHandler(device)
# State machine - handles all state transitions and permission checks
self._sm = GMSLocalFSM(on_rw_abort=self._handler.on_rw_abort)
self._waiting_writers: int = 0
# Async waiting for lock acquisition
self._condition = asyncio.Condition()
self._shutdown = False
# Session ID generation
self._next_session_id: int = 0
# Server state
self._gms = GMS(
device,
allocation_retry_interval=allocation_retry_interval,
allocation_retry_timeout=allocation_retry_timeout,
)
self._server: Optional[asyncio.Server] = None
self._running: bool = False
logger.info("GMSRPCServer initialized: device=%d", device)
logger.info(f"GMSRPCServer initialized: device={device}")
def _prepare_socket_path(self) -> None:
if not os.path.exists(self.socket_path):
return
# ==================== State Properties ====================
probe = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
try:
probe.connect(self.socket_path)
except OSError:
if os.path.exists(self.socket_path):
os.unlink(self.socket_path)
return
finally:
probe.close()
@property
def state(self) -> ServerState:
"""Current server state (delegated to state machine)."""
return self._sm.state
raise RuntimeError(f"GMS already running at {self.socket_path}")
@property
def granularity(self) -> int:
return self._handler.granularity
def state(self):
return self._gms.state
def is_ready(self) -> bool:
"""Ready = committed and no RW connection."""
return self._sm.committed and self._sm.rw_conn is None
@property
def running(self) -> bool:
"""Whether the server is running."""
return self._running
def _generate_session_id(self) -> str:
self._next_session_id += 1
return f"session_{self._next_session_id}"
# ==================== Connection Lifecycle ====================
return self._gms.is_ready()
async def _handle_connection(
self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter
) -> None:
"""Handle a connection from accept to close."""
session_id = self._generate_session_id()
conn: Optional[Connection] = None
session_id = self._gms.next_session_id()
try:
conn = await self._do_handshake(reader, writer, session_id)
if conn is None:
return
await self._request_loop(conn)
except (InvalidTransition, AssertionError) as exc:
fail("fatal server error", exc_info=exc)
except ConnectionResetError:
logger.debug(f"Connection reset: {session_id}")
logger.debug("Connection reset: %s", session_id)
except asyncio.CancelledError:
raise
except Exception:
logger.exception(f"Connection error: {session_id}")
except Exception as exc:
fail("fatal server error", exc_info=exc)
finally:
await self._cleanup_connection(conn)
try:
await self._gms.cleanup_connection(conn)
except Exception as exc:
fail("fatal server error", exc_info=exc)
async def _do_handshake(
self,
......@@ -143,29 +126,57 @@ class GMSRPCServer:
writer: asyncio.StreamWriter,
session_id: str,
) -> Optional[Connection]:
"""Perform handshake and acquire lock via state machine transition."""
try:
# Server never receives FDs from clients, so no need for raw_sock
msg, _, recv_buffer = await recv_message(reader, bytearray())
except Exception:
logger.exception("Handshake recv error")
return None
if isinstance(msg, GetRuntimeStateRequest):
try:
await send_message(writer, self._gms.get_runtime_state())
except Exception as exc:
logger.debug("Runtime-state response send failed: %s", exc)
finally:
writer.close()
return None
if isinstance(msg, GetEventHistoryRequest):
try:
await send_message(writer, self._gms.get_event_history())
except Exception as exc:
logger.debug("Event-history response send failed: %s", exc)
finally:
writer.close()
return None
if not isinstance(msg, HandshakeRequest):
await send_message(writer, ErrorResponse(error="Expected HandshakeRequest"))
try:
await send_message(
writer, ErrorResponse(error="Expected HandshakeRequest")
)
except Exception:
pass
writer.close()
return None
# Acquire lock (blocks until available or timeout)
# Returns the actual granted mode (may differ from requested for rw_or_ro)
granted_mode = await self._acquire_lock(msg.lock_type, msg.timeout_ms)
granted_mode = await self._gms.acquire_lock(
msg.lock_type,
msg.timeout_ms,
session_id,
)
if granted_mode is None:
try:
await send_message(
writer, HandshakeResponse(success=False, committed=self._sm.committed)
writer,
HandshakeResponse(success=False, committed=self._gms.committed),
)
except Exception:
pass
writer.close()
return None
try:
conn = Connection(
reader=reader,
writer=writer,
......@@ -173,128 +184,35 @@ class GMSRPCServer:
session_id=session_id,
recv_buffer=recv_buffer,
)
self._gms.on_connect(conn)
except Exception:
await self._gms.cancel_connect(session_id, granted_mode)
raise
# State transition: connect
event = (
StateEvent.RW_CONNECT
if granted_mode == GrantedLockType.RW
else StateEvent.RO_CONNECT
)
self._sm.transition(event, conn)
try:
await send_message(
writer,
HandshakeResponse(
success=True,
committed=self._sm.committed,
committed=self._gms.committed,
granted_lock_type=granted_mode,
),
)
return conn
async def _acquire_lock(
self,
mode: RequestedLockType,
timeout_ms: Optional[int],
) -> Optional[GrantedLockType]:
"""Wait until lock can be acquired (uses state machine predicates).
Returns the granted lock type, or None if failed/timeout.
For rw_or_ro mode, returns RW if available immediately, else waits for RO.
"""
timeout = timeout_ms / 1000 if timeout_ms is not None else None
if mode == RequestedLockType.RW:
self._waiting_writers += 1
try:
async with self._condition:
try:
await asyncio.wait_for(
self._condition.wait_for(
lambda: self._shutdown or self._sm.can_acquire_rw()
),
timeout=timeout,
)
return None if self._shutdown else GrantedLockType.RW
except asyncio.TimeoutError:
return None
finally:
self._waiting_writers -= 1
elif mode == RequestedLockType.RO:
async with self._condition:
try:
await asyncio.wait_for(
self._condition.wait_for(
lambda: self._shutdown
or self._sm.can_acquire_ro(self._waiting_writers)
),
timeout=timeout,
)
return None if self._shutdown else GrantedLockType.RO
except asyncio.TimeoutError:
return None
elif mode == RequestedLockType.RW_OR_RO:
# Auto mode: try RW if available immediately AND no committed weights,
# otherwise wait for RO (to import existing weights)
async with self._condition:
# Check if RW is available AND no committed weights exist
# If weights are already committed, prefer RO to import them
if self._sm.can_acquire_rw() and not self._sm.committed:
return GrantedLockType.RW
# Either RW not available OR weights already committed - wait for RO
if self._sm.committed:
logger.info(
"RW_OR_RO: Weights already committed, preferring RO to import"
except Exception as exc:
logger.warning(
"Handshake failed after acquiring %s for session %s: %s",
granted_mode.value,
session_id,
exc,
)
else:
logger.info(
"RW_OR_RO: RW not available (another writer active), "
"falling back to RO"
)
try:
await asyncio.wait_for(
self._condition.wait_for(
lambda: self._shutdown
or self._sm.can_acquire_ro(self._waiting_writers)
),
timeout=timeout,
)
return None if self._shutdown else GrantedLockType.RO
except asyncio.TimeoutError:
await self._gms.cleanup_connection(conn)
return None
return None
async def _cleanup_connection(self, conn: Optional[Connection]) -> None:
"""Clean up after connection closes via state machine transition."""
if conn is None:
return
# State transition: disconnect
if conn.mode == GrantedLockType.RW:
if self._sm.rw_conn is conn and not self._sm.committed:
# RW abort - state machine callback handles cleanup
self._sm.transition(StateEvent.RW_ABORT, conn)
elif self._sm.rw_conn is conn:
# Already committed, no transition needed (commit already did it)
pass
else:
if conn in self._sm.ro_conns:
self._sm.transition(StateEvent.RO_DISCONNECT, conn)
await conn.close()
async with self._condition:
self._condition.notify_all()
# ==================== Request Handling ====================
return conn
async def _request_loop(self, conn: Connection) -> None:
"""Process requests until close or commit."""
while self._running:
while True:
try:
# Server never receives FDs from clients, so no need for raw_socket
msg, _, conn.recv_buffer = await recv_message(
conn.reader, conn.recv_buffer
)
......@@ -302,136 +220,78 @@ class GMSRPCServer:
return
except asyncio.CancelledError:
raise
except Exception:
logger.exception("Recv error")
except Exception as exc:
logger.warning("Recv error on session %s: %s", conn.session_id, exc)
return
if msg is None:
continue
fd = -1
try:
response, fd, should_close = await self._dispatch(conn, msg)
if response is not None:
response, fd, should_close = await self._gms.handle_request(
conn,
msg,
lambda: _is_connection_alive(conn),
)
except ConnectionAbortedError as exc:
logger.warning(
"Connection lost during %s on session %s: %s",
type(msg).__name__,
conn.session_id,
exc,
)
return
except (
OperationNotAllowed,
ValueError,
TimeoutError,
AllocationNotFoundError,
) as exc:
logger.warning(
"Rejected %s from session %s: %s",
type(msg).__name__,
conn.session_id,
exc,
)
try:
await send_message(conn.writer, ErrorResponse(error=str(exc)))
except Exception as send_exc:
logger.warning(
"Failed to send ErrorResponse for %s on session %s: %s",
type(msg).__name__,
conn.session_id,
send_exc,
)
return
continue
except (InvalidTransition, AssertionError) as exc:
fail("fatal server error", exc_info=exc)
except Exception as exc:
fail("fatal server error", exc_info=exc)
try:
await send_message(conn.writer, response, fd)
except Exception as exc:
logger.warning(
"Response send failed for %s on session %s: %s",
type(msg).__name__,
conn.session_id,
exc,
)
return
finally:
if fd >= 0:
os.close(fd)
if should_close:
return
except Exception as e:
logger.exception("Request error")
await send_message(conn.writer, ErrorResponse(error=str(e)))
# Dispatch table: message type -> handler method name
# Handlers take (msg) and return response. Special cases handled separately.
_HANDLERS: ClassVar[dict[type, str]] = {
AllocateRequest: "handle_allocate",
GetAllocationRequest: "handle_get_allocation",
ListAllocationsRequest: "handle_list_allocations",
FreeRequest: "handle_free",
MetadataPutRequest: "handle_metadata_put",
MetadataGetRequest: "handle_metadata_get",
MetadataDeleteRequest: "handle_metadata_delete",
MetadataListRequest: "handle_metadata_list",
}
async def _dispatch(self, conn: Connection, msg) -> tuple[object, int, bool]:
"""Dispatch request to handler. Returns (response, fd, should_close)."""
msg_type = type(msg)
self._sm.check_operation(msg_type, conn)
# Special cases
if msg_type is CommitRequest:
return await self._handle_commit(conn)
if msg_type is GetLockStateRequest:
return (
self._handler.handle_get_lock_state(
self._sm.rw_conn is not None,
self._sm.ro_count,
self._waiting_writers,
self._sm.committed,
),
-1,
False,
)
if msg_type is GetAllocationStateRequest:
return self._handler.handle_get_allocation_state(), -1, False
if msg_type is ExportRequest:
response, fd = self._handler.handle_export(msg.allocation_id)
return response, fd, False
if msg_type is ClearAllRequest:
return self._handler.handle_clear_all(), -1, False
if msg_type is GetStateHashRequest:
return self._handler.handle_get_memory_layout_hash(), -1, False
# Standard dispatch: handler takes msg, returns response
handler_name = self._HANDLERS.get(msg_type)
if handler_name:
handler = getattr(self._handler, handler_name)
return handler(msg), -1, False
raise ValueError(f"Unknown request: {msg_type.__name__}")
async def _handle_commit(self, conn: Connection) -> tuple[object, int, bool]:
"""Handle commit via state machine transition - atomic with disconnect."""
self._handler.on_commit()
self._sm.transition(StateEvent.RW_COMMIT, conn)
await send_message(conn.writer, CommitResponse(success=True))
await conn.close()
async with self._condition:
self._condition.notify_all()
return None, -1, True
# ==================== Server Lifecycle ====================
async def start(self) -> None:
if os.path.exists(self.socket_path):
os.unlink(self.socket_path)
async def serve(self) -> None:
self._prepare_socket_path()
self._server = await asyncio.start_unix_server(
self._handle_connection, path=self.socket_path
self._handle_connection,
path=self.socket_path,
)
self._running = True
logger.info(f"Server started: {self.socket_path}")
async def stop(self) -> None:
self._running = False
self._shutdown = True
async with self._condition:
self._condition.notify_all()
if self._server:
self._server.close()
await self._server.wait_closed()
self._server = None
# Close connections (bypassing state machine - this is shutdown)
if self._sm.rw_conn:
await self._sm.rw_conn.close()
for conn in list(self._sm.ro_conns):
await conn.close()
self._handler.on_shutdown()
if os.path.exists(self.socket_path):
os.unlink(self.socket_path)
logger.info("Server stopped")
async def serve_forever(self) -> None:
await self.start()
try:
while self._running:
await asyncio.sleep(1)
finally:
await self.stop()
logger.info("Server started: %s", self.socket_path)
await self._server.serve_forever()
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Connection and state machine for GPU Memory Service.
This module handles:
- Connection: Represents an active client connection
- GMSLocalFSM: Explicit state transitions with validated permissions
State Diagram:
EMPTY ──RW_CONNECT──► RW ──RW_COMMIT──► COMMITTED
▲ │ │
│ │ │
└───RW_ABORT─────────┘ │
COMMITTED ◄──RO_DISCONNECT (last)── RO ◄──RO_CONNECT
│ ▲
│ │
└──RO_CONNECT──────┘
└──RO_DISCONNECT───┘ (not last)
"""
"""Server-side connection, FSM, and waiter state."""
from __future__ import annotations
import asyncio
import logging
import socket
from dataclasses import dataclass, field
from typing import Callable, Optional, Set
from typing import Optional, Set
from gpu_memory_service.common.types import (
RO_ALLOWED,
RW_ALLOWED,
RW_REQUIRED,
GrantedLockType,
RequestedLockType,
ServerState,
StateEvent,
)
logger = logging.getLogger(__name__)
# =============================================================================
# Connection
# =============================================================================
@dataclass(eq=False)
class Connection:
"""Represents an active connection.
The existence of Connection objects IS the state - we don't track
sessions separately. When a Connection is removed, the lock is released.
Note: eq=False disables auto-generated __eq__ so we can use default
object identity for equality and add __hash__ for use in sets.
"""
reader: asyncio.StreamReader
writer: asyncio.StreamWriter
mode: GrantedLockType
......@@ -64,16 +29,9 @@ class Connection:
recv_buffer: bytearray = field(default_factory=bytearray)
def __hash__(self) -> int:
"""Hash based on session_id (immutable identifier)."""
return hash(self.session_id)
@property
def raw_socket(self) -> socket.socket:
"""Get underlying socket for FD passing."""
return self.writer.get_extra_info("socket")
async def close(self) -> None:
"""Close the connection."""
self.writer.close()
try:
await self.writer.wait_closed()
......@@ -81,80 +39,49 @@ class Connection:
pass
# =============================================================================
# State Machine
# =============================================================================
class InvalidTransition(Exception):
"""Raised when an invalid state transition is attempted."""
pass
class OperationNotAllowed(Exception):
"""Raised when an operation is not allowed in the current state/mode."""
pass
@dataclass(frozen=True)
class Transition:
"""A valid state transition.
Attributes:
from_states: Set of states this transition can originate from
event: The event that triggers this transition
to_state: The resulting state (or None if conditional)
condition: Optional condition function for conditional transitions
"""
from_states: frozenset[ServerState]
event: StateEvent
to_state: Optional[ServerState]
condition: Optional[str] = None # Name of condition method
condition: Optional[str] = None
# Transition table - the single source of truth for valid state transitions
TRANSITIONS: list[Transition] = [
# From EMPTY or COMMITTED: RW can connect
# Writer acquires exclusive lock
Transition(
from_states=frozenset({ServerState.EMPTY, ServerState.COMMITTED}),
event=StateEvent.RW_CONNECT,
to_state=ServerState.RW,
),
# From RW: commit publishes and transitions to COMMITTED
# Writer publishes and releases lock
Transition(
from_states=frozenset({ServerState.RW}),
event=StateEvent.RW_COMMIT,
to_state=ServerState.COMMITTED,
),
# From RW: abort (disconnect without commit) transitions to EMPTY
# Writer aborts, state invalidated
Transition(
from_states=frozenset({ServerState.RW}),
event=StateEvent.RW_ABORT,
to_state=ServerState.EMPTY,
),
# From COMMITTED or RO: RO can connect
# Reader acquires shared lock
Transition(
from_states=frozenset({ServerState.COMMITTED, ServerState.RO}),
event=StateEvent.RO_CONNECT,
to_state=ServerState.RO,
),
# From RO: reader disconnect (not last) stays in RO
# Reader leaves, others remain
Transition(
from_states=frozenset({ServerState.RO}),
event=StateEvent.RO_DISCONNECT,
to_state=ServerState.RO,
condition="has_remaining_readers",
),
# From RO: last reader disconnect transitions to COMMITTED
# Last reader leaves
Transition(
from_states=frozenset({ServerState.RO}),
event=StateEvent.RO_DISCONNECT,
......@@ -164,52 +91,19 @@ TRANSITIONS: list[Transition] = [
]
@dataclass
class TransitionRecord:
"""Record of a state transition for debugging/auditing."""
from_state: ServerState
event: StateEvent
to_state: ServerState
session_id: Optional[str] = None
class GMSLocalFSM:
"""Explicit state machine for GPU Memory Service.
State is DERIVED from actual connection objects:
- _rw_conn: The active RW connection (or None)
- _ro_conns: Set of active RO connections
- _committed: Whether allocations have been committed
All state mutations happen through explicit transitions.
"""
"""Explicit connection/lock state machine."""
def __init__(self, on_rw_abort: Optional[Callable[[], None]] = None):
"""Initialize the state machine.
Args:
on_rw_abort: Callback invoked when RW aborts (for cleanup)
"""
# Connection state - THIS IS THE SOURCE OF TRUTH
def __init__(self):
self._rw_conn: Optional[Connection] = None
self._ro_conns: Set[Connection] = set()
self._committed: bool = False
# Callback for RW abort cleanup
self._on_rw_abort = on_rw_abort
# Transition history for debugging
self._transition_log: list[TransitionRecord] = []
# ==================== State Properties ====================
self._committed = False
@property
def state(self) -> ServerState:
"""Derive current state from connection objects."""
if self._rw_conn is not None:
return ServerState.RW
if len(self._ro_conns) > 0:
if self._ro_conns:
return ServerState.RO
if self._committed:
return ServerState.COMMITTED
......@@ -217,41 +111,27 @@ class GMSLocalFSM:
@property
def rw_conn(self) -> Optional[Connection]:
"""The active RW connection, if any."""
return self._rw_conn
@property
def ro_conns(self) -> Set[Connection]:
"""Set of active RO connections."""
return self._ro_conns
@property
def ro_count(self) -> int:
"""Number of active RO connections."""
return len(self._ro_conns)
@property
def committed(self) -> bool:
"""Whether allocations have been committed."""
return self._committed
@property
def transition_log(self) -> list[TransitionRecord]:
"""History of state transitions."""
return self._transition_log
# ==================== Transition Conditions ====================
def _has_remaining_readers(self, conn: Connection) -> bool:
"""Check if there are readers remaining after removing conn."""
return len(self._ro_conns) > 1 or conn not in self._ro_conns
def _is_last_reader(self, conn: Connection) -> bool:
"""Check if conn is the last reader."""
return len(self._ro_conns) == 1 and conn in self._ro_conns
def _check_condition(self, condition: Optional[str], conn: Connection) -> bool:
"""Evaluate a named condition."""
if condition is None:
return True
if condition == "has_remaining_readers":
......@@ -260,143 +140,220 @@ class GMSLocalFSM:
return self._is_last_reader(conn)
raise ValueError(f"Unknown condition: {condition}")
# ==================== State Transitions ====================
def _find_transition(
self, from_state: ServerState, event: StateEvent, conn: Connection
self,
from_state: ServerState,
event: StateEvent,
conn: Connection,
) -> Optional[Transition]:
"""Find the applicable transition for the given event."""
for t in TRANSITIONS:
if from_state not in t.from_states:
for transition in TRANSITIONS:
if from_state not in transition.from_states:
continue
if t.event != event:
if transition.event != event:
continue
if not self._check_condition(t.condition, conn):
if not self._check_condition(transition.condition, conn):
continue
return t
return transition
return None
def _apply_event(self, event: StateEvent, conn: Connection) -> None:
"""Mutate internal state based on event."""
match event:
case StateEvent.RW_CONNECT:
if event == StateEvent.RW_CONNECT:
self._rw_conn = conn
self._committed = False # Invalidate on RW connect
case StateEvent.RW_COMMIT:
self._committed = False
elif event == StateEvent.RW_COMMIT:
self._committed = True
self._rw_conn = None
case StateEvent.RW_ABORT:
elif event == StateEvent.RW_ABORT:
self._rw_conn = None
if self._on_rw_abort:
self._on_rw_abort()
case StateEvent.RO_CONNECT:
elif event == StateEvent.RO_CONNECT:
self._ro_conns.add(conn)
case StateEvent.RO_DISCONNECT:
elif event == StateEvent.RO_DISCONNECT:
self._ro_conns.discard(conn)
def transition(self, event: StateEvent, conn: Connection) -> ServerState:
"""Execute a state transition.
transition = self._find_transition(self.state, event, conn)
if transition is None:
raise InvalidTransition(
f"No transition for {event.name} from state {self.state.name} "
f"(session={conn.session_id})"
)
self._apply_event(event, conn)
return self.state
Args:
event: The triggering event
conn: The connection involved in the transition
def check_operation(self, msg_type: type, conn: Connection) -> None:
if conn.mode == GrantedLockType.RW and msg_type not in RW_ALLOWED:
raise OperationNotAllowed(
f"{msg_type.__name__} not allowed for RW session in state {self.state.name}"
)
if conn.mode == GrantedLockType.RO and msg_type not in RO_ALLOWED:
raise OperationNotAllowed(
f"{msg_type.__name__} not allowed for RO session in state {self.state.name}"
)
if msg_type in RW_REQUIRED and conn.mode != GrantedLockType.RW:
raise OperationNotAllowed(
f"{msg_type.__name__} requires RW session, got {conn.mode.value}"
)
Returns:
The new state after the transition
def can_acquire_rw(self) -> bool:
return self._rw_conn is None and not self._ro_conns
Raises:
InvalidTransition: If the transition is not valid from current state
"""
from_state = self.state
session_id = conn.session_id if conn else None
def can_acquire_ro(self, waiting_writers: int) -> bool:
return self._committed and self._rw_conn is None and waiting_writers == 0
# Find valid transition
trans = self._find_transition(from_state, event, conn)
if trans is None:
raise InvalidTransition(
f"No transition for {event.name} from state {from_state.name} "
f"(session={session_id})"
)
# Apply the transition
self._apply_event(event, conn)
to_state = self.state
@dataclass(frozen=True)
class SessionSnapshot:
state: ServerState
has_rw_session: bool
ro_session_count: int
waiting_writers: int
committed: bool
is_ready: bool
# Validate we ended up in expected state
if trans.to_state is not None and to_state != trans.to_state:
raise InvalidTransition(
f"Transition mismatch: expected {trans.to_state.name}, "
f"got {to_state.name}"
)
# Record transition
record = TransitionRecord(
from_state,
event,
to_state,
session_id=session_id,
class GMSSessionManager:
"""Owns lock transitions, waiter coordination, and cleanup."""
def __init__(self):
self._locking = GMSLocalFSM()
self._waiting_writers = 0
self._reserved_rw_session_id: Optional[str] = None
self._condition = asyncio.Condition()
self._next_session_id = 0
@property
def state(self) -> ServerState:
return self._locking.state
def next_session_id(self) -> str:
self._next_session_id += 1
return f"session_{self._next_session_id}"
def snapshot(self) -> SessionSnapshot:
has_rw_session = self._locking.rw_conn is not None
return SessionSnapshot(
state=self._locking.state,
has_rw_session=has_rw_session,
ro_session_count=self._locking.ro_count,
waiting_writers=self._waiting_writers,
committed=self._locking.committed,
is_ready=self._locking.committed and not has_rw_session,
)
self._transition_log.append(record)
logger.info(
f"State transition: {from_state.name} --{event.name}--> {to_state.name} "
f"(session={session_id})"
def _can_grant_rw(self) -> bool:
return self._reserved_rw_session_id is None and self._locking.can_acquire_rw()
def _can_grant_ro(self) -> bool:
return self._reserved_rw_session_id is None and self._locking.can_acquire_ro(
self._waiting_writers
)
return to_state
def _can_grant_rw_or_ro(self) -> bool:
if self._can_grant_ro():
return True
return self._can_grant_rw() and not self._locking.committed
# ==================== Operation Permissions ====================
async def acquire_lock(
self,
mode: RequestedLockType,
timeout_ms: Optional[int],
session_id: str,
) -> Optional[GrantedLockType]:
timeout = timeout_ms / 1000 if timeout_ms is not None else None
def check_operation(self, msg_type: type, conn: Connection) -> None:
"""Check if a request type is allowed for the given connection.
Args:
msg_type: The request message type (e.g., AllocateRequest)
conn: The connection attempting the operation
Raises:
OperationNotAllowed: If the operation is not permitted
"""
current_state = self.state
# Determine allowed operations based on state
if current_state == ServerState.RW:
allowed = RW_ALLOWED
elif current_state == ServerState.RO:
allowed = RO_ALLOWED
else:
allowed = frozenset() # EMPTY and COMMITTED have no connections
if msg_type not in allowed:
raise OperationNotAllowed(
f"{msg_type.__name__} not allowed in state {current_state.name}"
if mode == RequestedLockType.RW:
try:
async with self._condition:
self._waiting_writers += 1
try:
await asyncio.wait_for(
self._condition.wait_for(self._can_grant_rw),
timeout=timeout,
)
# Check connection mode
if msg_type in RW_REQUIRED and conn.mode != GrantedLockType.RW:
raise OperationNotAllowed(
f"{msg_type.__name__} requires RW connection, "
f"but connection is {conn.mode.value}"
except asyncio.TimeoutError:
return None
self._reserved_rw_session_id = session_id
return GrantedLockType.RW
finally:
async with self._condition:
self._waiting_writers -= 1
self._condition.notify_all()
if mode == RequestedLockType.RO:
async with self._condition:
try:
await asyncio.wait_for(
self._condition.wait_for(self._can_grant_ro),
timeout=timeout,
)
except asyncio.TimeoutError:
return None
return GrantedLockType.RO
# ==================== Lock Acquisition Predicates ====================
def can_acquire_rw(self) -> bool:
"""Check if RW lock can be acquired now.
async with self._condition:
if self._can_grant_rw() and not self._locking.committed:
self._reserved_rw_session_id = session_id
return GrantedLockType.RW
try:
await asyncio.wait_for(
self._condition.wait_for(self._can_grant_rw_or_ro),
timeout=timeout,
)
except asyncio.TimeoutError:
return None
if self._can_grant_rw() and not self._locking.committed:
self._reserved_rw_session_id = session_id
return GrantedLockType.RW
return GrantedLockType.RO
async def cancel_connect(
self,
session_id: str,
mode: Optional[GrantedLockType],
) -> None:
if mode != GrantedLockType.RW:
return
async with self._condition:
if self._reserved_rw_session_id == session_id:
self._reserved_rw_session_id = None
self._condition.notify_all()
def on_connect(self, conn: Connection) -> None:
if conn.mode == GrantedLockType.RW:
if self._reserved_rw_session_id != conn.session_id:
raise AssertionError(
f"RW session {conn.session_id} was not reserved before connect"
)
self._reserved_rw_session_id = None
event = (
StateEvent.RW_CONNECT
if conn.mode == GrantedLockType.RW
else StateEvent.RO_CONNECT
)
self._locking.transition(event, conn)
RW can only be acquired if:
- No current RW holder
- No RO holders
def on_commit(self, conn: Connection) -> None:
self._locking.transition(StateEvent.RW_COMMIT, conn)
Note: This allows RW from COMMITTED state (for explicit reload).
For rw_or_ro mode, callers should also check `committed` to prefer RO.
"""
return self._rw_conn is None and len(self._ro_conns) == 0
def check_operation(self, msg_type: type, conn: Connection) -> None:
self._locking.check_operation(msg_type, conn)
def can_acquire_ro(self, waiting_writers: int) -> bool:
"""Check if RO lock can be acquired now.
def begin_cleanup(self, conn: Optional[Connection]) -> StateEvent | None:
if conn is None:
return None
Args:
waiting_writers: Number of writers waiting for the lock
"""
return self._rw_conn is None and waiting_writers == 0 and self._committed
event = None
if conn.mode == GrantedLockType.RW:
if self._locking.rw_conn is conn and not self._locking.committed:
self._locking.transition(StateEvent.RW_ABORT, conn)
event = StateEvent.RW_ABORT
elif conn in self._locking.ro_conns:
self._locking.transition(StateEvent.RO_DISCONNECT, conn)
event = StateEvent.RO_DISCONNECT
return event
async def finish_cleanup(self, conn: Optional[Connection]) -> None:
if conn is not None:
await conn.close()
async with self._condition:
self._condition.notify_all()
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""GPU Memory Service Shadow Engine Failover Test for SGLang."""
import pytest
from tests.utils.constants import FAULT_TOLERANCE_MODEL_NAME
from .utils.common import run_shadow_failover_test
from .utils.sglang import SGLangWithGMSProcess
@pytest.mark.sglang
@pytest.mark.e2e
@pytest.mark.gpu_1
@pytest.mark.fault_tolerance
@pytest.mark.nightly
@pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME)
@pytest.mark.timeout(600)
def test_gms_shadow_engine_failover(
request, runtime_services, gms_ports, predownload_models
):
ports = gms_ports
run_shadow_failover_test(
request,
ports,
make_shadow=lambda: SGLangWithGMSProcess(
request,
"shadow",
ports["shadow_system"],
ports["shadow_sglang"],
ports["frontend"],
),
make_primary=lambda: SGLangWithGMSProcess(
request,
"primary",
ports["primary_system"],
ports["primary_sglang"],
ports["frontend"],
),
)
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""GPU Memory Service Shadow Engine Failover Test for vLLM."""
import pytest
from tests.utils.constants import FAULT_TOLERANCE_MODEL_NAME
from .utils.common import run_shadow_failover_test
from .utils.vllm import VLLMWithGMSProcess
@pytest.mark.vllm
@pytest.mark.e2e
@pytest.mark.gpu_1
@pytest.mark.fault_tolerance
@pytest.mark.nightly
@pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME)
@pytest.mark.timeout(600)
def test_gms_shadow_engine_failover(
request, runtime_services, gms_ports, predownload_models
):
ports = gms_ports
run_shadow_failover_test(
request,
ports,
make_shadow=lambda: VLLMWithGMSProcess(
request,
"shadow",
ports["shadow_system"],
ports["shadow_kv_event"],
ports["shadow_nixl"],
ports["frontend"],
),
make_primary=lambda: VLLMWithGMSProcess(
request,
"primary",
ports["primary_system"],
ports["primary_kv_event"],
ports["primary_nixl"],
ports["frontend"],
),
)
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""
GPU Memory Service Basic Sleep/Wake Test for SGLang.
Tests the basic sleep/wake cycle of a single SGLang engine using the GPU Memory
Service for VA-stable weight offloading.
"""
import logging
import pytest
from tests.utils.constants import FAULT_TOLERANCE_MODEL_NAME
from tests.utils.managed_process import DynamoFrontendProcess
from .utils.common import GMSServerProcess, get_gpu_memory_used, send_completion
from .utils.sglang import SGLangWithGMSProcess
logger = logging.getLogger(__name__)
@pytest.mark.sglang
@pytest.mark.e2e
@pytest.mark.gpu_1
@pytest.mark.fault_tolerance
@pytest.mark.nightly
@pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME)
@pytest.mark.timeout(300)
def test_gms_basic_sleep_wake(request, runtime_services, gms_ports, predownload_models):
"""Test basic sleep/wake with GPU Memory Service.
1. Start GMS server and SGLang engine with GMS integration
2. Run initial inference to verify engine works
3. Put engine to sleep and verify GPU memory is freed
4. Wake engine and verify inference still works
"""
ports = gms_ports
with GMSServerProcess(request, device=0):
with DynamoFrontendProcess(request, frontend_port=ports["frontend"]):
with SGLangWithGMSProcess(
request,
"engine",
ports["shadow_system"],
ports["shadow_sglang"],
ports["frontend"],
) as engine:
# Initial inference
result = send_completion(ports["frontend"])
logger.info(f"Initial inference result: {result}")
assert result["choices"]
mem_before = get_gpu_memory_used()
logger.info(f"Memory before sleep: {mem_before / (1 << 20):.0f} MB")
# Sleep (release memory occupation)
sleep_result = engine.sleep()
assert sleep_result["status"] == "ok"
mem_after_sleep = get_gpu_memory_used()
logger.info(f"Memory after sleep: {mem_after_sleep / (1 << 20):.0f} MB")
assert mem_after_sleep < mem_before, "Sleep should reduce memory"
# Wake (resume memory occupation)
wake_result = engine.wake()
assert wake_result["status"] == "ok"
# Inference after wake
result = send_completion(ports["frontend"], "Goodbye")
logger.info(f"Post-wake inference result: {result}")
assert result["choices"]
logger.info(
f"Memory freed: {(mem_before - mem_after_sleep) / (1 << 20):.0f} MB"
)
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""
GPU Memory Service Basic Sleep/Wake Test for vLLM.
Tests the basic sleep/wake cycle of a single vLLM engine using the GPU Memory
Service for VA-stable weight offloading.
"""
import logging
import pytest
from tests.utils.constants import FAULT_TOLERANCE_MODEL_NAME
from tests.utils.managed_process import DynamoFrontendProcess
from .utils.common import GMSServerProcess, get_gpu_memory_used, send_completion
from .utils.vllm import VLLMWithGMSProcess
logger = logging.getLogger(__name__)
@pytest.mark.vllm
@pytest.mark.e2e
@pytest.mark.gpu_1
@pytest.mark.fault_tolerance
@pytest.mark.nightly
@pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME)
@pytest.mark.timeout(300)
def test_gms_basic_sleep_wake(request, runtime_services, gms_ports, predownload_models):
"""Test basic sleep/wake with GPU Memory Service.
1. Start GMS server and vLLM engine with GMS integration
2. Run initial inference to verify engine works
3. Put engine to sleep and verify GPU memory is freed
4. Wake engine and verify inference still works
"""
ports = gms_ports
with GMSServerProcess(request, device=0):
with DynamoFrontendProcess(request, frontend_port=ports["frontend"]):
with VLLMWithGMSProcess(
request,
"engine",
ports["shadow_system"],
ports["shadow_kv_event"],
ports["shadow_nixl"],
ports["frontend"],
) as engine:
# Initial inference
result = send_completion(ports["frontend"])
logger.info(f"Initial inference result: {result}")
assert result["choices"]
mem_before = get_gpu_memory_used()
logger.info(f"Memory before sleep: {mem_before / (1 << 20):.0f} MB")
# Sleep
sleep_result = engine.sleep()
assert sleep_result["status"] == "ok"
mem_after_sleep = get_gpu_memory_used()
logger.info(f"Memory after sleep: {mem_after_sleep / (1 << 20):.0f} MB")
assert mem_after_sleep < mem_before, "Sleep should reduce memory"
# Wake
wake_result = engine.wake()
assert wake_result["status"] == "ok"
# Inference after wake
result = send_completion(ports["frontend"], "Goodbye")
logger.info(f"Post-wake inference result: {result}")
assert result["choices"]
logger.info(
f"Memory freed: {(mem_before - mem_after_sleep) / (1 << 20):.0f} MB"
)
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""
Shared utilities for GPU Memory Service tests.
This module provides process managers and helper functions that are
backend-agnostic and can be used by vLLM, SGLang, or other backends.
"""
import logging
import os
import shutil
import signal
import time
from typing import Callable
import pynvml
import requests
from gpu_memory_service.common.utils import get_socket_path
from tests.utils.constants import FAULT_TOLERANCE_MODEL_NAME
from tests.utils.managed_process import DynamoFrontendProcess, ManagedProcess
logger = logging.getLogger(__name__)
def get_gpu_memory_used(device: int = 0) -> int:
"""Get GPU memory usage in bytes for the specified device."""
pynvml.nvmlInit()
try:
handle = pynvml.nvmlDeviceGetHandleByIndex(device)
return pynvml.nvmlDeviceGetMemoryInfo(handle).used
finally:
pynvml.nvmlShutdown()
def kill_force(
process: ManagedProcess,
timeout_s: float = 30.0,
poll_interval_s: float = 0.5,
) -> None:
"""SIGKILL a process group and wait for GPU memory reclamation.
Snapshots GPU memory before the kill, sends SIGKILL to the entire
process group, reaps the zombie, then polls pynvml until the CUDA
driver finishes asynchronous cleanup (memory drops below the
pre-kill snapshot).
"""
mem_before = get_gpu_memory_used()
pid = process.get_pid()
if pid is None:
logger.warning("kill_force: no PID available")
return
try:
pgid = os.getpgid(pid)
logger.info(f"kill_force: sending SIGKILL to process group {pgid} (pid={pid})")
os.killpg(pgid, signal.SIGKILL)
except ProcessLookupError:
logger.warning(f"kill_force: process {pid} already dead")
return
# Reap the process to avoid zombies
try:
os.waitpid(pid, 0)
except ChildProcessError:
pass
# Wait for CUDA driver to asynchronously reclaim GPU memory
start = time.time()
mem_after = mem_before
while time.time() - start < timeout_s:
mem_after = get_gpu_memory_used()
if mem_after < mem_before:
break
time.sleep(poll_interval_s)
freed_mb = (mem_before - mem_after) / (1 << 20)
logger.info(
f"kill_force: before={mem_before / (1 << 30):.2f} GiB, "
f"after={mem_after / (1 << 30):.2f} GiB, freed={freed_mb:.0f} MB"
)
def send_completion(
port: int, prompt: str = "Hello", max_retries: int = 3, retry_delay: float = 1.0
) -> dict:
"""Send a completion request to the frontend.
Includes retry logic to handle transient failures from stale routing
(e.g., after failover when etcd still has dead instance entries).
"""
last_error = None
for attempt in range(max_retries):
try:
r = requests.post(
f"http://localhost:{port}/v1/completions",
json={
"model": FAULT_TOLERANCE_MODEL_NAME,
"prompt": prompt,
"max_tokens": 20,
},
timeout=120,
)
r.raise_for_status()
result = r.json()
assert result.get("choices"), "No choices in response"
if attempt > 0:
logger.info(f"send_completion succeeded after {attempt + 1} attempts")
return result
except (requests.exceptions.RequestException, AssertionError) as e:
last_error = e
if attempt < max_retries - 1:
logger.debug(
f"send_completion attempt {attempt + 1}/{max_retries} failed: {e}"
)
time.sleep(retry_delay)
raise last_error # type: ignore
class GMSServerProcess(ManagedProcess):
"""Manages GMS server lifecycle for tests."""
def __init__(self, request, device: int):
self.device = device
self.socket_path = get_socket_path(device)
if os.path.exists(self.socket_path):
os.unlink(self.socket_path)
log_dir = f"{request.node.name}_gms_{device}"
shutil.rmtree(log_dir, ignore_errors=True)
super().__init__(
command=["python3", "-m", "gpu_memory_service", "--device", str(device)],
env={**os.environ, "DYN_LOG": "debug"},
timeout=60,
display_output=True,
terminate_all_matching_process_names=False,
log_dir=log_dir,
health_check_funcs=[self._socket_ready],
)
def __exit__(self, exc_type, exc_val, exc_tb):
try:
return super().__exit__(exc_type, exc_val, exc_tb)
finally:
if os.path.exists(self.socket_path):
os.unlink(self.socket_path)
def _socket_ready(self, timeout: float = 30) -> bool:
start = time.time()
while time.time() - start < timeout:
if os.path.exists(self.socket_path):
return True
time.sleep(0.1)
return False
def run_shadow_failover_test(
request,
ports: dict,
make_shadow: Callable[[], ManagedProcess],
make_primary: Callable[[], ManagedProcess],
) -> None:
"""Shared shadow-engine failover flow for both vLLM and SGLang.
1. Start shadow -> verify inference
2. Sleep shadow -> log memory freed
3. Start primary -> verify inference
4. kill -9 primary -> wait for GPU memory reclamation
5. Wake shadow -> verify inference x 3
"""
frontend_port = ports["frontend"]
with GMSServerProcess(request, device=0):
with DynamoFrontendProcess(request, frontend_port=frontend_port):
with make_shadow() as shadow:
# Shadow inference
result = send_completion(frontend_port)
assert result["choices"], "Shadow inference failed"
logger.info(f"Shadow inference OK: {result}")
# Sleep shadow
mem_before = get_gpu_memory_used()
assert shadow.sleep()["status"] == "ok"
mem_after = get_gpu_memory_used()
logger.info(
f"Shadow sleep: {mem_before / (1 << 30):.2f} -> "
f"{mem_after / (1 << 30):.2f} GiB "
f"(freed {(mem_before - mem_after) / (1 << 20):.0f} MB)"
)
# Primary: start, verify, kill -9
with make_primary() as primary:
result = send_completion(frontend_port, "Primary test")
assert result["choices"], "Primary inference failed"
logger.info(f"Primary inference OK: {result}")
kill_force(primary)
# Wake shadow, verify 3x
assert shadow.wake()["status"] == "ok"
for i in range(3):
result = send_completion(frontend_port, f"Verify {i}")
assert result["choices"], f"Verification {i} failed"
logger.info("All verification passed")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment