Unverified Commit 39a6a240 authored by Schwinn Saereesitthipitak's avatar Schwinn Saereesitthipitak Committed by GitHub
Browse files

refactor: simplify GPU Memory Service integrations and module boundaries (#7875)

parent 02666f04
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Hybrid torch_memory_saver implementation for GPU Memory Service.
"""torch_memory_saver implementation for GPU Memory Service.
This module uses:
1. GPU Memory Service for "weights" (shared RO/RW publish flow)
2. GPU Memory Service for "kv_cache" (RW-only failover flow)
3. torch_memory_saver for any remaining tags
SGLang with GMS owns exactly two memory classes:
1. "weights" via the shared RO/RW publish flow
2. "kv_cache" via the RW failover flow
Unsupported release/resume tags stay no-ops with a warning so the generic
SGLang memory-control API can still pass broader tag sets without reintroducing
the old torch-memory-saver fallback. `cuda_graph` is a hard error because the
pauseable CUDA-graph path depends on the LD_PRELOAD torch allocator hooks that
GMS intentionally does not use.
"""
from __future__ import annotations
import logging
from contextlib import contextmanager
from typing import TYPE_CHECKING, Optional
from typing import Optional
import torch
from gpu_memory_service import get_or_create_gms_client_memory_manager
from gpu_memory_service.client.torch.allocator import gms_use_mem_pool
from gpu_memory_service.common.types import GrantedLockType, RequestedLockType
from gpu_memory_service.client.torch.allocator import (
get_or_create_gms_client_memory_manager,
gms_use_mem_pool,
)
from gpu_memory_service.common.locks import GrantedLockType, RequestedLockType
from gpu_memory_service.common.utils import get_socket_path
if TYPE_CHECKING:
from gpu_memory_service.client.memory_manager import GMSClientMemoryManager
from torch_memory_saver.entrypoint import _TorchMemorySaverImpl
from gpu_memory_service.integrations.common.utils import GMS_TAGS, finalize_gms_write
logger = logging.getLogger(__name__)
# Published weights must come back RO, while KV cache always resumes in a fresh
# RW epoch so the restored engine can rebuild mutable cache state.
_TAG_LOCK_TYPES = {"weights": RequestedLockType.RO, "kv_cache": RequestedLockType.RW}
def _pause_resume_tags(tag: Optional[str]) -> tuple[str, ...]:
if tag is None:
return GMS_TAGS
if tag in _TAG_LOCK_TYPES:
return (tag,)
logger.warning(
"[GMS] Ignoring unsupported torch_memory_saver tag %r; supported tags are %s",
tag,
list(GMS_TAGS),
)
return ()
def get_gms_memory_saver_impl() -> Optional["GMSMemorySaverImpl"]:
"""Get the GMS memory saver impl from the torch_memory_saver singleton."""
......@@ -39,170 +60,126 @@ def get_gms_memory_saver_impl() -> Optional["GMSMemorySaverImpl"]:
class GMSMemorySaverImpl:
"""Hybrid implementation: GMS for weights and KV cache."""
"""SGLang memory saver implementation backed only by GMS."""
def __init__(
self,
torch_impl: "_TorchMemorySaverImpl",
device_index: int,
mode=None,
):
self._torch_impl = torch_impl
self._device_index = device_index
self._requested_mode = mode
self._disabled = False
self._imported_weights_bytes: int = 0
self._weights_allocator: Optional["GMSClientMemoryManager"]
self._kv_cache_allocator: "GMSClientMemoryManager"
self._mode: str
(
self._weights_allocator,
self._kv_cache_allocator,
self._mode,
) = self._init_allocators()
logger.info(
"[GMS] Initialized weights=%s mode, kv_cache=RW (device=%d)",
self._mode.upper(),
self._device = torch.device("cuda", device_index)
self.imported_weights_bytes = 0
requested_mode = mode or RequestedLockType.RW_OR_RO
self.allocators = {
tag: get_or_create_gms_client_memory_manager(
get_socket_path(device_index, tag),
device_index,
# weights follow the configured publish/import mode; kv_cache is
# always mutable and therefore always needs an RW session.
mode=requested_mode if tag == "weights" else RequestedLockType.RW,
tag=tag,
)
for tag in GMS_TAGS
}
def _init_allocators(
self,
) -> tuple[Optional["GMSClientMemoryManager"], "GMSClientMemoryManager", str,]:
"""Create allocator with mode from config (default: RW_OR_RO)."""
mode = self._requested_mode or RequestedLockType.RW_OR_RO
weights_allocator = get_or_create_gms_client_memory_manager(
get_socket_path(self._device_index, "weights"),
self._device_index,
mode=mode,
tag="weights",
)
kv_cache_allocator = get_or_create_gms_client_memory_manager(
get_socket_path(self._device_index, "kv_cache"),
self._device_index,
mode=RequestedLockType.RW,
tag="kv_cache",
)
granted_mode = weights_allocator.granted_lock_type
if granted_mode == GrantedLockType.RW:
actual_mode = "write"
else:
actual_mode = "read"
logger.info(
"[GMS] Initialized in AUTO mode, granted=%s (device=%d)",
actual_mode.upper(),
self._device_index,
"[GMS] Initialized weights: requested=%s granted=%s (device=%d)",
requested_mode.name,
self.allocators["weights"].granted_lock_type.name,
device_index,
)
return weights_allocator, kv_cache_allocator, actual_mode
def _is_weights_tag(self, tag: Optional[str]) -> bool:
return tag in ("weights", "model_weights")
def get_mode(self) -> str:
return self._mode
def get_allocator(self) -> Optional["GMSClientMemoryManager"]:
return self._weights_allocator
@contextmanager
def region(self, tag: str, enable_cpu_backup: bool):
"""Mark allocation region with tag."""
if self._is_weights_tag(tag):
if self._mode == "read":
yield
return
if enable_cpu_backup:
raise ValueError(
"SGLang with GMS does not support CPU backup for allocations."
)
target_device = torch.device("cuda", self._device_index)
with gms_use_mem_pool("weights", target_device):
if tag not in _TAG_LOCK_TYPES:
logger.warning(
"[GMS] Ignoring unsupported torch_memory_saver region tag %r; "
"supported tags are %s",
tag,
list(GMS_TAGS),
)
yield
return
if tag == "kv_cache":
target_device = torch.device("cuda", self._device_index)
with gms_use_mem_pool("kv_cache", target_device):
if (
tag == "weights"
and self.allocators["weights"].granted_lock_type == GrantedLockType.RO
):
# Imported weights are already mapped and immutable in RO mode, so
# there is no allocator swap to install for this region.
yield
return
with self._torch_impl.region(tag=tag, enable_cpu_backup=enable_cpu_backup):
allocator = self.allocators[tag]
if allocator.granted_lock_type != GrantedLockType.RW:
mode = (
allocator.granted_lock_type.name
if allocator.granted_lock_type is not None
else "DISCONNECTED"
)
# The server would reject writes on a non-RW session too, but we
# fail before entering the allocation path so SGLang never starts a
# partial region with the wrong lock state.
raise RuntimeError(
f"SGLang with GMS requires {tag!r} to be RW for allocations; got {mode}"
)
with gms_use_mem_pool(tag, self._device):
yield
@contextmanager
def cuda_graph(
self,
cuda_graph,
pool,
stream,
capture_error_mode,
tag: str,
enable_cpu_backup: bool,
):
# The old hybrid path could delegate this to torch_memory_saver, but
# strict GMS mode has no compatible pauseable CUDA-graph allocator hook.
raise RuntimeError(
"SGLang with GMS does not support pauseable CUDA graphs. "
"torch_memory_saver only supports cuda_graph in hook_mode=preload, "
"and GMS does not use the LD_PRELOAD path."
)
def pause(self, tag: Optional[str] = None) -> None:
if self._disabled:
return
if tag is None or self._is_weights_tag(tag):
self._pause_weights()
if tag is None or tag == "kv_cache":
self._pause_kv_cache()
if tag is None or (not self._is_weights_tag(tag) and tag != "kv_cache"):
self._torch_impl.pause(tag=tag)
for target_tag in _pause_resume_tags(tag):
if self.allocators[target_tag].is_unmapped:
continue
logger.info("[GMS] Unmapping %s", target_tag)
self.allocators[target_tag].unmap_all_vas()
# abort() drops the current session after unmapping while keeping
# the VA reservation alive for the next resume().
self.allocators[target_tag].abort()
def resume(self, tag: Optional[str] = None) -> None:
if self._disabled:
return
if tag is None or self._is_weights_tag(tag):
self._resume_weights()
if tag is None or tag == "kv_cache":
self._resume_kv_cache()
if tag is None or (not self._is_weights_tag(tag) and tag != "kv_cache"):
self._torch_impl.resume(tag=tag)
def _pause_weights(self) -> None:
if self._weights_allocator is None:
return
if self._weights_allocator.is_unmapped:
return
logger.info("[GMS] Unmapping weights (VA-stable)")
self._weights_allocator.unmap_all_vas()
self._weights_allocator.abort()
def _resume_weights(self) -> None:
if self._weights_allocator is None:
return
if not self._weights_allocator.is_unmapped:
return
logger.info("[GMS] Remapping weights (VA-stable)")
self._weights_allocator.connect(RequestedLockType.RO)
self._weights_allocator.remap_all_vas()
def _pause_kv_cache(self) -> None:
if self._kv_cache_allocator.is_unmapped:
return
logger.info("[GMS] Unmapping KV cache")
self._kv_cache_allocator.unmap_all_vas()
self._kv_cache_allocator.abort()
def _resume_kv_cache(self) -> None:
if not self._kv_cache_allocator.is_unmapped:
return
logger.info("[GMS] Remapping KV cache")
self._kv_cache_allocator.connect(RequestedLockType.RW)
self._kv_cache_allocator.reallocate_all_handles(tag="kv_cache")
self._kv_cache_allocator.remap_all_vas()
for target_tag in _pause_resume_tags(tag):
if not self.allocators[target_tag].is_unmapped:
continue
logger.info("[GMS] Remapping %s", target_tag)
self.allocators[target_tag].connect(_TAG_LOCK_TYPES[target_tag])
if target_tag == "kv_cache":
# KV cache resumes into a new RW layout epoch, so the handles
# must be re-created before the VA range is mapped again.
self.allocators[target_tag].reallocate_all_handles(tag=target_tag)
self.allocators[target_tag].remap_all_vas()
def finalize_write_mode(self, model: torch.nn.Module) -> None:
"""Finalize write mode: register tensors, commit, and switch to read."""
if self._mode != "write":
if self.allocators["weights"].granted_lock_type != GrantedLockType.RW:
# Read-only import mode never republishes weights.
return
if self._weights_allocator is None:
raise RuntimeError("Allocator is None in WRITE mode")
from gpu_memory_service.integrations.common.utils import finalize_gms_write
self._imported_weights_bytes = finalize_gms_write(
self._weights_allocator, model
self.imported_weights_bytes = finalize_gms_write(
self.allocators["weights"], model
)
self._mode = "read"
def set_imported_weights_bytes(self, bytes_count: int) -> None:
self._imported_weights_bytes = bytes_count
def get_imported_weights_bytes(self) -> int:
return self._imported_weights_bytes
def disable(self) -> None:
self._disabled = True
def enable(self) -> None:
self._disabled = False
......@@ -16,11 +16,16 @@ from __future__ import annotations
import logging
import torch
from gpu_memory_service.client.torch.module import materialize_module_from_gms
from gpu_memory_service.common.locks import GrantedLockType
from gpu_memory_service.integrations.common import patch_empty_cache
from gpu_memory_service.integrations.common.utils import (
setup_meta_tensor_workaround,
strip_gms_model_loader_config,
)
from gpu_memory_service.integrations.sglang.memory_saver import (
get_gms_memory_saver_impl,
)
from gpu_memory_service.integrations.sglang.patches import (
patch_model_runner,
patch_static_state_for_gms,
......@@ -66,10 +71,6 @@ class GMSModelLoader:
device_config,
) -> torch.nn.Module:
"""Load or import model weights."""
from gpu_memory_service.integrations.sglang.memory_saver import (
get_gms_memory_saver_impl,
)
impl = get_gms_memory_saver_impl()
if impl is None:
raise RuntimeError(
......@@ -77,12 +78,11 @@ class GMSModelLoader:
"Ensure torch_memory_saver patch was applied before model loading."
)
mode = impl.get_mode()
logger.info("[GMS] Loading model in %s mode", mode.upper())
mode = impl.allocators["weights"].granted_lock_type
logger.info("[GMS] Loading model in %s mode", mode.name)
if mode == "read":
if mode == GrantedLockType.RO:
return self._load_import_only(model_config, device_config, impl)
else:
return self._load_write_mode(model_config, device_config, impl)
def _load_write_mode(self, model_config, device_config, impl) -> torch.nn.Module:
......@@ -99,17 +99,13 @@ class GMSModelLoader:
def _load_import_only(self, model_config, device_config, impl) -> torch.nn.Module:
"""Import model weights from GMS metadata (READ mode)."""
from gpu_memory_service.client.torch.module import materialize_module_from_gms
allocator = impl.get_allocator()
if allocator is None:
raise RuntimeError("GMS allocator is None in READ mode")
allocator = impl.allocators["weights"]
device_index = torch.cuda.current_device()
model = self._create_meta_model(model_config, device_config)
materialize_module_from_gms(allocator, model, device_index=device_index)
impl.set_imported_weights_bytes(allocator.total_bytes)
impl.imported_weights_bytes = allocator.total_bytes
logger.info(
"[GMS] READ mode: imported %.2f GiB from metadata",
......
......@@ -3,7 +3,7 @@
"""SGLang-specific patches for GPU Memory Service integration.
- patch_torch_memory_saver: Routes to GMS hybrid implementation
- patch_torch_memory_saver: Routes weights and kv_cache to GMS
- patch_model_runner: Fixes memory accounting with pre-loaded weights
- patch_static_state_for_gms: No-ops named-buffer export/import (GMS preserves them)
"""
......@@ -15,7 +15,12 @@ import logging
from contextlib import contextmanager
from typing import Optional
import gpu_memory_service.integrations.sglang as gms_sglang
import torch
from gpu_memory_service.integrations.sglang.memory_saver import (
GMSMemorySaverImpl,
get_gms_memory_saver_impl,
)
logger = logging.getLogger(__name__)
......@@ -57,25 +62,16 @@ def patch_torch_memory_saver() -> None:
logger.info(f"[GMS] TorchMemorySaver initializing with hook_mode={hook_mode}")
if hook_mode is None or hook_mode == "gms":
# Use our GPU Memory Service implementation
from gpu_memory_service.integrations.sglang.memory_saver import (
GMSMemorySaverImpl,
)
from torch_memory_saver.entrypoint import _TorchMemorySaverImpl
# In GMS mode we install only the strict GMS implementation:
# weights + kv_cache go through GMS, generic unsupported tags stay
# no-ops/warnings, and cuda_graph remains unsupported.
# Get device from torch.cuda.current_device() (already set by SGLang)
device_index = torch.cuda.current_device()
# Create underlying torch impl for non-GMS tags.
torch_impl = _TorchMemorySaverImpl(hook_mode="torch")
# Read lock mode set by setup_gms() (defaults to RW_OR_RO)
from gpu_memory_service.integrations.sglang import _gms_lock_mode
gms_impl = GMSMemorySaverImpl(
torch_impl=torch_impl,
device_index=device_index,
mode=_gms_lock_mode,
mode=gms_sglang._gms_lock_mode,
)
# Set _impl directly (accessible via gms_impl property)
......@@ -83,7 +79,7 @@ def patch_torch_memory_saver() -> None:
logger.info(
"[GMS] Using GMS mode (device=%d, mode=%s)",
device_index,
gms_impl.get_mode(),
gms_impl.allocators["weights"].granted_lock_type.name,
)
del self._impl_ctor_kwargs
else:
......@@ -111,8 +107,6 @@ def patch_torch_memory_saver() -> None:
torch_memory_saver.configure_subprocess = patched_configure_subprocess
# Add property to access GMS impl directly from the singleton
from gpu_memory_service.integrations.sglang.memory_saver import GMSMemorySaverImpl
@property
def gms_impl(self) -> Optional[GMSMemorySaverImpl]:
"""Get the GMS impl if installed, None otherwise."""
......@@ -185,12 +179,8 @@ def patch_model_runner() -> None:
weights are already resident. Newer SGLang versions changed this API, so
only rewrite the old total_gpu_memory parameter shape.
"""
from gpu_memory_service.integrations.sglang.memory_saver import (
get_gms_memory_saver_impl,
)
impl = get_gms_memory_saver_impl()
if impl is not None and impl.get_imported_weights_bytes() > 0:
if impl is not None and impl.imported_weights_bytes > 0:
total_memory_gib = torch.cuda.get_device_properties(
torch.cuda.current_device()
).total_memory / (1 << 30)
......
......@@ -14,10 +14,12 @@ import logging
from typing import TYPE_CHECKING
import torch
from gpu_memory_service import get_or_create_gms_client_memory_manager
from gpu_memory_service.client.torch.allocator import gms_use_mem_pool
from gpu_memory_service.client.torch.allocator import (
get_or_create_gms_client_memory_manager,
gms_use_mem_pool,
)
from gpu_memory_service.client.torch.module import materialize_module_from_gms
from gpu_memory_service.common.types import GrantedLockType
from gpu_memory_service.common.locks import GrantedLockType
from gpu_memory_service.common.utils import get_socket_path
from gpu_memory_service.integrations.common.utils import (
finalize_gms_write,
......
......@@ -14,8 +14,8 @@ from __future__ import annotations
import logging
from gpu_memory_service import get_gms_client_memory_manager
from gpu_memory_service.common.types import GrantedLockType
from gpu_memory_service.client.torch.allocator import get_gms_client_memory_manager
from gpu_memory_service.common.locks import GrantedLockType
from gpu_memory_service.integrations.vllm.utils import is_shadow_mode
logger = logging.getLogger(__name__)
......
......@@ -18,16 +18,16 @@ from contextlib import nullcontext
from typing import List, Optional
import torch
from gpu_memory_service import (
from gpu_memory_service.client.memory_manager import StaleMemoryLayoutError
from gpu_memory_service.client.torch.allocator import (
get_gms_client_memory_manager,
get_or_create_gms_client_memory_manager,
gms_use_mem_pool,
)
from gpu_memory_service.client.memory_manager import StaleMemoryLayoutError
from gpu_memory_service.client.torch.allocator import gms_use_mem_pool
from gpu_memory_service.common.types import RequestedLockType
from gpu_memory_service.common.locks import RequestedLockType
from gpu_memory_service.common.utils import get_socket_path
from gpu_memory_service.integrations.common import patch_empty_cache
from gpu_memory_service.integrations.common.utils import get_gms_lock_mode
from gpu_memory_service.integrations.common.utils import GMS_TAGS, get_gms_lock_mode
from gpu_memory_service.integrations.vllm.model_loader import register_gms_loader
from gpu_memory_service.integrations.vllm.patches import (
apply_shadow_mode_patches,
......@@ -264,7 +264,7 @@ class GMSWorker(Worker):
self.model_runner.exit_shadow_init()
if tags is None:
tags = ["weights", "kv_cache"]
tags = list(GMS_TAGS)
if "weights" in tags:
weights_manager = get_gms_client_memory_manager("weights")
......
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""GPU Memory Service server components."""
from gpu_memory_service.common.types import (
GrantedLockType,
RequestedLockType,
ServerState,
StateSnapshot,
)
from gpu_memory_service.server.allocations import (
AllocationInfo,
AllocationNotFoundError,
GMSAllocationManager,
)
from gpu_memory_service.server.gms import GMS, MetadataEntry
from gpu_memory_service.server.rpc import GMSRPCServer
from gpu_memory_service.server.session import (
Connection,
GMSSessionManager,
InvalidTransition,
OperationNotAllowed,
)
__all__ = [
"GMSRPCServer",
"GMS",
"GMSSessionManager",
"GMSAllocationManager",
"AllocationInfo",
"AllocationNotFoundError",
"MetadataEntry",
"Connection",
"GrantedLockType",
"RequestedLockType",
"ServerState",
"StateSnapshot",
"InvalidTransition",
"OperationNotAllowed",
]
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import asyncio
from dataclasses import dataclass, field
from enum import Enum, auto
from typing import Optional, Set
from gpu_memory_service.common.locks import GrantedLockType
class ServerState(str, Enum):
EMPTY = "EMPTY"
RW = "RW"
COMMITTED = "COMMITTED"
RO = "RO"
class StateEvent(Enum):
RW_CONNECT = auto()
RW_COMMIT = auto()
RW_ABORT = auto()
RO_CONNECT = auto()
RO_DISCONNECT = auto()
@dataclass(eq=False)
class Connection:
reader: asyncio.StreamReader
writer: asyncio.StreamWriter
mode: GrantedLockType
session_id: str
recv_buffer: bytearray = field(default_factory=bytearray)
def __hash__(self) -> int:
return hash(self.session_id)
async def close(self) -> None:
self.writer.close()
try:
await self.writer.wait_closed()
except Exception:
pass
class InvalidTransition(Exception):
pass
@dataclass(frozen=True)
class Transition:
from_states: frozenset[ServerState]
event: StateEvent
to_state: Optional[ServerState]
condition: Optional[str] = None
TRANSITIONS: list[Transition] = [
Transition(
from_states=frozenset({ServerState.EMPTY, ServerState.COMMITTED}),
event=StateEvent.RW_CONNECT,
to_state=ServerState.RW,
),
Transition(
from_states=frozenset({ServerState.RW}),
event=StateEvent.RW_COMMIT,
to_state=ServerState.COMMITTED,
),
Transition(
from_states=frozenset({ServerState.RW}),
event=StateEvent.RW_ABORT,
to_state=ServerState.EMPTY,
),
Transition(
from_states=frozenset({ServerState.COMMITTED, ServerState.RO}),
event=StateEvent.RO_CONNECT,
to_state=ServerState.RO,
),
Transition(
from_states=frozenset({ServerState.RO}),
event=StateEvent.RO_DISCONNECT,
to_state=ServerState.RO,
condition="has_remaining_readers",
),
Transition(
from_states=frozenset({ServerState.RO}),
event=StateEvent.RO_DISCONNECT,
to_state=ServerState.COMMITTED,
condition="is_last_reader",
),
]
class GMSFSM:
def __init__(self):
self._rw_conn: Optional[Connection] = None
self._ro_conns: Set[Connection] = set()
self._committed = False
@property
def state(self) -> ServerState:
if self._rw_conn is not None:
return ServerState.RW
if self._ro_conns:
return ServerState.RO
if self._committed:
return ServerState.COMMITTED
return ServerState.EMPTY
@property
def rw_conn(self) -> Optional[Connection]:
return self._rw_conn
@property
def ro_conns(self) -> Set[Connection]:
return self._ro_conns
@property
def ro_count(self) -> int:
return len(self._ro_conns)
@property
def committed(self) -> bool:
return self._committed
def _check_condition(self, condition: Optional[str], conn: Connection) -> bool:
if condition is None:
return True
if condition == "has_remaining_readers":
return len(self._ro_conns) > 1 or conn not in self._ro_conns
if condition == "is_last_reader":
return len(self._ro_conns) == 1 and conn in self._ro_conns
raise ValueError(f"Unknown condition: {condition}")
def transition(self, event: StateEvent, conn: Connection) -> ServerState:
from_state = self.state
for transition in TRANSITIONS:
if from_state not in transition.from_states:
continue
if transition.event != event:
continue
if not self._check_condition(transition.condition, conn):
continue
break
else:
raise InvalidTransition(
f"No transition for {event.name} from state {from_state.name} "
f"(session={conn.session_id})"
)
if event == StateEvent.RW_CONNECT:
self._rw_conn = conn
self._committed = False
elif event == StateEvent.RW_COMMIT:
self._committed = True
self._rw_conn = None
elif event == StateEvent.RW_ABORT:
self._rw_conn = None
elif event == StateEvent.RO_CONNECT:
self._ro_conns.add(conn)
elif event == StateEvent.RO_DISCONNECT:
self._ro_conns.discard(conn)
return self.state
def can_acquire_rw(self) -> bool:
return self._rw_conn is None and not self._ro_conns
def can_acquire_ro(self, waiting_writers: int) -> bool:
return self._committed and self._rw_conn is None and waiting_writers == 0
......@@ -11,6 +11,7 @@ from collections import deque
from dataclasses import dataclass
from typing import Callable, Optional
from gpu_memory_service.common.locks import GrantedLockType, RequestedLockType
from gpu_memory_service.common.protocol.messages import (
AllocateRequest,
AllocateResponse,
......@@ -42,15 +43,10 @@ from gpu_memory_service.common.protocol.messages import (
MetadataPutRequest,
MetadataPutResponse,
)
from gpu_memory_service.common.types import (
GrantedLockType,
RequestedLockType,
ServerState,
StateEvent,
)
from .allocations import AllocationInfo, GMSAllocationManager
from .session import Connection, GMSSessionManager
from .fsm import Connection, ServerState, StateEvent
from .session import GMSSessionManager
logger = logging.getLogger(__name__)
......
......@@ -23,8 +23,9 @@ from gpu_memory_service.common.protocol.wire import recv_message, send_message
from gpu_memory_service.common.utils import fail
from .allocations import AllocationNotFoundError
from .fsm import Connection, InvalidTransition
from .gms import GMS
from .session import Connection, InvalidTransition, OperationNotAllowed
from .session import OperationNotAllowed
logger = logging.getLogger(__name__)
......
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Server-side connection, FSM, and waiter state."""
"""Server-side lock acquisition and cleanup."""
from __future__ import annotations
import asyncio
from dataclasses import dataclass, field
from typing import Optional, Set
from gpu_memory_service.common.types import (
RO_ALLOWED,
RW_ALLOWED,
RW_REQUIRED,
GrantedLockType,
RequestedLockType,
ServerState,
StateEvent,
from dataclasses import dataclass
from typing import Optional
from gpu_memory_service.common.locks import GrantedLockType, RequestedLockType
from gpu_memory_service.common.protocol.messages import (
AllocateRequest,
CommitRequest,
ExportAllocationRequest,
FreeAllocationRequest,
GetAllocationRequest,
GetAllocationStateRequest,
GetLockStateRequest,
GetStateHashRequest,
ListAllocationsRequest,
MetadataDeleteRequest,
MetadataGetRequest,
MetadataListRequest,
MetadataPutRequest,
)
@dataclass(eq=False)
class Connection:
reader: asyncio.StreamReader
writer: asyncio.StreamWriter
mode: GrantedLockType
session_id: str
recv_buffer: bytearray = field(default_factory=bytearray)
def __hash__(self) -> int:
return hash(self.session_id)
async def close(self) -> None:
self.writer.close()
try:
await self.writer.wait_closed()
except Exception:
pass
class InvalidTransition(Exception):
"""Raised when an invalid state transition is attempted."""
from .fsm import GMSFSM, Connection, ServerState, StateEvent
class OperationNotAllowed(Exception):
"""Raised when an operation is not allowed in the current state/mode."""
@dataclass(frozen=True)
class Transition:
from_states: frozenset[ServerState]
event: StateEvent
to_state: Optional[ServerState]
condition: Optional[str] = None
TRANSITIONS: list[Transition] = [
Transition(
from_states=frozenset({ServerState.EMPTY, ServerState.COMMITTED}),
event=StateEvent.RW_CONNECT,
to_state=ServerState.RW,
),
Transition(
from_states=frozenset({ServerState.RW}),
event=StateEvent.RW_COMMIT,
to_state=ServerState.COMMITTED,
),
Transition(
from_states=frozenset({ServerState.RW}),
event=StateEvent.RW_ABORT,
to_state=ServerState.EMPTY,
),
Transition(
from_states=frozenset({ServerState.COMMITTED, ServerState.RO}),
event=StateEvent.RO_CONNECT,
to_state=ServerState.RO,
),
Transition(
from_states=frozenset({ServerState.RO}),
event=StateEvent.RO_DISCONNECT,
to_state=ServerState.RO,
condition="has_remaining_readers",
),
Transition(
from_states=frozenset({ServerState.RO}),
event=StateEvent.RO_DISCONNECT,
to_state=ServerState.COMMITTED,
condition="is_last_reader",
),
]
class GMSLocalFSM:
"""Explicit connection/lock state machine."""
def __init__(self):
self._rw_conn: Optional[Connection] = None
self._ro_conns: Set[Connection] = set()
self._committed = False
@property
def state(self) -> ServerState:
if self._rw_conn is not None:
return ServerState.RW
if self._ro_conns:
return ServerState.RO
if self._committed:
return ServerState.COMMITTED
return ServerState.EMPTY
@property
def rw_conn(self) -> Optional[Connection]:
return self._rw_conn
@property
def ro_conns(self) -> Set[Connection]:
return self._ro_conns
@property
def ro_count(self) -> int:
return len(self._ro_conns)
@property
def committed(self) -> bool:
return self._committed
def _has_remaining_readers(self, conn: Connection) -> bool:
return len(self._ro_conns) > 1 or conn not in self._ro_conns
def _is_last_reader(self, conn: Connection) -> bool:
return len(self._ro_conns) == 1 and conn in self._ro_conns
def _check_condition(self, condition: Optional[str], conn: Connection) -> bool:
if condition is None:
return True
if condition == "has_remaining_readers":
return self._has_remaining_readers(conn)
if condition == "is_last_reader":
return self._is_last_reader(conn)
raise ValueError(f"Unknown condition: {condition}")
def _find_transition(
self,
from_state: ServerState,
event: StateEvent,
conn: Connection,
) -> Optional[Transition]:
for transition in TRANSITIONS:
if from_state not in transition.from_states:
continue
if transition.event != event:
continue
if not self._check_condition(transition.condition, conn):
continue
return transition
return None
pass
def _apply_event(self, event: StateEvent, conn: Connection) -> None:
if event == StateEvent.RW_CONNECT:
self._rw_conn = conn
self._committed = False
elif event == StateEvent.RW_COMMIT:
self._committed = True
self._rw_conn = None
elif event == StateEvent.RW_ABORT:
self._rw_conn = None
elif event == StateEvent.RO_CONNECT:
self._ro_conns.add(conn)
elif event == StateEvent.RO_DISCONNECT:
self._ro_conns.discard(conn)
def transition(self, event: StateEvent, conn: Connection) -> ServerState:
transition = self._find_transition(self.state, event, conn)
if transition is None:
raise InvalidTransition(
f"No transition for {event.name} from state {self.state.name} "
f"(session={conn.session_id})"
)
self._apply_event(event, conn)
return self.state
def check_operation(self, msg_type: type, conn: Connection) -> None:
if conn.mode == GrantedLockType.RW and msg_type not in RW_ALLOWED:
raise OperationNotAllowed(
f"{msg_type.__name__} not allowed for RW session in state {self.state.name}"
)
if conn.mode == GrantedLockType.RO and msg_type not in RO_ALLOWED:
raise OperationNotAllowed(
f"{msg_type.__name__} not allowed for RO session in state {self.state.name}"
)
if msg_type in RW_REQUIRED and conn.mode != GrantedLockType.RW:
raise OperationNotAllowed(
f"{msg_type.__name__} requires RW session, got {conn.mode.value}"
)
RW_REQUIRED: frozenset[type] = frozenset(
{
AllocateRequest,
FreeAllocationRequest,
MetadataPutRequest,
MetadataDeleteRequest,
CommitRequest,
}
)
def can_acquire_rw(self) -> bool:
return self._rw_conn is None and not self._ro_conns
RO_ALLOWED: frozenset[type] = frozenset(
{
ExportAllocationRequest,
GetAllocationRequest,
ListAllocationsRequest,
MetadataGetRequest,
MetadataListRequest,
GetLockStateRequest,
GetAllocationStateRequest,
GetStateHashRequest,
}
)
def can_acquire_ro(self, waiting_writers: int) -> bool:
return self._committed and self._rw_conn is None and waiting_writers == 0
RW_ALLOWED: frozenset[type] = RW_REQUIRED | RO_ALLOWED
@dataclass(frozen=True)
......@@ -215,7 +73,7 @@ class GMSSessionManager:
"""Owns lock transitions, waiter coordination, and cleanup."""
def __init__(self):
self._locking = GMSLocalFSM()
self._locking = GMSFSM()
self._waiting_writers = 0
self._reserved_rw_session_id: Optional[str] = None
self._condition = asyncio.Condition()
......@@ -336,7 +194,18 @@ class GMSSessionManager:
self._locking.transition(StateEvent.RW_COMMIT, conn)
def check_operation(self, msg_type: type, conn: Connection) -> None:
self._locking.check_operation(msg_type, conn)
if conn.mode == GrantedLockType.RW and msg_type not in RW_ALLOWED:
raise OperationNotAllowed(
f"{msg_type.__name__} not allowed for RW session in state {self.state.name}"
)
if conn.mode == GrantedLockType.RO and msg_type not in RO_ALLOWED:
raise OperationNotAllowed(
f"{msg_type.__name__} not allowed for RO session in state {self.state.name}"
)
if msg_type in RW_REQUIRED and conn.mode != GrantedLockType.RW:
raise OperationNotAllowed(
f"{msg_type.__name__} requires RW session, got {conn.mode.value}"
)
def begin_cleanup(self, conn: Optional[Connection]) -> StateEvent | None:
if conn is None:
......
......@@ -246,6 +246,7 @@ markers = [
"stress: marks tests as stress tests",
"performance: marks tests as performance tests",
"benchmark: marks tests as benchmark tests",
"none: marks tests that do not require a framework-specific runtime",
"vllm: marks tests as requiring vllm",
"trtllm: marks tests as requiring trtllm",
"sglang: marks tests as requiring sglang",
......
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import pytest
pytest.importorskip("gpu_memory_service", reason="gpu_memory_service is required")
......@@ -3,8 +3,9 @@
"""Tests for the flock-based failover lock.
No GPU required — these are pure Python/OS tests exercising flock
semantics across asyncio tasks and child processes.
These are pure Python/OS tests exercising flock semantics across asyncio
tasks and child processes, so they stay on the generic cpu-style pre-merge
lane instead of the dedicated GPU job.
"""
import asyncio
......@@ -19,6 +20,7 @@ from gpu_memory_service.failover_lock.flock import FlockFailoverLock
pytestmark = [
pytest.mark.pre_merge,
pytest.mark.unit,
pytest.mark.none,
pytest.mark.gpu_0,
]
......
......@@ -9,12 +9,13 @@ from gpu_memory_service.client.memory_manager import (
GMSClientMemoryManager,
LocalMapping,
)
from gpu_memory_service.common.types import GrantedLockType, RequestedLockType
from gpu_memory_service.common.locks import GrantedLockType, RequestedLockType
pytestmark = [
pytest.mark.pre_merge,
pytest.mark.unit,
pytest.mark.gpu_0,
pytest.mark.none,
pytest.mark.gpu_1,
]
......
......@@ -6,15 +6,16 @@ from __future__ import annotations
import pytest
from gpu_memory_service.client.rpc import _GMSRPCTransport
from gpu_memory_service.client.session import _GMSClientSession
from gpu_memory_service.common.locks import GrantedLockType, RequestedLockType
from gpu_memory_service.common.protocol.messages import (
CommitResponse,
HandshakeResponse,
)
from gpu_memory_service.common.types import GrantedLockType, RequestedLockType
pytestmark = [
pytest.mark.pre_merge,
pytest.mark.unit,
pytest.mark.none,
pytest.mark.gpu_0,
]
......
......@@ -15,6 +15,7 @@ from gpu_memory_service.common.protocol.messages import (
pytestmark = [
pytest.mark.pre_merge,
pytest.mark.unit,
pytest.mark.none,
pytest.mark.gpu_0,
]
......
......@@ -10,7 +10,12 @@ import pytest
from tests.gms.harness.gms import GMSServerProcess
from tests.utils.managed_process import ManagedProcess
pytestmark = [pytest.mark.pre_merge, pytest.mark.unit, pytest.mark.gpu_0]
pytestmark = [
pytest.mark.pre_merge,
pytest.mark.unit,
pytest.mark.none,
pytest.mark.gpu_1,
]
@pytest.fixture
......
......@@ -24,19 +24,16 @@ from gpu_memory_service.client.memory_manager import (
from gpu_memory_service.client.rpc import _GMSRPCTransport
from gpu_memory_service.client.session import _GMSClientSession
from gpu_memory_service.common import cuda_utils
from gpu_memory_service.common.locks import GrantedLockType, RequestedLockType
from gpu_memory_service.common.protocol.messages import (
GetEventHistoryRequest,
GetEventHistoryResponse,
GetRuntimeStateRequest,
GetRuntimeStateResponse,
)
from gpu_memory_service.common.types import (
GrantedLockType,
RequestedLockType,
ServerState,
)
from gpu_memory_service.server import allocations as server_allocations
from gpu_memory_service.server.allocations import GMSAllocationManager
from gpu_memory_service.server.fsm import ServerState
from gpu_memory_service.server.rpc import GMSRPCServer
from tests.gms.harness.gms import ServerThread
......@@ -44,7 +41,8 @@ from tests.gms.harness.gms import ServerThread
pytestmark = [
pytest.mark.pre_merge,
pytest.mark.unit,
pytest.mark.gpu_0,
pytest.mark.none,
pytest.mark.gpu_1,
]
......
......@@ -15,6 +15,7 @@ from dataclasses import dataclass
import pytest
from gpu_memory_service.common import cuda_utils
from gpu_memory_service.common.locks import GrantedLockType, RequestedLockType
from gpu_memory_service.common.protocol.messages import (
CommitRequest,
CommitResponse,
......@@ -24,13 +25,8 @@ from gpu_memory_service.common.protocol.messages import (
GetRuntimeStateRequest,
HandshakeRequest,
)
from gpu_memory_service.common.types import (
GrantedLockType,
RequestedLockType,
ServerState,
StateEvent,
)
from gpu_memory_service.server.allocations import GMSAllocationManager
from gpu_memory_service.server.fsm import ServerState, StateEvent
from gpu_memory_service.server.gms import GMS
from gpu_memory_service.server.rpc import GMSRPCServer, _is_connection_alive
from gpu_memory_service.server.session import (
......@@ -46,7 +42,8 @@ from cuda.bindings import driver as cuda # noqa: E402
pytestmark = [
pytest.mark.pre_merge,
pytest.mark.unit,
pytest.mark.gpu_0,
pytest.mark.none,
pytest.mark.gpu_1,
]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment