Unverified Commit a6d970e9 authored by Schwinn Saereesitthipitak's avatar Schwinn Saereesitthipitak Committed by GitHub
Browse files

feat: refactor GMS client memory manager with tiered API (#6549)

parent 651ef5b5
......@@ -34,7 +34,7 @@ This leads to:
│ │ │ Memory Manager │ │ ◄── Unix ───────►│ │ GMSRPCClient │ │ │
│ │ └────────────────┘ │ Socket │ └─────────────────────────────────┘ │ │
│ │ │ + │ │ │
│ │ ┌────────────────┐ │ FD │ Writer-only: allocate_and_map, commit │ │
│ │ ┌────────────────┐ │ FD │ Writer-only: create_mapping, commit │ │
│ │ │ State Machine │ │ (SCM_RIGHTS) └─────────────────────────────────────────┘ │
│ │ └────────────────┘ │ │
│ │ │ ┌─────────────────────────────────────────┐ │
......@@ -44,8 +44,8 @@ This leads to:
│ │ │ Socket │ │ GMSRPCClient │ │ │
│ └────────────────────┘ + │ └─────────────────────────────────┘ │ │
│ FD │ │ │
│ (SCM_RIGHTS) │ Reader-only: import_allocation, │ │
│ │ unmap, remap │ │
│ (SCM_RIGHTS) │ Reader-only: create_mapping (import), │ │
│ │ unmap_all_vas, remap │ │
│ └─────────────────────────────────────────┘ │
│ │
└──────────────────────────────────────────────────────────────────────────────────────┘
......@@ -189,17 +189,18 @@ sequenceDiagram
participant C as GMSClientMemoryManager
participant S as GMS Server
W->>C: new GMSClientMemoryManager(mode=RW)
W->>C: mgr = GMSClientMemoryManager(socket_path, device=0)
W->>C: mgr.connect(RW)
C->>S: HandshakeRequest(lock_type=RW)
S-->>C: HandshakeResponse(success=true)
loop For each tensor
W->>C: allocate_and_map(size, tag)
W->>C: mgr.create_mapping(size=size, tag=tag)
Note over C,S: See Memory Allocation Flow above
W->>C: metadata_put(key, allocation_id, offset, shape)
W->>C: mgr.metadata_put(key, allocation_id, offset, shape)
end
W->>C: commit()
W->>C: mgr.commit()
C->>S: CommitRequest()
S->>S: FSM: RW → COMMITTED
S-->>C: CommitResponse(success=true)
......@@ -215,17 +216,18 @@ sequenceDiagram
participant C as GMSClientMemoryManager
participant S as GMS Server
R->>C: new GMSClientMemoryManager(mode=RO)
R->>C: mgr = GMSClientMemoryManager(socket_path, device=0)
R->>C: mgr.connect(RO)
C->>S: HandshakeRequest(lock_type=RO)
S-->>C: HandshakeResponse(success=true, committed=true)
R->>C: metadata_list()
R->>C: mgr.metadata_list()
S-->>C: keys=[...]
loop For each tensor key
R->>C: metadata_get(key)
R->>C: mgr.metadata_get(key)
S-->>C: allocation_id, offset, shape
R->>C: import_allocation(allocation_id)
R->>C: mgr.create_mapping(allocation_id=allocation_id)
Note over C,S: See Memory Import Flow above
end
......@@ -245,7 +247,7 @@ sequenceDiagram
Note over R,GPU: Need to temporarily release GPU memory
R->>C: unmap()
R->>C: mgr.unmap_all_vas()
C->>GPU: cudaDeviceSynchronize()
loop For each mapping
......@@ -254,18 +256,19 @@ sequenceDiagram
Note over C: Keep VA reservation!
end
C->>C: Save memory_layout_hash
R->>C: mgr.disconnect()
C->>S: Close socket (release RO lock)
S->>S: FSM: RO → COMMITTED (if last reader)
Note over R,GPU: GPU memory released, VA preserved
Note over R,GPU: Another writer could modify weights here
R->>C: remap()
R->>C: mgr.connect(RO)
C->>S: HandshakeRequest(lock_type=RO)
S->>S: FSM: COMMITTED → RO
S-->>C: HandshakeResponse(success=true)
R->>C: mgr.remap_all_vas()
C->>S: GetStateHashRequest()
S-->>C: GetStateHashResponse(hash)
......@@ -295,7 +298,8 @@ sequenceDiagram
Note over P,S: Auto-mode: Writer if first, Reader if weights exist
P->>C: new GMSClientMemoryManager(mode=RW_OR_RO)
P->>C: mgr = GMSClientMemoryManager(socket_path, device=0)
P->>C: mgr.connect(RW_OR_RO)
C->>S: HandshakeRequest(lock_type=RW_OR_RO)
alt No committed weights AND no RW holder
......@@ -340,11 +344,11 @@ Benefits:
### 3. VA-Stable Unmap/Remap
During `unmap()`:
During `unmap_all_vas()`:
- Physical memory is released (`cuMemUnmap` + `cuMemRelease`)
- VA reservations are **kept** (`cuMemAddressReserve` still valid)
During `remap()`:
During `remap_all_vas()`:
- Same VAs are reused for mapping
- **Tensor pointers remain valid** (no need to update PyTorch tensors)
......@@ -354,7 +358,7 @@ On commit, the server computes a hash of:
- All allocation IDs, sizes, and tags
- All metadata entries
On `remap()`, this hash is checked:
On `remap_all_vas()`, this hash is checked:
- If match: Safe to remap (layout unchanged)
- If mismatch: Raise `StaleMemoryLayoutError` (must re-import)
......@@ -393,46 +397,121 @@ fd = fds[0] if fds else -1
### GMSClientMemoryManager
The API is organized in two tiers. **Tier 2 (convenience)** is what integrations normally use. **Tier 1 (atomic)** exposes individual operations for advanced callers.
```python
class GMSClientMemoryManager:
def __init__(
socket_path: str,
mode: RequestedLockType, # RW, RO, or RW_OR_RO
device: int = 0,
timeout_ms: Optional[int] = None,
): ...
def __init__(socket_path: str, *, device: int = 0): ...
# Properties
@property mode: GrantedLockType # Actual granted mode
@property granted_lock_type: Optional[GrantedLockType]
@property is_connected: bool
@property is_unmapped: bool
@property total_bytes: int
# Allocation (RW only)
def allocate_and_map(size: int, tag: str = "default") -> int # Returns VA
def free_mapping(va: int) -> None
def clear_all() -> int # Returns count cleared
# Import (RO or RW)
def import_allocation(allocation_id: str) -> int # Returns VA
# Metadata (RW: put/delete, RO: get/list)
# --- Tier 1: Connection ---
def connect(lock_type: RequestedLockType, timeout_ms: Optional[int] = None) -> None
def disconnect() -> None
# --- Tier 1: Handle ops (server-side, RW only) ---
def allocate_handle(size: int, tag: str = "default") -> str # Returns allocation_id
def export_handle(allocation_id: str) -> int # Returns FD
def get_handle_info(allocation_id: str) -> AllocationInfo
def free_handle(allocation_id: str) -> bool
def clear_all_handles() -> int # Returns count cleared
def commit() -> bool # Transition to COMMITTED
def get_memory_layout_hash() -> str
def list_handles(tag: Optional[str] = None) -> List[Dict]
# --- Tier 1: VA ops (local) ---
def reserve_va(size: int) -> int # Returns VA
def map_va(fd, va, size, allocation_id, tag) -> int # Returns handle
def unmap_va(va: int) -> None # Keeps VA reservation
def free_va(va: int) -> None # Releases VA reservation
# --- Tier 1: Metadata ---
def metadata_put(key: str, allocation_id: str, offset: int, value: bytes) -> bool
def metadata_get(key: str) -> Optional[Tuple[str, int, bytes]]
def metadata_list(prefix: str = "") -> List[str]
def metadata_delete(key: str) -> bool
# Lifecycle
def commit() -> bool # Publish weights, release RW lock
def switch_to_read(timeout_ms: Optional[int] = None) -> None
def unmap() -> None # Release RO lock, preserve VAs
def remap(timeout_ms: Optional[int] = None) -> bool
def close() -> None
# --- Tier 2: Convenience ---
def create_mapping(allocation_id=None, size=0, tag="default") -> int # Allocate or import
def destroy_mapping(va: int) -> None
def unmap_all_vas() -> None # Sync + unmap all, preserve VA reservations
def remap_all_vas() -> None # Re-import at preserved VAs (checks layout hash)
def reallocate_all_handles(tag="default") -> None # Fresh server handles for preserved VAs
def close(free: bool = False) -> None
```
## Limitations
1. **Single-GPU per server**: Each GMS server manages one GPU device
2. **CUDA VMM required**: Requires a GPU with Virtual Memory Management support. Check at runtime via `CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED` - there is no guaranteed minimum compute capability
3. **No content validation**: Remap doesn't detect in-place weight modifications
---
## Framework Integration (vLLM / SGLang)
GMS provides pre-built integrations for vLLM and SGLang. Enable GMS by passing `--load-format gms` when launching an engine.
### How It Works
When `--load-format gms` is set:
1. **A GMS server must already be running** for the target GPU device. The engine connects to it via a Unix socket derived from the GPU UUID.
2. The engine uses `RW_OR_RO` mode by default: the **first** process gets RW (loads weights from disk, commits to GMS), and **subsequent** processes get RO (import weights from GMS metadata).
3. Weights are managed by GMS; KV cache is managed by the framework's own allocator (e.g., vLLM's `CuMemAllocator`).
#### vLLM
```bash
python -m dynamo.vllm \
--model <model> \
--load-format gms \
--enable-sleep-mode \
--gpu-memory-utilization 0.9
```
The integration uses a custom worker class (`GMSWorker`) that:
- Establishes the GMS connection early in `init_device()` so vLLM's `MemorySnapshot` can account for committed weights
- Registers a custom model loader (`GMSModelLoader`) for the `gms` load format
- Patches `torch.cuda.empty_cache` to avoid releasing GMS-managed memory
- Routes weight allocation through a `CUDAPluggableAllocator` backed by GMS
#### SGLang
```bash
python -m dynamo.sglang \
--model-path <model> \
--load-format gms \
--enable-memory-saver \
--mem-fraction-static 0.9
```
The integration patches `torch_memory_saver` to route weight operations through GMS:
- Weights (`"weights"` / `"model_weights"` tags) go through `GMSMemorySaverImpl`
- Other tags (e.g., `"kv_cache"`) are delegated to the default torch mempool implementation
- The `--enable-memory-saver` flag is required to activate the memory saver pathway
### Shadow Engine Failover (Sleep / Wake)
Both integrations support releasing and reclaiming GPU memory for shadow engine patterns. The API names differ by framework:
- **vLLM**: `sleep` / `wake_up` (via `/engine/sleep` and `/engine/wake_up` HTTP endpoints)
- **SGLang**: `release_memory_occupation` / `resume_memory_occupation` (via the corresponding HTTP endpoints)
Under the hood, sleeping calls `unmap_all_vas()` + `disconnect()` to release GPU memory while preserving VA reservations, and waking calls `connect(RO)` + `remap_all_vas()` to re-import weights at the same virtual addresses. Tensor pointers remain valid, so no model re-initialization is needed.
This enables a shadow engine to release its GPU memory, let a primary engine use the GPU, and then reclaim the memory after the primary is killed.
### Configuration via `model_loader_extra_config`
To force read-only mode (import only, never load from disk), pass `gms_read_only` via the framework's `--model-loader-extra-config` flag:
```bash
--model-loader-extra-config '{"gms_read_only": true}'
```
This forces `RO` lock mode instead of the default `RW_OR_RO` auto-detection. The engine will only import existing committed weights and fail if none are available.
......@@ -9,6 +9,8 @@ for importing, mapping, and unmapping GPU memory.
from __future__ import annotations
import os
from cuda.bindings import driver as cuda
from gpu_memory_service.common.cuda_vmm_utils import check_cuda_result
from gpu_memory_service.common.types import GrantedLockType
......@@ -17,18 +19,25 @@ from gpu_memory_service.common.types import GrantedLockType
def import_handle_from_fd(fd: int) -> int:
"""Import a CUDA memory handle from a file descriptor.
Closes the FD after import — the imported handle holds its own reference
to the physical allocation. Leaving the FD open leaks a DMA-buf ref that
prevents cuMemRelease from freeing GPU memory.
Args:
fd: POSIX file descriptor received via SCM_RIGHTS.
Returns:
CUDA memory handle.
"""
result, handle = cuda.cuMemImportFromShareableHandle(
fd,
cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR,
)
check_cuda_result(result, "cuMemImportFromShareableHandle")
return int(handle)
try:
result, handle = cuda.cuMemImportFromShareableHandle(
fd,
cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR,
)
check_cuda_result(result, "cuMemImportFromShareableHandle")
return int(handle)
finally:
os.close(fd)
def reserve_va(size: int, granularity: int) -> int:
......@@ -111,6 +120,31 @@ def release_handle(handle: int) -> None:
check_cuda_result(result, "cuMemRelease")
def validate_pointer(va: int) -> bool:
"""Validate that a mapped VA is accessible.
Returns True if the pointer is valid, False otherwise (logs a warning).
"""
result, _dev_ptr = cuda.cuPointerGetAttribute(
cuda.CUpointer_attribute.CU_POINTER_ATTRIBUTE_DEVICE_POINTER, va
)
if result != cuda.CUresult.CUDA_SUCCESS:
err_result, err_str = cuda.cuGetErrorString(result)
err_msg = ""
if err_result == cuda.CUresult.CUDA_SUCCESS and err_str:
err_msg = err_str.decode() if isinstance(err_str, bytes) else str(err_str)
import logging
logging.getLogger(__name__).warning(
"cuPointerGetAttribute failed for VA 0x%x: %s (%s)",
va,
result,
err_msg,
)
return False
return True
def synchronize() -> None:
"""Synchronize the current CUDA context.
......
......@@ -3,14 +3,23 @@
"""GPU Memory Service client-side memory manager.
This is the unified memory manager for the GPU Memory Service architecture.
Two-tier API for GPU memory lifecycle management:
Key properties:
- Uses GMSRPCClient over a Unix-domain socket.
- The socket connection itself is the RW/RO lock.
- In write mode, the manager can allocate + map RW and then publish via commit().
- In read mode, the manager can import + map RO and hold the RO lock during inference.
- unmap()/remap() releases and reacquires the RO lock (and remaps allocations).
Tier 1 (Atomic Operations):
- Connection: connect(), disconnect()
- Handle ops (server-side cuMem allocations): allocate_handle, export_handle,
get_handle_info, free_handle, clear_all_handles, commit, list_handles,
get_memory_layout_hash
- VA ops (local address space): reserve_va, map_va, unmap_va, free_va
- Metadata: metadata_put, metadata_get, metadata_list, metadata_delete
Tier 2 (Convenience — compose Tier 1 with error handling + sync):
- create_mapping, destroy_mapping
- unmap_all_vas, remap_all_vas, reallocate_all_handles
- close
Integrations (vLLM/SGLang) call Tier 2. Advanced callers (e.g., KV failover)
can compose Tier 1 atomics directly.
This module uses cuda-python bindings for CUDA driver API calls:
- import FDs (cuMemImportFromShareableHandle)
......@@ -25,17 +34,19 @@ import logging
from dataclasses import dataclass
from typing import Dict, List, Optional
from cuda.bindings import driver as cuda
from gpu_memory_service.client.cuda_vmm_utils import free_va as _cuda_free_va
from gpu_memory_service.client.cuda_vmm_utils import (
free_va,
import_handle_from_fd,
map_to_va,
release_handle,
reserve_va,
)
from gpu_memory_service.client.cuda_vmm_utils import reserve_va as _cuda_reserve_va
from gpu_memory_service.client.cuda_vmm_utils import (
set_access,
set_current_device,
synchronize,
unmap,
validate_pointer,
)
from gpu_memory_service.client.rpc import GMSRPCClient
from gpu_memory_service.common.cuda_vmm_utils import (
......@@ -56,11 +67,11 @@ class StaleMemoryLayoutError(Exception):
IMPORTANT: This is a LAYOUT check, NOT a CONTENT check.
- Detected: Allocation sizes changed, tensors added/removed, metadata structure changed
- NOT detected: Weight values modified in-place
- NOT detected: Data values modified in-place
This design is intentional: unmap/remap enables use cases like RL training
where another process can write to the same memory locations (e.g., updating
weights) while preserving the structure. As long as the layout (allocation
data) while preserving the structure. As long as the layout (allocation
and metadata table hashes) remains identical, remap() succeeds.
"""
......@@ -69,7 +80,16 @@ class StaleMemoryLayoutError(Exception):
@dataclass(frozen=True)
class LocalMapping:
"""Immutable record of a local VA mapping."""
"""Immutable record of a local VA mapping.
Fields:
- allocation_id: Server-side allocation ID
- va: Local virtual address
- size: Original requested size
- aligned_size: Size aligned to VMM granularity
- handle: CUDA memory handle (0 if unmapped but VA reserved)
- tag: Allocation tag for server tracking
"""
allocation_id: str
va: int
......@@ -77,7 +97,6 @@ class LocalMapping:
aligned_size: int
handle: int # 0 if unmapped but VA reserved
tag: str
access: GrantedLockType
def with_handle(self, handle: int) -> "LocalMapping":
return LocalMapping(
......@@ -87,94 +106,53 @@ class LocalMapping:
self.aligned_size,
handle,
self.tag,
self.access,
)
def with_access(self, access: GrantedLockType) -> "LocalMapping":
def with_allocation_id(self, allocation_id: str) -> "LocalMapping":
return LocalMapping(
self.allocation_id,
allocation_id,
self.va,
self.size,
self.aligned_size,
self.handle,
self.tag,
access,
)
class GMSClientMemoryManager:
"""Unified memory manager that can act as writer or reader.
"""Unified memory manager for GPU Memory Service.
Modes:
- mode=RequestedLockType.RW: acquire RW lock, allocate/map RW, mutate metadata, commit/publish.
- mode=RequestedLockType.RO: acquire RO lock (READY only), import/map RO, unmap/remap.
- mode=RequestedLockType.RW_OR_RO: try RW if available, else wait for RO.
Constructor does NOT connect — call connect() explicitly after construction.
"""
def __init__(
self,
socket_path: str,
*,
mode: RequestedLockType,
device: int = 0,
timeout_ms: Optional[int] = None,
) -> None:
self.socket_path = socket_path
self.device = device
self._timeout_ms = timeout_ms
self._client: Optional[GMSRPCClient] = None
self._mappings: Dict[int, LocalMapping] = {} # va -> mapping
self._allocation_id_to_va: Dict[str, int] = {}
self._inverse_mapping: Dict[str, int] = {}
self._unmapped = False
self._closed = False
self._preserved_allocation_ids: List[str] = []
self._published = False
self._mode: Optional[GrantedLockType] = None # Updated by _connect
self._granted_lock_type: Optional[GrantedLockType] = None
# VA-stable unmap/remap state
self._va_preserved = False
self._last_memory_layout_hash: str = (
"" # Hash from server, saved on connect/commit
)
self._last_memory_layout_hash: str = ""
# Set the current CUDA device for subsequent operations.
set_current_device(self.device)
# Cache granularity for VA alignment
self.granularity = get_allocation_granularity(device)
self._connect(lock_type=mode, timeout_ms=timeout_ms)
def _connect(
self,
*,
lock_type: RequestedLockType,
timeout_ms: Optional[int],
update_memory_layout_hash: bool = True,
) -> None:
self._client = GMSRPCClient(
self.socket_path, lock_type=lock_type, timeout_ms=timeout_ms
)
self._unmapped = False
# Update mode based on granted lock type (may differ from requested for rw_or_ro)
self._mode = self._client.lock_type
# Save state hash for stale detection on remap (skip during remap itself)
if update_memory_layout_hash and self._client.committed:
self._last_memory_layout_hash = self._client.get_memory_layout_hash()
@property
def mode(self) -> Optional[GrantedLockType]:
"""Current mode of the memory manager."""
return self._mode
# ==================== Properties ====================
@property
def lock_type(self) -> Optional[GrantedLockType]:
"""Get the lock type actually granted by the server."""
if self._client is None:
return None
return self._client.lock_type
def granted_lock_type(self) -> Optional[GrantedLockType]:
return self._granted_lock_type
@property
def is_connected(self) -> bool:
......@@ -186,15 +164,98 @@ class GMSClientMemoryManager:
@property
def mappings(self) -> Dict[int, LocalMapping]:
"""Read-only view of VA -> LocalMapping dictionary."""
return self._mappings
@property
def total_bytes(self) -> int:
"""Total bytes allocated across all mappings."""
return sum(m.aligned_size for m in self._mappings.values())
# ==================== Metadata convenience ====================
# ==================== Tier 1: Connection ====================
def connect(
self, lock_type: RequestedLockType, timeout_ms: Optional[int] = None
) -> None:
"""Connect to GMS server and acquire lock.
Updates self._granted_lock_type based on granted lock type. Saves memory layout hash
for stale detection if server is in committed state.
"""
self._client = GMSRPCClient(
self.socket_path,
lock_type=lock_type,
timeout_ms=timeout_ms,
)
self._granted_lock_type = self._client.lock_type
# Save layout hash for stale detection on future remap
if self._client.committed:
self._last_memory_layout_hash = self._client.get_memory_layout_hash()
def disconnect(self) -> None:
"""Close connection and release lock."""
if self._client is not None:
try:
self._client.close()
except Exception:
pass
self._client = None
# ==================== Tier 1: Handle Operations (server-side) ====================
def allocate_handle(self, size: int, tag: str = "default") -> str:
"""Allocate a cuMem handle on the server.
Returns allocation_id. Size is aligned to VMM granularity before sending.
"""
self._require_rw()
aligned_size = align_to_granularity(size, self.granularity)
allocation_id, server_aligned = self._client_rpc.allocate(aligned_size, tag)
if int(server_aligned) != aligned_size:
raise RuntimeError(
f"Alignment mismatch: {aligned_size} vs {server_aligned}"
)
return allocation_id
def export_handle(self, allocation_id: str) -> int:
"""Export allocation as POSIX FD."""
return self._client_rpc.export(allocation_id)
def get_handle_info(self, allocation_id: str):
"""Query allocation info from server."""
return self._client_rpc.get_allocation(allocation_id)
def free_handle(self, allocation_id: str) -> bool:
"""Release a cuMem allocation on the server."""
return self._client_rpc.free(allocation_id)
def clear_all_handles(self) -> int:
"""Clear all allocations on the server. NO local unmap.
Safe at startup (no local mappings) and during failover
(preserves local VA reservations).
"""
self._require_rw()
return self._client_rpc.clear_all()
def commit(self) -> bool:
"""Server-only commit: transition to COMMITTED state.
No synchronize(), no CUDA access flip. The caller is responsible for
synchronizing before calling this. Server closes the RW socket on
success, so self._client becomes None.
"""
self._require_rw()
ok = self._client_rpc.commit()
if ok:
self._client = None
return bool(ok)
def get_memory_layout_hash(self) -> str:
return self._client_rpc.get_memory_layout_hash()
def list_handles(self, tag: Optional[str] = None) -> List[Dict]:
return self._client_rpc.list_allocations(tag)
# ==================== Tier 1: Metadata ====================
def metadata_put(
self, key: str, allocation_id: str, offset_bytes: int, value: bytes
......@@ -210,294 +271,323 @@ class GMSClientMemoryManager:
def metadata_delete(self, key: str) -> bool:
return self._client_rpc.metadata_delete(key)
# ==================== Allocation operations ====================
def list_allocations(self, tag: Optional[str] = None) -> List[Dict]:
"""List all allocations on the server."""
return self._client_rpc.list_allocations(tag)
# ==================== Tier 1: VA Operations (local) ====================
def allocate_and_map(self, size: int, tag: str = "default") -> int:
"""Allocate on server, reserve VA, and map locally.
def reserve_va(self, size: int) -> int:
"""Reserve virtual address space (cuMemAddressReserve). No tracking."""
aligned_size = align_to_granularity(size, self.granularity)
return _cuda_reserve_va(aligned_size, self.granularity)
Args:
size: Requested allocation size in bytes.
tag: Allocation tag for server tracking.
def map_va(self, fd: int, va: int, size: int, allocation_id: str, tag: str) -> int:
"""Import FD + cuMemMap + set access + track.
Returns:
Virtual address of the mapped allocation.
Access is set based on current lock_type. Returns the CUDA handle.
"""
self._require_rw()
client = self._client_rpc
assert self._granted_lock_type is not None
aligned_size = align_to_granularity(size, self.granularity)
va = reserve_va(aligned_size, self.granularity)
handle = import_handle_from_fd(fd)
try:
allocation_id, server_aligned = client.allocate(aligned_size, tag)
if int(server_aligned) != aligned_size:
raise RuntimeError(
f"Alignment mismatch: {aligned_size} vs {server_aligned}"
)
fd = client.export(allocation_id)
handle = import_handle_from_fd(fd)
map_to_va(va, aligned_size, handle)
set_access(va, aligned_size, self.device, GrantedLockType.RW)
self._track_mapping(
LocalMapping(
allocation_id=allocation_id,
va=va,
size=size,
aligned_size=aligned_size,
handle=handle,
tag=tag,
access=GrantedLockType.RW,
)
)
return va
set_access(va, aligned_size, self.device, self._granted_lock_type)
except Exception:
free_va(va, aligned_size)
raise
def free_mapping(self, va: int) -> None:
"""Unmap and free a local mapping."""
mapping = self._mappings.pop(va, None)
if mapping is None:
return
self._allocation_id_to_va.pop(mapping.allocation_id, None)
try:
if mapping.handle != 0:
unmap(va, mapping.aligned_size)
release_handle(mapping.handle)
free_va(va, mapping.aligned_size)
except Exception as e:
logger.warning(f"Error freeing VA 0x{va:x}: {e}")
if self.lock_type == GrantedLockType.RW and not self._published:
try:
self._client_rpc.free(mapping.allocation_id)
unmap(va, aligned_size)
except Exception:
pass
def import_allocation(self, allocation_id: str) -> int:
"""Import an existing allocation and map locally.
In RO mode, maps read-only. In RW mode, maps read-write.
"""
if allocation_id in self._allocation_id_to_va:
return self._allocation_id_to_va[allocation_id]
client = self._client_rpc
# lock_type is guaranteed non-None when connected (after _client_rpc succeeds)
assert self.lock_type is not None
current_access = self.lock_type
alloc_info = client.get_allocation(allocation_id)
aligned_size = int(alloc_info.aligned_size)
size = int(alloc_info.size)
tag = str(getattr(alloc_info, "tag", "default"))
va = reserve_va(aligned_size, self.granularity)
try:
fd = client.export(allocation_id)
handle = import_handle_from_fd(fd)
map_to_va(va, aligned_size, handle)
set_access(va, aligned_size, self.device, current_access)
self._track_mapping(
LocalMapping(
allocation_id=allocation_id,
va=va,
size=size,
aligned_size=aligned_size,
handle=handle,
tag=tag,
access=current_access,
)
)
return va
except Exception:
free_va(va, aligned_size)
release_handle(handle)
raise
self._track_mapping(
LocalMapping(
allocation_id=allocation_id,
va=va,
size=size,
aligned_size=aligned_size,
handle=handle,
tag=tag,
)
)
return handle
def clear_all(self) -> int:
"""Clear all allocations on the server (RW only). Local mappings are unmapped first."""
self._require_rw()
self._unmap_all()
return self._client_rpc.clear_all()
# ==================== Publish / mode switching ====================
def commit(self) -> bool:
"""Publish weights (RW only).
Client responsibilities:
- cudaDeviceSynchronize() before publishing
- flip local mappings to RO before publishing
def unmap_va(self, va: int) -> None:
"""Unmap a single VA: cuMemUnmap + release handle.
Server responsibilities:
- transition to COMMITTED
- close the RW socket (publish + release)
Keeps the VA reservation and tracking entry (handle set to 0).
Works in both RW and RO modes.
"""
self._require_rw()
synchronize()
# After publishing, prevent further writes locally.
for va, m in list(self._mappings.items()):
if m.access != GrantedLockType.RO:
set_access(m.va, m.aligned_size, self.device, GrantedLockType.RO)
self._mappings[va] = m.with_access(GrantedLockType.RO)
ok = self._client_rpc.commit()
self._published = bool(ok)
# _client.commit() closes the socket on success; reflect that here.
if ok:
self._client = None
return bool(ok)
mapping = self._mappings.get(va)
if mapping is None or mapping.handle == 0:
return
unmap(va, mapping.aligned_size)
release_handle(mapping.handle)
self._mappings[va] = mapping.with_handle(0)
def switch_to_read(self, timeout_ms: Optional[int] = None) -> None:
"""Acquire an RO lock after publishing.
def free_va(self, va: int) -> None:
"""Release a VA reservation: cuMemAddressFree + untrack.
This is intended for the common flow where a writer loads weights and then
becomes a reader for inference.
Unmaps first if still mapped.
"""
if self._closed:
raise RuntimeError("Memory manager is closed")
if self._unmapped:
raise RuntimeError(
"Cannot switch_to_read() while unmapped; call remap() first"
)
if self._client is not None:
if self.lock_type == GrantedLockType.RO:
mapping = self._mappings.get(va)
if mapping is None:
return
if mapping.handle != 0:
self.unmap_va(va)
mapping = self._mappings.get(va)
if mapping is None:
return
raise RuntimeError(
"switch_to_read() requires the RW connection to be released (call commit() first)"
)
_cuda_free_va(va, mapping.aligned_size)
self._mappings.pop(va, None)
self._inverse_mapping.pop(mapping.allocation_id, None)
eff_timeout = timeout_ms if timeout_ms is not None else self._timeout_ms
self._connect(lock_type=RequestedLockType.RO, timeout_ms=eff_timeout)
# ==================== Tier 2: Convenience ====================
# ==================== Unmap / remap (read mode) ====================
def create_mapping(
self,
allocation_id: Optional[str] = None,
size: int = 0,
tag: str = "default",
) -> int:
"""Allocate or import a handle and map to a new VA.
def unmap(self) -> None:
"""Release RO lock and unmap local allocations (VA-stable).
If allocation_id is None (allocate path):
allocate_handle -> export_handle -> reserve_va -> map_va
VAs are preserved during unmap so tensor pointers remain stable.
On remap, allocations are remapped to the same VAs.
If allocation_id given (import path, cached):
Check cache -> get_handle_info -> export_handle -> reserve_va -> map_va
"""
if self._closed:
raise RuntimeError("Memory manager is closed")
if self._unmapped:
if allocation_id is not None:
# Import path: check cache first
cached_va = self._inverse_mapping.get(allocation_id)
if cached_va is not None:
mapping = self._mappings.get(cached_va)
if mapping is not None and mapping.handle == 0:
raise RuntimeError(
f"Allocation {allocation_id} is cached but unmapped "
f"(VA 0x{cached_va:x}). Use remap_all_vas() to restore."
)
return cached_va
info = self.get_handle_info(allocation_id)
alloc_size = int(info.size)
aligned_size = int(info.aligned_size)
alloc_tag = str(getattr(info, "tag", "default"))
fd = self.export_handle(allocation_id)
va = self.reserve_va(aligned_size)
try:
self.map_va(fd, va, alloc_size, allocation_id, alloc_tag)
except Exception:
_cuda_free_va(va, align_to_granularity(aligned_size, self.granularity))
raise
return va
# Allocate path
if size <= 0:
raise ValueError("size must be > 0 when allocation_id is None")
alloc_id = self.allocate_handle(size, tag)
fd = self.export_handle(alloc_id)
aligned_size = align_to_granularity(size, self.granularity)
va = self.reserve_va(aligned_size)
try:
self.map_va(fd, va, size, alloc_id, tag)
except Exception:
_cuda_free_va(va, aligned_size)
raise
return va
def destroy_mapping(self, va: int) -> None:
"""Unmap + free VA + free server handle for a single mapping."""
mapping = self._mappings.get(va)
if mapping is None:
return
if self.lock_type != GrantedLockType.RO:
raise RuntimeError("unmap() requires RO mode")
synchronize()
alloc_id = mapping.allocation_id
# Preserve allocation IDs for remapping on remap
self._preserved_allocation_ids = list(self._allocation_id_to_va.keys())
try:
self.unmap_va(va)
except Exception as e:
logger.warning("Error in unmap_va for 0x%x: %s", va, e)
# Unmap physical memory but keep VA reservations
self._unmap_preserving_va()
self._va_preserved = True
try:
self.free_va(va)
except Exception as e:
logger.warning("Error in free_va for 0x%x: %s", va, e)
self._client_rpc.close()
self._client = None
self._unmapped = True
# Only free server handle if we're RW and haven't committed
if self._granted_lock_type == GrantedLockType.RW:
try:
self.free_handle(alloc_id)
except Exception:
pass
def remap(self, timeout_ms: Optional[int] = None) -> bool:
"""Reacquire RO lock and remap preserved allocations (VA-stable).
def unmap_all_vas(self) -> None:
"""Synchronize + unmap all VAs. Preserves VA reservations for remap."""
synchronize()
Allocations are remapped to the same VAs they had before unmap,
ensuring tensor pointers remain valid.
unmapped_count = 0
total_bytes = 0
for va, mapping in list(self._mappings.items()):
if mapping.handle == 0:
continue
try:
self.unmap_va(va)
unmapped_count += 1
total_bytes += mapping.aligned_size
except Exception as e:
logger.warning("Error unmapping VA 0x%x: %s", va, e)
Args:
timeout_ms: Timeout for RO lock acquisition.
self._va_preserved = True
self._unmapped = True
logger.info(
"[GPU Memory Service] Unmapped %d allocations (%.2f GiB), "
"preserving %d VA reservations",
unmapped_count,
total_bytes / (1 << 30),
len(self._mappings),
)
Returns:
True on success.
def remap_all_vas(self) -> None:
"""Re-import existing handles at preserved VAs.
Raises:
TimeoutError: If timeout_ms expires waiting for RO lock.
StaleMemoryLayoutError: If weights were structurally changed while unmapped.
Checks layout hash for staleness. Validates each allocation still
exists and size matches before remapping.
"""
if self._closed:
raise RuntimeError("Memory manager is closed")
if not self._unmapped:
return True
set_current_device(self.device)
eff_timeout = timeout_ms if timeout_ms is not None else self._timeout_ms
self._connect(
lock_type=RequestedLockType.RO,
timeout_ms=eff_timeout,
update_memory_layout_hash=False,
)
# Check if memory layout changed while unmapped
current_hash = self._client_rpc.get_memory_layout_hash()
# Stale layout check
current_hash = self.get_memory_layout_hash()
if (
self._last_memory_layout_hash
and current_hash != self._last_memory_layout_hash
):
raise StaleMemoryLayoutError(
f"State changed while unmapped: hash {self._last_memory_layout_hash[:16]}... -> {current_hash[:16]}..."
f"Layout changed: {self._last_memory_layout_hash[:16]}... -> {current_hash[:16]}..."
)
# Remap to preserved VAs
assert self._granted_lock_type is not None
remapped_count = 0
failed_count = 0
total_bytes = 0
for alloc_id in self._preserved_allocation_ids:
for va, mapping in list(self._mappings.items()):
if mapping.handle != 0:
continue # Already mapped
# Validate allocation still exists
try:
va = self._remap_preserved_va(alloc_id)
mapping = self._mappings.get(va)
if mapping:
total_bytes += mapping.aligned_size
remapped_count += 1
except StaleMemoryLayoutError:
raise # Let StaleMemoryLayoutError propagate
alloc_info = self.get_handle_info(mapping.allocation_id)
except Exception as e:
logger.warning(f"Failed to remap {alloc_id}: {e}")
failed_count += 1
raise StaleMemoryLayoutError(
f"Allocation {mapping.allocation_id} no longer exists: {e}"
) from e
if int(alloc_info.aligned_size) != mapping.aligned_size:
raise StaleMemoryLayoutError(
f"Allocation {mapping.allocation_id} size changed: "
f"{mapping.aligned_size} vs {int(alloc_info.aligned_size)}"
)
# Re-import and map to preserved VA
fd = self.export_handle(mapping.allocation_id)
handle = import_handle_from_fd(fd)
map_to_va(va, mapping.aligned_size, handle)
set_access(va, mapping.aligned_size, self.device, self._granted_lock_type)
synchronize()
validate_pointer(va)
self._mappings[va] = mapping.with_handle(handle)
remapped_count += 1
total_bytes += mapping.aligned_size
self._va_preserved = False
self._unmapped = False
logger.info(
"[GPU Memory Service] Remap complete on device %d: "
"remapped %d allocations (%.2f GiB)",
self.device,
remapped_count,
total_bytes / (1 << 30),
)
if failed_count > 0:
def reallocate_all_handles(self, tag: str = "default") -> None:
"""Allocate fresh server handles for all preserved VAs (no mapping).
Used during failover: the shadow engine's VAs are still reserved,
but the physical memory was freed. This allocates new server-side
handles and updates tracking (handle stays 0 — call remap_all_vas()
afterward to actually map them).
"""
self._require_rw()
if not self._va_preserved:
raise RuntimeError(
f"Remap failed: {failed_count} of {len(self._preserved_allocation_ids)} "
f"allocations could not be remapped"
"reallocate_all_handles requires preserved VAs (call unmap_all_vas first)"
)
reallocated = 0
for va, mapping in list(self._mappings.items()):
if mapping.handle != 0:
continue
# Allocate fresh handle on server (uses raw RPC to avoid re-aligning)
allocation_id, server_aligned = self._client_rpc.allocate(
mapping.aligned_size, tag
)
if int(server_aligned) != mapping.aligned_size:
raise RuntimeError(
f"Alignment mismatch during reallocation: "
f"{mapping.aligned_size} vs {server_aligned}"
)
# Update tracking: new allocation_id, handle stays 0
old_alloc_id = mapping.allocation_id
self._inverse_mapping.pop(old_alloc_id, None)
self._mappings[va] = mapping.with_allocation_id(allocation_id)
self._inverse_mapping[allocation_id] = va
reallocated += 1
logger.info(
f"[GPU Memory Service] Remap complete on device {self.device}: "
f"remapped {remapped_count} allocations ({total_bytes / (1 << 30):.2f} GiB)"
"[GPU Memory Service] Reallocated %d handles for preserved VAs",
reallocated,
)
self._unmapped = False
self._va_preserved = False
return True
# ==================== Lifecycle ====================
# ==================== Cleanup ====================
def close(self, free: bool = False) -> None:
"""Best-effort cleanup. NOT reliable in crash/signal paths.
def close(self) -> None:
if self._closed:
return
synchronize + unmap all + free all VAs + disconnect.
free=True: also clear_all_handles() on server before disconnect.
VAs are freed by CUDA context teardown on process exit anyway.
"""
try:
synchronize()
except Exception:
pass
# Ensure kernels are done before tearing down mappings.
synchronize()
for va in list(self._mappings.keys()):
try:
self.unmap_va(va)
except Exception as e:
logger.warning("Error unmapping VA 0x%x during close: %s", va, e)
# Release all mappings including preserved VA reservations
self._unmap_all()
for va in list(self._mappings.keys()):
try:
self.free_va(va)
except Exception as e:
logger.warning("Error freeing VA 0x%x during close: %s", va, e)
if self._client is not None:
self._client.close()
self._client = None
self._closed = True
if (
free
and self._client is not None
and self._granted_lock_type == GrantedLockType.RW
):
try:
self.clear_all_handles()
except Exception as e:
logger.warning("Error clearing handles during close: %s", e)
self.disconnect()
self._unmapped = False
self._va_preserved = False
self._preserved_allocation_ids.clear()
def __enter__(self) -> "GMSClientMemoryManager":
return self
......@@ -509,7 +599,7 @@ class GMSClientMemoryManager:
@property
def _client_rpc(self) -> GMSRPCClient:
"""Get connected client or raise. Use instead of _require_connected() + assert."""
"""Get connected client or raise."""
if self._client is None:
if self._unmapped:
raise RuntimeError("Memory manager is unmapped")
......@@ -517,129 +607,9 @@ class GMSClientMemoryManager:
return self._client
def _require_rw(self) -> None:
"""Raise if not in RW mode."""
if self.lock_type != GrantedLockType.RW:
if self._granted_lock_type != GrantedLockType.RW:
raise RuntimeError("Operation requires RW mode")
def _track_mapping(self, m: LocalMapping) -> None:
self._mappings[m.va] = m
self._allocation_id_to_va[m.allocation_id] = m.va
def _unmap_preserving_va(self) -> None:
"""Unmap physical memory but PRESERVE VA reservations for unmap/remap.
This keeps the VA reservation intact so tensors maintain stable pointers.
On remap, we can remap to the same VAs.
"""
unmapped_count = 0
total_bytes = 0
for va, mapping in list(self._mappings.items()):
if mapping.handle == 0:
continue # Already unmapped
try:
unmap(va, mapping.aligned_size)
release_handle(mapping.handle)
self._mappings[va] = mapping.with_handle(
0
) # Mark unmapped, VA reserved
unmapped_count += 1
total_bytes += mapping.aligned_size
except Exception as e:
logger.warning(
f"Error unmapping VA 0x{va:x} (preserving reservation): {e}"
)
logger.info(
f"[GPU Memory Service] Unmapped {unmapped_count} allocations ({total_bytes / (1 << 30):.2f} GiB), "
f"preserving {len(self._mappings)} VA reservations"
)
def _remap_preserved_va(self, allocation_id: str) -> int:
"""Remap an allocation to its preserved VA.
Requires the VA to already be reserved (from before unmap).
Validates allocation still exists and size matches.
Returns the VA.
Raises StaleMemoryLayoutError if allocation is missing or size changed.
"""
set_current_device(self.device)
va = self._allocation_id_to_va.get(allocation_id)
if va is None:
raise RuntimeError(f"No preserved VA for allocation {allocation_id}")
mapping = self._mappings.get(va)
if mapping is None:
raise RuntimeError(f"No mapping info for VA 0x{va:x}")
if mapping.handle != 0:
return va # Already mapped
client = self._client_rpc
# lock_type is guaranteed non-None when connected (after _client_rpc succeeds)
assert self.lock_type is not None
current_access = self.lock_type
# Validate allocation still exists and size matches
try:
alloc_info = client.get_allocation(allocation_id)
except Exception as e:
raise StaleMemoryLayoutError(
f"Allocation {allocation_id} no longer exists on server: {e}"
) from e
server_aligned_size = int(alloc_info.aligned_size)
if server_aligned_size != mapping.aligned_size:
raise StaleMemoryLayoutError(
f"Allocation {allocation_id} size changed: expected {mapping.aligned_size}, got {server_aligned_size}"
)
# Re-import the handle and map to the SAME VA (which is still reserved)
fd = client.export(allocation_id)
handle = import_handle_from_fd(fd)
map_to_va(va, mapping.aligned_size, handle)
# Set access permissions based on current lock type
set_access(va, mapping.aligned_size, self.device, current_access)
# Synchronize to ensure mapping is complete before any access
synchronize()
# Validate the pointer is accessible (this is what Triton checks)
result, _dev_ptr = cuda.cuPointerGetAttribute(
cuda.CUpointer_attribute.CU_POINTER_ATTRIBUTE_DEVICE_POINTER, va
)
if result != cuda.CUresult.CUDA_SUCCESS:
err_result, err_str = cuda.cuGetErrorString(result)
err_msg = ""
if err_result == cuda.CUresult.CUDA_SUCCESS and err_str:
err_msg = (
err_str.decode() if isinstance(err_str, bytes) else str(err_str)
)
logger.warning(
f"[GPU Memory Service] cuPointerGetAttribute failed for VA 0x{va:x} after remap: "
f"error {result} ({err_msg})"
)
else:
logger.debug(
f"[GPU Memory Service] Remapped VA 0x{va:x} validated OK (device={self.device})"
)
# Update mapping with new handle and access
updated = mapping.with_handle(handle)
self._mappings[va] = updated.with_access(current_access)
return va
def _unmap_all(self) -> None:
"""Unmap and release all local mappings including VA reservations."""
for va, mapping in list(self._mappings.items()):
try:
if mapping.handle != 0:
unmap(va, mapping.aligned_size)
release_handle(mapping.handle)
free_va(va, mapping.aligned_size)
except Exception as e:
logger.warning(f"Error unmapping VA 0x{va:x}: {e}")
self._mappings.clear()
self._allocation_id_to_va.clear()
self._inverse_mapping[m.allocation_id] = m.va
......@@ -129,20 +129,30 @@ class GMSRPCClient:
try:
self._socket.connect(self.socket_path)
except FileNotFoundError:
self._socket.close()
self._socket = None
raise ConnectionError(f"Server not running at {self.socket_path}") from None
except Exception as e:
self._socket.close()
self._socket = None
raise ConnectionError(f"Failed to connect: {e}") from e
# Send handshake (this IS lock acquisition)
request = HandshakeRequest(
lock_type=self._requested_lock_type, timeout_ms=timeout_ms
)
send_message_sync(self._socket, request)
# Handshake I/O — clean up socket on any failure
try:
request = HandshakeRequest(
lock_type=self._requested_lock_type,
timeout_ms=timeout_ms,
)
send_message_sync(self._socket, request)
# Receive response (may block waiting for lock)
response, _, self._recv_buffer = recv_message_sync(
self._socket, self._recv_buffer
)
# May block waiting for lock
response, _, self._recv_buffer = recv_message_sync(
self._socket, self._recv_buffer
)
except Exception:
self._socket.close()
self._socket = None
raise
if isinstance(response, ErrorResponse):
self._socket.close()
......
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""GPU Memory Service allocator singleton management.
"""GPU Memory Service allocator management (singleton).
Manages the singleton memory manager and PyTorch MemPool integration.
Manages a single weights memory manager and PyTorch MemPool integration.
Only one GMS scope is needed: weights. KV cache is handled by CuMemAllocator.
"""
from __future__ import annotations
......@@ -19,12 +20,53 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
# Global singleton state
_gms_client_memory_manager: Optional["GMSClientMemoryManager"] = None
# Singleton state
_manager: Optional["GMSClientMemoryManager"] = None
_mem_pool: Optional["MemPool"] = None
_tag: str = "weights"
_callbacks_initialized: bool = False
_pluggable_alloc: Optional[Any] = None
def _gms_malloc(size: int, device: int, stream: int) -> int:
"""Route malloc to the singleton weights manager."""
if _manager is None:
raise RuntimeError("No GMS manager initialized")
va = _manager.create_mapping(size=int(size), tag=_tag)
logger.debug("[GMS] malloc: va=0x%x size=%d", va, size)
return va
def _gms_free(ptr: int, size: int, device: int, stream: int) -> None:
"""Route free to the singleton weights manager."""
if _manager is None:
logger.warning("[GMS] free: no manager, ignoring va=0x%x", ptr)
return
if int(ptr) in _manager.mappings:
logger.debug("[GMS] free: va=0x%x size=%d", ptr, size)
_manager.destroy_mapping(int(ptr))
else:
logger.warning("[GMS] free: manager does not own va=0x%x, ignoring", ptr)
def _ensure_callbacks_initialized() -> "MemPool":
"""Initialize C-level callbacks exactly once, return a new MemPool."""
global _callbacks_initialized, _pluggable_alloc
from gpu_memory_service.client.torch.extensions import _allocator_ext as cumem
from torch.cuda import CUDAPluggableAllocator
from torch.cuda.memory import MemPool
if not _callbacks_initialized:
_pluggable_alloc = CUDAPluggableAllocator(
cumem.__file__, "my_malloc", "my_free"
)
cumem.init_module(_gms_malloc, _gms_free)
_callbacks_initialized = True
return MemPool(allocator=_pluggable_alloc.allocator())
def get_or_create_gms_client_memory_manager(
socket_path: str,
device: int,
......@@ -33,7 +75,7 @@ def get_or_create_gms_client_memory_manager(
tag: str = "weights",
timeout_ms: Optional[int] = None,
) -> Tuple["GMSClientMemoryManager", Optional["MemPool"]]:
"""Get existing memory manager or create a new one.
"""Get existing memory manager, or create a new one.
Args:
socket_path: Unix socket path for the allocation server.
......@@ -45,80 +87,53 @@ def get_or_create_gms_client_memory_manager(
Returns:
(gms_client_memory_manager, pool) - pool is None for RO mode.
"""
global _gms_client_memory_manager, _mem_pool
global _manager, _mem_pool, _tag
from gpu_memory_service.client.memory_manager import GMSClientMemoryManager
if _gms_client_memory_manager is not None:
if _manager is not None:
return _get_existing(mode)
# Create new manager
gms_client_memory_manager = GMSClientMemoryManager(
socket_path, mode=mode, device=device, timeout_ms=timeout_ms
)
_gms_client_memory_manager = gms_client_memory_manager
manager = GMSClientMemoryManager(socket_path, device=device)
manager.connect(mode, timeout_ms=timeout_ms)
if gms_client_memory_manager.mode == GrantedLockType.RW:
_mem_pool = _setup_mempool(gms_client_memory_manager, tag)
if manager.granted_lock_type == GrantedLockType.RW:
pool = _ensure_callbacks_initialized()
# Only set globals after mempool succeeds (avoids partial singleton)
_manager = manager
_tag = tag
_mem_pool = pool
logger.info("[GMS] Created RW allocator (device=%d)", device)
return gms_client_memory_manager, _mem_pool
return manager, pool
else:
_manager = manager
_tag = tag
logger.info("[GMS] Created RO allocator (device=%d)", device)
return gms_client_memory_manager, None
return manager, None
def _get_existing(
mode: RequestedLockType,
) -> Tuple["GMSClientMemoryManager", Optional["MemPool"]]:
"""Return existing allocator if mode-compatible."""
current = _gms_client_memory_manager.mode
assert _manager is not None
current = _manager.granted_lock_type
if mode == RequestedLockType.RW:
if current == GrantedLockType.RW:
return _gms_client_memory_manager, _mem_pool
return _manager, _mem_pool
raise RuntimeError(f"Cannot get RW allocator: existing is in {current} mode")
if mode == RequestedLockType.RO:
if current == GrantedLockType.RO:
return _gms_client_memory_manager, None
raise RuntimeError(
f"Cannot get RO allocator: existing is in {current} mode. "
"Call manager.switch_to_read() first."
)
return _manager, None
raise RuntimeError(f"Cannot get RO allocator: existing is in {current} mode")
# RW_OR_RO: return whatever exists
pool = _mem_pool if current == GrantedLockType.RW else None
return _gms_client_memory_manager, pool
def _setup_mempool(
gms_client_memory_manager: "GMSClientMemoryManager",
tag: str,
) -> "MemPool":
"""Set up PyTorch CUDAPluggableAllocator and MemPool."""
global _pluggable_alloc
from gpu_memory_service.client.torch.extensions import _allocator_ext as cumem
from torch.cuda import CUDAPluggableAllocator
from torch.cuda.memory import MemPool
pluggable_alloc = CUDAPluggableAllocator(cumem.__file__, "my_malloc", "my_free")
pool = MemPool(allocator=pluggable_alloc.allocator())
_pluggable_alloc = pluggable_alloc
def malloc_cb(size: int, device: int, stream: int) -> int:
va = gms_client_memory_manager.allocate_and_map(int(size), tag=tag)
logger.debug("[GMS] malloc: va=0x%x size=%d", va, size)
return va
def free_cb(ptr: int, size: int, device: int, stream: int) -> None:
logger.debug("[GMS] free: va=0x%x size=%d", ptr, size)
gms_client_memory_manager.free_mapping(int(ptr))
cumem.init_module(malloc_cb, free_cb)
return pool
effective_pool = _mem_pool if current == GrantedLockType.RW else None
return _manager, effective_pool
def get_gms_client_memory_manager() -> Optional["GMSClientMemoryManager"]:
"""Get the active GMS client memory manager, or None if not initialized."""
return _gms_client_memory_manager
"""Get the active GMS client memory manager, or None."""
return _manager
......@@ -214,7 +214,9 @@ class GMSTensorSpec:
device_index: int,
) -> torch.Tensor:
"""Create a tensor aliasing mapped CUDA memory."""
base_va = gms_client_memory_manager.import_allocation(self.allocation_id)
base_va = gms_client_memory_manager.create_mapping(
allocation_id=self.allocation_id
)
ptr = int(base_va) + int(self.offset_bytes)
return _tensor_from_pointer(
......
......@@ -16,6 +16,19 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
def get_gms_lock_mode(extra_config: dict):
"""Resolve GMS lock mode from model_loader_extra_config.
Returns RO if gms_read_only=True, otherwise RW_OR_RO (default).
"""
from gpu_memory_service.common.types import RequestedLockType
if extra_config.get("gms_read_only", False):
logger.info("[GMS] gms_read_only=True, forcing RO mode")
return RequestedLockType.RO
return RequestedLockType.RW_OR_RO
def setup_meta_tensor_workaround() -> None:
"""Enable workaround for meta tensor operations like torch.nonzero()."""
try:
......@@ -30,9 +43,8 @@ def finalize_gms_write(
allocator: "GMSClientMemoryManager", model: torch.nn.Module
) -> int:
"""Finalize GMS write mode: register tensors, commit, switch to read.
This is typically called when the (writing) model loader finishes, and
is ready to commit the weights so that other engines can import these
weights and read them.
Flow: register tensors -> sync -> commit (server-only) -> disconnect -> connect(RO)
Args:
allocator: The GMS client memory manager in write mode.
......@@ -45,17 +57,20 @@ def finalize_gms_write(
RuntimeError: If commit fails.
"""
from gpu_memory_service.client.torch.module import register_module_tensors
from gpu_memory_service.common.types import RequestedLockType
register_module_tensors(allocator, model)
total_bytes = allocator.total_bytes
# Wait for all writes to weights (from caller) to complete before mode switch
# Synchronize before commit — caller's writes must be visible
torch.cuda.synchronize()
if not allocator.commit():
raise RuntimeError("GMS commit failed")
allocator.switch_to_read()
# commit() closed the RW socket; acquire RO for inference
allocator.disconnect() # no-op if commit already cleared _client, but safe
allocator.connect(RequestedLockType.RO)
logger.info(
"[GMS] Committed %.2f GiB, switched to read mode with %d mappings",
......
......@@ -20,6 +20,10 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
# Module-level GMS lock mode, set by setup_gms() before loader is instantiated.
# Read by patches.py when creating GMSMemorySaverImpl.
_gms_lock_mode = None
def setup_gms(server_args) -> Type["GMSModelLoader"]:
"""Setup GPU Memory Service for SGLang.
......@@ -46,6 +50,19 @@ def setup_gms(server_args) -> Type["GMSModelLoader"]:
"Cannot use --enable-draft-weights-cpu-backup with --load-format gms."
)
# Resolve lock mode from model_loader_extra_config before patches fire
global _gms_lock_mode
extra = getattr(server_args, "model_loader_extra_config", None)
if isinstance(extra, str):
import json
extra = json.loads(extra) if extra else {}
extra = extra or {}
from gpu_memory_service.integrations.common.utils import get_gms_lock_mode
_gms_lock_mode = get_gms_lock_mode(extra)
# Import triggers patches at module level
from gpu_memory_service.integrations.sglang.model_loader import GMSModelLoader
......
......@@ -51,14 +51,15 @@ class GMSMemorySaverImpl:
torch_impl: "_TorchMemorySaverImpl",
socket_path: str,
device_index: int,
mode=None,
):
self._torch_impl = torch_impl
self._socket_path = socket_path
self._device_index = device_index
self._requested_mode = mode
self._disabled = False
self._imported_weights_bytes: int = 0
# Initialize allocator with auto mode
self._allocator: Optional["GMSClientMemoryManager"]
self._mem_pool: Optional["MemPool"]
self._mode: str
......@@ -74,19 +75,20 @@ class GMSMemorySaverImpl:
def _init_allocator(
self,
) -> tuple[Optional["GMSClientMemoryManager"], Optional["MemPool"], str]:
"""Create allocator with automatic mode selection."""
"""Create allocator with mode from config (default: RW_OR_RO)."""
from gpu_memory_service import get_or_create_gms_client_memory_manager
from gpu_memory_service.common.types import GrantedLockType, RequestedLockType
mode = self._requested_mode or RequestedLockType.RW_OR_RO
allocator, mem_pool = get_or_create_gms_client_memory_manager(
self._socket_path,
self._device_index,
mode=RequestedLockType.RW_OR_RO,
mode=mode,
tag="weights",
)
granted_mode = allocator.mode
granted_mode = allocator.granted_lock_type
if granted_mode == GrantedLockType.RW:
allocator.clear_all()
allocator.clear_all_handles()
actual_mode = "write"
else:
actual_mode = "read"
......@@ -151,7 +153,8 @@ class GMSMemorySaverImpl:
if self._allocator.is_unmapped:
return
logger.info("[GMS] Unmapping weights (VA-stable)")
self._allocator.unmap()
self._allocator.unmap_all_vas()
self._allocator.disconnect()
def _resume_weights(self) -> None:
if self._allocator is None:
......@@ -159,7 +162,10 @@ class GMSMemorySaverImpl:
if not self._allocator.is_unmapped:
return
logger.info("[GMS] Remapping weights (VA-stable)")
self._allocator.remap()
from gpu_memory_service.common.types import RequestedLockType
self._allocator.connect(RequestedLockType.RO)
self._allocator.remap_all_vas()
def finalize_write_mode(self, model: torch.nn.Module) -> None:
"""Finalize write mode: register tensors, commit, and switch to read."""
......
......@@ -67,11 +67,14 @@ def patch_torch_memory_saver() -> None:
# Create underlying torch impl for non-weights tags (KV cache etc.)
torch_impl = _TorchMemorySaverImpl(hook_mode="torch")
# Create GPU Memory Service impl
# Read lock mode set by setup_gms() (defaults to RW_OR_RO)
from gpu_memory_service.integrations.sglang import _gms_lock_mode
gms_impl = GMSMemorySaverImpl(
torch_impl=torch_impl,
socket_path=socket_path,
device_index=device_index,
mode=_gms_lock_mode,
)
# Set _impl directly (accessible via gms_impl property)
......
......@@ -17,10 +17,11 @@ from typing import TYPE_CHECKING
import torch
from gpu_memory_service import get_or_create_gms_client_memory_manager
from gpu_memory_service.client.torch.module import materialize_module_from_gms
from gpu_memory_service.common.types import GrantedLockType, RequestedLockType
from gpu_memory_service.common.types import GrantedLockType
from gpu_memory_service.common.utils import get_socket_path
from gpu_memory_service.integrations.common.utils import (
finalize_gms_write,
get_gms_lock_mode,
setup_meta_tensor_workaround,
)
......@@ -60,16 +61,18 @@ def register_gms_loader(load_format: str = "gms") -> None:
def load_weights(self, model: torch.nn.Module, model_config) -> None:
self.default_loader.load_weights(model, model_config)
def load_model(self, vllm_config, model_config) -> torch.nn.Module:
def load_model(self, vllm_config, model_config, prefix="") -> torch.nn.Module:
device = torch.cuda.current_device()
extra = getattr(self.load_config, "model_loader_extra_config", {}) or {}
mode = get_gms_lock_mode(extra)
gms_client, pool = get_or_create_gms_client_memory_manager(
get_socket_path(device),
device,
mode=RequestedLockType.RW_OR_RO,
mode=mode,
tag="weights",
)
if gms_client.mode == GrantedLockType.RO:
if gms_client.granted_lock_type == GrantedLockType.RO:
return _load_read_mode(gms_client, vllm_config, model_config, device)
else:
return _load_write_mode(
......@@ -133,7 +136,7 @@ def _load_write_mode(
)
from vllm.utils.torch_utils import set_default_torch_dtype
gms_client.clear_all()
gms_client.clear_all_handles()
# Allocate model tensors using GMS memory pool
with set_default_torch_dtype(model_config.dtype):
......
......@@ -43,8 +43,8 @@ def patch_memory_snapshot() -> None:
manager = get_gms_client_memory_manager()
assert manager is not None, "GMS client is not initialized"
if manager.mode == GrantedLockType.RO:
allocations = manager.list_allocations()
if manager.granted_lock_type == GrantedLockType.RO:
allocations = manager.list_handles()
committed_bytes = sum(alloc.get("aligned_size", 0) for alloc in allocations)
else:
# NOTE: by design, we want to assume we have the whole GPU when writing
......
......@@ -24,6 +24,7 @@ from gpu_memory_service import (
from gpu_memory_service.common.types import RequestedLockType
from gpu_memory_service.common.utils import get_socket_path
from gpu_memory_service.integrations.common import patch_empty_cache
from gpu_memory_service.integrations.common.utils import get_gms_lock_mode
from gpu_memory_service.integrations.vllm.model_loader import register_gms_loader
from gpu_memory_service.integrations.vllm.patches import patch_memory_snapshot
......@@ -57,10 +58,17 @@ class GMSWorker(Worker):
device = self.local_rank
current_platform.set_device(torch.device(f"cuda:{device}"))
# Establish GMS connection (so MemorySnapshot can query committed bytes)
# Establish weights GMS connection (so MemorySnapshot can query committed bytes).
# Fetch extra config from vLLM load_config to determine RW/RO lock mode.
extra = (
getattr(self.vllm_config.load_config, "model_loader_extra_config", {}) or {}
)
socket_path = get_socket_path(device)
get_or_create_gms_client_memory_manager(
socket_path, device, mode=RequestedLockType.RW_OR_RO, tag="weights"
socket_path,
device,
mode=get_gms_lock_mode(extra),
tag="weights",
)
# Parent will set device again (harmless) and do memory checks
......@@ -105,17 +113,18 @@ class GMSWorker(Worker):
NOTE: We do NOT call super().sleep() because it tries to copy GPU buffers to CPU,
which segfaults on already-unmapped GMS memory.
"""
from vllm.device_allocator.cumem import CuMemAllocator
free_bytes_before = torch.cuda.mem_get_info()[0]
# Unmap GMS weights (VA-stable unmap, no CPU backup needed)
# Unmap GMS weights: synchronize + unmap all VAs + disconnect
manager = get_gms_client_memory_manager()
assert manager is not None, "GMS client is not initialized"
assert not manager.is_unmapped, "GMS weights are already unmapped"
manager.unmap()
manager.unmap_all_vas()
manager.disconnect()
# Sleep KV cache via CuMemAllocator
from vllm.device_allocator.cumem import CuMemAllocator
# Sleep KV cache via CuMemAllocator (discard, no CPU backup)
allocator = CuMemAllocator.get_instance()
allocator.sleep(offload_tags=tuple())
......@@ -130,8 +139,6 @@ class GMSWorker(Worker):
def wake_up(self, tags: Optional[List[str]] = None) -> None:
"""vLLM wake implementation with GMS integration."""
from vllm.device_allocator.cumem import CuMemAllocator
if tags is None:
tags = ["weights", "kv_cache"]
......@@ -139,9 +146,12 @@ class GMSWorker(Worker):
manager = get_gms_client_memory_manager()
assert manager is not None, "GMS client is not initialized"
assert manager.is_unmapped, "GMS weights are not unmapped"
manager.remap()
manager.connect(RequestedLockType.RO)
manager.remap_all_vas()
if "kv_cache" in tags:
from vllm.device_allocator.cumem import CuMemAllocator
allocator = CuMemAllocator.get_instance()
allocator.wake_up(tags=["kv_cache"])
......@@ -154,8 +164,8 @@ class GMSWorker(Worker):
def _maybe_get_memory_pool_context(self, tag: str):
"""Skip CuMemAllocator for weights when using GMS.
GMS manages its own memory pool for weights, so we don't want
vLLM's CuMemAllocator to interfere.
GMS manages its own memory pool for weights, so we don't want vLLM's
CuMemAllocator to interfere.
"""
if tag == "weights":
logger.debug("[GMS] Skipping CuMemAllocator for weights")
......
......@@ -10,7 +10,7 @@ from gpu_memory_service.common.types import (
StateSnapshot,
)
from gpu_memory_service.server.handler import MetadataEntry, RequestHandler
from gpu_memory_service.server.locking import Connection, GlobalLockFSM
from gpu_memory_service.server.locking import Connection, GMSLocalFSM
from gpu_memory_service.server.memory_manager import (
AllocationInfo,
AllocationNotFoundError,
......@@ -29,6 +29,6 @@ __all__ = [
"RequestedLockType",
"RequestHandler",
"ServerState",
"GlobalLockFSM",
"GMSLocalFSM",
"StateSnapshot",
]
......@@ -5,7 +5,7 @@
This module handles:
- Connection: Represents an active client connection
- GlobalLockFSM: Explicit state transitions with validated permissions
- GMSLocalFSM: Explicit state transitions with validated permissions
State Diagram:
......@@ -174,7 +174,7 @@ class TransitionRecord:
session_id: Optional[str] = None
class GlobalLockFSM:
class GMSLocalFSM:
"""Explicit state machine for GPU Memory Service.
State is DERIVED from actual connection objects:
......@@ -330,7 +330,12 @@ class GlobalLockFSM:
)
# Record transition
record = TransitionRecord(from_state, event, to_state, session_id)
record = TransitionRecord(
from_state,
event,
to_state,
session_id=session_id,
)
self._transition_log.append(record)
logger.info(
......
......@@ -68,7 +68,7 @@ class GMSServerMemoryManager:
so it doesn't create a CUDA context. This allows it to survive GPU
driver failures.
- NOT thread-safe: Callers must provide external synchronization.
The GlobalLockFSM's RW/RO semantics ensure single-writer access.
The GMSLocalFSM's RW/RO semantics ensure single-writer access.
"""
def __init__(self, device: int = 0):
......
......@@ -3,8 +3,8 @@
"""Async Allocation RPC Server - Single-threaded event loop with explicit state machine.
State transitions are explicit and validated by the GlobalLockFSM class.
Operations are checked against state/mode permissions before execution.
State transitions are explicit and validated by the GMSLocalFSM class.
Operations are checked against state/mode permissions before operation.
State Machine (see locking.py for full diagram):
EMPTY: No connections, not committed
......@@ -49,7 +49,7 @@ from gpu_memory_service.common.types import (
)
from .handler import RequestHandler
from .locking import Connection, GlobalLockFSM
from .locking import Connection, GMSLocalFSM
logger = logging.getLogger(__name__)
......@@ -57,12 +57,16 @@ logger = logging.getLogger(__name__)
class GMSRPCServer:
"""GPU Memory Service RPC Server.
Async single-threaded server using GlobalLockFSM for explicit state transitions
Async single-threaded server using GMSLocalFSM for explicit state transitions
and operation validation. All state mutations happen through the state machine's
transition() method.
"""
def __init__(self, socket_path: str, device: int = 0):
def __init__(
self,
socket_path: str,
device: int = 0,
):
self.socket_path = socket_path
self.device = device
......@@ -70,7 +74,7 @@ class GMSRPCServer:
self._handler = RequestHandler(device)
# State machine - handles all state transitions and permission checks
self._sm = GlobalLockFSM(on_rw_abort=self._handler.on_rw_abort)
self._sm = GMSLocalFSM(on_rw_abort=self._handler.on_rw_abort)
self._waiting_writers: int = 0
# Async waiting for lock acquisition
......@@ -162,7 +166,13 @@ class GMSRPCServer:
writer.close()
return None
conn = Connection(reader, writer, granted_mode, session_id, recv_buffer)
conn = Connection(
reader=reader,
writer=writer,
mode=granted_mode,
session_id=session_id,
recv_buffer=recv_buffer,
)
# State transition: connect
event = (
......@@ -183,7 +193,9 @@ class GMSRPCServer:
return conn
async def _acquire_lock(
self, mode: RequestedLockType, timeout_ms: Optional[int]
self,
mode: RequestedLockType,
timeout_ms: Optional[int],
) -> Optional[GrantedLockType]:
"""Wait until lock can be acquired (uses state machine predicates).
......@@ -368,9 +380,7 @@ class GMSRPCServer:
async def _handle_commit(self, conn: Connection) -> tuple[object, int, bool]:
"""Handle commit via state machine transition - atomic with disconnect."""
# Compute state hash before transitioning
self._handler.on_commit()
# State transition: commit
self._sm.transition(StateEvent.RW_COMMIT, conn)
await send_message(conn.writer, CommitResponse(success=True))
......
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""
GPU Memory Service Shadow Engine Failover Test for SGLang.
Tests the shadow engine failover scenario where a sleeping shadow engine can
wake up and take over when the primary engine fails.
"""
import logging
"""GPU Memory Service Shadow Engine Failover Test for SGLang."""
import pytest
from tests.utils.constants import FAULT_TOLERANCE_MODEL_NAME
from tests.utils.managed_process import DynamoFrontendProcess
from .utils.common import GMSServerProcess, get_gpu_memory_used, send_completion
from .utils.common import run_shadow_failover_test
from .utils.sglang import SGLangWithGMSProcess
logger = logging.getLogger(__name__)
@pytest.mark.sglang
@pytest.mark.e2e
......@@ -31,69 +21,23 @@ logger = logging.getLogger(__name__)
def test_gms_shadow_engine_failover(
request, runtime_services, gms_ports, predownload_models
):
"""Test shadow engine failover with GPU Memory Service.
1. Start shadow engine and put it to sleep
2. Start primary engine and serve inference
3. Kill primary engine
4. Wake shadow engine and verify it handles inference
"""
ports = gms_ports
with GMSServerProcess(request, device=0):
with DynamoFrontendProcess(request, frontend_port=ports["frontend"]):
# Start shadow engine
with SGLangWithGMSProcess(
request,
"shadow",
ports["shadow_system"],
ports["shadow_sglang"],
ports["frontend"],
) as shadow:
# Verify shadow works
result = send_completion(ports["frontend"])
logger.info(f"Shadow inference result: {result}")
assert result["choices"]
logger.info("Shadow inference OK")
# Sleep shadow (release memory occupation)
mem_before = get_gpu_memory_used()
sleep_result = shadow.sleep()
assert sleep_result["status"] == "ok"
mem_after_sleep = get_gpu_memory_used()
logger.info(
f"Shadow sleep freed {(mem_before - mem_after_sleep) / (1 << 20):.0f} MB"
)
assert mem_after_sleep < mem_before
# Start primary engine
with SGLangWithGMSProcess(
request,
"primary",
ports["primary_system"],
ports["primary_sglang"],
ports["frontend"],
):
result = send_completion(ports["frontend"], "Primary test")
logger.info(f"Primary inference result: {result}")
assert result["choices"]
logger.info("Primary inference OK")
# Primary is dead (exited context manager)
# Wake shadow (resume memory occupation)
wake_result = shadow.wake()
assert wake_result["status"] == "ok"
# Verify shadow handles failover
result = send_completion(ports["frontend"], "After failover")
logger.info(f"Failover inference result: {result}")
assert result["choices"]
logger.info("Shadow handles failover OK")
for i in range(3):
result = send_completion(ports["frontend"], f"Verify {i}")
logger.info(f"Verification {i} result: {result}")
assert result["choices"]
logger.info("All verification passed")
run_shadow_failover_test(
request,
ports,
make_shadow=lambda: SGLangWithGMSProcess(
request,
"shadow",
ports["shadow_system"],
ports["shadow_sglang"],
ports["frontend"],
),
make_primary=lambda: SGLangWithGMSProcess(
request,
"primary",
ports["primary_system"],
ports["primary_sglang"],
ports["frontend"],
),
)
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""
GPU Memory Service Shadow Engine Failover Test for vLLM.
Tests the shadow engine failover scenario where a sleeping shadow engine can
wake up and take over when the primary engine fails.
"""
import logging
"""GPU Memory Service Shadow Engine Failover Test for vLLM."""
import pytest
from tests.utils.constants import FAULT_TOLERANCE_MODEL_NAME
from tests.utils.managed_process import DynamoFrontendProcess
from .utils.common import GMSServerProcess, get_gpu_memory_used, send_completion
from .utils.common import run_shadow_failover_test
from .utils.vllm import VLLMWithGMSProcess
logger = logging.getLogger(__name__)
@pytest.mark.vllm
@pytest.mark.e2e
......@@ -31,71 +21,25 @@ logger = logging.getLogger(__name__)
def test_gms_shadow_engine_failover(
request, runtime_services, gms_ports, predownload_models
):
"""Test shadow engine failover with GPU Memory Service.
1. Start shadow engine and put it to sleep
2. Start primary engine and serve inference
3. Kill primary engine
4. Wake shadow engine and verify it handles inference
"""
ports = gms_ports
with GMSServerProcess(request, device=0):
with DynamoFrontendProcess(request, frontend_port=ports["frontend"]):
# Start shadow engine
with VLLMWithGMSProcess(
request,
"shadow",
ports["shadow_system"],
ports["shadow_kv_event"],
ports["shadow_nixl"],
ports["frontend"],
) as shadow:
# Verify shadow works
result = send_completion(ports["frontend"])
logger.info(f"Shadow inference result: {result}")
assert result["choices"]
logger.info("Shadow inference OK")
# Sleep shadow
mem_before = get_gpu_memory_used()
sleep_result = shadow.sleep()
assert sleep_result["status"] == "ok"
mem_after_sleep = get_gpu_memory_used()
logger.info(
f"Shadow sleep freed {(mem_before - mem_after_sleep) / (1 << 20):.0f} MB"
)
assert mem_after_sleep < mem_before
# Start primary engine
with VLLMWithGMSProcess(
request,
"primary",
ports["primary_system"],
ports["primary_kv_event"],
ports["primary_nixl"],
ports["frontend"],
):
result = send_completion(ports["frontend"], "Primary test")
logger.info(f"Primary inference result: {result}")
assert result["choices"]
logger.info("Primary inference OK")
# Primary is dead (exited context manager)
# Wake shadow
wake_result = shadow.wake()
assert wake_result["status"] == "ok"
# Verify shadow handles failover
result = send_completion(ports["frontend"], "After failover")
logger.info(f"Failover inference result: {result}")
assert result["choices"]
logger.info("Shadow handles failover OK")
for i in range(3):
result = send_completion(ports["frontend"], f"Verify {i}")
logger.info(f"Verification {i} result: {result}")
assert result["choices"]
logger.info("All verification passed")
run_shadow_failover_test(
request,
ports,
make_shadow=lambda: VLLMWithGMSProcess(
request,
"shadow",
ports["shadow_system"],
ports["shadow_kv_event"],
ports["shadow_nixl"],
ports["frontend"],
),
make_primary=lambda: VLLMWithGMSProcess(
request,
"primary",
ports["primary_system"],
ports["primary_kv_event"],
ports["primary_nixl"],
ports["frontend"],
),
)
......@@ -11,14 +11,16 @@ backend-agnostic and can be used by vLLM, SGLang, or other backends.
import logging
import os
import shutil
import signal
import time
from typing import Callable
import pynvml
import requests
from gpu_memory_service.common.utils import get_socket_path
from tests.utils.constants import FAULT_TOLERANCE_MODEL_NAME
from tests.utils.managed_process import ManagedProcess
from tests.utils.managed_process import DynamoFrontendProcess, ManagedProcess
logger = logging.getLogger(__name__)
......@@ -33,6 +35,55 @@ def get_gpu_memory_used(device: int = 0) -> int:
pynvml.nvmlShutdown()
def kill_force(
process: ManagedProcess,
timeout_s: float = 30.0,
poll_interval_s: float = 0.5,
) -> None:
"""SIGKILL a process group and wait for GPU memory reclamation.
Snapshots GPU memory before the kill, sends SIGKILL to the entire
process group, reaps the zombie, then polls pynvml until the CUDA
driver finishes asynchronous cleanup (memory drops below the
pre-kill snapshot).
"""
mem_before = get_gpu_memory_used()
pid = process.get_pid()
if pid is None:
logger.warning("kill_force: no PID available")
return
try:
pgid = os.getpgid(pid)
logger.info(f"kill_force: sending SIGKILL to process group {pgid} (pid={pid})")
os.killpg(pgid, signal.SIGKILL)
except ProcessLookupError:
logger.warning(f"kill_force: process {pid} already dead")
return
# Reap the process to avoid zombies
try:
os.waitpid(pid, 0)
except ChildProcessError:
pass
# Wait for CUDA driver to asynchronously reclaim GPU memory
start = time.time()
mem_after = mem_before
while time.time() - start < timeout_s:
mem_after = get_gpu_memory_used()
if mem_after < mem_before:
break
time.sleep(poll_interval_s)
freed_mb = (mem_before - mem_after) / (1 << 20)
logger.info(
f"kill_force: before={mem_before / (1 << 30):.2f} GiB, "
f"after={mem_after / (1 << 30):.2f} GiB, freed={freed_mb:.0f} MB"
)
def send_completion(
port: int, prompt: str = "Hello", max_retries: int = 3, retry_delay: float = 1.0
) -> dict:
......@@ -40,12 +91,6 @@ def send_completion(
Includes retry logic to handle transient failures from stale routing
(e.g., after failover when etcd still has dead instance entries).
Args:
port: The frontend HTTP port.
prompt: The prompt to send.
max_retries: Max retries for transient failures.
retry_delay: Delay between retries in seconds.
"""
last_error = None
for attempt in range(max_retries):
......@@ -76,10 +121,7 @@ def send_completion(
class GMSServerProcess(ManagedProcess):
"""
Manages GMS server lifecycle for tests. Starts server, waits for socket, cleans up on exit.
Runs only for the specified GPU device.
"""
"""Manages GMS server lifecycle for tests."""
def __init__(self, request, device: int):
self.device = device
......@@ -115,3 +157,52 @@ class GMSServerProcess(ManagedProcess):
return True
time.sleep(0.1)
return False
def run_shadow_failover_test(
request,
ports: dict,
make_shadow: Callable[[], ManagedProcess],
make_primary: Callable[[], ManagedProcess],
) -> None:
"""Shared shadow-engine failover flow for both vLLM and SGLang.
1. Start shadow -> verify inference
2. Sleep shadow -> log memory freed
3. Start primary -> verify inference
4. kill -9 primary -> wait for GPU memory reclamation
5. Wake shadow -> verify inference x 3
"""
frontend_port = ports["frontend"]
with GMSServerProcess(request, device=0):
with DynamoFrontendProcess(request, frontend_port=frontend_port):
with make_shadow() as shadow:
# Shadow inference
result = send_completion(frontend_port)
assert result["choices"], "Shadow inference failed"
logger.info(f"Shadow inference OK: {result}")
# Sleep shadow
mem_before = get_gpu_memory_used()
assert shadow.sleep()["status"] == "ok"
mem_after = get_gpu_memory_used()
logger.info(
f"Shadow sleep: {mem_before / (1 << 30):.2f} -> "
f"{mem_after / (1 << 30):.2f} GiB "
f"(freed {(mem_before - mem_after) / (1 << 20):.0f} MB)"
)
# Primary: start, verify, kill -9
with make_primary() as primary:
result = send_completion(frontend_port, "Primary test")
assert result["choices"], "Primary inference failed"
logger.info(f"Primary inference OK: {result}")
kill_force(primary)
# Wake shadow, verify 3x
assert shadow.wake()["status"] == "ok"
for i in range(3):
result = send_completion(frontend_port, f"Verify {i}")
assert result["choices"], f"Verification {i} failed"
logger.info("All verification passed")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment