Unverified Commit 91a8c6af authored by Schwinn Saereesitthipitak's avatar Schwinn Saereesitthipitak Committed by GitHub
Browse files

fix: simplify GMS layout state and harden GPU-backed flows (#7006)


Signed-off-by: default avatarmohammedabdulwahhab <furkhan324@berkeley.edu>
Signed-off-by: default avatarHannah Zhang <hannahz@nvidia.com>
Co-authored-by: default avatarmohammedabdulwahhab <furkhan324@berkeley.edu>
Co-authored-by: default avatarhhzhang16 <54051230+hhzhang16@users.noreply.github.com>
parent dd7ceb4a
......@@ -28,20 +28,20 @@ This leads to:
┌──────────────────────────────────────────────────────────────────────────────────────┐
│ │
│ ┌────────────────────┐ ┌─────────────────────────────────────────┐ │
│ │ GMS Server │ │ GMSClientMemoryManager (Writer) │ │
│ │ GMS │ │ GMSClientMemoryManager (Writer) │ │
│ │ │ │ │ │
│ │ ┌────────────────┐ │ │ ┌─────────────────────────────────┐ │ │
│ │ │ Memory Manager │ │ ◄── Unix ───────►│ │ GMSRPCClient │ │ │
│ │ │ Memory Manager │ │ ◄── Unix ───────►│ │ GMS Session │ │ │
│ │ └────────────────┘ │ Socket │ └─────────────────────────────────┘ │ │
│ │ │ + │ │ │
│ │ ┌────────────────┐ │ FD │ Writer-only: create_mapping, commit │ │
│ │ │ State Machine │ │ (SCM_RIGHTS) └─────────────────────────────────────────┘ │
│ │ │ Session / FSM │ │ (SCM_RIGHTS) └─────────────────────────────────────────┘ │
│ │ └────────────────┘ │ │
│ │ │ ┌─────────────────────────────────────────┐ │
│ │ ┌────────────────┐ │ │ GMSClientMemoryManager (Reader) │ │
│ │ │ Metadata Store │ │ │ │ │
│ │ └────────────────┘ │ ◄── Unix ───────►│ ┌─────────────────────────────────┐ │ │
│ │ │ Socket │ │ GMSRPCClient │ │ │
│ │ │ Socket │ │ GMS Session │ │ │
│ └────────────────────┘ + │ └─────────────────────────────────┘ │ │
│ FD │ │ │
│ (SCM_RIGHTS) │ Reader-only: create_mapping (import), │ │
......@@ -65,25 +65,25 @@ The GMS server runs as an independent process that manages GPU memory without ev
The server consists of three main components:
1. **Memory Manager** - Allocates physical GPU memory via CUDA VMM (`cuMemCreate`) and exports shareable file descriptors (`cuMemExportToShareableHandle`). Critically, it never calls `cuMemMap` - clients handle all virtual address mapping.
1. **Memory Manager** - Allocates physical GPU memory via CUDA VMM (`cuMemCreate`) and exports shareable file descriptors (`cuMemExportToShareableHandle`). Critically, it never calls `cuMemMap` - clients handle all virtual address mapping. Allocation requests retry on OOM until they succeed or the optional retry timeout is reached.
2. **State Machine (FSM)** - Manages the global lock state and enforces access rules that ensures consistency across multiple clients. See [State Machine](#state-machine) below for details.
2. **State Machine (FSM)** - Manages global lock state, waiter coordination, and disconnect cleanup.
3. **Metadata Store** - Key-value store for tensor metadata (shapes, dtypes, offsets), enabling clients to reconstruct model structure.
3. **Metadata Store / Layout State** - `GMS` owns the metadata table and committed layout hash. Allocations and metadata live in one flat store that is cleared on each new writer connect or writer abort.
### Client
Each GMS server is responsible for managing memory of only 1 GPU, and does not interact with GMS servers corresponding to other GPUs.
Clients connect to the server to acquire locks and access GPU memory. Two client classes are provided:
### Client
1. **GMSRPCClient** - Low-level RPC client for direct protocol access. Handles socket communication, msgpack serialization, and file descriptor passing via `SCM_RIGHTS`. The socket connection **is** the lock - connection lifetime equals lock lifetime, providing automatic crash resilience.
Clients connect to the server to acquire locks and access GPU memory. The supported client API is:
2. **GMSClientMemoryManager** - High-level client that wraps `GMSRPCClient` and handles all CUDA VMM operations for memory import and mapping safely:
1. **GMSClientMemoryManager** - High-level client that wraps an internal RPC transport layer and handles all CUDA VMM operations for memory import and mapping safely:
- Imports file descriptors and converts them to CUDA memory handles
- Reserves virtual address space and maps physical memory
- Sets appropriate access permissions (RW for writers, RO for readers)
- Supports **unmap/remap** for VA-stable memory release under memory pressure
> **Note**: Always use `GMSClientMemoryManager` to interact with GMS from client code. The low-level `GMSRPCClient` is an implementation detail and should not be used directly.
> **Note**: Always use `GMSClientMemoryManager` to interact with GMS from client code. The low-level RPC client is an implementation detail and should not be used directly.
### Memory Allocation and Import Flow
......@@ -92,7 +92,7 @@ The following diagram shows how `GMSClientMemoryManager` interacts with the serv
```mermaid
sequenceDiagram
participant C as GMSClientMemoryManager
participant S as GMS Server
participant S as GMS
participant GPU as GPU Memory
%% Connection
......@@ -111,7 +111,7 @@ sequenceDiagram
%% Export/Import (Both Writer and Reader)
Note over C,GPU: Both Writer and Reader: Export and map
C->>S: ExportRequest(allocation_id)
C->>S: ExportAllocationRequest(allocation_id)
S->>GPU: cuMemExportToShareableHandle(handle)
GPU-->>S: fd
S-->>C: Response + fd (via SCM_RIGHTS)
......@@ -152,31 +152,103 @@ stateDiagram-v2
| State | Description | Can Connect RW | Can Connect RO |
|-------|-------------|:--------------:|:--------------:|
| `EMPTY` | No connections, no committed weights | ✓ | ✗ |
| `EMPTY` | No connections, no committed layout visible | ✓ | ✗ |
| `RW` | Writer connected (exclusive access) | ✗ | ✗ |
| `COMMITTED` | Weights published, no active connections | ✓ | ✓ |
| `COMMITTED` | Committed layout visible to readers, no active connections | ✓ | ✓ |
| `RO` | One or more readers connected (shared access) | ✗ | ✓ |
### Events
| Event | Trigger | Description |
|-------|---------|-------------|
| `RW_CONNECT` | Writer connects | Acquires exclusive write lock |
| `RW_COMMIT` | Writer calls `commit()` | Publishes weights, releases lock |
| `RW_ABORT` | Writer disconnects without commit | Discards allocations, releases lock |
| `RW_CONNECT` | Writer connects | Acquires exclusive write lock, clears the previous committed layout immediately, and starts a fresh RW layout build |
| `RW_COMMIT` | Writer calls `commit()` | Publishes the current RW layout as the committed layout and releases the lock |
| `RW_ABORT` | Writer disconnects without commit | Drops the active RW layout and returns to `EMPTY` |
| `RO_CONNECT` | Reader connects | Acquires shared read lock |
| `RO_DISCONNECT` | Reader disconnects | Releases shared lock; if last reader, returns to COMMITTED |
### Lock Semantics
The socket connection **is** the lock:
A handshaken socket connection **is** the lock:
- **Crash resilience**: Connection close (including process crash) automatically releases the lock
- **No explicit unlock**: Eliminates forgotten locks and deadlocks
- **Atomic transitions**: State changes happen atomically with socket operations
The only exception is the runtime inspection probes (`GetRuntimeState`, `GetEventHistory`): they connect, fetch diagnostics, and close without entering the lock FSM.
### Layout Lifecycle
Layout creation and publication work like this:
```mermaid
flowchart LR
A[EMPTY or COMMITTED] -->|RW_CONNECT| B[Fresh RW layout]
B -->|Allocate memory and write metadata| C{Writer outcome}
C -->|RW_COMMIT| D[Publish layout as committed]
C -->|RW_ABORT| E[Discard layout]
D -->|Next RW_CONNECT| F[Fresh RW layout]
E -->|Next RW_CONNECT| F
```
- `RW_CONNECT` starts a fresh RW layout build.
- `RW_COMMIT` publishes the current layout; it does not create another one.
- `RW_ABORT` discards the current RW layout and returns the system to `EMPTY`.
- Allocations and metadata live in one flat store that is cleared on `RW_CONNECT` and `RW_ABORT`.
- RO requests are served only from the committed layout, while RW requests mutate only the active layout.
- Read RPCs (`export`, allocation lookup/listing, metadata lookup/listing) operate on that single live store. This is safe because the FSM prevents RW and RO sessions from coexisting.
- `metadata_put` validates allocation ownership and offset bounds, `free` cascades metadata cleanup, and `commit` rejects dangling metadata references.
### Allocation Backpressure on OOM
When a writer requests a new allocation, GMS treats CUDA OOM as a transient condition:
- `cuMemCreate` OOM does **not** immediately fail the request.
- The server retries in a loop and only returns success after allocation is created.
- Server CLI flags:
- `--alloc-retry-interval` (default `0.5`)
- `--alloc-retry-timeout` (default unset = wait indefinitely)
This ensures the "new writer gets fresh allocations" workflow can wait for memory reclamation instead of racing into immediate OOM failures.
### Guarantees
- GMS guarantees that its own RPCs do not mix committed and active generations, and that `GMSClientMemoryManager.commit()` performs a CUDA synchronize and unmaps the writer's local mappings before publish.
- After local unmap, `commit()` does not attempt in-process recovery. Non-CUDA failures raise, and CUDA VMM failures exit the process.
- The only non-fatal client connection failure is lock acquisition timeout. Other client-side GMS transport, protocol, and server error responses raise.
- Any non-OOM CUDA VMM failure on either client or server is fatal and exits the process.
- On the server, an untrusted client connection is isolated to that connection: transport loss and response-send failures unwind the connection state, and only server invariant violations or CUDA failures kill the server.
- Runtime-state `allocation_count` and `allocations_cleared` report server-owned allocation handles only. Imported handles in other processes can still keep VRAM alive after the server clears its own layout state.
- GMS *does not* prove that a disconnected or already-submitted writer has no in-flight GPU work left on the device. The mitigation in this design is that new RW layouts use fresh allocations and may wait for memory reclamation before allocation succeeds.
---
### Server Trust Boundary
```mermaid
flowchart TD
A[Client event on server connection] --> B{Can server read and decode it?}
B -- no --> C[Drop connection]
C --> D[Run disconnect cleanup]
D --> E[RW_ABORT or RO_DISCONNECT]
B -- yes --> F{Valid client request?}
F -- no --> G[Send ErrorResponse]
F -- yes --> H{Did request expose server invariant failure?}
H -- yes --> I[Exit server process]
H -- no --> J[Build response or apply commit]
J --> K{Can server send response?}
K -- no --> D
K -- yes --> L[Continue session or close committed writer]
```
- `Drop connection` means the server stops trusting that socket and unwinds only that connection's lock state.
- After `RW_COMMIT`, disconnect cleanup only closes the committed writer socket; it does not roll the server back to `RW_ABORT`.
- `Valid client request?` covers mode/state violations, unknown requests, and request validation failures like bad metadata offsets.
- `Did request expose server invariant failure?` covers impossible layout/FSM states and commit-time metadata integrity failures.
## Sequence Diagrams
### Writer Flow (Cold Start)
......@@ -187,11 +259,14 @@ The first worker loads weights from disk and publishes them to GMS.
sequenceDiagram
participant W as Writer Process
participant C as GMSClientMemoryManager
participant S as GMS Server
participant S as GMS
W->>C: mgr = GMSClientMemoryManager(socket_path, device=0)
W->>C: mgr.connect(RW)
C->>S: HandshakeRequest(lock_type=RW)
S->>S: Session FSM: EMPTY/COMMITTED -> RW
S->>S: Clear prior committed layout
S->>S: Start fresh RW layout
S-->>C: HandshakeResponse(success=true)
loop For each tensor
......@@ -201,9 +276,14 @@ sequenceDiagram
end
W->>C: mgr.commit()
C->>GPU: synchronize()
C->>GPU: cuMemUnmap(...) + cuMemRelease(...)
C->>S: CommitRequest()
S->>S: Publish current layout as committed
S->>S: FSM: RW → COMMITTED
S-->>C: CommitResponse(success=true)
W->>C: mgr.connect(RO)
W->>C: mgr.remap_all_vas()
```
### Reader Flow (Warm Start)
......@@ -214,7 +294,7 @@ Subsequent workers import weights from GMS instead of loading from disk.
sequenceDiagram
participant R as Reader Process
participant C as GMSClientMemoryManager
participant S as GMS Server
participant S as GMS
R->>C: mgr = GMSClientMemoryManager(socket_path, device=0)
R->>C: mgr.connect(RO)
......@@ -242,7 +322,7 @@ Readers can temporarily release GPU memory while preserving virtual address rese
sequenceDiagram
participant R as Reader Process
participant C as GMSClientMemoryManager
participant S as GMS Server
participant S as GMS
participant GPU as GPU Memory
Note over R,GPU: Need to temporarily release GPU memory
......@@ -256,33 +336,26 @@ sequenceDiagram
Note over C: Keep VA reservation!
end
R->>C: mgr.disconnect()
R->>C: mgr.abort()
C->>S: Close socket (release RO lock)
S->>S: FSM: RO → COMMITTED (if last reader)
Note over R,GPU: GPU memory released, VA preserved
Note over R,GPU: Another writer could modify weights here
Note over R,GPU: Another writer could publish a new layout here
R->>C: mgr.connect(RO)
C->>S: HandshakeRequest(lock_type=RO)
S->>S: FSM: COMMITTED → RO
S-->>C: HandshakeResponse(success=true)
R->>C: mgr.remap_all_vas()
C->>S: GetStateHashRequest()
S-->>C: GetStateHashResponse(hash)
alt hash == saved_hash
loop For each preserved VA
C->>S: ExportRequest(allocation_id)
S-->>C: Response + fd
C->>GPU: cuMemImportFromShareableHandle(fd)
C->>GPU: cuMemMap(same_va, handle)
Note over C: Tensors valid at same addresses!
end
C->>S: Export preserved allocations from the committed layout
S-->>C: Response + FDs
C->>GPU: Import handles and remap at preserved VAs
C-->>R: Remap succeeds and tensor pointers stay valid
else hash != saved_hash
C-->>R: StaleMemoryLayoutError
Note over R: Must re-import from scratch
C-->>R: Re-import from scratch
end
```
......@@ -294,9 +367,9 @@ The `RW_OR_RO` mode automatically selects writer or reader based on server state
sequenceDiagram
participant P as Process
participant C as GMSClientMemoryManager
participant S as GMS Server
participant S as GMS
Note over P,S: Auto-mode: Writer if first, Reader if weights exist
Note over P,S: Auto-mode: try RW only when no committed layout exists
P->>C: mgr = GMSClientMemoryManager(socket_path, device=0)
P->>C: mgr.connect(RW_OR_RO)
......@@ -313,10 +386,10 @@ sequenceDiagram
S-->>C: HandshakeResponse(granted=RO, committed=true)
Note over P: Subsequent process - import from GMS
else RW held by another
S->>S: Wait for RO availability
S->>S: FSM: COMMITTED → RO
S->>S: Wait until a committed layout becomes available
S->>S: Then grant RO from COMMITTED
S-->>C: HandshakeResponse(granted=RO, committed=true)
Note over P: Wait for writer to finish
Note over P: Wait for writer to publish committed weights
end
```
......@@ -355,13 +428,15 @@ During `remap_all_vas()`:
### 4. Memory Layout Hash
On commit, the server computes a hash of:
- All allocation IDs, sizes, and tags
- All metadata entries
- All allocation layout slots, sizes, aligned sizes, and tags
- All metadata keys, offsets, and values
On `remap_all_vas()`, this hash is checked:
- If match: Safe to remap (layout unchanged)
- If mismatch: Raise `StaleMemoryLayoutError` (must re-import)
The hash is tied to the currently committed layout and is cleared as soon as a writer acquires RW.
**Important**: This detects **structural** changes, not **content** changes.
Weight values can be modified in-place (e.g., RL training updates) as long as the structure is preserved.
......@@ -411,17 +486,16 @@ class GMSClientMemoryManager:
# --- Tier 1: Connection ---
def connect(lock_type: RequestedLockType, timeout_ms: Optional[int] = None) -> None
def disconnect() -> None
def abort() -> None
# --- Tier 1: Handle ops (server-side, RW only) ---
def allocate_handle(size: int, tag: str = "default") -> str # Returns allocation_id
def allocate_handle(size: int, tag: str = "default") -> Tuple[str, int] # Returns allocation_id, layout_slot
def export_handle(allocation_id: str) -> int # Returns FD
def get_handle_info(allocation_id: str) -> AllocationInfo
def get_handle_info(allocation_id: str) -> GetAllocationResponse
def free_handle(allocation_id: str) -> bool
def clear_all_handles() -> int # Returns count cleared
def commit() -> bool # Transition to COMMITTED
def commit() -> bool # Sync + unmap local mappings + publish; raises on non-CUDA failure after unmap
def get_memory_layout_hash() -> str
def list_handles(tag: Optional[str] = None) -> List[Dict]
def list_handles(tag: Optional[str] = None) -> List[GetAllocationResponse]
# --- Tier 1: VA ops (local) ---
def reserve_va(size: int) -> int # Returns VA
......@@ -430,7 +504,7 @@ class GMSClientMemoryManager:
def free_va(va: int) -> None # Releases VA reservation
# --- Tier 1: Metadata ---
def metadata_put(key: str, allocation_id: str, offset: int, value: bytes) -> bool
def metadata_put(key: str, allocation_id: str, offset_bytes: int, value: bytes) -> bool
def metadata_get(key: str) -> Optional[Tuple[str, int, bytes]]
def metadata_list(prefix: str = "") -> List[str]
def metadata_delete(key: str) -> bool
......@@ -441,15 +515,9 @@ class GMSClientMemoryManager:
def unmap_all_vas() -> None # Sync + unmap all, preserve VA reservations
def remap_all_vas() -> None # Re-import at preserved VAs (checks layout hash)
def reallocate_all_handles(tag="default") -> None # Fresh server handles for preserved VAs
def close(free: bool = False) -> None
def close() -> None
```
## Limitations
1. **Single-GPU per server**: Each GMS server manages one GPU device
2. **CUDA VMM required**: Requires a GPU with Virtual Memory Management support. Check at runtime via `CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED` - there is no guaranteed minimum compute capability
3. **No content validation**: Remap doesn't detect in-place weight modifications
---
## Framework Integration (vLLM / SGLang)
......@@ -461,8 +529,10 @@ GMS provides pre-built integrations for vLLM and SGLang. Enable GMS by passing `
When `--load-format gms` is set:
1. **A GMS server must already be running** for the target GPU device. The engine connects to it via a Unix socket derived from the GPU UUID.
2. The engine uses `RW_OR_RO` mode by default: the **first** process gets RW (loads weights from disk, commits to GMS), and **subsequent** processes get RO (import weights from GMS metadata).
3. Weights are managed by GMS; KV cache is managed by the framework's own allocator (e.g., vLLM's `CuMemAllocator`).
2. The engine uses `RW_OR_RO` mode by default: if no committed layout exists and no writer holds the lock, the first process gets RW and loads weights from disk. Otherwise clients wait for a committed layout and then get RO to import published weights.
3. Both weights and KV cache are managed by GMS, but they use separate tags:
- `weights`: publish/import flow (`RW_OR_RO`, then `RO` after commit)
- `kv_cache`: separate RW-only tag for mutable KV-cache memory
#### vLLM
......@@ -470,6 +540,7 @@ When `--load-format gms` is set:
python -m dynamo.vllm \
--model <model> \
--load-format gms \
--worker-cls gpu_memory_service.integrations.vllm.worker:GMSWorker \
--enable-sleep-mode \
--gpu-memory-utilization 0.9
```
......@@ -478,7 +549,10 @@ The integration uses a custom worker class (`GMSWorker`) that:
- Establishes the GMS connection early in `init_device()` so vLLM's `MemorySnapshot` can account for committed weights
- Registers a custom model loader (`GMSModelLoader`) for the `gms` load format
- Patches `torch.cuda.empty_cache` to avoid releasing GMS-managed memory
- Routes weight allocation through a `CUDAPluggableAllocator` backed by GMS
- Uses two GMS tags on the GPU:
- `weights`: normal publish/import flow (`RW_OR_RO`, then `RO` after commit)
- `kv_cache`: separate RW-only tag for mutable KV-cache memory
- Routes both weight and KV-cache allocation through a `CUDAPluggableAllocator` backed by the appropriate GMS tag
#### SGLang
......@@ -490,9 +564,10 @@ python -m dynamo.sglang \
--mem-fraction-static 0.9
```
The integration patches `torch_memory_saver` to route weight operations through GMS:
- Weights (`"weights"` / `"model_weights"` tags) go through `GMSMemorySaverImpl`
- Other tags (e.g., `"kv_cache"`) are delegated to the default torch mempool implementation
The integration patches `torch_memory_saver` to route both weight and KV-cache operations through GMS:
- Weights (`"weights"` / `"model_weights"` tags) use the `weights` GMS tag
- KV cache (`"kv_cache"`) uses a separate RW-only `kv_cache` GMS tag
- Other tags still use the default torch mempool implementation
- The `--enable-memory-saver` flag is required to activate the memory saver pathway
### Shadow Engine Failover (Sleep / Wake)
......@@ -502,9 +577,14 @@ Both integrations support releasing and reclaiming GPU memory for shadow engine
- **vLLM**: `sleep` / `wake_up` (via `/engine/sleep` and `/engine/wake_up` HTTP endpoints)
- **SGLang**: `release_memory_occupation` / `resume_memory_occupation` (via the corresponding HTTP endpoints)
Under the hood, sleeping calls `unmap_all_vas()` + `disconnect()` to release GPU memory while preserving VA reservations, and waking calls `connect(RO)` + `remap_all_vas()` to re-import weights at the same virtual addresses. Tensor pointers remain valid, so no model re-initialization is needed.
Under the hood, sleeping calls `unmap_all_vas()` + `abort()` to release GPU memory while preserving VA reservations. Waking is tag-specific:
- **weights**: `connect(RO)` + `remap_all_vas()`
- **kv_cache**: `connect(RW)` + `reallocate_all_handles("kv_cache")` + `remap_all_vas()`
Tensor pointers remain valid because the original virtual addresses are preserved.
This enables a shadow engine to release its GPU memory, let a primary engine use the GPU, and then reclaim the memory after the primary is killed.
This enables a shadow engine to release its GPU memory, let a primary engine use the GPU, and then reclaim the memory after the primary is killed. The mutable KV cache always moves through a fresh RW layout in its own GMS tag before it is reallocated.
### Configuration via `model_loader_extra_config`
......
......@@ -32,7 +32,9 @@ from gpu_memory_service.client.memory_manager import (
# PyTorch integration (GMS client memory manager)
from gpu_memory_service.client.torch.allocator import (
get_gms_client_memory_manager,
get_gms_client_memory_managers,
get_or_create_gms_client_memory_manager,
gms_use_mem_pool,
)
__all__ = [
......@@ -42,4 +44,6 @@ __all__ = [
# GMS client memory manager
"get_or_create_gms_client_memory_manager",
"get_gms_client_memory_manager",
"get_gms_client_memory_managers",
"gms_use_mem_pool",
]
......@@ -6,6 +6,7 @@
import argparse
import logging
from dataclasses import dataclass
from typing import Optional
from gpu_memory_service.common.utils import get_socket_path
......@@ -17,7 +18,10 @@ class Config:
"""Configuration for GPU Memory Service server."""
device: int
tag: str
socket_path: str
alloc_retry_interval: float
alloc_retry_timeout: Optional[float]
verbose: bool
......@@ -33,6 +37,12 @@ def parse_args() -> Config:
required=True,
help="CUDA device ID to manage memory for.",
)
parser.add_argument(
"--tag",
type=str,
default="weights",
help="Logical GMS tag for this server (default: weights).",
)
parser.add_argument(
"--socket-path",
type=str,
......@@ -45,14 +55,33 @@ def parse_args() -> Config:
action="store_true",
help="Enable verbose logging.",
)
parser.add_argument(
"--alloc-retry-interval",
type=float,
default=0.5,
help="Seconds to sleep between allocation retries on CUDA OOM (default: 0.5).",
)
parser.add_argument(
"--alloc-retry-timeout",
type=float,
default=None,
help="Optional max seconds to wait for allocation retries before failing (default: wait indefinitely).",
)
args = parser.parse_args()
# Use UUID-based socket path by default (stable across CUDA_VISIBLE_DEVICES)
socket_path = args.socket_path or get_socket_path(args.device)
socket_path = args.socket_path or get_socket_path(args.device, args.tag)
if args.alloc_retry_interval <= 0:
parser.error("--alloc-retry-interval must be > 0")
if args.alloc_retry_timeout is not None and args.alloc_retry_timeout <= 0:
parser.error("--alloc-retry-timeout must be > 0 when set")
return Config(
device=args.device,
tag=args.tag,
socket_path=socket_path,
alloc_retry_interval=args.alloc_retry_interval,
alloc_retry_timeout=args.alloc_retry_timeout,
verbose=args.verbose,
)
......@@ -13,7 +13,6 @@ Usage:
import asyncio
import logging
import signal
import uvloop
from gpu_memory_service.server import GMSRPCServer
......@@ -37,33 +36,28 @@ async def worker() -> None:
logging.getLogger("gpu_memory_service").setLevel(logging.DEBUG)
logger.info(f"Starting GPU Memory Service Server for device {config.device}")
logger.info("GMS tag: %s", config.tag)
logger.info(f"Socket path: {config.socket_path}")
server = GMSRPCServer(config.socket_path, device=config.device)
# Set up shutdown handling
shutdown_event = asyncio.Event()
def signal_handler():
logger.info("Received shutdown signal")
shutdown_event.set()
loop = asyncio.get_running_loop()
for sig in (signal.SIGTERM, signal.SIGINT):
loop.add_signal_handler(sig, signal_handler)
await server.start()
logger.info(
"Allocation retry config: interval=%ss timeout=%s",
config.alloc_retry_interval,
(
f"{config.alloc_retry_timeout}s"
if config.alloc_retry_timeout is not None
else "none"
),
)
server = GMSRPCServer(
config.socket_path,
device=config.device,
allocation_retry_interval=config.alloc_retry_interval,
allocation_retry_timeout=config.alloc_retry_timeout,
)
logger.info("GPU Memory Service Server ready, waiting for connections...")
logger.info(f"Clients can connect via socket: {config.socket_path}")
# Wait for shutdown signal
try:
await shutdown_event.wait()
finally:
logger.info("Shutting down GPU Memory Service Server...")
await server.stop()
logger.info("GPU Memory Service Server shutdown complete")
await server.serve()
def main() -> None:
......
......@@ -7,7 +7,6 @@ This module provides the client-side components for interacting with the
GPU Memory Service:
- GMSClientMemoryManager: Manages local VA mappings of remote GPU memory
- GMSRPCClient: Low-level RPC client (pure Python, no PyTorch dependency)
For PyTorch integration (MemPool, tensor utilities), see gpu_memory_service.client.torch.
"""
......@@ -16,10 +15,8 @@ from gpu_memory_service.client.memory_manager import (
GMSClientMemoryManager,
StaleMemoryLayoutError,
)
from gpu_memory_service.client.rpc import GMSRPCClient
__all__ = [
"GMSClientMemoryManager",
"StaleMemoryLayoutError",
"GMSRPCClient",
]
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Client-side CUDA VMM utilities.
These functions wrap CUDA driver API calls used by the client memory manager
for importing, mapping, and unmapping GPU memory.
"""
from __future__ import annotations
import os
from cuda.bindings import driver as cuda
from gpu_memory_service.common.cuda_vmm_utils import check_cuda_result
from gpu_memory_service.common.types import GrantedLockType
def import_handle_from_fd(fd: int) -> int:
"""Import a CUDA memory handle from a file descriptor.
Closes the FD after import — the imported handle holds its own reference
to the physical allocation. Leaving the FD open leaks a DMA-buf ref that
prevents cuMemRelease from freeing GPU memory.
Args:
fd: POSIX file descriptor received via SCM_RIGHTS.
Returns:
CUDA memory handle.
"""
try:
result, handle = cuda.cuMemImportFromShareableHandle(
fd,
cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR,
)
check_cuda_result(result, "cuMemImportFromShareableHandle")
return int(handle)
finally:
os.close(fd)
def reserve_va(size: int, granularity: int) -> int:
"""Reserve virtual address space.
Args:
size: Size in bytes (should be aligned to granularity).
granularity: VMM allocation granularity.
Returns:
Reserved virtual address.
"""
result, va = cuda.cuMemAddressReserve(size, granularity, 0, 0)
check_cuda_result(result, "cuMemAddressReserve")
return int(va)
def free_va(va: int, size: int) -> None:
"""Free a virtual address reservation.
Args:
va: Virtual address to free.
size: Size of the reservation.
"""
(result,) = cuda.cuMemAddressFree(va, size)
check_cuda_result(result, "cuMemAddressFree")
def map_to_va(va: int, size: int, handle: int) -> None:
"""Map a CUDA handle to a virtual address.
Args:
va: Virtual address (must be reserved).
size: Size of the mapping.
handle: CUDA memory handle.
"""
(result,) = cuda.cuMemMap(va, size, 0, handle, 0)
check_cuda_result(result, "cuMemMap")
def set_access(va: int, size: int, device: int, access: GrantedLockType) -> None:
"""Set access permissions for a mapped region.
Args:
va: Virtual address.
size: Size of the region.
device: CUDA device index.
access: Access mode - RO for read-only, RW for read-write.
"""
acc = cuda.CUmemAccessDesc()
acc.location.type = cuda.CUmemLocationType.CU_MEM_LOCATION_TYPE_DEVICE
acc.location.id = device
acc.flags = (
cuda.CUmemAccess_flags.CU_MEM_ACCESS_FLAGS_PROT_READ
if access == GrantedLockType.RO
else cuda.CUmemAccess_flags.CU_MEM_ACCESS_FLAGS_PROT_READWRITE
)
(result,) = cuda.cuMemSetAccess(va, size, [acc], 1)
check_cuda_result(result, "cuMemSetAccess")
def unmap(va: int, size: int) -> None:
"""Unmap a virtual address region.
Args:
va: Virtual address to unmap.
size: Size of the mapping.
"""
(result,) = cuda.cuMemUnmap(va, size)
check_cuda_result(result, "cuMemUnmap")
def release_handle(handle: int) -> None:
"""Release a CUDA memory handle.
Args:
handle: CUDA memory handle to release.
"""
(result,) = cuda.cuMemRelease(handle)
check_cuda_result(result, "cuMemRelease")
def validate_pointer(va: int) -> bool:
"""Validate that a mapped VA is accessible.
Returns True if the pointer is valid, False otherwise (logs a warning).
"""
result, _dev_ptr = cuda.cuPointerGetAttribute(
cuda.CUpointer_attribute.CU_POINTER_ATTRIBUTE_DEVICE_POINTER, va
)
if result != cuda.CUresult.CUDA_SUCCESS:
err_result, err_str = cuda.cuGetErrorString(result)
err_msg = ""
if err_result == cuda.CUresult.CUDA_SUCCESS and err_str:
err_msg = err_str.decode() if isinstance(err_str, bytes) else str(err_str)
import logging
logging.getLogger(__name__).warning(
"cuPointerGetAttribute failed for VA 0x%x: %s (%s)",
va,
result,
err_msg,
)
return False
return True
def synchronize() -> None:
"""Synchronize the current CUDA context.
Blocks until all preceding commands in the current context have completed.
"""
(result,) = cuda.cuCtxSynchronize()
check_cuda_result(result, "cuCtxSynchronize")
def set_current_device(device: int) -> None:
"""Set the current CUDA device by activating its primary context.
Args:
device: CUDA device index.
"""
result, ctx = cuda.cuDevicePrimaryCtxRetain(device)
check_cuda_result(result, "cuDevicePrimaryCtxRetain")
(result,) = cuda.cuCtxSetCurrent(ctx)
check_cuda_result(result, "cuCtxSetCurrent")
......@@ -8,7 +8,7 @@ Two-tier API for GPU memory lifecycle management:
Tier 1 (Atomic Operations):
- Connection: connect(), disconnect()
- Handle ops (server-side cuMem allocations): allocate_handle, export_handle,
get_handle_info, free_handle, clear_all_handles, commit, list_handles,
get_handle_info, free_handle, commit, list_handles,
get_memory_layout_hash
- VA ops (local address space): reserve_va, map_va, unmap_va, free_va
- Metadata: metadata_put, metadata_get, metadata_list, metadata_delete
......@@ -34,25 +34,22 @@ import logging
from dataclasses import dataclass
from typing import Dict, List, Optional
from gpu_memory_service.client.cuda_vmm_utils import free_va as _cuda_free_va
from gpu_memory_service.client.cuda_vmm_utils import (
import_handle_from_fd,
map_to_va,
release_handle,
)
from gpu_memory_service.client.cuda_vmm_utils import reserve_va as _cuda_reserve_va
from gpu_memory_service.client.cuda_vmm_utils import (
set_access,
set_current_device,
synchronize,
unmap,
validate_pointer,
)
from gpu_memory_service.client.rpc import GMSRPCClient
from gpu_memory_service.common.cuda_vmm_utils import (
from gpu_memory_service.client.session import _GMSClientSession
from gpu_memory_service.common.cuda_utils import (
align_to_granularity,
get_allocation_granularity,
cuda_set_current_device,
cuda_synchronize,
cuda_validate_pointer,
cumem_address_free,
cumem_address_reserve,
cumem_get_allocation_granularity,
cumem_import_from_shareable_handle_close_fd,
cumem_map,
cumem_release,
cumem_set_access,
cumem_unmap,
)
from gpu_memory_service.common.protocol.messages import GetAllocationResponse
from gpu_memory_service.common.types import GrantedLockType, RequestedLockType
logger = logging.getLogger(__name__)
......@@ -97,6 +94,7 @@ class LocalMapping:
aligned_size: int
handle: int # 0 if unmapped but VA reserved
tag: str
layout_slot: int
def with_handle(self, handle: int) -> "LocalMapping":
return LocalMapping(
......@@ -106,9 +104,14 @@ class LocalMapping:
self.aligned_size,
handle,
self.tag,
self.layout_slot,
)
def with_allocation_id(self, allocation_id: str) -> "LocalMapping":
def with_server_identity(
self,
allocation_id: str,
layout_slot: int,
) -> "LocalMapping":
return LocalMapping(
allocation_id,
self.va,
......@@ -116,6 +119,7 @@ class LocalMapping:
self.aligned_size,
self.handle,
self.tag,
layout_slot,
)
......@@ -134,7 +138,7 @@ class GMSClientMemoryManager:
self.socket_path = socket_path
self.device = device
self._client: Optional[GMSRPCClient] = None
self._client: Optional[_GMSClientSession] = None
self._mappings: Dict[int, LocalMapping] = {} # va -> mapping
self._inverse_mapping: Dict[str, int] = {}
......@@ -145,8 +149,8 @@ class GMSClientMemoryManager:
self._va_preserved = False
self._last_memory_layout_hash: str = ""
set_current_device(self.device)
self.granularity = get_allocation_granularity(device)
cuda_set_current_device(self.device)
self.granularity = cumem_get_allocation_granularity(device)
# ==================== Properties ====================
......@@ -180,40 +184,59 @@ class GMSClientMemoryManager:
Updates self._granted_lock_type based on granted lock type. Saves memory layout hash
for stale detection if server is in committed state.
"""
self._client = GMSRPCClient(
if self._client is not None:
raise RuntimeError("Memory manager is already connected")
self._client = _GMSClientSession(
self.socket_path,
lock_type=lock_type,
timeout_ms=timeout_ms,
)
self._granted_lock_type = self._client.lock_type
# Save layout hash for stale detection on future remap
if self._client.committed:
if self._granted_lock_type == GrantedLockType.RW:
self._last_memory_layout_hash = ""
return
# Preserve the pre-unmap hash across reconnects so remap_all_vas can
# detect that another writer changed the committed layout while this
# process was disconnected.
if self._client.committed and (
not self._va_preserved or not self._last_memory_layout_hash
):
self._last_memory_layout_hash = self._client.get_memory_layout_hash()
elif not self._va_preserved:
self._last_memory_layout_hash = ""
def abort(self) -> None:
"""Drop the GMS session.
def disconnect(self) -> None:
"""Close connection and release lock."""
Clean callers should unmap first. This also supports abrupt session
drop with live mappings still present.
"""
if self._client is not None:
try:
self._client.close()
except Exception:
pass
self._client = None
finally:
self._client = None
self._granted_lock_type = None
return
self._granted_lock_type = None
# ==================== Tier 1: Handle Operations (server-side) ====================
def allocate_handle(self, size: int, tag: str = "default") -> str:
def allocate_handle(self, size: int, tag: str = "default") -> tuple[str, int]:
"""Allocate a cuMem handle on the server.
Returns allocation_id. Size is aligned to VMM granularity before sending.
Returns allocation_id and layout_slot. Size is aligned to VMM granularity
before sending.
"""
self._require_rw()
aligned_size = align_to_granularity(size, self.granularity)
allocation_id, server_aligned = self._client_rpc.allocate(aligned_size, tag)
if int(server_aligned) != aligned_size:
response = self._client_rpc.allocate_info(aligned_size, tag)
if int(response.aligned_size) != aligned_size:
raise RuntimeError(
f"Alignment mismatch: {aligned_size} vs {server_aligned}"
"GMS allocation alignment mismatch: "
f"{aligned_size} vs {response.aligned_size}"
)
return allocation_id
return response.allocation_id, int(response.layout_slot)
def export_handle(self, allocation_id: str) -> int:
"""Export allocation as POSIX FD."""
......@@ -225,34 +248,43 @@ class GMSClientMemoryManager:
def free_handle(self, allocation_id: str) -> bool:
"""Release a cuMem allocation on the server."""
return self._client_rpc.free(allocation_id)
ok = self._client_rpc.free(allocation_id)
if not ok:
raise RuntimeError(
f"GMS free_handle failed for allocation_id={allocation_id}"
)
return True
def clear_all_handles(self) -> int:
"""Clear all allocations on the server. NO local unmap.
def commit(self) -> bool:
"""Synchronize, unmap writer mappings, then commit.
Safe at startup (no local mappings) and during failover
(preserves local VA reservations).
Commit is a publish barrier. It guarantees all prior GPU writes in the
current context are complete before the server transitions state. After
a successful commit, the former writer process no longer has any mapped
access to the published allocations. Any failure after local unmap
raises because the process cannot safely recover its CUDA VMM state.
"""
self._require_rw()
return self._client_rpc.clear_all()
def commit(self) -> bool:
"""Server-only commit: transition to COMMITTED state.
# Publish barrier: all writer-side GPU work must be visible before commit.
cuda_synchronize()
No synchronize(), no CUDA access flip. The caller is responsible for
synchronizing before calling this. Server closes the RW socket on
success, so self._client becomes None.
"""
self._require_rw()
ok = self._client_rpc.commit()
if ok:
self._client = None
return bool(ok)
for mapping in list(self._mappings.values()):
if mapping.handle != 0:
self.unmap_va(mapping.va)
self._va_preserved = True
self._unmapped = True
self._client_rpc.commit()
self._client = None
self._granted_lock_type = None
return True
def get_memory_layout_hash(self) -> str:
return self._client_rpc.get_memory_layout_hash()
def list_handles(self, tag: Optional[str] = None) -> List[Dict]:
def list_handles(self, tag: Optional[str] = None) -> List[GetAllocationResponse]:
return self._client_rpc.list_allocations(tag)
# ==================== Tier 1: Metadata ====================
......@@ -276,26 +308,26 @@ class GMSClientMemoryManager:
def reserve_va(self, size: int) -> int:
"""Reserve virtual address space (cuMemAddressReserve). No tracking."""
aligned_size = align_to_granularity(size, self.granularity)
return _cuda_reserve_va(aligned_size, self.granularity)
return cumem_address_reserve(aligned_size, self.granularity)
def map_va(self, fd: int, va: int, size: int, allocation_id: str, tag: str) -> int:
def map_va(
self,
fd: int,
va: int,
size: int,
allocation_id: str,
tag: str,
layout_slot: int,
) -> int:
"""Import FD + cuMemMap + set access + track.
Access is set based on current lock_type. Returns the CUDA handle.
"""
assert self._granted_lock_type is not None
aligned_size = align_to_granularity(size, self.granularity)
handle = import_handle_from_fd(fd)
try:
map_to_va(va, aligned_size, handle)
set_access(va, aligned_size, self.device, self._granted_lock_type)
except Exception:
try:
unmap(va, aligned_size)
except Exception:
pass
release_handle(handle)
raise
handle = cumem_import_from_shareable_handle_close_fd(fd)
cumem_map(va, aligned_size, handle)
cumem_set_access(va, aligned_size, self.device, self._granted_lock_type)
self._track_mapping(
LocalMapping(
allocation_id=allocation_id,
......@@ -304,6 +336,7 @@ class GMSClientMemoryManager:
aligned_size=aligned_size,
handle=handle,
tag=tag,
layout_slot=layout_slot,
)
)
return handle
......@@ -317,8 +350,8 @@ class GMSClientMemoryManager:
mapping = self._mappings.get(va)
if mapping is None or mapping.handle == 0:
return
unmap(va, mapping.aligned_size)
release_handle(mapping.handle)
cumem_unmap(va, mapping.aligned_size)
cumem_release(mapping.handle)
self._mappings[va] = mapping.with_handle(0)
def free_va(self, va: int) -> None:
......@@ -334,7 +367,7 @@ class GMSClientMemoryManager:
mapping = self._mappings.get(va)
if mapping is None:
return
_cuda_free_va(va, mapping.aligned_size)
cumem_address_free(va, mapping.aligned_size)
self._mappings.pop(va, None)
self._inverse_mapping.pop(mapping.allocation_id, None)
......@@ -370,28 +403,21 @@ class GMSClientMemoryManager:
alloc_size = int(info.size)
aligned_size = int(info.aligned_size)
alloc_tag = str(getattr(info, "tag", "default"))
layout_slot = int(info.layout_slot)
fd = self.export_handle(allocation_id)
va = self.reserve_va(aligned_size)
try:
self.map_va(fd, va, alloc_size, allocation_id, alloc_tag)
except Exception:
_cuda_free_va(va, align_to_granularity(aligned_size, self.granularity))
raise
self.map_va(fd, va, alloc_size, allocation_id, alloc_tag, layout_slot)
return va
# Allocate path
if size <= 0:
raise ValueError("size must be > 0 when allocation_id is None")
alloc_id = self.allocate_handle(size, tag)
alloc_id, layout_slot = self.allocate_handle(size, tag)
fd = self.export_handle(alloc_id)
aligned_size = align_to_granularity(size, self.granularity)
va = self.reserve_va(aligned_size)
try:
self.map_va(fd, va, size, alloc_id, tag)
except Exception:
_cuda_free_va(va, aligned_size)
raise
self.map_va(fd, va, size, alloc_id, tag, layout_slot)
return va
def destroy_mapping(self, va: int) -> None:
......@@ -402,38 +428,25 @@ class GMSClientMemoryManager:
alloc_id = mapping.allocation_id
try:
self.unmap_va(va)
except Exception as e:
logger.warning("Error in unmap_va for 0x%x: %s", va, e)
try:
self.free_va(va)
except Exception as e:
logger.warning("Error in free_va for 0x%x: %s", va, e)
# Only free server handle if we're RW and haven't committed
if self._granted_lock_type == GrantedLockType.RW:
try:
self.free_handle(alloc_id)
except Exception:
pass
self.free_handle(alloc_id)
self.unmap_va(va)
self.free_va(va)
def unmap_all_vas(self) -> None:
"""Synchronize + unmap all VAs. Preserves VA reservations for remap."""
synchronize()
cuda_synchronize()
unmapped_count = 0
total_bytes = 0
for va, mapping in list(self._mappings.items()):
if mapping.handle == 0:
continue
try:
self.unmap_va(va)
unmapped_count += 1
total_bytes += mapping.aligned_size
except Exception as e:
logger.warning("Error unmapping VA 0x%x: %s", va, e)
self.unmap_va(va)
unmapped_count += 1
total_bytes += mapping.aligned_size
self._va_preserved = True
self._unmapped = True
......@@ -451,7 +464,7 @@ class GMSClientMemoryManager:
Checks layout hash for staleness. Validates each allocation still
exists and size matches before remapping.
"""
set_current_device(self.device)
cuda_set_current_device(self.device)
# Stale layout check
current_hash = self.get_memory_layout_hash()
......@@ -465,36 +478,50 @@ class GMSClientMemoryManager:
assert self._granted_lock_type is not None
allocations_by_slot = {
int(info.layout_slot): info for info in self.list_handles()
}
remapped_count = 0
total_bytes = 0
for va, mapping in list(self._mappings.items()):
for va, mapping in sorted(
self._mappings.items(), key=lambda item: item[1].layout_slot
):
if mapping.handle != 0:
continue # Already mapped
continue
# Validate allocation still exists
try:
alloc_info = self.get_handle_info(mapping.allocation_id)
except Exception as e:
alloc_info = allocations_by_slot.get(mapping.layout_slot)
if alloc_info is None:
raise StaleMemoryLayoutError(
f"Allocation {mapping.allocation_id} no longer exists: {e}"
) from e
f"Layout slot {mapping.layout_slot} is missing from the committed layout"
)
if int(alloc_info.aligned_size) != mapping.aligned_size:
raise StaleMemoryLayoutError(
f"Allocation {mapping.allocation_id} size changed: "
f"Layout slot {mapping.layout_slot} size changed: "
f"{mapping.aligned_size} vs {int(alloc_info.aligned_size)}"
)
if str(alloc_info.tag) != mapping.tag:
raise StaleMemoryLayoutError(
f"Layout slot {mapping.layout_slot} tag changed: "
f"{mapping.tag} vs {alloc_info.tag}"
)
# Re-import and map to preserved VA
fd = self.export_handle(mapping.allocation_id)
handle = import_handle_from_fd(fd)
map_to_va(va, mapping.aligned_size, handle)
set_access(va, mapping.aligned_size, self.device, self._granted_lock_type)
synchronize()
validate_pointer(va)
self._mappings[va] = mapping.with_handle(handle)
fd = self.export_handle(alloc_info.allocation_id)
handle = cumem_import_from_shareable_handle_close_fd(fd)
cumem_map(va, mapping.aligned_size, handle)
cumem_set_access(
va, mapping.aligned_size, self.device, self._granted_lock_type
)
cuda_synchronize()
cuda_validate_pointer(va)
if mapping.allocation_id != alloc_info.allocation_id:
self._inverse_mapping.pop(mapping.allocation_id, None)
self._mappings[va] = mapping.with_server_identity(
alloc_info.allocation_id,
int(alloc_info.layout_slot),
).with_handle(handle)
self._inverse_mapping[alloc_info.allocation_id] = va
remapped_count += 1
total_bytes += mapping.aligned_size
......@@ -523,24 +550,26 @@ class GMSClientMemoryManager:
)
reallocated = 0
for va, mapping in list(self._mappings.items()):
for va, mapping in sorted(
self._mappings.items(), key=lambda item: item[1].layout_slot
):
if mapping.handle != 0:
continue
# Allocate fresh handle on server (uses raw RPC to avoid re-aligning)
allocation_id, server_aligned = self._client_rpc.allocate(
mapping.aligned_size, tag
)
if int(server_aligned) != mapping.aligned_size:
response = self._client_rpc.allocate_info(mapping.aligned_size, tag)
if int(response.aligned_size) != mapping.aligned_size:
raise RuntimeError(
f"Alignment mismatch during reallocation: "
f"{mapping.aligned_size} vs {server_aligned}"
"GMS reallocation alignment mismatch: "
f"{mapping.aligned_size} vs {response.aligned_size}"
)
allocation_id = response.allocation_id
# Update tracking: new allocation_id, handle stays 0
old_alloc_id = mapping.allocation_id
self._inverse_mapping.pop(old_alloc_id, None)
self._mappings[va] = mapping.with_allocation_id(allocation_id)
self._mappings[va] = mapping.with_server_identity(
allocation_id,
int(response.layout_slot),
)
self._inverse_mapping[allocation_id] = va
reallocated += 1
......@@ -551,43 +580,25 @@ class GMSClientMemoryManager:
# ==================== Lifecycle ====================
def close(self, free: bool = False) -> None:
"""Best-effort cleanup. NOT reliable in crash/signal paths.
def close(self) -> None:
"""Strict cleanup.
synchronize + unmap all + free all VAs + disconnect.
free=True: also clear_all_handles() on server before disconnect.
VAs are freed by CUDA context teardown on process exit anyway.
synchronize + unmap all + free all VAs + abort.
"""
try:
synchronize()
except Exception:
pass
for va in list(self._mappings.keys()):
try:
self.unmap_va(va)
except Exception as e:
logger.warning("Error unmapping VA 0x%x during close: %s", va, e)
cuda_synchronize()
for va in list(self._mappings.keys()):
try:
self.free_va(va)
except Exception as e:
logger.warning("Error freeing VA 0x%x during close: %s", va, e)
if (
free
and self._client is not None
and self._granted_lock_type == GrantedLockType.RW
):
try:
self.clear_all_handles()
except Exception as e:
logger.warning("Error clearing handles during close: %s", e)
self.unmap_va(va)
self.free_va(va)
self.disconnect()
self.abort()
self._unmapped = False
self._va_preserved = False
from gpu_memory_service.client.torch.allocator import (
evict_gms_client_memory_manager,
)
evict_gms_client_memory_manager(self)
def __enter__(self) -> "GMSClientMemoryManager":
return self
......@@ -598,7 +609,7 @@ class GMSClientMemoryManager:
# ==================== Internals ====================
@property
def _client_rpc(self) -> GMSRPCClient:
def _client_rpc(self) -> _GMSClientSession:
"""Get connected client or raise."""
if self._client is None:
if self._unmapped:
......
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""GPU Memory Service RPC Client.
"""Internal GPU Memory Service transport.
Low-level RPC client stub. The client provides a simple interface for acquiring
locks and performing allocation operations. The socket connection IS the lock.
This module has NO PyTorch dependency.
Usage:
# Writer (acquires RW lock in constructor)
with GMSRPCClient(socket_path, lock_type=RequestedLockType.RW) as client:
alloc_id, aligned_size = client.allocate(size=1024*1024)
fd = client.export(alloc_id)
# ... write weights using fd ...
client.commit()
# Lock released on exit
# Reader (acquires RO lock in constructor)
client = GMSRPCClient(socket_path, lock_type=RequestedLockType.RO)
if client.committed: # Check if weights are valid
allocations = client.list_allocations()
for alloc in allocations:
fd = client.export(alloc["allocation_id"])
# ... import and map fd ...
# Keep connection open during inference!
# client.close() only when done with inference
This module only owns Unix socket transport and typed request/response exchange.
Session semantics live in `gpu_memory_service.client.session`.
"""
from __future__ import annotations
import logging
import os
import socket
from typing import Dict, List, Optional, Tuple, Type, TypeVar
from typing import Optional, Tuple, Type, TypeVar
from gpu_memory_service.common.protocol.messages import (
AllocateRequest,
AllocateResponse,
ClearAllRequest,
ClearAllResponse,
CommitRequest,
CommitResponse,
ErrorResponse,
ExportRequest,
FreeRequest,
FreeResponse,
GetAllocationRequest,
GetAllocationResponse,
GetAllocationStateRequest,
GetAllocationStateResponse,
GetLockStateRequest,
GetLockStateResponse,
GetStateHashRequest,
GetStateHashResponse,
HandshakeRequest,
HandshakeResponse,
ListAllocationsRequest,
ListAllocationsResponse,
MetadataDeleteRequest,
MetadataDeleteResponse,
MetadataGetRequest,
MetadataGetResponse,
MetadataListRequest,
MetadataListResponse,
MetadataPutRequest,
MetadataPutResponse,
)
from gpu_memory_service.common.protocol.wire import recv_message_sync, send_message_sync
from gpu_memory_service.common.types import (
RW_REQUIRED,
GrantedLockType,
RequestedLockType,
)
from gpu_memory_service.common.types import RequestedLockType
T = TypeVar("T")
logger = logging.getLogger(__name__)
class GMSRPCClient:
"""GPU Memory Service RPC Client.
CRITICAL: Socket connection IS the lock.
- Constructor blocks until lock is acquired
- close() releases the lock
- committed property tells readers if weights are valid
class _GMSRPCTransport:
"""Raw GMS Unix socket transport."""
For writers (lock_type=RequestedLockType.RW):
- Use context manager (with statement) for automatic lock release
- Call commit() after weights are written
- Call clear_all() before loading new model
For readers (lock_type=RequestedLockType.RO):
- Check committed property after construction
- Keep connection open during inference lifetime
- Only call close() when shutting down or allowing weight updates
"""
def __init__(
self,
socket_path: str,
lock_type: RequestedLockType = RequestedLockType.RO,
timeout_ms: Optional[int] = None,
):
"""Connect to Allocation Server and acquire lock.
Args:
socket_path: Path to server's Unix domain socket
lock_type: Requested lock type (RW, RO, or RW_OR_RO)
timeout_ms: Timeout in milliseconds for lock acquisition.
None means wait indefinitely.
Raises:
ConnectionError: If connection fails
TimeoutError: If timeout_ms expires waiting for lock
"""
def __init__(self, socket_path: str):
self.socket_path = socket_path
self._requested_lock_type = lock_type
self._socket: Optional[socket.socket] = None
self._recv_buffer = bytearray()
self._committed = False
self._granted_lock_type: Optional[GrantedLockType] = None
# Connect and acquire lock
self._connect(timeout_ms=timeout_ms)
@property
def is_connected(self) -> bool:
return self._socket is not None
def _connect(self, timeout_ms: Optional[int]) -> None:
"""Connect to server and perform handshake (lock acquisition)."""
def connect(self) -> None:
self._socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
try:
self._socket.connect(self.socket_path)
except FileNotFoundError:
self._socket.close()
self._socket = None
raise ConnectionError(f"Server not running at {self.socket_path}") from None
except Exception as e:
self._socket.close()
self._socket = None
raise ConnectionError(f"Failed to connect: {e}") from e
# Handshake I/O — clean up socket on any failure
try:
request = HandshakeRequest(
lock_type=self._requested_lock_type,
timeout_ms=timeout_ms,
)
send_message_sync(self._socket, request)
# May block waiting for lock
response, _, self._recv_buffer = recv_message_sync(
self._socket, self._recv_buffer
)
except Exception:
self._socket.close()
self._socket = None
raise
if isinstance(response, ErrorResponse):
self._socket.close()
self._socket = None
raise ConnectionError(f"Handshake error: {response.error}")
if not isinstance(response, HandshakeResponse):
self._socket.close()
self._socket = None
raise ConnectionError(f"Unexpected response: {type(response)}")
if not response.success:
raise ConnectionError(
f"GMS server not running at {self.socket_path}"
) from None
except Exception as exc:
self._socket.close()
self._socket = None
raise TimeoutError("Timeout waiting for lock")
self._committed = response.committed
# Store granted lock type (may differ from requested for rw_or_ro mode)
if response.granted_lock_type is not None:
self._granted_lock_type = response.granted_lock_type
elif self._requested_lock_type == RequestedLockType.RW:
self._granted_lock_type = GrantedLockType.RW
else:
self._granted_lock_type = GrantedLockType.RO
logger.info(
f"Connected with {self._requested_lock_type.value} lock (granted={self._granted_lock_type.value}), "
f"committed={self._committed}"
)
raise ConnectionError(f"Failed to connect to GMS: {exc}") from exc
@property
def committed(self) -> bool:
"""Check if weights are committed (valid)."""
return self._committed
@property
def lock_type(self) -> Optional[GrantedLockType]:
"""Get the lock type actually granted by the server.
For rw_or_ro mode, this tells you whether RW or RO was granted.
"""
return self._granted_lock_type
@property
def is_connected(self) -> bool:
"""Check if client is connected."""
return self._socket is not None
def _send_recv(self, request) -> Tuple[object, int]:
"""Send request and receive response. Returns (response, fd)."""
if not self._socket:
raise RuntimeError("Client not connected")
send_message_sync(self._socket, request)
response, fd, self._recv_buffer = recv_message_sync(
self._socket, self._recv_buffer
def handshake(
self,
lock_type: RequestedLockType,
timeout_ms: Optional[int],
) -> HandshakeResponse:
response, _ = self.request_with_fd(
HandshakeRequest(lock_type=lock_type, timeout_ms=timeout_ms),
HandshakeResponse,
error_prefix="GMS handshake",
)
if isinstance(response, ErrorResponse):
raise RuntimeError(f"Server error: {response.error}")
return response, fd
def _call(self, request, response_type: Type[T]) -> T:
"""Send request, validate response type, return typed response."""
if type(request) in RW_REQUIRED and self.lock_type != GrantedLockType.RW:
raise RuntimeError("Operation requires RW connection")
response, _ = self._send_recv(request)
if not isinstance(response, response_type):
raise RuntimeError(f"Unexpected response: {type(response)}")
return response
def get_lock_state(self) -> GetLockStateResponse:
return self._call(GetLockStateRequest(), GetLockStateResponse)
def get_allocation_state(self) -> GetAllocationStateResponse:
return self._call(GetAllocationStateRequest(), GetAllocationStateResponse)
def request(self, request, response_type: Type[T]) -> T:
response, fd = self.request_with_fd(request, response_type)
if fd >= 0:
os.close(fd)
raise RuntimeError(
f"GMS request {type(request).__name__} returned an unexpected FD"
)
return response
def is_ready(self) -> bool:
return self.committed
def request_with_fd(
self,
request,
response_type: Type[T],
*,
error_prefix: Optional[str] = None,
) -> Tuple[T, int]:
response, fd = self._send_recv(request, error_prefix=error_prefix)
if not isinstance(response, response_type):
prefix = error_prefix or f"GMS request {type(request).__name__}"
if fd >= 0:
os.close(fd)
raise RuntimeError(
f"{prefix} returned unexpected response type: {type(response)}"
)
return response, fd
def commit(self) -> bool:
"""Commit weights and release RW lock. Returns True on success."""
if CommitRequest in RW_REQUIRED and self.lock_type != GrantedLockType.RW:
raise RuntimeError("Operation requires RW connection")
def _send_recv(
self, request, *, error_prefix: Optional[str] = None
) -> Tuple[object, int]:
if self._socket is None:
raise RuntimeError("Attempted GMS request on disconnected transport")
prefix = error_prefix or f"GMS request {type(request).__name__}"
try:
response, _ = self._send_recv(CommitRequest())
ok = isinstance(response, CommitResponse) and response.success
except (ConnectionResetError, BrokenPipeError, OSError) as e:
# Server closes RW socket as part of commit
logger.debug(
f"Commit saw socket error ({type(e).__name__}); verifying via RO connect"
send_message_sync(self._socket, request)
response, fd, self._recv_buffer = recv_message_sync(
self._socket, self._recv_buffer
)
self.close()
try:
ro = GMSRPCClient(
self.socket_path, lock_type=RequestedLockType.RO, timeout_ms=1000
)
try:
ok = ro.committed
finally:
ro.close()
except TimeoutError:
ok = False
if ok:
self._committed = True
self.close()
logger.info("Committed weights and released RW connection")
return True
return False
def allocate(self, size: int, tag: str = "default") -> Tuple[str, int]:
"""Returns (allocation_id, aligned_size)."""
r = self._call(AllocateRequest(size=size, tag=tag), AllocateResponse)
return r.allocation_id, r.aligned_size
def export(self, allocation_id: str) -> int:
"""Export allocation as POSIX FD. Caller must close."""
_, fd = self._send_recv(ExportRequest(allocation_id=allocation_id))
if fd < 0:
raise RuntimeError("No FD received from server")
return fd
def get_allocation(self, allocation_id: str) -> GetAllocationResponse:
return self._call(
GetAllocationRequest(allocation_id=allocation_id), GetAllocationResponse
)
def list_allocations(self, tag: Optional[str] = None) -> List[Dict]:
return self._call(
ListAllocationsRequest(tag=tag), ListAllocationsResponse
).allocations
def free(self, allocation_id: str) -> bool:
return self._call(
FreeRequest(allocation_id=allocation_id), FreeResponse
).success
def clear_all(self) -> int:
return self._call(ClearAllRequest(), ClearAllResponse).cleared_count
def metadata_put(
self, key: str, allocation_id: str, offset_bytes: int, value: bytes
) -> bool:
req = MetadataPutRequest(
key=key, allocation_id=allocation_id, offset_bytes=offset_bytes, value=value
)
return self._call(req, MetadataPutResponse).success
def metadata_get(self, key: str) -> Optional[tuple[str, int, bytes]]:
"""Returns (allocation_id, offset_bytes, value) or None if not found."""
r = self._call(MetadataGetRequest(key=key), MetadataGetResponse)
return (r.allocation_id, r.offset_bytes, r.value) if r.found else None
def metadata_delete(self, key: str) -> bool:
return self._call(
MetadataDeleteRequest(key=key), MetadataDeleteResponse
).deleted
def metadata_list(self, prefix: str = "") -> List[str]:
return self._call(MetadataListRequest(prefix=prefix), MetadataListResponse).keys
def get_memory_layout_hash(self) -> str:
"""Get state hash (hash of allocations + metadata). Empty if not committed."""
return self._call(
GetStateHashRequest(), GetStateHashResponse
).memory_layout_hash
def close(self) -> None:
"""Close connection and release lock."""
if self._socket:
except Exception as exc:
try:
self._socket.close()
except Exception:
pass
self._socket = None
lock_str = self.lock_type.value if self.lock_type else "unknown"
logger.info(f"Closed {lock_str} connection")
raise ConnectionError(f"{prefix} failed: {exc}") from exc
def __enter__(self) -> "GMSRPCClient":
"""Context manager entry."""
if isinstance(response, ErrorResponse):
if fd >= 0:
os.close(fd)
raise RuntimeError(f"{prefix} error: {response.error}")
return response, fd
def close(self) -> None:
if self._socket is None:
return
try:
self._socket.close()
except Exception as exc:
raise ConnectionError(
f"Failed to close GMS transport socket: {exc}"
) from exc
finally:
self._socket = None
def __enter__(self) -> "_GMSRPCTransport":
return self
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
"""Context manager exit."""
self.close()
def __del__(self):
"""Destructor: warn if connection not closed."""
if self._socket:
logger.warning("GMSRPCClient not closed properly")
if self._socket is not None:
try:
self._socket.close()
except Exception:
pass
self._socket = None
logger.warning("_GMSRPCTransport not closed properly")
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Internal GPU Memory Service client session."""
from __future__ import annotations
import logging
from typing import List, Optional, Tuple
from gpu_memory_service.client.rpc import _GMSRPCTransport
from gpu_memory_service.common.protocol.messages import (
AllocateRequest,
AllocateResponse,
CommitRequest,
CommitResponse,
ExportAllocationRequest,
ExportAllocationResponse,
FreeAllocationRequest,
FreeAllocationResponse,
GetAllocationRequest,
GetAllocationResponse,
GetAllocationStateRequest,
GetAllocationStateResponse,
GetLockStateRequest,
GetLockStateResponse,
GetStateHashRequest,
GetStateHashResponse,
HandshakeResponse,
ListAllocationsRequest,
ListAllocationsResponse,
MetadataDeleteRequest,
MetadataDeleteResponse,
MetadataGetRequest,
MetadataGetResponse,
MetadataListRequest,
MetadataListResponse,
MetadataPutRequest,
MetadataPutResponse,
)
from gpu_memory_service.common.types import GrantedLockType, RequestedLockType
logger = logging.getLogger(__name__)
class _GMSClientSession:
"""Connected GMS client session with granted lock state."""
def __init__(
self,
socket_path: str,
lock_type: RequestedLockType,
timeout_ms: Optional[int],
):
self._requested_lock_type = lock_type
self._transport = _GMSRPCTransport(socket_path)
self._transport.connect()
try:
response = self._transport.handshake(lock_type, timeout_ms)
except Exception:
try:
self._transport.close()
except Exception:
pass
raise
self._initialize_from_handshake(response)
def _initialize_from_handshake(self, response: HandshakeResponse) -> None:
if not response.success:
self._transport.close()
raise TimeoutError("Timeout waiting for lock")
self._committed = response.committed
if response.granted_lock_type is None:
self._transport.close()
raise RuntimeError("HandshakeResponse omitted granted_lock_type")
self._granted_lock_type = response.granted_lock_type
logger.info(
"Connected with %s lock (granted=%s), committed=%s",
self._requested_lock_type.value,
self._granted_lock_type.value,
self._committed,
)
@property
def committed(self) -> bool:
return self._committed
@property
def lock_type(self) -> GrantedLockType:
return self._granted_lock_type
@property
def is_connected(self) -> bool:
return self._transport.is_connected
def get_lock_state(self) -> GetLockStateResponse:
return self._transport.request(GetLockStateRequest(), GetLockStateResponse)
def get_allocation_state(self) -> GetAllocationStateResponse:
return self._transport.request(
GetAllocationStateRequest(), GetAllocationStateResponse
)
def is_ready(self) -> bool:
return self.committed
def commit(self) -> bool:
response = self._transport.request(CommitRequest(), CommitResponse)
if not response.success:
raise RuntimeError("GMS commit returned failure")
self._committed = True
try:
self.close()
except ConnectionError as exc:
logger.warning("Commit succeeded but closing transport failed: %s", exc)
logger.info("Committed weights and released RW connection")
return True
def allocate_info(self, size: int, tag: str = "default") -> AllocateResponse:
return self._transport.request(
AllocateRequest(size=size, tag=tag), AllocateResponse
)
def allocate(self, size: int, tag: str = "default") -> Tuple[str, int]:
response = self.allocate_info(size=size, tag=tag)
return response.allocation_id, response.aligned_size
def export(self, allocation_id: str) -> int:
response, fd = self._transport.request_with_fd(
ExportAllocationRequest(allocation_id=allocation_id),
ExportAllocationResponse,
)
if fd < 0:
raise RuntimeError(
f"GMS export returned no FD for allocation_id={allocation_id}"
)
return fd
def get_allocation(self, allocation_id: str) -> GetAllocationResponse:
return self._transport.request(
GetAllocationRequest(allocation_id=allocation_id),
GetAllocationResponse,
)
def list_allocations(
self, tag: Optional[str] = None
) -> List[GetAllocationResponse]:
return self._transport.request(
ListAllocationsRequest(tag=tag),
ListAllocationsResponse,
).allocations
def free(self, allocation_id: str) -> bool:
return self._transport.request(
FreeAllocationRequest(allocation_id=allocation_id),
FreeAllocationResponse,
).success
def metadata_put(
self, key: str, allocation_id: str, offset_bytes: int, value: bytes
) -> bool:
return self._transport.request(
MetadataPutRequest(
key=key,
allocation_id=allocation_id,
offset_bytes=offset_bytes,
value=value,
),
MetadataPutResponse,
).success
def metadata_get(self, key: str) -> Optional[tuple[str, int, bytes]]:
response = self._transport.request(
MetadataGetRequest(key=key), MetadataGetResponse
)
if not response.found:
return None
return response.allocation_id, response.offset_bytes, response.value
def metadata_delete(self, key: str) -> bool:
return self._transport.request(
MetadataDeleteRequest(key=key), MetadataDeleteResponse
).deleted
def metadata_list(self, prefix: str = "") -> List[str]:
return self._transport.request(
MetadataListRequest(prefix=prefix), MetadataListResponse
).keys
def get_memory_layout_hash(self) -> str:
return self._transport.request(
GetStateHashRequest(), GetStateHashResponse
).memory_layout_hash
def close(self) -> None:
self._transport.close()
logger.info("Closed %s connection", self._granted_lock_type.value)
def __enter__(self) -> "_GMSClientSession":
return self
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
self.close()
......@@ -12,7 +12,9 @@ This module provides PyTorch-specific functionality:
from gpu_memory_service.client.torch.allocator import (
get_gms_client_memory_manager,
get_gms_client_memory_managers,
get_or_create_gms_client_memory_manager,
gms_use_mem_pool,
)
from gpu_memory_service.client.torch.module import (
materialize_module_from_gms,
......@@ -23,6 +25,8 @@ __all__ = [
# GMS client memory manager
"get_or_create_gms_client_memory_manager",
"get_gms_client_memory_manager",
"get_gms_client_memory_managers",
"gms_use_mem_pool",
# Tensor operations (public API)
"register_module_tensors",
"materialize_module_from_gms",
......
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""GPU Memory Service allocator management (singleton).
Manages a single weights memory manager and PyTorch MemPool integration.
Only one GMS scope is needed: weights. KV cache is handled by CuMemAllocator.
"""
"""GPU Memory Service allocator registry for PyTorch integration."""
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Any, Optional, Tuple
from contextlib import contextmanager
from contextvars import ContextVar
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Iterator, Optional
from gpu_memory_service.common.types import GrantedLockType, RequestedLockType
if TYPE_CHECKING:
import torch
from gpu_memory_service.client.memory_manager import GMSClientMemoryManager
from torch.cuda.memory import MemPool
logger = logging.getLogger(__name__)
# Singleton state
_manager: Optional["GMSClientMemoryManager"] = None
_mem_pool: Optional["MemPool"] = None
_tag: str = "weights"
_callbacks_initialized: bool = False
_pluggable_alloc: Optional[Any] = None
@dataclass
class _TagState:
manager: "GMSClientMemoryManager"
mem_pool: "MemPool | None"
socket_path: str
device: int
_tag_states: dict[str, _TagState] = {}
_active_tag: ContextVar[str | None] = ContextVar(
"gpu_memory_service_active_tag",
default=None,
)
_callbacks_initialized = False
_pluggable_alloc: Any | None = None
def _gms_malloc(size: int, device: int, stream: int) -> int:
"""Route malloc to the singleton weights manager."""
if _manager is None:
raise RuntimeError("No GMS manager initialized")
va = _manager.create_mapping(size=int(size), tag=_tag)
logger.debug("[GMS] malloc: va=0x%x size=%d", va, size)
tag = _active_tag.get()
if tag is None:
raise RuntimeError("No active GMS allocation tag")
state = _tag_states.get(tag)
if state is None:
raise RuntimeError(f"Unknown GMS allocation tag: {tag}")
va = state.manager.create_mapping(size=int(size), tag=tag)
logger.debug("[GMS] malloc(tag=%s): va=0x%x size=%d", tag, va, size)
return va
def _gms_free(ptr: int, size: int, device: int, stream: int) -> None:
"""Route free to the singleton weights manager."""
if _manager is None:
logger.warning("[GMS] free: no manager, ignoring va=0x%x", ptr)
va = int(ptr)
for tag, state in _tag_states.items():
if va not in state.manager.mappings:
continue
logger.debug("[GMS] free(tag=%s): va=0x%x size=%d", tag, va, size)
state.manager.destroy_mapping(va)
return
if int(ptr) in _manager.mappings:
logger.debug("[GMS] free: va=0x%x size=%d", ptr, size)
_manager.destroy_mapping(int(ptr))
else:
logger.warning("[GMS] free: manager does not own va=0x%x, ignoring", ptr)
logger.warning("[GMS] free: no manager owns va=0x%x, ignoring", va)
def _ensure_callbacks_initialized() -> "MemPool":
"""Initialize C-level callbacks exactly once, return a new MemPool."""
def _ensure_callbacks_initialized() -> None:
global _callbacks_initialized, _pluggable_alloc
from gpu_memory_service.client.torch.extensions import _allocator_ext as cumem
from torch.cuda import CUDAPluggableAllocator
from torch.cuda.memory import MemPool
if not _callbacks_initialized:
_pluggable_alloc = CUDAPluggableAllocator(
cumem.__file__, "my_malloc", "my_free"
)
cumem.init_module(_gms_malloc, _gms_free)
_callbacks_initialized = True
if _callbacks_initialized:
return
_pluggable_alloc = CUDAPluggableAllocator(cumem.__file__, "my_malloc", "my_free")
cumem.init_module(_gms_malloc, _gms_free)
_callbacks_initialized = True
def _create_mem_pool() -> "MemPool":
from torch.cuda.memory import MemPool
assert _pluggable_alloc is not None
return MemPool(allocator=_pluggable_alloc.allocator())
......@@ -74,66 +91,98 @@ def get_or_create_gms_client_memory_manager(
*,
tag: str = "weights",
timeout_ms: Optional[int] = None,
) -> Tuple["GMSClientMemoryManager", Optional["MemPool"]]:
"""Get existing memory manager, or create a new one.
Args:
socket_path: Unix socket path for the allocation server.
device: CUDA device index.
mode: RW for cold start, RO for import-only, RW_OR_RO for auto.
tag: Allocation tag for RW mode.
timeout_ms: Lock acquisition timeout (None = wait indefinitely).
Returns:
(gms_client_memory_manager, pool) - pool is None for RO mode.
"""
global _manager, _mem_pool, _tag
) -> "GMSClientMemoryManager":
from gpu_memory_service.client.memory_manager import GMSClientMemoryManager
if _manager is not None:
return _get_existing(mode)
state = _tag_states.get(tag)
if state is not None:
if state.socket_path != socket_path or state.device != device:
raise RuntimeError(
f"GMS allocator tag={tag} was initialized for "
f"{state.socket_path} on device {state.device}, not {socket_path} "
f"on device {device}"
)
manager = state.manager
if not manager.is_connected:
if manager.mappings or manager.is_unmapped or manager.granted_lock_type:
raise RuntimeError(
f"GMS allocator tag={tag} is disconnected but still owns "
"preserved state; recreate the process instead of reusing it"
)
manager._client = None
manager._granted_lock_type = None
_tag_states.pop(tag, None)
state = None
if state is not None:
current = state.manager.granted_lock_type
if mode == RequestedLockType.RW and current != GrantedLockType.RW:
raise RuntimeError(
f"Cannot get RW allocator for tag {tag}: existing is in {current} mode"
)
if mode == RequestedLockType.RO and current != GrantedLockType.RO:
raise RuntimeError(
f"Cannot get RO allocator for tag {tag}: existing is in {current} mode"
)
return state.manager
manager = GMSClientMemoryManager(socket_path, device=device)
manager.connect(mode, timeout_ms=timeout_ms)
mem_pool = None
if manager.granted_lock_type == GrantedLockType.RW:
pool = _ensure_callbacks_initialized()
# Only set globals after mempool succeeds (avoids partial singleton)
_manager = manager
_tag = tag
_mem_pool = pool
logger.info("[GMS] Created RW allocator (device=%d)", device)
return manager, pool
else:
_manager = manager
_tag = tag
logger.info("[GMS] Created RO allocator (device=%d)", device)
return manager, None
def _get_existing(
mode: RequestedLockType,
) -> Tuple["GMSClientMemoryManager", Optional["MemPool"]]:
"""Return existing allocator if mode-compatible."""
assert _manager is not None
current = _manager.granted_lock_type
_ensure_callbacks_initialized()
mem_pool = _create_mem_pool()
_tag_states[tag] = _TagState(
manager=manager,
mem_pool=mem_pool,
socket_path=socket_path,
device=device,
)
logger.info(
"[GMS] Created %s allocator for tag=%s (device=%d)",
manager.granted_lock_type.value,
tag,
device,
)
return manager
def get_gms_client_memory_manager(
tag: str = "weights",
) -> "GMSClientMemoryManager | None":
state = _tag_states.get(tag)
if state is None:
return None
return state.manager
def get_gms_client_memory_managers() -> tuple["GMSClientMemoryManager", ...]:
return tuple(state.manager for state in _tag_states.values())
if mode == RequestedLockType.RW:
if current == GrantedLockType.RW:
return _manager, _mem_pool
raise RuntimeError(f"Cannot get RW allocator: existing is in {current} mode")
def evict_gms_client_memory_manager(manager: "GMSClientMemoryManager") -> None:
for tag, state in list(_tag_states.items()):
if state.manager is manager:
_tag_states.pop(tag, None)
return
if mode == RequestedLockType.RO:
if current == GrantedLockType.RO:
return _manager, None
raise RuntimeError(f"Cannot get RO allocator: existing is in {current} mode")
# RW_OR_RO: return whatever exists
effective_pool = _mem_pool if current == GrantedLockType.RW else None
return _manager, effective_pool
@contextmanager
def gms_use_mem_pool(tag: str, device: "torch.device | int") -> Iterator[None]:
import torch
state = _tag_states.get(tag)
if state is None:
raise RuntimeError(f"No GMS allocator initialized for tag={tag}")
if state.mem_pool is None:
raise RuntimeError(f"GMS allocator tag={tag} does not have a mempool")
def get_gms_client_memory_manager() -> Optional["GMSClientMemoryManager"]:
"""Get the active GMS client memory manager, or None."""
return _manager
token = _active_tag.set(tag)
try:
with torch.cuda.use_mem_pool(state.mem_pool, device=device):
yield
finally:
_active_tag.reset(token)
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""CUDA driver helpers shared by the GMS client and server."""
from __future__ import annotations
import atexit
import os
from cuda.bindings import driver as cuda
from gpu_memory_service.common.types import GrantedLockType
from gpu_memory_service.common.utils import fail
_primary_contexts: dict[int, object] = {}
_primary_context_release_registered = False
def cuda_check_result(result: cuda.CUresult, name: str) -> None:
if result != cuda.CUresult.CUDA_SUCCESS:
err_result, err_str = cuda.cuGetErrorString(result)
if err_result == cuda.CUresult.CUDA_SUCCESS and err_str:
err_msg = err_str.decode() if isinstance(err_str, bytes) else str(err_str)
else:
err_msg = str(result)
fail("fatal CUDA VMM error in %s: %s", name, err_msg)
def cuda_ensure_initialized() -> None:
(result,) = cuda.cuInit(0)
cuda_check_result(result, "cuInit")
def cumem_get_allocation_granularity(device: int) -> int:
"""Get VMM allocation granularity for a device.
Args:
device: CUDA device index
Returns:
Allocation granularity in bytes (typically 2 MiB)
"""
prop = cuda.CUmemAllocationProp()
prop.type = cuda.CUmemAllocationType.CU_MEM_ALLOCATION_TYPE_PINNED
prop.location.type = cuda.CUmemLocationType.CU_MEM_LOCATION_TYPE_DEVICE
prop.location.id = device
prop.requestedHandleTypes = (
cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR
)
result, granularity = cuda.cuMemGetAllocationGranularity(
prop, cuda.CUmemAllocationGranularity_flags.CU_MEM_ALLOC_GRANULARITY_MINIMUM
)
cuda_check_result(result, "cuMemGetAllocationGranularity")
return int(granularity)
def cumem_create_tolerate_oom(size: int, device: int) -> tuple[bool, int]:
prop = cuda.CUmemAllocationProp()
prop.type = cuda.CUmemAllocationType.CU_MEM_ALLOCATION_TYPE_PINNED
prop.location.type = cuda.CUmemLocationType.CU_MEM_LOCATION_TYPE_DEVICE
prop.location.id = device
prop.requestedHandleTypes = (
cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR
)
result, handle = cuda.cuMemCreate(size, prop, 0)
if result == cuda.CUresult.CUDA_SUCCESS:
return True, int(handle)
if result == cuda.CUresult.CUDA_ERROR_OUT_OF_MEMORY:
return False, 0
cuda_check_result(result, "cuMemCreate")
return False, 0
def cumem_export_to_shareable_handle(handle: int) -> int:
result, fd = cuda.cuMemExportToShareableHandle(
handle,
cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR,
0,
)
cuda_check_result(result, "cuMemExportToShareableHandle")
return int(fd)
def align_to_granularity(size: int, granularity: int) -> int:
"""Align size up to VMM granularity.
Args:
size: Size in bytes
granularity: Allocation granularity
Returns:
Aligned size
"""
return ((size + granularity - 1) // granularity) * granularity
def cumem_import_from_shareable_handle_close_fd(fd: int) -> int:
try:
result, handle = cuda.cuMemImportFromShareableHandle(
fd,
cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR,
)
cuda_check_result(result, "cuMemImportFromShareableHandle")
return int(handle)
finally:
os.close(fd)
def cumem_address_reserve(size: int, granularity: int) -> int:
result, va = cuda.cuMemAddressReserve(size, granularity, 0, 0)
cuda_check_result(result, "cuMemAddressReserve")
return int(va)
def cumem_address_free(va: int, size: int) -> None:
(result,) = cuda.cuMemAddressFree(va, size)
cuda_check_result(result, "cuMemAddressFree")
def cumem_map(va: int, size: int, handle: int) -> None:
(result,) = cuda.cuMemMap(va, size, 0, handle, 0)
cuda_check_result(result, "cuMemMap")
def cumem_set_access(va: int, size: int, device: int, access: GrantedLockType) -> None:
access_desc = cuda.CUmemAccessDesc()
access_desc.location.type = cuda.CUmemLocationType.CU_MEM_LOCATION_TYPE_DEVICE
access_desc.location.id = device
access_desc.flags = (
cuda.CUmemAccess_flags.CU_MEM_ACCESS_FLAGS_PROT_READ
if access == GrantedLockType.RO
else cuda.CUmemAccess_flags.CU_MEM_ACCESS_FLAGS_PROT_READWRITE
)
(result,) = cuda.cuMemSetAccess(va, size, [access_desc], 1)
cuda_check_result(result, "cuMemSetAccess")
def cumem_unmap(va: int, size: int) -> None:
(result,) = cuda.cuMemUnmap(va, size)
cuda_check_result(result, "cuMemUnmap")
def cumem_release(handle: int) -> None:
(result,) = cuda.cuMemRelease(handle)
cuda_check_result(result, "cuMemRelease")
def cuda_validate_pointer(va: int) -> None:
result, _ = cuda.cuPointerGetAttribute(
cuda.CUpointer_attribute.CU_POINTER_ATTRIBUTE_DEVICE_POINTER, va
)
cuda_check_result(result, "cuPointerGetAttribute")
def cuda_synchronize() -> None:
(result,) = cuda.cuCtxSynchronize()
cuda_check_result(result, "cuCtxSynchronize")
def cuda_set_current_device(device: int) -> None:
global _primary_context_release_registered
ctx = _primary_contexts.get(device)
if ctx is None:
result, ctx = cuda.cuDevicePrimaryCtxRetain(device)
cuda_check_result(result, "cuDevicePrimaryCtxRetain")
_primary_contexts[device] = ctx
if not _primary_context_release_registered:
_primary_context_release_registered = True
atexit.register(_release_primary_contexts)
(result,) = cuda.cuCtxSetCurrent(ctx)
cuda_check_result(result, "cuCtxSetCurrent")
def _release_primary_contexts() -> None:
for device in list(_primary_contexts):
try:
(result,) = cuda.cuDevicePrimaryCtxRelease(device)
except Exception:
continue
if result == cuda.CUresult.CUDA_SUCCESS:
_primary_contexts.pop(device, None)
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""CUDA Virtual Memory Management (VMM) utility functions.
This module provides utility functions for CUDA driver API operations
used by both server (GMSServerMemoryManager) and client (GMSClientMemoryManager).
"""
from cuda.bindings import driver as cuda
def check_cuda_result(result: cuda.CUresult, name: str) -> None:
"""Check CUDA driver API result and raise on error.
Args:
result: CUDA driver API return code (CUresult enum)
name: Operation name for error message
Raises:
RuntimeError: If result is not CUDA_SUCCESS
"""
if result != cuda.CUresult.CUDA_SUCCESS:
err_result, err_str = cuda.cuGetErrorString(result)
if err_result == cuda.CUresult.CUDA_SUCCESS and err_str:
err_msg = err_str.decode() if isinstance(err_str, bytes) else str(err_str)
else:
err_msg = str(result)
raise RuntimeError(f"{name}: {err_msg}")
def ensure_cuda_initialized() -> None:
"""Ensure CUDA driver is initialized.
Raises:
RuntimeError: If cuInit fails
"""
(result,) = cuda.cuInit(0)
check_cuda_result(result, "cuInit")
def get_allocation_granularity(device: int) -> int:
"""Get VMM allocation granularity for a device.
Args:
device: CUDA device index
Returns:
Allocation granularity in bytes (typically 2 MiB)
"""
prop = cuda.CUmemAllocationProp()
prop.type = cuda.CUmemAllocationType.CU_MEM_ALLOCATION_TYPE_PINNED
prop.location.type = cuda.CUmemLocationType.CU_MEM_LOCATION_TYPE_DEVICE
prop.location.id = device
prop.requestedHandleTypes = (
cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR
)
result, granularity = cuda.cuMemGetAllocationGranularity(
prop, cuda.CUmemAllocationGranularity_flags.CU_MEM_ALLOC_GRANULARITY_MINIMUM
)
check_cuda_result(result, "cuMemGetAllocationGranularity")
return int(granularity)
def align_to_granularity(size: int, granularity: int) -> int:
"""Align size up to VMM granularity.
Args:
size: Size in bytes
granularity: Allocation granularity
Returns:
Aligned size
"""
return ((size + granularity - 1) // granularity) * granularity
......@@ -4,7 +4,7 @@
"""Message types for GPU Memory Service RPC protocol."""
from enum import Enum
from typing import Any, Dict, List, Optional, Union
from typing import List, Optional, Union
import msgspec
......@@ -62,7 +62,6 @@ class GetAllocationStateRequest(msgspec.Struct, tag="get_allocation_state_reques
class GetAllocationStateResponse(msgspec.Struct, tag="get_allocation_state_response"):
allocation_count: int
total_bytes: int
class AllocateRequest(msgspec.Struct, tag="allocate_request"):
......@@ -74,12 +73,21 @@ class AllocateResponse(msgspec.Struct, tag="allocate_response"):
allocation_id: str
size: int
aligned_size: int
layout_slot: int
class ExportRequest(msgspec.Struct, tag="export_request"):
class ExportAllocationRequest(msgspec.Struct, tag="export_allocation_request"):
allocation_id: str
class ExportAllocationResponse(msgspec.Struct, tag="export_allocation_response"):
allocation_id: str
size: int
aligned_size: int
tag: str
layout_slot: int
class GetAllocationRequest(msgspec.Struct, tag="get_allocation_request"):
allocation_id: str
......@@ -89,6 +97,7 @@ class GetAllocationResponse(msgspec.Struct, tag="get_allocation_response"):
size: int
aligned_size: int
tag: str
layout_slot: int
class ListAllocationsRequest(msgspec.Struct, tag="list_allocations_request"):
......@@ -96,25 +105,17 @@ class ListAllocationsRequest(msgspec.Struct, tag="list_allocations_request"):
class ListAllocationsResponse(msgspec.Struct, tag="list_allocations_response"):
allocations: List[Dict[str, Any]] = []
allocations: List[GetAllocationResponse] = []
class FreeRequest(msgspec.Struct, tag="free_request"):
class FreeAllocationRequest(msgspec.Struct, tag="free_allocation_request"):
allocation_id: str
class FreeResponse(msgspec.Struct, tag="free_response"):
class FreeAllocationResponse(msgspec.Struct, tag="free_allocation_response"):
success: bool
class ClearAllRequest(msgspec.Struct, tag="clear_all_request"):
pass
class ClearAllResponse(msgspec.Struct, tag="clear_all_response"):
cleared_count: int
class ErrorResponse(msgspec.Struct, tag="error_response"):
error: str
code: int = 0
......@@ -166,6 +167,34 @@ class GetStateHashResponse(msgspec.Struct, tag="get_memory_layout_hash_response"
memory_layout_hash: str # Hash of allocations + metadata, empty if not committed
class GetRuntimeStateRequest(msgspec.Struct, tag="get_runtime_state_request"):
pass
class GetRuntimeStateResponse(msgspec.Struct, tag="get_runtime_state_response"):
state: str
has_rw_session: bool
ro_session_count: int
waiting_writers: int
committed: bool
is_ready: bool
allocation_count: int = 0
memory_layout_hash: str = ""
class GMSRuntimeEvent(msgspec.Struct):
kind: str
allocation_count: int = 0
class GetEventHistoryRequest(msgspec.Struct, tag="get_event_history_request"):
pass
class GetEventHistoryResponse(msgspec.Struct, tag="get_event_history_response"):
events: List[GMSRuntimeEvent] = []
Message = Union[
HandshakeRequest,
HandshakeResponse,
......@@ -177,15 +206,14 @@ Message = Union[
GetAllocationStateResponse,
AllocateRequest,
AllocateResponse,
ExportRequest,
ExportAllocationRequest,
ExportAllocationResponse,
GetAllocationRequest,
GetAllocationResponse,
ListAllocationsRequest,
ListAllocationsResponse,
FreeRequest,
FreeResponse,
ClearAllRequest,
ClearAllResponse,
FreeAllocationRequest,
FreeAllocationResponse,
ErrorResponse,
MetadataPutRequest,
MetadataPutResponse,
......@@ -197,6 +225,10 @@ Message = Union[
MetadataListResponse,
GetStateHashRequest,
GetStateHashResponse,
GetRuntimeStateRequest,
GetRuntimeStateResponse,
GetEventHistoryRequest,
GetEventHistoryResponse,
]
_encoder = msgspec.msgpack.Encoder()
......
......@@ -96,7 +96,11 @@ async def recv_message(
raw_msg, fds, _flags, _addr = await loop.run_in_executor(
None, lambda: socket.recv_fds(raw_sock, 65536, 1)
)
for extra_fd in fds[1:]:
os.close(extra_fd)
if not raw_msg:
if fds:
os.close(fds[0])
raise ConnectionResetError("Connection closed")
recv_buffer.extend(raw_msg)
fd = fds[0] if fds else -1
......@@ -107,21 +111,25 @@ async def recv_message(
recv_buffer.extend(chunk)
# Try to extract message, read more if needed
msg, remaining, bytes_needed = _try_extract_message(recv_buffer)
while msg is None and bytes_needed > 0:
if raw_sock is not None:
# Continue reading from raw socket to avoid buffer inconsistency
chunk = await loop.run_in_executor(
None, lambda n=bytes_needed: raw_sock.recv(n)
)
else:
chunk = await reader.read(bytes_needed)
if not chunk:
raise ConnectionResetError("Connection closed")
remaining.extend(chunk)
msg, remaining, bytes_needed = _try_extract_message(remaining)
return msg, fd, remaining
try:
msg, remaining, bytes_needed = _try_extract_message(recv_buffer)
while msg is None and bytes_needed > 0:
if raw_sock is not None:
# Continue reading from raw socket to avoid buffer inconsistency
chunk = await loop.run_in_executor(
None, lambda n=bytes_needed: raw_sock.recv(n)
)
else:
chunk = await reader.read(bytes_needed)
if not chunk:
raise ConnectionResetError("Connection closed")
remaining.extend(chunk)
msg, remaining, bytes_needed = _try_extract_message(remaining)
return msg, fd, remaining
except Exception:
if fd >= 0:
os.close(fd)
raise
# ==================== Sync (for client) ====================
......@@ -153,18 +161,26 @@ def recv_message_sync(
# Receive more data (with potential FD)
raw_msg, fds, _flags, _addr = socket.recv_fds(sock, 65536, 1)
for extra_fd in fds[1:]:
os.close(extra_fd)
if not raw_msg:
if fds:
os.close(fds[0])
raise ConnectionResetError("Connection closed")
recv_buffer.extend(raw_msg)
fd = fds[0] if fds else -1
# Try to extract message, read more if needed
msg, remaining, bytes_needed = _try_extract_message(recv_buffer)
while msg is None and bytes_needed > 0:
chunk = sock.recv(bytes_needed)
if not chunk:
raise ConnectionResetError("Connection closed")
remaining.extend(chunk)
msg, remaining, bytes_needed = _try_extract_message(remaining)
return msg, fd, remaining
try:
msg, remaining, bytes_needed = _try_extract_message(recv_buffer)
while msg is None and bytes_needed > 0:
chunk = sock.recv(bytes_needed)
if not chunk:
raise ConnectionResetError("Connection closed")
remaining.extend(chunk)
msg, remaining, bytes_needed = _try_extract_message(remaining)
return msg, fd, remaining
except Exception:
if fd >= 0:
os.close(fd)
raise
......@@ -8,10 +8,9 @@ from enum import Enum, auto
from gpu_memory_service.common.protocol.messages import (
AllocateRequest,
ClearAllRequest,
CommitRequest,
ExportRequest,
FreeRequest,
ExportAllocationRequest,
FreeAllocationRequest,
GetAllocationRequest,
GetAllocationStateRequest,
GetLockStateRequest,
......@@ -89,8 +88,7 @@ def derive_state(has_rw: bool, ro_count: int, committed: bool) -> ServerState:
RW_REQUIRED: frozenset[type] = frozenset(
{
AllocateRequest,
FreeRequest,
ClearAllRequest,
FreeAllocationRequest,
MetadataPutRequest,
MetadataDeleteRequest,
CommitRequest,
......@@ -99,7 +97,7 @@ RW_REQUIRED: frozenset[type] = frozenset(
RO_ALLOWED: frozenset[type] = frozenset(
{
ExportRequest,
ExportAllocationRequest,
GetAllocationRequest,
ListAllocationsRequest,
MetadataGetRequest,
......
......@@ -3,36 +3,39 @@
"""Shared utilities for GPU Memory Service."""
import logging
import os
import tempfile
import uuid
from typing import NoReturn
from cuda.bindings import driver as cuda
from gpu_memory_service.common.cuda_vmm_utils import (
check_cuda_result,
ensure_cuda_initialized,
)
logger = logging.getLogger(__name__)
def get_socket_path(device: int) -> str:
"""Get GMS socket path for the given CUDA device.
def fail(message: str, *args, exc_info=None) -> NoReturn:
logger.critical(message, *args, exc_info=exc_info)
logging.shutdown()
os._exit(1)
The socket path is based on GPU UUID resolved by CUDA.
CUDA_VISIBLE_DEVICES remapping is handled by CUDA device enumeration.
def get_socket_path(device: int, tag: str = "weights") -> str:
"""Get GMS socket path for the given CUDA device and tag.
The socket path is based on GPU UUID, making it stable across different
CUDA_VISIBLE_DEVICES configurations.
Args:
device: CUDA device index.
Returns:
Socket path (e.g., "<tempdir>/gms_GPU-12345678-1234-1234-1234-123456789abc.sock").
Socket path
(e.g., "<tempdir>/gms_GPU-12345678-1234-1234-1234-123456789abc_weights.sock").
"""
ensure_cuda_initialized()
result, cu_device = cuda.cuDeviceGet(device)
check_cuda_result(result, "cuDeviceGet")
result, cu_uuid = cuda.cuDeviceGetUuid(cu_device)
check_cuda_result(result, "cuDeviceGetUuid")
gpu_uuid = f"GPU-{uuid.UUID(bytes=bytes(cu_uuid.bytes))}"
return os.path.join(tempfile.gettempdir(), f"gms_{gpu_uuid}.sock")
import pynvml
pynvml.nvmlInit()
try:
handle = pynvml.nvmlDeviceGetHandleByIndex(device)
uuid = pynvml.nvmlDeviceGetUUID(handle)
finally:
pynvml.nvmlShutdown()
return os.path.join(tempfile.gettempdir(), f"gms_{uuid}_{tag}.sock")
......@@ -8,7 +8,7 @@ from __future__ import annotations
import logging
import torch
from gpu_memory_service import get_gms_client_memory_manager
from gpu_memory_service.client.torch.allocator import get_gms_client_memory_managers
logger = logging.getLogger(__name__)
......@@ -32,11 +32,13 @@ def patch_empty_cache() -> None:
_original_empty_cache = torch.cuda.empty_cache
def safe_empty_cache() -> None:
manager = get_gms_client_memory_manager()
if manager is not None and len(manager.mappings) > 0:
mapping_count = sum(
len(manager.mappings) for manager in get_gms_client_memory_managers()
)
if mapping_count > 0:
logger.debug(
"[GMS] Skipping torch.cuda.empty_cache() - %d VMM allocations active",
len(manager.mappings),
mapping_count,
)
return
_original_empty_cache()
......
......@@ -6,6 +6,7 @@
from __future__ import annotations
import logging
from dataclasses import replace
from typing import TYPE_CHECKING
import torch
......@@ -29,6 +30,20 @@ def get_gms_lock_mode(extra_config: dict):
return RequestedLockType.RW_OR_RO
def strip_gms_model_loader_config(load_config, load_format: str):
"""Copy a loader config with GMS-only keys removed for backend loaders."""
extra_config = getattr(load_config, "model_loader_extra_config", {}) or {}
return replace(
load_config,
load_format=load_format,
model_loader_extra_config={
key: value
for key, value in extra_config.items()
if not key.startswith("gms_")
},
)
def setup_meta_tensor_workaround() -> None:
"""Enable workaround for meta tensor operations like torch.nonzero()."""
try:
......@@ -42,9 +57,9 @@ def setup_meta_tensor_workaround() -> None:
def finalize_gms_write(
allocator: "GMSClientMemoryManager", model: torch.nn.Module
) -> int:
"""Finalize GMS write mode: register tensors, commit, switch to read.
"""Finalize GMS write mode: register tensors, commit, reconnect in read mode.
Flow: register tensors -> sync -> commit (server-only) -> disconnect -> connect(RO)
Flow: register tensors -> sync -> unmap + commit -> connect(RO) -> remap
Args:
allocator: The GMS client memory manager in write mode.
......@@ -52,9 +67,6 @@ def finalize_gms_write(
Returns:
Total bytes committed.
Raises:
RuntimeError: If commit fails.
"""
from gpu_memory_service.client.torch.module import register_module_tensors
from gpu_memory_service.common.types import RequestedLockType
......@@ -65,12 +77,10 @@ def finalize_gms_write(
# Synchronize before commit — caller's writes must be visible
torch.cuda.synchronize()
if not allocator.commit():
raise RuntimeError("GMS commit failed")
allocator.commit()
# commit() closed the RW socket; acquire RO for inference
allocator.disconnect() # no-op if commit already cleared _client, but safe
allocator.connect(RequestedLockType.RO)
allocator.remap_all_vas()
logger.info(
"[GMS] Committed %.2f GiB, switched to read mode with %d mappings",
......
......@@ -3,13 +3,10 @@
"""Hybrid torch_memory_saver implementation for GPU Memory Service.
This module provides a hybrid implementation that combines:
1. GPU Memory Service allocator for "weights" tag (VA-stable unmap/remap, shared)
2. Torch mempool mode for other tags like "kv_cache" (CPU backup, per-instance)
The impl uses RW_OR_RO mode to connect to GMS:
- First process gets RW lock and loads weights from disk
- Subsequent processes get RO lock and import weights from metadata
This module uses:
1. GPU Memory Service for "weights" (shared RO/RW publish flow)
2. GPU Memory Service for "kv_cache" (RW-only failover flow)
3. torch_memory_saver for any remaining tags
"""
from __future__ import annotations
......@@ -19,10 +16,13 @@ from contextlib import contextmanager
from typing import TYPE_CHECKING, Optional
import torch
from gpu_memory_service import get_or_create_gms_client_memory_manager
from gpu_memory_service.client.torch.allocator import gms_use_mem_pool
from gpu_memory_service.common.types import GrantedLockType, RequestedLockType
from gpu_memory_service.common.utils import get_socket_path
if TYPE_CHECKING:
from gpu_memory_service.client.memory_manager import GMSClientMemoryManager
from torch.cuda.memory import MemPool
from torch_memory_saver.entrypoint import _TorchMemorySaverImpl
logger = logging.getLogger(__name__)
......@@ -39,56 +39,54 @@ def get_gms_memory_saver_impl() -> Optional["GMSMemorySaverImpl"]:
class GMSMemorySaverImpl:
"""Hybrid implementation: GMS for weights, torch mempool for KV cache.
Routes operations based on tag:
- "weights" or "model_weights": Handled by GMS allocator (VA-stable)
- Other tags (e.g., "kv_cache"): Delegated to torch mempool mode
"""
"""Hybrid implementation: GMS for weights and KV cache."""
def __init__(
self,
torch_impl: "_TorchMemorySaverImpl",
socket_path: str,
device_index: int,
mode=None,
):
self._torch_impl = torch_impl
self._socket_path = socket_path
self._device_index = device_index
self._requested_mode = mode
self._disabled = False
self._imported_weights_bytes: int = 0
self._allocator: Optional["GMSClientMemoryManager"]
self._mem_pool: Optional["MemPool"]
self._weights_allocator: Optional["GMSClientMemoryManager"]
self._kv_cache_allocator: "GMSClientMemoryManager"
self._mode: str
self._allocator, self._mem_pool, self._mode = self._init_allocator()
(
self._weights_allocator,
self._kv_cache_allocator,
self._mode,
) = self._init_allocators()
logger.info(
"[GMS] Initialized: weights=%s mode (device=%d, socket=%s)",
"[GMS] Initialized weights=%s mode, kv_cache=RW (device=%d)",
self._mode.upper(),
device_index,
socket_path,
)
def _init_allocator(
def _init_allocators(
self,
) -> tuple[Optional["GMSClientMemoryManager"], Optional["MemPool"], str]:
) -> tuple[Optional["GMSClientMemoryManager"], "GMSClientMemoryManager", str,]:
"""Create allocator with mode from config (default: RW_OR_RO)."""
from gpu_memory_service import get_or_create_gms_client_memory_manager
from gpu_memory_service.common.types import GrantedLockType, RequestedLockType
mode = self._requested_mode or RequestedLockType.RW_OR_RO
allocator, mem_pool = get_or_create_gms_client_memory_manager(
self._socket_path,
weights_allocator = get_or_create_gms_client_memory_manager(
get_socket_path(self._device_index, "weights"),
self._device_index,
mode=mode,
tag="weights",
)
granted_mode = allocator.granted_lock_type
kv_cache_allocator = get_or_create_gms_client_memory_manager(
get_socket_path(self._device_index, "kv_cache"),
self._device_index,
mode=RequestedLockType.RW,
tag="kv_cache",
)
granted_mode = weights_allocator.granted_lock_type
if granted_mode == GrantedLockType.RW:
allocator.clear_all_handles()
actual_mode = "write"
else:
actual_mode = "read"
......@@ -97,11 +95,7 @@ class GMSMemorySaverImpl:
actual_mode.upper(),
self._device_index,
)
return (
allocator,
mem_pool if granted_mode == GrantedLockType.RW else None,
actual_mode,
)
return weights_allocator, kv_cache_allocator, actual_mode
def _is_weights_tag(self, tag: Optional[str]) -> bool:
return tag in ("weights", "model_weights")
......@@ -110,25 +104,28 @@ class GMSMemorySaverImpl:
return self._mode
def get_allocator(self) -> Optional["GMSClientMemoryManager"]:
return self._allocator
return self._weights_allocator
@contextmanager
def region(self, tag: str, enable_cpu_backup: bool):
"""Mark allocation region with tag."""
if not self._is_weights_tag(tag):
with self._torch_impl.region(tag=tag, enable_cpu_backup=enable_cpu_backup):
if self._is_weights_tag(tag):
if self._mode == "read":
yield
return
return
if self._mode == "read":
yield
target_device = torch.device("cuda", self._device_index)
with gms_use_mem_pool("weights", target_device):
yield
return
if self._mem_pool is None:
raise RuntimeError("GMS mempool is None in WRITE mode")
if tag == "kv_cache":
target_device = torch.device("cuda", self._device_index)
with gms_use_mem_pool("kv_cache", target_device):
yield
return
target_device = torch.device("cuda", self._device_index)
with torch.cuda.use_mem_pool(self._mem_pool, device=target_device):
with self._torch_impl.region(tag=tag, enable_cpu_backup=enable_cpu_backup):
yield
def pause(self, tag: Optional[str] = None) -> None:
......@@ -136,7 +133,9 @@ class GMSMemorySaverImpl:
return
if tag is None or self._is_weights_tag(tag):
self._pause_weights()
if tag is None or not self._is_weights_tag(tag):
if tag is None or tag == "kv_cache":
self._pause_kv_cache()
if tag is None or (not self._is_weights_tag(tag) and tag != "kv_cache"):
self._torch_impl.pause(tag=tag)
def resume(self, tag: Optional[str] = None) -> None:
......@@ -144,39 +143,56 @@ class GMSMemorySaverImpl:
return
if tag is None or self._is_weights_tag(tag):
self._resume_weights()
if tag is None or not self._is_weights_tag(tag):
if tag is None or tag == "kv_cache":
self._resume_kv_cache()
if tag is None or (not self._is_weights_tag(tag) and tag != "kv_cache"):
self._torch_impl.resume(tag=tag)
def _pause_weights(self) -> None:
if self._allocator is None:
if self._weights_allocator is None:
return
if self._allocator.is_unmapped:
if self._weights_allocator.is_unmapped:
return
logger.info("[GMS] Unmapping weights (VA-stable)")
self._allocator.unmap_all_vas()
self._allocator.disconnect()
self._weights_allocator.unmap_all_vas()
self._weights_allocator.abort()
def _resume_weights(self) -> None:
if self._allocator is None:
if self._weights_allocator is None:
return
if not self._allocator.is_unmapped:
if not self._weights_allocator.is_unmapped:
return
logger.info("[GMS] Remapping weights (VA-stable)")
from gpu_memory_service.common.types import RequestedLockType
self._weights_allocator.connect(RequestedLockType.RO)
self._weights_allocator.remap_all_vas()
self._allocator.connect(RequestedLockType.RO)
self._allocator.remap_all_vas()
def _pause_kv_cache(self) -> None:
if self._kv_cache_allocator.is_unmapped:
return
logger.info("[GMS] Unmapping KV cache")
self._kv_cache_allocator.unmap_all_vas()
self._kv_cache_allocator.abort()
def _resume_kv_cache(self) -> None:
if not self._kv_cache_allocator.is_unmapped:
return
logger.info("[GMS] Remapping KV cache")
self._kv_cache_allocator.connect(RequestedLockType.RW)
self._kv_cache_allocator.reallocate_all_handles(tag="kv_cache")
self._kv_cache_allocator.remap_all_vas()
def finalize_write_mode(self, model: torch.nn.Module) -> None:
"""Finalize write mode: register tensors, commit, and switch to read."""
if self._mode != "write":
return
if self._allocator is None:
if self._weights_allocator is None:
raise RuntimeError("Allocator is None in WRITE mode")
from gpu_memory_service.integrations.common.utils import finalize_gms_write
self._imported_weights_bytes = finalize_gms_write(self._allocator, model)
self._imported_weights_bytes = finalize_gms_write(
self._weights_allocator, model
)
self._mode = "read"
def set_imported_weights_bytes(self, bytes_count: int) -> None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment