Unverified Commit 39a6a240 authored by Schwinn Saereesitthipitak's avatar Schwinn Saereesitthipitak Committed by GitHub
Browse files

refactor: simplify GPU Memory Service integrations and module boundaries (#7875)

parent 02666f04
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
"""Hybrid torch_memory_saver implementation for GPU Memory Service. """torch_memory_saver implementation for GPU Memory Service.
This module uses: SGLang with GMS owns exactly two memory classes:
1. GPU Memory Service for "weights" (shared RO/RW publish flow) 1. "weights" via the shared RO/RW publish flow
2. GPU Memory Service for "kv_cache" (RW-only failover flow) 2. "kv_cache" via the RW failover flow
3. torch_memory_saver for any remaining tags
Unsupported release/resume tags stay no-ops with a warning so the generic
SGLang memory-control API can still pass broader tag sets without reintroducing
the old torch-memory-saver fallback. `cuda_graph` is a hard error because the
pauseable CUDA-graph path depends on the LD_PRELOAD torch allocator hooks that
GMS intentionally does not use.
""" """
from __future__ import annotations from __future__ import annotations
import logging import logging
from contextlib import contextmanager from contextlib import contextmanager
from typing import TYPE_CHECKING, Optional from typing import Optional
import torch import torch
from gpu_memory_service import get_or_create_gms_client_memory_manager from gpu_memory_service.client.torch.allocator import (
from gpu_memory_service.client.torch.allocator import gms_use_mem_pool get_or_create_gms_client_memory_manager,
from gpu_memory_service.common.types import GrantedLockType, RequestedLockType gms_use_mem_pool,
)
from gpu_memory_service.common.locks import GrantedLockType, RequestedLockType
from gpu_memory_service.common.utils import get_socket_path from gpu_memory_service.common.utils import get_socket_path
from gpu_memory_service.integrations.common.utils import GMS_TAGS, finalize_gms_write
if TYPE_CHECKING:
from gpu_memory_service.client.memory_manager import GMSClientMemoryManager
from torch_memory_saver.entrypoint import _TorchMemorySaverImpl
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Published weights must come back RO, while KV cache always resumes in a fresh
# RW epoch so the restored engine can rebuild mutable cache state.
_TAG_LOCK_TYPES = {"weights": RequestedLockType.RO, "kv_cache": RequestedLockType.RW}
def _pause_resume_tags(tag: Optional[str]) -> tuple[str, ...]:
if tag is None:
return GMS_TAGS
if tag in _TAG_LOCK_TYPES:
return (tag,)
logger.warning(
"[GMS] Ignoring unsupported torch_memory_saver tag %r; supported tags are %s",
tag,
list(GMS_TAGS),
)
return ()
def get_gms_memory_saver_impl() -> Optional["GMSMemorySaverImpl"]: def get_gms_memory_saver_impl() -> Optional["GMSMemorySaverImpl"]:
"""Get the GMS memory saver impl from the torch_memory_saver singleton.""" """Get the GMS memory saver impl from the torch_memory_saver singleton."""
...@@ -39,170 +60,126 @@ def get_gms_memory_saver_impl() -> Optional["GMSMemorySaverImpl"]: ...@@ -39,170 +60,126 @@ def get_gms_memory_saver_impl() -> Optional["GMSMemorySaverImpl"]:
class GMSMemorySaverImpl: class GMSMemorySaverImpl:
"""Hybrid implementation: GMS for weights and KV cache.""" """SGLang memory saver implementation backed only by GMS."""
def __init__( def __init__(
self, self,
torch_impl: "_TorchMemorySaverImpl",
device_index: int, device_index: int,
mode=None, mode=None,
): ):
self._torch_impl = torch_impl self._device = torch.device("cuda", device_index)
self._device_index = device_index self.imported_weights_bytes = 0
self._requested_mode = mode requested_mode = mode or RequestedLockType.RW_OR_RO
self._disabled = False self.allocators = {
self._imported_weights_bytes: int = 0 tag: get_or_create_gms_client_memory_manager(
get_socket_path(device_index, tag),
self._weights_allocator: Optional["GMSClientMemoryManager"]
self._kv_cache_allocator: "GMSClientMemoryManager"
self._mode: str
(
self._weights_allocator,
self._kv_cache_allocator,
self._mode,
) = self._init_allocators()
logger.info(
"[GMS] Initialized weights=%s mode, kv_cache=RW (device=%d)",
self._mode.upper(),
device_index, device_index,
# weights follow the configured publish/import mode; kv_cache is
# always mutable and therefore always needs an RW session.
mode=requested_mode if tag == "weights" else RequestedLockType.RW,
tag=tag,
) )
for tag in GMS_TAGS
}
def _init_allocators(
self,
) -> tuple[Optional["GMSClientMemoryManager"], "GMSClientMemoryManager", str,]:
"""Create allocator with mode from config (default: RW_OR_RO)."""
mode = self._requested_mode or RequestedLockType.RW_OR_RO
weights_allocator = get_or_create_gms_client_memory_manager(
get_socket_path(self._device_index, "weights"),
self._device_index,
mode=mode,
tag="weights",
)
kv_cache_allocator = get_or_create_gms_client_memory_manager(
get_socket_path(self._device_index, "kv_cache"),
self._device_index,
mode=RequestedLockType.RW,
tag="kv_cache",
)
granted_mode = weights_allocator.granted_lock_type
if granted_mode == GrantedLockType.RW:
actual_mode = "write"
else:
actual_mode = "read"
logger.info( logger.info(
"[GMS] Initialized in AUTO mode, granted=%s (device=%d)", "[GMS] Initialized weights: requested=%s granted=%s (device=%d)",
actual_mode.upper(), requested_mode.name,
self._device_index, self.allocators["weights"].granted_lock_type.name,
device_index,
) )
return weights_allocator, kv_cache_allocator, actual_mode
def _is_weights_tag(self, tag: Optional[str]) -> bool:
return tag in ("weights", "model_weights")
def get_mode(self) -> str:
return self._mode
def get_allocator(self) -> Optional["GMSClientMemoryManager"]:
return self._weights_allocator
@contextmanager @contextmanager
def region(self, tag: str, enable_cpu_backup: bool): def region(self, tag: str, enable_cpu_backup: bool):
"""Mark allocation region with tag.""" """Mark allocation region with tag."""
if self._is_weights_tag(tag): if enable_cpu_backup:
if self._mode == "read": raise ValueError(
yield "SGLang with GMS does not support CPU backup for allocations."
return )
target_device = torch.device("cuda", self._device_index) if tag not in _TAG_LOCK_TYPES:
with gms_use_mem_pool("weights", target_device): logger.warning(
"[GMS] Ignoring unsupported torch_memory_saver region tag %r; "
"supported tags are %s",
tag,
list(GMS_TAGS),
)
yield yield
return return
if tag == "kv_cache": if (
target_device = torch.device("cuda", self._device_index) tag == "weights"
with gms_use_mem_pool("kv_cache", target_device): and self.allocators["weights"].granted_lock_type == GrantedLockType.RO
):
# Imported weights are already mapped and immutable in RO mode, so
# there is no allocator swap to install for this region.
yield yield
return return
with self._torch_impl.region(tag=tag, enable_cpu_backup=enable_cpu_backup): allocator = self.allocators[tag]
if allocator.granted_lock_type != GrantedLockType.RW:
mode = (
allocator.granted_lock_type.name
if allocator.granted_lock_type is not None
else "DISCONNECTED"
)
# The server would reject writes on a non-RW session too, but we
# fail before entering the allocation path so SGLang never starts a
# partial region with the wrong lock state.
raise RuntimeError(
f"SGLang with GMS requires {tag!r} to be RW for allocations; got {mode}"
)
with gms_use_mem_pool(tag, self._device):
yield yield
@contextmanager
def cuda_graph(
self,
cuda_graph,
pool,
stream,
capture_error_mode,
tag: str,
enable_cpu_backup: bool,
):
# The old hybrid path could delegate this to torch_memory_saver, but
# strict GMS mode has no compatible pauseable CUDA-graph allocator hook.
raise RuntimeError(
"SGLang with GMS does not support pauseable CUDA graphs. "
"torch_memory_saver only supports cuda_graph in hook_mode=preload, "
"and GMS does not use the LD_PRELOAD path."
)
def pause(self, tag: Optional[str] = None) -> None: def pause(self, tag: Optional[str] = None) -> None:
if self._disabled: for target_tag in _pause_resume_tags(tag):
return if self.allocators[target_tag].is_unmapped:
if tag is None or self._is_weights_tag(tag): continue
self._pause_weights() logger.info("[GMS] Unmapping %s", target_tag)
if tag is None or tag == "kv_cache": self.allocators[target_tag].unmap_all_vas()
self._pause_kv_cache() # abort() drops the current session after unmapping while keeping
if tag is None or (not self._is_weights_tag(tag) and tag != "kv_cache"): # the VA reservation alive for the next resume().
self._torch_impl.pause(tag=tag) self.allocators[target_tag].abort()
def resume(self, tag: Optional[str] = None) -> None: def resume(self, tag: Optional[str] = None) -> None:
if self._disabled: for target_tag in _pause_resume_tags(tag):
return if not self.allocators[target_tag].is_unmapped:
if tag is None or self._is_weights_tag(tag): continue
self._resume_weights()
if tag is None or tag == "kv_cache": logger.info("[GMS] Remapping %s", target_tag)
self._resume_kv_cache() self.allocators[target_tag].connect(_TAG_LOCK_TYPES[target_tag])
if tag is None or (not self._is_weights_tag(tag) and tag != "kv_cache"): if target_tag == "kv_cache":
self._torch_impl.resume(tag=tag) # KV cache resumes into a new RW layout epoch, so the handles
# must be re-created before the VA range is mapped again.
def _pause_weights(self) -> None: self.allocators[target_tag].reallocate_all_handles(tag=target_tag)
if self._weights_allocator is None: self.allocators[target_tag].remap_all_vas()
return
if self._weights_allocator.is_unmapped:
return
logger.info("[GMS] Unmapping weights (VA-stable)")
self._weights_allocator.unmap_all_vas()
self._weights_allocator.abort()
def _resume_weights(self) -> None:
if self._weights_allocator is None:
return
if not self._weights_allocator.is_unmapped:
return
logger.info("[GMS] Remapping weights (VA-stable)")
self._weights_allocator.connect(RequestedLockType.RO)
self._weights_allocator.remap_all_vas()
def _pause_kv_cache(self) -> None:
if self._kv_cache_allocator.is_unmapped:
return
logger.info("[GMS] Unmapping KV cache")
self._kv_cache_allocator.unmap_all_vas()
self._kv_cache_allocator.abort()
def _resume_kv_cache(self) -> None:
if not self._kv_cache_allocator.is_unmapped:
return
logger.info("[GMS] Remapping KV cache")
self._kv_cache_allocator.connect(RequestedLockType.RW)
self._kv_cache_allocator.reallocate_all_handles(tag="kv_cache")
self._kv_cache_allocator.remap_all_vas()
def finalize_write_mode(self, model: torch.nn.Module) -> None: def finalize_write_mode(self, model: torch.nn.Module) -> None:
"""Finalize write mode: register tensors, commit, and switch to read.""" """Finalize write mode: register tensors, commit, and switch to read."""
if self._mode != "write": if self.allocators["weights"].granted_lock_type != GrantedLockType.RW:
# Read-only import mode never republishes weights.
return return
if self._weights_allocator is None:
raise RuntimeError("Allocator is None in WRITE mode")
from gpu_memory_service.integrations.common.utils import finalize_gms_write self.imported_weights_bytes = finalize_gms_write(
self.allocators["weights"], model
self._imported_weights_bytes = finalize_gms_write(
self._weights_allocator, model
) )
self._mode = "read"
def set_imported_weights_bytes(self, bytes_count: int) -> None:
self._imported_weights_bytes = bytes_count
def get_imported_weights_bytes(self) -> int:
return self._imported_weights_bytes
def disable(self) -> None:
self._disabled = True
def enable(self) -> None:
self._disabled = False
...@@ -16,11 +16,16 @@ from __future__ import annotations ...@@ -16,11 +16,16 @@ from __future__ import annotations
import logging import logging
import torch import torch
from gpu_memory_service.client.torch.module import materialize_module_from_gms
from gpu_memory_service.common.locks import GrantedLockType
from gpu_memory_service.integrations.common import patch_empty_cache from gpu_memory_service.integrations.common import patch_empty_cache
from gpu_memory_service.integrations.common.utils import ( from gpu_memory_service.integrations.common.utils import (
setup_meta_tensor_workaround, setup_meta_tensor_workaround,
strip_gms_model_loader_config, strip_gms_model_loader_config,
) )
from gpu_memory_service.integrations.sglang.memory_saver import (
get_gms_memory_saver_impl,
)
from gpu_memory_service.integrations.sglang.patches import ( from gpu_memory_service.integrations.sglang.patches import (
patch_model_runner, patch_model_runner,
patch_static_state_for_gms, patch_static_state_for_gms,
...@@ -66,10 +71,6 @@ class GMSModelLoader: ...@@ -66,10 +71,6 @@ class GMSModelLoader:
device_config, device_config,
) -> torch.nn.Module: ) -> torch.nn.Module:
"""Load or import model weights.""" """Load or import model weights."""
from gpu_memory_service.integrations.sglang.memory_saver import (
get_gms_memory_saver_impl,
)
impl = get_gms_memory_saver_impl() impl = get_gms_memory_saver_impl()
if impl is None: if impl is None:
raise RuntimeError( raise RuntimeError(
...@@ -77,12 +78,11 @@ class GMSModelLoader: ...@@ -77,12 +78,11 @@ class GMSModelLoader:
"Ensure torch_memory_saver patch was applied before model loading." "Ensure torch_memory_saver patch was applied before model loading."
) )
mode = impl.get_mode() mode = impl.allocators["weights"].granted_lock_type
logger.info("[GMS] Loading model in %s mode", mode.upper()) logger.info("[GMS] Loading model in %s mode", mode.name)
if mode == "read": if mode == GrantedLockType.RO:
return self._load_import_only(model_config, device_config, impl) return self._load_import_only(model_config, device_config, impl)
else:
return self._load_write_mode(model_config, device_config, impl) return self._load_write_mode(model_config, device_config, impl)
def _load_write_mode(self, model_config, device_config, impl) -> torch.nn.Module: def _load_write_mode(self, model_config, device_config, impl) -> torch.nn.Module:
...@@ -99,17 +99,13 @@ class GMSModelLoader: ...@@ -99,17 +99,13 @@ class GMSModelLoader:
def _load_import_only(self, model_config, device_config, impl) -> torch.nn.Module: def _load_import_only(self, model_config, device_config, impl) -> torch.nn.Module:
"""Import model weights from GMS metadata (READ mode).""" """Import model weights from GMS metadata (READ mode)."""
from gpu_memory_service.client.torch.module import materialize_module_from_gms allocator = impl.allocators["weights"]
allocator = impl.get_allocator()
if allocator is None:
raise RuntimeError("GMS allocator is None in READ mode")
device_index = torch.cuda.current_device() device_index = torch.cuda.current_device()
model = self._create_meta_model(model_config, device_config) model = self._create_meta_model(model_config, device_config)
materialize_module_from_gms(allocator, model, device_index=device_index) materialize_module_from_gms(allocator, model, device_index=device_index)
impl.set_imported_weights_bytes(allocator.total_bytes) impl.imported_weights_bytes = allocator.total_bytes
logger.info( logger.info(
"[GMS] READ mode: imported %.2f GiB from metadata", "[GMS] READ mode: imported %.2f GiB from metadata",
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
"""SGLang-specific patches for GPU Memory Service integration. """SGLang-specific patches for GPU Memory Service integration.
- patch_torch_memory_saver: Routes to GMS hybrid implementation - patch_torch_memory_saver: Routes weights and kv_cache to GMS
- patch_model_runner: Fixes memory accounting with pre-loaded weights - patch_model_runner: Fixes memory accounting with pre-loaded weights
- patch_static_state_for_gms: No-ops named-buffer export/import (GMS preserves them) - patch_static_state_for_gms: No-ops named-buffer export/import (GMS preserves them)
""" """
...@@ -15,7 +15,12 @@ import logging ...@@ -15,7 +15,12 @@ import logging
from contextlib import contextmanager from contextlib import contextmanager
from typing import Optional from typing import Optional
import gpu_memory_service.integrations.sglang as gms_sglang
import torch import torch
from gpu_memory_service.integrations.sglang.memory_saver import (
GMSMemorySaverImpl,
get_gms_memory_saver_impl,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -57,25 +62,16 @@ def patch_torch_memory_saver() -> None: ...@@ -57,25 +62,16 @@ def patch_torch_memory_saver() -> None:
logger.info(f"[GMS] TorchMemorySaver initializing with hook_mode={hook_mode}") logger.info(f"[GMS] TorchMemorySaver initializing with hook_mode={hook_mode}")
if hook_mode is None or hook_mode == "gms": if hook_mode is None or hook_mode == "gms":
# Use our GPU Memory Service implementation # In GMS mode we install only the strict GMS implementation:
from gpu_memory_service.integrations.sglang.memory_saver import ( # weights + kv_cache go through GMS, generic unsupported tags stay
GMSMemorySaverImpl, # no-ops/warnings, and cuda_graph remains unsupported.
)
from torch_memory_saver.entrypoint import _TorchMemorySaverImpl
# Get device from torch.cuda.current_device() (already set by SGLang) # Get device from torch.cuda.current_device() (already set by SGLang)
device_index = torch.cuda.current_device() device_index = torch.cuda.current_device()
# Create underlying torch impl for non-GMS tags.
torch_impl = _TorchMemorySaverImpl(hook_mode="torch")
# Read lock mode set by setup_gms() (defaults to RW_OR_RO) # Read lock mode set by setup_gms() (defaults to RW_OR_RO)
from gpu_memory_service.integrations.sglang import _gms_lock_mode
gms_impl = GMSMemorySaverImpl( gms_impl = GMSMemorySaverImpl(
torch_impl=torch_impl,
device_index=device_index, device_index=device_index,
mode=_gms_lock_mode, mode=gms_sglang._gms_lock_mode,
) )
# Set _impl directly (accessible via gms_impl property) # Set _impl directly (accessible via gms_impl property)
...@@ -83,7 +79,7 @@ def patch_torch_memory_saver() -> None: ...@@ -83,7 +79,7 @@ def patch_torch_memory_saver() -> None:
logger.info( logger.info(
"[GMS] Using GMS mode (device=%d, mode=%s)", "[GMS] Using GMS mode (device=%d, mode=%s)",
device_index, device_index,
gms_impl.get_mode(), gms_impl.allocators["weights"].granted_lock_type.name,
) )
del self._impl_ctor_kwargs del self._impl_ctor_kwargs
else: else:
...@@ -111,8 +107,6 @@ def patch_torch_memory_saver() -> None: ...@@ -111,8 +107,6 @@ def patch_torch_memory_saver() -> None:
torch_memory_saver.configure_subprocess = patched_configure_subprocess torch_memory_saver.configure_subprocess = patched_configure_subprocess
# Add property to access GMS impl directly from the singleton # Add property to access GMS impl directly from the singleton
from gpu_memory_service.integrations.sglang.memory_saver import GMSMemorySaverImpl
@property @property
def gms_impl(self) -> Optional[GMSMemorySaverImpl]: def gms_impl(self) -> Optional[GMSMemorySaverImpl]:
"""Get the GMS impl if installed, None otherwise.""" """Get the GMS impl if installed, None otherwise."""
...@@ -185,12 +179,8 @@ def patch_model_runner() -> None: ...@@ -185,12 +179,8 @@ def patch_model_runner() -> None:
weights are already resident. Newer SGLang versions changed this API, so weights are already resident. Newer SGLang versions changed this API, so
only rewrite the old total_gpu_memory parameter shape. only rewrite the old total_gpu_memory parameter shape.
""" """
from gpu_memory_service.integrations.sglang.memory_saver import (
get_gms_memory_saver_impl,
)
impl = get_gms_memory_saver_impl() impl = get_gms_memory_saver_impl()
if impl is not None and impl.get_imported_weights_bytes() > 0: if impl is not None and impl.imported_weights_bytes > 0:
total_memory_gib = torch.cuda.get_device_properties( total_memory_gib = torch.cuda.get_device_properties(
torch.cuda.current_device() torch.cuda.current_device()
).total_memory / (1 << 30) ).total_memory / (1 << 30)
......
...@@ -14,10 +14,12 @@ import logging ...@@ -14,10 +14,12 @@ import logging
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
import torch import torch
from gpu_memory_service import get_or_create_gms_client_memory_manager from gpu_memory_service.client.torch.allocator import (
from gpu_memory_service.client.torch.allocator import gms_use_mem_pool get_or_create_gms_client_memory_manager,
gms_use_mem_pool,
)
from gpu_memory_service.client.torch.module import materialize_module_from_gms from gpu_memory_service.client.torch.module import materialize_module_from_gms
from gpu_memory_service.common.types import GrantedLockType from gpu_memory_service.common.locks import GrantedLockType
from gpu_memory_service.common.utils import get_socket_path from gpu_memory_service.common.utils import get_socket_path
from gpu_memory_service.integrations.common.utils import ( from gpu_memory_service.integrations.common.utils import (
finalize_gms_write, finalize_gms_write,
......
...@@ -14,8 +14,8 @@ from __future__ import annotations ...@@ -14,8 +14,8 @@ from __future__ import annotations
import logging import logging
from gpu_memory_service import get_gms_client_memory_manager from gpu_memory_service.client.torch.allocator import get_gms_client_memory_manager
from gpu_memory_service.common.types import GrantedLockType from gpu_memory_service.common.locks import GrantedLockType
from gpu_memory_service.integrations.vllm.utils import is_shadow_mode from gpu_memory_service.integrations.vllm.utils import is_shadow_mode
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
......
...@@ -18,16 +18,16 @@ from contextlib import nullcontext ...@@ -18,16 +18,16 @@ from contextlib import nullcontext
from typing import List, Optional from typing import List, Optional
import torch import torch
from gpu_memory_service import ( from gpu_memory_service.client.memory_manager import StaleMemoryLayoutError
from gpu_memory_service.client.torch.allocator import (
get_gms_client_memory_manager, get_gms_client_memory_manager,
get_or_create_gms_client_memory_manager, get_or_create_gms_client_memory_manager,
gms_use_mem_pool,
) )
from gpu_memory_service.client.memory_manager import StaleMemoryLayoutError from gpu_memory_service.common.locks import RequestedLockType
from gpu_memory_service.client.torch.allocator import gms_use_mem_pool
from gpu_memory_service.common.types import RequestedLockType
from gpu_memory_service.common.utils import get_socket_path from gpu_memory_service.common.utils import get_socket_path
from gpu_memory_service.integrations.common import patch_empty_cache from gpu_memory_service.integrations.common import patch_empty_cache
from gpu_memory_service.integrations.common.utils import get_gms_lock_mode from gpu_memory_service.integrations.common.utils import GMS_TAGS, get_gms_lock_mode
from gpu_memory_service.integrations.vllm.model_loader import register_gms_loader from gpu_memory_service.integrations.vllm.model_loader import register_gms_loader
from gpu_memory_service.integrations.vllm.patches import ( from gpu_memory_service.integrations.vllm.patches import (
apply_shadow_mode_patches, apply_shadow_mode_patches,
...@@ -264,7 +264,7 @@ class GMSWorker(Worker): ...@@ -264,7 +264,7 @@ class GMSWorker(Worker):
self.model_runner.exit_shadow_init() self.model_runner.exit_shadow_init()
if tags is None: if tags is None:
tags = ["weights", "kv_cache"] tags = list(GMS_TAGS)
if "weights" in tags: if "weights" in tags:
weights_manager = get_gms_client_memory_manager("weights") weights_manager = get_gms_client_memory_manager("weights")
......
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
"""GPU Memory Service server components."""
from gpu_memory_service.common.types import (
GrantedLockType,
RequestedLockType,
ServerState,
StateSnapshot,
)
from gpu_memory_service.server.allocations import (
AllocationInfo,
AllocationNotFoundError,
GMSAllocationManager,
)
from gpu_memory_service.server.gms import GMS, MetadataEntry
from gpu_memory_service.server.rpc import GMSRPCServer
from gpu_memory_service.server.session import (
Connection,
GMSSessionManager,
InvalidTransition,
OperationNotAllowed,
)
__all__ = [
"GMSRPCServer",
"GMS",
"GMSSessionManager",
"GMSAllocationManager",
"AllocationInfo",
"AllocationNotFoundError",
"MetadataEntry",
"Connection",
"GrantedLockType",
"RequestedLockType",
"ServerState",
"StateSnapshot",
"InvalidTransition",
"OperationNotAllowed",
]
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import asyncio
from dataclasses import dataclass, field
from enum import Enum, auto
from typing import Optional, Set
from gpu_memory_service.common.locks import GrantedLockType
class ServerState(str, Enum):
EMPTY = "EMPTY"
RW = "RW"
COMMITTED = "COMMITTED"
RO = "RO"
class StateEvent(Enum):
RW_CONNECT = auto()
RW_COMMIT = auto()
RW_ABORT = auto()
RO_CONNECT = auto()
RO_DISCONNECT = auto()
@dataclass(eq=False)
class Connection:
reader: asyncio.StreamReader
writer: asyncio.StreamWriter
mode: GrantedLockType
session_id: str
recv_buffer: bytearray = field(default_factory=bytearray)
def __hash__(self) -> int:
return hash(self.session_id)
async def close(self) -> None:
self.writer.close()
try:
await self.writer.wait_closed()
except Exception:
pass
class InvalidTransition(Exception):
pass
@dataclass(frozen=True)
class Transition:
from_states: frozenset[ServerState]
event: StateEvent
to_state: Optional[ServerState]
condition: Optional[str] = None
TRANSITIONS: list[Transition] = [
Transition(
from_states=frozenset({ServerState.EMPTY, ServerState.COMMITTED}),
event=StateEvent.RW_CONNECT,
to_state=ServerState.RW,
),
Transition(
from_states=frozenset({ServerState.RW}),
event=StateEvent.RW_COMMIT,
to_state=ServerState.COMMITTED,
),
Transition(
from_states=frozenset({ServerState.RW}),
event=StateEvent.RW_ABORT,
to_state=ServerState.EMPTY,
),
Transition(
from_states=frozenset({ServerState.COMMITTED, ServerState.RO}),
event=StateEvent.RO_CONNECT,
to_state=ServerState.RO,
),
Transition(
from_states=frozenset({ServerState.RO}),
event=StateEvent.RO_DISCONNECT,
to_state=ServerState.RO,
condition="has_remaining_readers",
),
Transition(
from_states=frozenset({ServerState.RO}),
event=StateEvent.RO_DISCONNECT,
to_state=ServerState.COMMITTED,
condition="is_last_reader",
),
]
class GMSFSM:
def __init__(self):
self._rw_conn: Optional[Connection] = None
self._ro_conns: Set[Connection] = set()
self._committed = False
@property
def state(self) -> ServerState:
if self._rw_conn is not None:
return ServerState.RW
if self._ro_conns:
return ServerState.RO
if self._committed:
return ServerState.COMMITTED
return ServerState.EMPTY
@property
def rw_conn(self) -> Optional[Connection]:
return self._rw_conn
@property
def ro_conns(self) -> Set[Connection]:
return self._ro_conns
@property
def ro_count(self) -> int:
return len(self._ro_conns)
@property
def committed(self) -> bool:
return self._committed
def _check_condition(self, condition: Optional[str], conn: Connection) -> bool:
if condition is None:
return True
if condition == "has_remaining_readers":
return len(self._ro_conns) > 1 or conn not in self._ro_conns
if condition == "is_last_reader":
return len(self._ro_conns) == 1 and conn in self._ro_conns
raise ValueError(f"Unknown condition: {condition}")
def transition(self, event: StateEvent, conn: Connection) -> ServerState:
from_state = self.state
for transition in TRANSITIONS:
if from_state not in transition.from_states:
continue
if transition.event != event:
continue
if not self._check_condition(transition.condition, conn):
continue
break
else:
raise InvalidTransition(
f"No transition for {event.name} from state {from_state.name} "
f"(session={conn.session_id})"
)
if event == StateEvent.RW_CONNECT:
self._rw_conn = conn
self._committed = False
elif event == StateEvent.RW_COMMIT:
self._committed = True
self._rw_conn = None
elif event == StateEvent.RW_ABORT:
self._rw_conn = None
elif event == StateEvent.RO_CONNECT:
self._ro_conns.add(conn)
elif event == StateEvent.RO_DISCONNECT:
self._ro_conns.discard(conn)
return self.state
def can_acquire_rw(self) -> bool:
return self._rw_conn is None and not self._ro_conns
def can_acquire_ro(self, waiting_writers: int) -> bool:
return self._committed and self._rw_conn is None and waiting_writers == 0
...@@ -11,6 +11,7 @@ from collections import deque ...@@ -11,6 +11,7 @@ from collections import deque
from dataclasses import dataclass from dataclasses import dataclass
from typing import Callable, Optional from typing import Callable, Optional
from gpu_memory_service.common.locks import GrantedLockType, RequestedLockType
from gpu_memory_service.common.protocol.messages import ( from gpu_memory_service.common.protocol.messages import (
AllocateRequest, AllocateRequest,
AllocateResponse, AllocateResponse,
...@@ -42,15 +43,10 @@ from gpu_memory_service.common.protocol.messages import ( ...@@ -42,15 +43,10 @@ from gpu_memory_service.common.protocol.messages import (
MetadataPutRequest, MetadataPutRequest,
MetadataPutResponse, MetadataPutResponse,
) )
from gpu_memory_service.common.types import (
GrantedLockType,
RequestedLockType,
ServerState,
StateEvent,
)
from .allocations import AllocationInfo, GMSAllocationManager from .allocations import AllocationInfo, GMSAllocationManager
from .session import Connection, GMSSessionManager from .fsm import Connection, ServerState, StateEvent
from .session import GMSSessionManager
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
......
...@@ -23,8 +23,9 @@ from gpu_memory_service.common.protocol.wire import recv_message, send_message ...@@ -23,8 +23,9 @@ from gpu_memory_service.common.protocol.wire import recv_message, send_message
from gpu_memory_service.common.utils import fail from gpu_memory_service.common.utils import fail
from .allocations import AllocationNotFoundError from .allocations import AllocationNotFoundError
from .fsm import Connection, InvalidTransition
from .gms import GMS from .gms import GMS
from .session import Connection, InvalidTransition, OperationNotAllowed from .session import OperationNotAllowed
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
......
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
"""Server-side connection, FSM, and waiter state.""" """Server-side lock acquisition and cleanup."""
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
from dataclasses import dataclass, field from dataclasses import dataclass
from typing import Optional, Set from typing import Optional
from gpu_memory_service.common.types import ( from gpu_memory_service.common.locks import GrantedLockType, RequestedLockType
RO_ALLOWED, from gpu_memory_service.common.protocol.messages import (
RW_ALLOWED, AllocateRequest,
RW_REQUIRED, CommitRequest,
GrantedLockType, ExportAllocationRequest,
RequestedLockType, FreeAllocationRequest,
ServerState, GetAllocationRequest,
StateEvent, GetAllocationStateRequest,
GetLockStateRequest,
GetStateHashRequest,
ListAllocationsRequest,
MetadataDeleteRequest,
MetadataGetRequest,
MetadataListRequest,
MetadataPutRequest,
) )
from .fsm import GMSFSM, Connection, ServerState, StateEvent
@dataclass(eq=False)
class Connection:
reader: asyncio.StreamReader
writer: asyncio.StreamWriter
mode: GrantedLockType
session_id: str
recv_buffer: bytearray = field(default_factory=bytearray)
def __hash__(self) -> int:
return hash(self.session_id)
async def close(self) -> None:
self.writer.close()
try:
await self.writer.wait_closed()
except Exception:
pass
class InvalidTransition(Exception):
"""Raised when an invalid state transition is attempted."""
class OperationNotAllowed(Exception): class OperationNotAllowed(Exception):
"""Raised when an operation is not allowed in the current state/mode.""" pass
@dataclass(frozen=True)
class Transition:
from_states: frozenset[ServerState]
event: StateEvent
to_state: Optional[ServerState]
condition: Optional[str] = None
TRANSITIONS: list[Transition] = [
Transition(
from_states=frozenset({ServerState.EMPTY, ServerState.COMMITTED}),
event=StateEvent.RW_CONNECT,
to_state=ServerState.RW,
),
Transition(
from_states=frozenset({ServerState.RW}),
event=StateEvent.RW_COMMIT,
to_state=ServerState.COMMITTED,
),
Transition(
from_states=frozenset({ServerState.RW}),
event=StateEvent.RW_ABORT,
to_state=ServerState.EMPTY,
),
Transition(
from_states=frozenset({ServerState.COMMITTED, ServerState.RO}),
event=StateEvent.RO_CONNECT,
to_state=ServerState.RO,
),
Transition(
from_states=frozenset({ServerState.RO}),
event=StateEvent.RO_DISCONNECT,
to_state=ServerState.RO,
condition="has_remaining_readers",
),
Transition(
from_states=frozenset({ServerState.RO}),
event=StateEvent.RO_DISCONNECT,
to_state=ServerState.COMMITTED,
condition="is_last_reader",
),
]
class GMSLocalFSM:
"""Explicit connection/lock state machine."""
def __init__(self):
self._rw_conn: Optional[Connection] = None
self._ro_conns: Set[Connection] = set()
self._committed = False
@property
def state(self) -> ServerState:
if self._rw_conn is not None:
return ServerState.RW
if self._ro_conns:
return ServerState.RO
if self._committed:
return ServerState.COMMITTED
return ServerState.EMPTY
@property
def rw_conn(self) -> Optional[Connection]:
return self._rw_conn
@property
def ro_conns(self) -> Set[Connection]:
return self._ro_conns
@property
def ro_count(self) -> int:
return len(self._ro_conns)
@property
def committed(self) -> bool:
return self._committed
def _has_remaining_readers(self, conn: Connection) -> bool:
return len(self._ro_conns) > 1 or conn not in self._ro_conns
def _is_last_reader(self, conn: Connection) -> bool:
return len(self._ro_conns) == 1 and conn in self._ro_conns
def _check_condition(self, condition: Optional[str], conn: Connection) -> bool:
if condition is None:
return True
if condition == "has_remaining_readers":
return self._has_remaining_readers(conn)
if condition == "is_last_reader":
return self._is_last_reader(conn)
raise ValueError(f"Unknown condition: {condition}")
def _find_transition(
self,
from_state: ServerState,
event: StateEvent,
conn: Connection,
) -> Optional[Transition]:
for transition in TRANSITIONS:
if from_state not in transition.from_states:
continue
if transition.event != event:
continue
if not self._check_condition(transition.condition, conn):
continue
return transition
return None
def _apply_event(self, event: StateEvent, conn: Connection) -> None:
if event == StateEvent.RW_CONNECT:
self._rw_conn = conn
self._committed = False
elif event == StateEvent.RW_COMMIT:
self._committed = True
self._rw_conn = None
elif event == StateEvent.RW_ABORT:
self._rw_conn = None
elif event == StateEvent.RO_CONNECT:
self._ro_conns.add(conn)
elif event == StateEvent.RO_DISCONNECT:
self._ro_conns.discard(conn)
def transition(self, event: StateEvent, conn: Connection) -> ServerState:
transition = self._find_transition(self.state, event, conn)
if transition is None:
raise InvalidTransition(
f"No transition for {event.name} from state {self.state.name} "
f"(session={conn.session_id})"
)
self._apply_event(event, conn)
return self.state
def check_operation(self, msg_type: type, conn: Connection) -> None: RW_REQUIRED: frozenset[type] = frozenset(
if conn.mode == GrantedLockType.RW and msg_type not in RW_ALLOWED: {
raise OperationNotAllowed( AllocateRequest,
f"{msg_type.__name__} not allowed for RW session in state {self.state.name}" FreeAllocationRequest,
) MetadataPutRequest,
if conn.mode == GrantedLockType.RO and msg_type not in RO_ALLOWED: MetadataDeleteRequest,
raise OperationNotAllowed( CommitRequest,
f"{msg_type.__name__} not allowed for RO session in state {self.state.name}" }
) )
if msg_type in RW_REQUIRED and conn.mode != GrantedLockType.RW:
raise OperationNotAllowed(
f"{msg_type.__name__} requires RW session, got {conn.mode.value}"
)
def can_acquire_rw(self) -> bool: RO_ALLOWED: frozenset[type] = frozenset(
return self._rw_conn is None and not self._ro_conns {
ExportAllocationRequest,
GetAllocationRequest,
ListAllocationsRequest,
MetadataGetRequest,
MetadataListRequest,
GetLockStateRequest,
GetAllocationStateRequest,
GetStateHashRequest,
}
)
def can_acquire_ro(self, waiting_writers: int) -> bool: RW_ALLOWED: frozenset[type] = RW_REQUIRED | RO_ALLOWED
return self._committed and self._rw_conn is None and waiting_writers == 0
@dataclass(frozen=True) @dataclass(frozen=True)
...@@ -215,7 +73,7 @@ class GMSSessionManager: ...@@ -215,7 +73,7 @@ class GMSSessionManager:
"""Owns lock transitions, waiter coordination, and cleanup.""" """Owns lock transitions, waiter coordination, and cleanup."""
def __init__(self): def __init__(self):
self._locking = GMSLocalFSM() self._locking = GMSFSM()
self._waiting_writers = 0 self._waiting_writers = 0
self._reserved_rw_session_id: Optional[str] = None self._reserved_rw_session_id: Optional[str] = None
self._condition = asyncio.Condition() self._condition = asyncio.Condition()
...@@ -336,7 +194,18 @@ class GMSSessionManager: ...@@ -336,7 +194,18 @@ class GMSSessionManager:
self._locking.transition(StateEvent.RW_COMMIT, conn) self._locking.transition(StateEvent.RW_COMMIT, conn)
def check_operation(self, msg_type: type, conn: Connection) -> None: def check_operation(self, msg_type: type, conn: Connection) -> None:
self._locking.check_operation(msg_type, conn) if conn.mode == GrantedLockType.RW and msg_type not in RW_ALLOWED:
raise OperationNotAllowed(
f"{msg_type.__name__} not allowed for RW session in state {self.state.name}"
)
if conn.mode == GrantedLockType.RO and msg_type not in RO_ALLOWED:
raise OperationNotAllowed(
f"{msg_type.__name__} not allowed for RO session in state {self.state.name}"
)
if msg_type in RW_REQUIRED and conn.mode != GrantedLockType.RW:
raise OperationNotAllowed(
f"{msg_type.__name__} requires RW session, got {conn.mode.value}"
)
def begin_cleanup(self, conn: Optional[Connection]) -> StateEvent | None: def begin_cleanup(self, conn: Optional[Connection]) -> StateEvent | None:
if conn is None: if conn is None:
......
...@@ -246,6 +246,7 @@ markers = [ ...@@ -246,6 +246,7 @@ markers = [
"stress: marks tests as stress tests", "stress: marks tests as stress tests",
"performance: marks tests as performance tests", "performance: marks tests as performance tests",
"benchmark: marks tests as benchmark tests", "benchmark: marks tests as benchmark tests",
"none: marks tests that do not require a framework-specific runtime",
"vllm: marks tests as requiring vllm", "vllm: marks tests as requiring vllm",
"trtllm: marks tests as requiring trtllm", "trtllm: marks tests as requiring trtllm",
"sglang: marks tests as requiring sglang", "sglang: marks tests as requiring sglang",
......
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import pytest
pytest.importorskip("gpu_memory_service", reason="gpu_memory_service is required")
...@@ -3,8 +3,9 @@ ...@@ -3,8 +3,9 @@
"""Tests for the flock-based failover lock. """Tests for the flock-based failover lock.
No GPU required — these are pure Python/OS tests exercising flock These are pure Python/OS tests exercising flock semantics across asyncio
semantics across asyncio tasks and child processes. tasks and child processes, so they stay on the generic cpu-style pre-merge
lane instead of the dedicated GPU job.
""" """
import asyncio import asyncio
...@@ -19,6 +20,7 @@ from gpu_memory_service.failover_lock.flock import FlockFailoverLock ...@@ -19,6 +20,7 @@ from gpu_memory_service.failover_lock.flock import FlockFailoverLock
pytestmark = [ pytestmark = [
pytest.mark.pre_merge, pytest.mark.pre_merge,
pytest.mark.unit, pytest.mark.unit,
pytest.mark.none,
pytest.mark.gpu_0, pytest.mark.gpu_0,
] ]
......
...@@ -9,12 +9,13 @@ from gpu_memory_service.client.memory_manager import ( ...@@ -9,12 +9,13 @@ from gpu_memory_service.client.memory_manager import (
GMSClientMemoryManager, GMSClientMemoryManager,
LocalMapping, LocalMapping,
) )
from gpu_memory_service.common.types import GrantedLockType, RequestedLockType from gpu_memory_service.common.locks import GrantedLockType, RequestedLockType
pytestmark = [ pytestmark = [
pytest.mark.pre_merge, pytest.mark.pre_merge,
pytest.mark.unit, pytest.mark.unit,
pytest.mark.gpu_0, pytest.mark.none,
pytest.mark.gpu_1,
] ]
......
...@@ -6,15 +6,16 @@ from __future__ import annotations ...@@ -6,15 +6,16 @@ from __future__ import annotations
import pytest import pytest
from gpu_memory_service.client.rpc import _GMSRPCTransport from gpu_memory_service.client.rpc import _GMSRPCTransport
from gpu_memory_service.client.session import _GMSClientSession from gpu_memory_service.client.session import _GMSClientSession
from gpu_memory_service.common.locks import GrantedLockType, RequestedLockType
from gpu_memory_service.common.protocol.messages import ( from gpu_memory_service.common.protocol.messages import (
CommitResponse, CommitResponse,
HandshakeResponse, HandshakeResponse,
) )
from gpu_memory_service.common.types import GrantedLockType, RequestedLockType
pytestmark = [ pytestmark = [
pytest.mark.pre_merge, pytest.mark.pre_merge,
pytest.mark.unit, pytest.mark.unit,
pytest.mark.none,
pytest.mark.gpu_0, pytest.mark.gpu_0,
] ]
......
...@@ -15,6 +15,7 @@ from gpu_memory_service.common.protocol.messages import ( ...@@ -15,6 +15,7 @@ from gpu_memory_service.common.protocol.messages import (
pytestmark = [ pytestmark = [
pytest.mark.pre_merge, pytest.mark.pre_merge,
pytest.mark.unit, pytest.mark.unit,
pytest.mark.none,
pytest.mark.gpu_0, pytest.mark.gpu_0,
] ]
......
...@@ -10,7 +10,12 @@ import pytest ...@@ -10,7 +10,12 @@ import pytest
from tests.gms.harness.gms import GMSServerProcess from tests.gms.harness.gms import GMSServerProcess
from tests.utils.managed_process import ManagedProcess from tests.utils.managed_process import ManagedProcess
pytestmark = [pytest.mark.pre_merge, pytest.mark.unit, pytest.mark.gpu_0] pytestmark = [
pytest.mark.pre_merge,
pytest.mark.unit,
pytest.mark.none,
pytest.mark.gpu_1,
]
@pytest.fixture @pytest.fixture
......
...@@ -24,19 +24,16 @@ from gpu_memory_service.client.memory_manager import ( ...@@ -24,19 +24,16 @@ from gpu_memory_service.client.memory_manager import (
from gpu_memory_service.client.rpc import _GMSRPCTransport from gpu_memory_service.client.rpc import _GMSRPCTransport
from gpu_memory_service.client.session import _GMSClientSession from gpu_memory_service.client.session import _GMSClientSession
from gpu_memory_service.common import cuda_utils from gpu_memory_service.common import cuda_utils
from gpu_memory_service.common.locks import GrantedLockType, RequestedLockType
from gpu_memory_service.common.protocol.messages import ( from gpu_memory_service.common.protocol.messages import (
GetEventHistoryRequest, GetEventHistoryRequest,
GetEventHistoryResponse, GetEventHistoryResponse,
GetRuntimeStateRequest, GetRuntimeStateRequest,
GetRuntimeStateResponse, GetRuntimeStateResponse,
) )
from gpu_memory_service.common.types import (
GrantedLockType,
RequestedLockType,
ServerState,
)
from gpu_memory_service.server import allocations as server_allocations from gpu_memory_service.server import allocations as server_allocations
from gpu_memory_service.server.allocations import GMSAllocationManager from gpu_memory_service.server.allocations import GMSAllocationManager
from gpu_memory_service.server.fsm import ServerState
from gpu_memory_service.server.rpc import GMSRPCServer from gpu_memory_service.server.rpc import GMSRPCServer
from tests.gms.harness.gms import ServerThread from tests.gms.harness.gms import ServerThread
...@@ -44,7 +41,8 @@ from tests.gms.harness.gms import ServerThread ...@@ -44,7 +41,8 @@ from tests.gms.harness.gms import ServerThread
pytestmark = [ pytestmark = [
pytest.mark.pre_merge, pytest.mark.pre_merge,
pytest.mark.unit, pytest.mark.unit,
pytest.mark.gpu_0, pytest.mark.none,
pytest.mark.gpu_1,
] ]
......
...@@ -15,6 +15,7 @@ from dataclasses import dataclass ...@@ -15,6 +15,7 @@ from dataclasses import dataclass
import pytest import pytest
from gpu_memory_service.common import cuda_utils from gpu_memory_service.common import cuda_utils
from gpu_memory_service.common.locks import GrantedLockType, RequestedLockType
from gpu_memory_service.common.protocol.messages import ( from gpu_memory_service.common.protocol.messages import (
CommitRequest, CommitRequest,
CommitResponse, CommitResponse,
...@@ -24,13 +25,8 @@ from gpu_memory_service.common.protocol.messages import ( ...@@ -24,13 +25,8 @@ from gpu_memory_service.common.protocol.messages import (
GetRuntimeStateRequest, GetRuntimeStateRequest,
HandshakeRequest, HandshakeRequest,
) )
from gpu_memory_service.common.types import (
GrantedLockType,
RequestedLockType,
ServerState,
StateEvent,
)
from gpu_memory_service.server.allocations import GMSAllocationManager from gpu_memory_service.server.allocations import GMSAllocationManager
from gpu_memory_service.server.fsm import ServerState, StateEvent
from gpu_memory_service.server.gms import GMS from gpu_memory_service.server.gms import GMS
from gpu_memory_service.server.rpc import GMSRPCServer, _is_connection_alive from gpu_memory_service.server.rpc import GMSRPCServer, _is_connection_alive
from gpu_memory_service.server.session import ( from gpu_memory_service.server.session import (
...@@ -46,7 +42,8 @@ from cuda.bindings import driver as cuda # noqa: E402 ...@@ -46,7 +42,8 @@ from cuda.bindings import driver as cuda # noqa: E402
pytestmark = [ pytestmark = [
pytest.mark.pre_merge, pytest.mark.pre_merge,
pytest.mark.unit, pytest.mark.unit,
pytest.mark.gpu_0, pytest.mark.none,
pytest.mark.gpu_1,
] ]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment