Unverified Commit 91a8c6af authored by Schwinn Saereesitthipitak's avatar Schwinn Saereesitthipitak Committed by GitHub
Browse files

fix: simplify GMS layout state and harden GPU-backed flows (#7006)


Signed-off-by: default avatarmohammedabdulwahhab <furkhan324@berkeley.edu>
Signed-off-by: default avatarHannah Zhang <hannahz@nvidia.com>
Co-authored-by: default avatarmohammedabdulwahhab <furkhan324@berkeley.edu>
Co-authored-by: default avatarhhzhang16 <54051230+hhzhang16@users.noreply.github.com>
parent dd7ceb4a
...@@ -28,20 +28,20 @@ This leads to: ...@@ -28,20 +28,20 @@ This leads to:
┌──────────────────────────────────────────────────────────────────────────────────────┐ ┌──────────────────────────────────────────────────────────────────────────────────────┐
│ │ │ │
│ ┌────────────────────┐ ┌─────────────────────────────────────────┐ │ │ ┌────────────────────┐ ┌─────────────────────────────────────────┐ │
│ │ GMS Server │ │ GMSClientMemoryManager (Writer) │ │ │ │ GMS │ │ GMSClientMemoryManager (Writer) │ │
│ │ │ │ │ │ │ │ │ │ │ │
│ │ ┌────────────────┐ │ │ ┌─────────────────────────────────┐ │ │ │ │ ┌────────────────┐ │ │ ┌─────────────────────────────────┐ │ │
│ │ │ Memory Manager │ │ ◄── Unix ───────►│ │ GMSRPCClient │ │ │ │ │ │ Memory Manager │ │ ◄── Unix ───────►│ │ GMS Session │ │ │
│ │ └────────────────┘ │ Socket │ └─────────────────────────────────┘ │ │ │ │ └────────────────┘ │ Socket │ └─────────────────────────────────┘ │ │
│ │ │ + │ │ │ │ │ │ + │ │ │
│ │ ┌────────────────┐ │ FD │ Writer-only: create_mapping, commit │ │ │ │ ┌────────────────┐ │ FD │ Writer-only: create_mapping, commit │ │
│ │ │ State Machine │ │ (SCM_RIGHTS) └─────────────────────────────────────────┘ │ │ │ │ Session / FSM │ │ (SCM_RIGHTS) └─────────────────────────────────────────┘ │
│ │ └────────────────┘ │ │ │ │ └────────────────┘ │ │
│ │ │ ┌─────────────────────────────────────────┐ │ │ │ │ ┌─────────────────────────────────────────┐ │
│ │ ┌────────────────┐ │ │ GMSClientMemoryManager (Reader) │ │ │ │ ┌────────────────┐ │ │ GMSClientMemoryManager (Reader) │ │
│ │ │ Metadata Store │ │ │ │ │ │ │ │ Metadata Store │ │ │ │ │
│ │ └────────────────┘ │ ◄── Unix ───────►│ ┌─────────────────────────────────┐ │ │ │ │ └────────────────┘ │ ◄── Unix ───────►│ ┌─────────────────────────────────┐ │ │
│ │ │ Socket │ │ GMSRPCClient │ │ │ │ │ │ Socket │ │ GMS Session │ │ │
│ └────────────────────┘ + │ └─────────────────────────────────┘ │ │ │ └────────────────────┘ + │ └─────────────────────────────────┘ │ │
│ FD │ │ │ │ FD │ │ │
│ (SCM_RIGHTS) │ Reader-only: create_mapping (import), │ │ │ (SCM_RIGHTS) │ Reader-only: create_mapping (import), │ │
...@@ -65,25 +65,25 @@ The GMS server runs as an independent process that manages GPU memory without ev ...@@ -65,25 +65,25 @@ The GMS server runs as an independent process that manages GPU memory without ev
The server consists of three main components: The server consists of three main components:
1. **Memory Manager** - Allocates physical GPU memory via CUDA VMM (`cuMemCreate`) and exports shareable file descriptors (`cuMemExportToShareableHandle`). Critically, it never calls `cuMemMap` - clients handle all virtual address mapping. 1. **Memory Manager** - Allocates physical GPU memory via CUDA VMM (`cuMemCreate`) and exports shareable file descriptors (`cuMemExportToShareableHandle`). Critically, it never calls `cuMemMap` - clients handle all virtual address mapping. Allocation requests retry on OOM until they succeed or the optional retry timeout is reached.
2. **State Machine (FSM)** - Manages the global lock state and enforces access rules that ensures consistency across multiple clients. See [State Machine](#state-machine) below for details. 2. **State Machine (FSM)** - Manages global lock state, waiter coordination, and disconnect cleanup.
3. **Metadata Store** - Key-value store for tensor metadata (shapes, dtypes, offsets), enabling clients to reconstruct model structure. 3. **Metadata Store / Layout State** - `GMS` owns the metadata table and committed layout hash. Allocations and metadata live in one flat store that is cleared on each new writer connect or writer abort.
### Client Each GMS server is responsible for managing memory of only 1 GPU, and does not interact with GMS servers corresponding to other GPUs.
Clients connect to the server to acquire locks and access GPU memory. Two client classes are provided: ### Client
1. **GMSRPCClient** - Low-level RPC client for direct protocol access. Handles socket communication, msgpack serialization, and file descriptor passing via `SCM_RIGHTS`. The socket connection **is** the lock - connection lifetime equals lock lifetime, providing automatic crash resilience. Clients connect to the server to acquire locks and access GPU memory. The supported client API is:
2. **GMSClientMemoryManager** - High-level client that wraps `GMSRPCClient` and handles all CUDA VMM operations for memory import and mapping safely: 1. **GMSClientMemoryManager** - High-level client that wraps an internal RPC transport layer and handles all CUDA VMM operations for memory import and mapping safely:
- Imports file descriptors and converts them to CUDA memory handles - Imports file descriptors and converts them to CUDA memory handles
- Reserves virtual address space and maps physical memory - Reserves virtual address space and maps physical memory
- Sets appropriate access permissions (RW for writers, RO for readers) - Sets appropriate access permissions (RW for writers, RO for readers)
- Supports **unmap/remap** for VA-stable memory release under memory pressure - Supports **unmap/remap** for VA-stable memory release under memory pressure
> **Note**: Always use `GMSClientMemoryManager` to interact with GMS from client code. The low-level `GMSRPCClient` is an implementation detail and should not be used directly. > **Note**: Always use `GMSClientMemoryManager` to interact with GMS from client code. The low-level RPC client is an implementation detail and should not be used directly.
### Memory Allocation and Import Flow ### Memory Allocation and Import Flow
...@@ -92,7 +92,7 @@ The following diagram shows how `GMSClientMemoryManager` interacts with the serv ...@@ -92,7 +92,7 @@ The following diagram shows how `GMSClientMemoryManager` interacts with the serv
```mermaid ```mermaid
sequenceDiagram sequenceDiagram
participant C as GMSClientMemoryManager participant C as GMSClientMemoryManager
participant S as GMS Server participant S as GMS
participant GPU as GPU Memory participant GPU as GPU Memory
%% Connection %% Connection
...@@ -111,7 +111,7 @@ sequenceDiagram ...@@ -111,7 +111,7 @@ sequenceDiagram
%% Export/Import (Both Writer and Reader) %% Export/Import (Both Writer and Reader)
Note over C,GPU: Both Writer and Reader: Export and map Note over C,GPU: Both Writer and Reader: Export and map
C->>S: ExportRequest(allocation_id) C->>S: ExportAllocationRequest(allocation_id)
S->>GPU: cuMemExportToShareableHandle(handle) S->>GPU: cuMemExportToShareableHandle(handle)
GPU-->>S: fd GPU-->>S: fd
S-->>C: Response + fd (via SCM_RIGHTS) S-->>C: Response + fd (via SCM_RIGHTS)
...@@ -152,31 +152,103 @@ stateDiagram-v2 ...@@ -152,31 +152,103 @@ stateDiagram-v2
| State | Description | Can Connect RW | Can Connect RO | | State | Description | Can Connect RW | Can Connect RO |
|-------|-------------|:--------------:|:--------------:| |-------|-------------|:--------------:|:--------------:|
| `EMPTY` | No connections, no committed weights | ✓ | ✗ | | `EMPTY` | No connections, no committed layout visible | ✓ | ✗ |
| `RW` | Writer connected (exclusive access) | ✗ | ✗ | | `RW` | Writer connected (exclusive access) | ✗ | ✗ |
| `COMMITTED` | Weights published, no active connections | ✓ | ✓ | | `COMMITTED` | Committed layout visible to readers, no active connections | ✓ | ✓ |
| `RO` | One or more readers connected (shared access) | ✗ | ✓ | | `RO` | One or more readers connected (shared access) | ✗ | ✓ |
### Events ### Events
| Event | Trigger | Description | | Event | Trigger | Description |
|-------|---------|-------------| |-------|---------|-------------|
| `RW_CONNECT` | Writer connects | Acquires exclusive write lock | | `RW_CONNECT` | Writer connects | Acquires exclusive write lock, clears the previous committed layout immediately, and starts a fresh RW layout build |
| `RW_COMMIT` | Writer calls `commit()` | Publishes weights, releases lock | | `RW_COMMIT` | Writer calls `commit()` | Publishes the current RW layout as the committed layout and releases the lock |
| `RW_ABORT` | Writer disconnects without commit | Discards allocations, releases lock | | `RW_ABORT` | Writer disconnects without commit | Drops the active RW layout and returns to `EMPTY` |
| `RO_CONNECT` | Reader connects | Acquires shared read lock | | `RO_CONNECT` | Reader connects | Acquires shared read lock |
| `RO_DISCONNECT` | Reader disconnects | Releases shared lock; if last reader, returns to COMMITTED | | `RO_DISCONNECT` | Reader disconnects | Releases shared lock; if last reader, returns to COMMITTED |
### Lock Semantics ### Lock Semantics
The socket connection **is** the lock: A handshaken socket connection **is** the lock:
- **Crash resilience**: Connection close (including process crash) automatically releases the lock - **Crash resilience**: Connection close (including process crash) automatically releases the lock
- **No explicit unlock**: Eliminates forgotten locks and deadlocks - **No explicit unlock**: Eliminates forgotten locks and deadlocks
- **Atomic transitions**: State changes happen atomically with socket operations - **Atomic transitions**: State changes happen atomically with socket operations
The only exception is the runtime inspection probes (`GetRuntimeState`, `GetEventHistory`): they connect, fetch diagnostics, and close without entering the lock FSM.
### Layout Lifecycle
Layout creation and publication work like this:
```mermaid
flowchart LR
A[EMPTY or COMMITTED] -->|RW_CONNECT| B[Fresh RW layout]
B -->|Allocate memory and write metadata| C{Writer outcome}
C -->|RW_COMMIT| D[Publish layout as committed]
C -->|RW_ABORT| E[Discard layout]
D -->|Next RW_CONNECT| F[Fresh RW layout]
E -->|Next RW_CONNECT| F
```
- `RW_CONNECT` starts a fresh RW layout build.
- `RW_COMMIT` publishes the current layout; it does not create another one.
- `RW_ABORT` discards the current RW layout and returns the system to `EMPTY`.
- Allocations and metadata live in one flat store that is cleared on `RW_CONNECT` and `RW_ABORT`.
- RO requests are served only from the committed layout, while RW requests mutate only the active layout.
- Read RPCs (`export`, allocation lookup/listing, metadata lookup/listing) operate on that single live store. This is safe because the FSM prevents RW and RO sessions from coexisting.
- `metadata_put` validates allocation ownership and offset bounds, `free` cascades metadata cleanup, and `commit` rejects dangling metadata references.
### Allocation Backpressure on OOM
When a writer requests a new allocation, GMS treats CUDA OOM as a transient condition:
- `cuMemCreate` OOM does **not** immediately fail the request.
- The server retries in a loop and only returns success after allocation is created.
- Server CLI flags:
- `--alloc-retry-interval` (default `0.5`)
- `--alloc-retry-timeout` (default unset = wait indefinitely)
This ensures the "new writer gets fresh allocations" workflow can wait for memory reclamation instead of racing into immediate OOM failures.
### Guarantees
- GMS guarantees that its own RPCs do not mix committed and active generations, and that `GMSClientMemoryManager.commit()` performs a CUDA synchronize and unmaps the writer's local mappings before publish.
- After local unmap, `commit()` does not attempt in-process recovery. Non-CUDA failures raise, and CUDA VMM failures exit the process.
- The only non-fatal client connection failure is lock acquisition timeout. Other client-side GMS transport, protocol, and server error responses raise.
- Any non-OOM CUDA VMM failure on either client or server is fatal and exits the process.
- On the server, an untrusted client connection is isolated to that connection: transport loss and response-send failures unwind the connection state, and only server invariant violations or CUDA failures kill the server.
- Runtime-state `allocation_count` and `allocations_cleared` report server-owned allocation handles only. Imported handles in other processes can still keep VRAM alive after the server clears its own layout state.
- GMS *does not* prove that a disconnected or already-submitted writer has no in-flight GPU work left on the device. The mitigation in this design is that new RW layouts use fresh allocations and may wait for memory reclamation before allocation succeeds.
--- ---
### Server Trust Boundary
```mermaid
flowchart TD
A[Client event on server connection] --> B{Can server read and decode it?}
B -- no --> C[Drop connection]
C --> D[Run disconnect cleanup]
D --> E[RW_ABORT or RO_DISCONNECT]
B -- yes --> F{Valid client request?}
F -- no --> G[Send ErrorResponse]
F -- yes --> H{Did request expose server invariant failure?}
H -- yes --> I[Exit server process]
H -- no --> J[Build response or apply commit]
J --> K{Can server send response?}
K -- no --> D
K -- yes --> L[Continue session or close committed writer]
```
- `Drop connection` means the server stops trusting that socket and unwinds only that connection's lock state.
- After `RW_COMMIT`, disconnect cleanup only closes the committed writer socket; it does not roll the server back to `RW_ABORT`.
- `Valid client request?` covers mode/state violations, unknown requests, and request validation failures like bad metadata offsets.
- `Did request expose server invariant failure?` covers impossible layout/FSM states and commit-time metadata integrity failures.
## Sequence Diagrams ## Sequence Diagrams
### Writer Flow (Cold Start) ### Writer Flow (Cold Start)
...@@ -187,11 +259,14 @@ The first worker loads weights from disk and publishes them to GMS. ...@@ -187,11 +259,14 @@ The first worker loads weights from disk and publishes them to GMS.
sequenceDiagram sequenceDiagram
participant W as Writer Process participant W as Writer Process
participant C as GMSClientMemoryManager participant C as GMSClientMemoryManager
participant S as GMS Server participant S as GMS
W->>C: mgr = GMSClientMemoryManager(socket_path, device=0) W->>C: mgr = GMSClientMemoryManager(socket_path, device=0)
W->>C: mgr.connect(RW) W->>C: mgr.connect(RW)
C->>S: HandshakeRequest(lock_type=RW) C->>S: HandshakeRequest(lock_type=RW)
S->>S: Session FSM: EMPTY/COMMITTED -> RW
S->>S: Clear prior committed layout
S->>S: Start fresh RW layout
S-->>C: HandshakeResponse(success=true) S-->>C: HandshakeResponse(success=true)
loop For each tensor loop For each tensor
...@@ -201,9 +276,14 @@ sequenceDiagram ...@@ -201,9 +276,14 @@ sequenceDiagram
end end
W->>C: mgr.commit() W->>C: mgr.commit()
C->>GPU: synchronize()
C->>GPU: cuMemUnmap(...) + cuMemRelease(...)
C->>S: CommitRequest() C->>S: CommitRequest()
S->>S: Publish current layout as committed
S->>S: FSM: RW → COMMITTED S->>S: FSM: RW → COMMITTED
S-->>C: CommitResponse(success=true) S-->>C: CommitResponse(success=true)
W->>C: mgr.connect(RO)
W->>C: mgr.remap_all_vas()
``` ```
### Reader Flow (Warm Start) ### Reader Flow (Warm Start)
...@@ -214,7 +294,7 @@ Subsequent workers import weights from GMS instead of loading from disk. ...@@ -214,7 +294,7 @@ Subsequent workers import weights from GMS instead of loading from disk.
sequenceDiagram sequenceDiagram
participant R as Reader Process participant R as Reader Process
participant C as GMSClientMemoryManager participant C as GMSClientMemoryManager
participant S as GMS Server participant S as GMS
R->>C: mgr = GMSClientMemoryManager(socket_path, device=0) R->>C: mgr = GMSClientMemoryManager(socket_path, device=0)
R->>C: mgr.connect(RO) R->>C: mgr.connect(RO)
...@@ -242,7 +322,7 @@ Readers can temporarily release GPU memory while preserving virtual address rese ...@@ -242,7 +322,7 @@ Readers can temporarily release GPU memory while preserving virtual address rese
sequenceDiagram sequenceDiagram
participant R as Reader Process participant R as Reader Process
participant C as GMSClientMemoryManager participant C as GMSClientMemoryManager
participant S as GMS Server participant S as GMS
participant GPU as GPU Memory participant GPU as GPU Memory
Note over R,GPU: Need to temporarily release GPU memory Note over R,GPU: Need to temporarily release GPU memory
...@@ -256,33 +336,26 @@ sequenceDiagram ...@@ -256,33 +336,26 @@ sequenceDiagram
Note over C: Keep VA reservation! Note over C: Keep VA reservation!
end end
R->>C: mgr.disconnect() R->>C: mgr.abort()
C->>S: Close socket (release RO lock) C->>S: Close socket (release RO lock)
S->>S: FSM: RO → COMMITTED (if last reader) S->>S: FSM: RO → COMMITTED (if last reader)
Note over R,GPU: GPU memory released, VA preserved Note over R,GPU: GPU memory released, VA preserved
Note over R,GPU: Another writer could modify weights here Note over R,GPU: Another writer could publish a new layout here
R->>C: mgr.connect(RO) R->>C: mgr.connect(RO)
C->>S: HandshakeRequest(lock_type=RO)
S->>S: FSM: COMMITTED → RO
S-->>C: HandshakeResponse(success=true)
R->>C: mgr.remap_all_vas() R->>C: mgr.remap_all_vas()
C->>S: GetStateHashRequest() C->>S: GetStateHashRequest()
S-->>C: GetStateHashResponse(hash) S-->>C: GetStateHashResponse(hash)
alt hash == saved_hash alt hash == saved_hash
loop For each preserved VA C->>S: Export preserved allocations from the committed layout
C->>S: ExportRequest(allocation_id) S-->>C: Response + FDs
S-->>C: Response + fd C->>GPU: Import handles and remap at preserved VAs
C->>GPU: cuMemImportFromShareableHandle(fd) C-->>R: Remap succeeds and tensor pointers stay valid
C->>GPU: cuMemMap(same_va, handle)
Note over C: Tensors valid at same addresses!
end
else hash != saved_hash else hash != saved_hash
C-->>R: StaleMemoryLayoutError C-->>R: StaleMemoryLayoutError
Note over R: Must re-import from scratch C-->>R: Re-import from scratch
end end
``` ```
...@@ -294,9 +367,9 @@ The `RW_OR_RO` mode automatically selects writer or reader based on server state ...@@ -294,9 +367,9 @@ The `RW_OR_RO` mode automatically selects writer or reader based on server state
sequenceDiagram sequenceDiagram
participant P as Process participant P as Process
participant C as GMSClientMemoryManager participant C as GMSClientMemoryManager
participant S as GMS Server participant S as GMS
Note over P,S: Auto-mode: Writer if first, Reader if weights exist Note over P,S: Auto-mode: try RW only when no committed layout exists
P->>C: mgr = GMSClientMemoryManager(socket_path, device=0) P->>C: mgr = GMSClientMemoryManager(socket_path, device=0)
P->>C: mgr.connect(RW_OR_RO) P->>C: mgr.connect(RW_OR_RO)
...@@ -313,10 +386,10 @@ sequenceDiagram ...@@ -313,10 +386,10 @@ sequenceDiagram
S-->>C: HandshakeResponse(granted=RO, committed=true) S-->>C: HandshakeResponse(granted=RO, committed=true)
Note over P: Subsequent process - import from GMS Note over P: Subsequent process - import from GMS
else RW held by another else RW held by another
S->>S: Wait for RO availability S->>S: Wait until a committed layout becomes available
S->>S: FSM: COMMITTED → RO S->>S: Then grant RO from COMMITTED
S-->>C: HandshakeResponse(granted=RO, committed=true) S-->>C: HandshakeResponse(granted=RO, committed=true)
Note over P: Wait for writer to finish Note over P: Wait for writer to publish committed weights
end end
``` ```
...@@ -355,13 +428,15 @@ During `remap_all_vas()`: ...@@ -355,13 +428,15 @@ During `remap_all_vas()`:
### 4. Memory Layout Hash ### 4. Memory Layout Hash
On commit, the server computes a hash of: On commit, the server computes a hash of:
- All allocation IDs, sizes, and tags - All allocation layout slots, sizes, aligned sizes, and tags
- All metadata entries - All metadata keys, offsets, and values
On `remap_all_vas()`, this hash is checked: On `remap_all_vas()`, this hash is checked:
- If match: Safe to remap (layout unchanged) - If match: Safe to remap (layout unchanged)
- If mismatch: Raise `StaleMemoryLayoutError` (must re-import) - If mismatch: Raise `StaleMemoryLayoutError` (must re-import)
The hash is tied to the currently committed layout and is cleared as soon as a writer acquires RW.
**Important**: This detects **structural** changes, not **content** changes. **Important**: This detects **structural** changes, not **content** changes.
Weight values can be modified in-place (e.g., RL training updates) as long as the structure is preserved. Weight values can be modified in-place (e.g., RL training updates) as long as the structure is preserved.
...@@ -411,17 +486,16 @@ class GMSClientMemoryManager: ...@@ -411,17 +486,16 @@ class GMSClientMemoryManager:
# --- Tier 1: Connection --- # --- Tier 1: Connection ---
def connect(lock_type: RequestedLockType, timeout_ms: Optional[int] = None) -> None def connect(lock_type: RequestedLockType, timeout_ms: Optional[int] = None) -> None
def disconnect() -> None def abort() -> None
# --- Tier 1: Handle ops (server-side, RW only) --- # --- Tier 1: Handle ops (server-side, RW only) ---
def allocate_handle(size: int, tag: str = "default") -> str # Returns allocation_id def allocate_handle(size: int, tag: str = "default") -> Tuple[str, int] # Returns allocation_id, layout_slot
def export_handle(allocation_id: str) -> int # Returns FD def export_handle(allocation_id: str) -> int # Returns FD
def get_handle_info(allocation_id: str) -> AllocationInfo def get_handle_info(allocation_id: str) -> GetAllocationResponse
def free_handle(allocation_id: str) -> bool def free_handle(allocation_id: str) -> bool
def clear_all_handles() -> int # Returns count cleared def commit() -> bool # Sync + unmap local mappings + publish; raises on non-CUDA failure after unmap
def commit() -> bool # Transition to COMMITTED
def get_memory_layout_hash() -> str def get_memory_layout_hash() -> str
def list_handles(tag: Optional[str] = None) -> List[Dict] def list_handles(tag: Optional[str] = None) -> List[GetAllocationResponse]
# --- Tier 1: VA ops (local) --- # --- Tier 1: VA ops (local) ---
def reserve_va(size: int) -> int # Returns VA def reserve_va(size: int) -> int # Returns VA
...@@ -430,7 +504,7 @@ class GMSClientMemoryManager: ...@@ -430,7 +504,7 @@ class GMSClientMemoryManager:
def free_va(va: int) -> None # Releases VA reservation def free_va(va: int) -> None # Releases VA reservation
# --- Tier 1: Metadata --- # --- Tier 1: Metadata ---
def metadata_put(key: str, allocation_id: str, offset: int, value: bytes) -> bool def metadata_put(key: str, allocation_id: str, offset_bytes: int, value: bytes) -> bool
def metadata_get(key: str) -> Optional[Tuple[str, int, bytes]] def metadata_get(key: str) -> Optional[Tuple[str, int, bytes]]
def metadata_list(prefix: str = "") -> List[str] def metadata_list(prefix: str = "") -> List[str]
def metadata_delete(key: str) -> bool def metadata_delete(key: str) -> bool
...@@ -441,15 +515,9 @@ class GMSClientMemoryManager: ...@@ -441,15 +515,9 @@ class GMSClientMemoryManager:
def unmap_all_vas() -> None # Sync + unmap all, preserve VA reservations def unmap_all_vas() -> None # Sync + unmap all, preserve VA reservations
def remap_all_vas() -> None # Re-import at preserved VAs (checks layout hash) def remap_all_vas() -> None # Re-import at preserved VAs (checks layout hash)
def reallocate_all_handles(tag="default") -> None # Fresh server handles for preserved VAs def reallocate_all_handles(tag="default") -> None # Fresh server handles for preserved VAs
def close(free: bool = False) -> None def close() -> None
``` ```
## Limitations
1. **Single-GPU per server**: Each GMS server manages one GPU device
2. **CUDA VMM required**: Requires a GPU with Virtual Memory Management support. Check at runtime via `CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED` - there is no guaranteed minimum compute capability
3. **No content validation**: Remap doesn't detect in-place weight modifications
--- ---
## Framework Integration (vLLM / SGLang) ## Framework Integration (vLLM / SGLang)
...@@ -461,8 +529,10 @@ GMS provides pre-built integrations for vLLM and SGLang. Enable GMS by passing ` ...@@ -461,8 +529,10 @@ GMS provides pre-built integrations for vLLM and SGLang. Enable GMS by passing `
When `--load-format gms` is set: When `--load-format gms` is set:
1. **A GMS server must already be running** for the target GPU device. The engine connects to it via a Unix socket derived from the GPU UUID. 1. **A GMS server must already be running** for the target GPU device. The engine connects to it via a Unix socket derived from the GPU UUID.
2. The engine uses `RW_OR_RO` mode by default: the **first** process gets RW (loads weights from disk, commits to GMS), and **subsequent** processes get RO (import weights from GMS metadata). 2. The engine uses `RW_OR_RO` mode by default: if no committed layout exists and no writer holds the lock, the first process gets RW and loads weights from disk. Otherwise clients wait for a committed layout and then get RO to import published weights.
3. Weights are managed by GMS; KV cache is managed by the framework's own allocator (e.g., vLLM's `CuMemAllocator`). 3. Both weights and KV cache are managed by GMS, but they use separate tags:
- `weights`: publish/import flow (`RW_OR_RO`, then `RO` after commit)
- `kv_cache`: separate RW-only tag for mutable KV-cache memory
#### vLLM #### vLLM
...@@ -470,6 +540,7 @@ When `--load-format gms` is set: ...@@ -470,6 +540,7 @@ When `--load-format gms` is set:
python -m dynamo.vllm \ python -m dynamo.vllm \
--model <model> \ --model <model> \
--load-format gms \ --load-format gms \
--worker-cls gpu_memory_service.integrations.vllm.worker:GMSWorker \
--enable-sleep-mode \ --enable-sleep-mode \
--gpu-memory-utilization 0.9 --gpu-memory-utilization 0.9
``` ```
...@@ -478,7 +549,10 @@ The integration uses a custom worker class (`GMSWorker`) that: ...@@ -478,7 +549,10 @@ The integration uses a custom worker class (`GMSWorker`) that:
- Establishes the GMS connection early in `init_device()` so vLLM's `MemorySnapshot` can account for committed weights - Establishes the GMS connection early in `init_device()` so vLLM's `MemorySnapshot` can account for committed weights
- Registers a custom model loader (`GMSModelLoader`) for the `gms` load format - Registers a custom model loader (`GMSModelLoader`) for the `gms` load format
- Patches `torch.cuda.empty_cache` to avoid releasing GMS-managed memory - Patches `torch.cuda.empty_cache` to avoid releasing GMS-managed memory
- Routes weight allocation through a `CUDAPluggableAllocator` backed by GMS - Uses two GMS tags on the GPU:
- `weights`: normal publish/import flow (`RW_OR_RO`, then `RO` after commit)
- `kv_cache`: separate RW-only tag for mutable KV-cache memory
- Routes both weight and KV-cache allocation through a `CUDAPluggableAllocator` backed by the appropriate GMS tag
#### SGLang #### SGLang
...@@ -490,9 +564,10 @@ python -m dynamo.sglang \ ...@@ -490,9 +564,10 @@ python -m dynamo.sglang \
--mem-fraction-static 0.9 --mem-fraction-static 0.9
``` ```
The integration patches `torch_memory_saver` to route weight operations through GMS: The integration patches `torch_memory_saver` to route both weight and KV-cache operations through GMS:
- Weights (`"weights"` / `"model_weights"` tags) go through `GMSMemorySaverImpl` - Weights (`"weights"` / `"model_weights"` tags) use the `weights` GMS tag
- Other tags (e.g., `"kv_cache"`) are delegated to the default torch mempool implementation - KV cache (`"kv_cache"`) uses a separate RW-only `kv_cache` GMS tag
- Other tags still use the default torch mempool implementation
- The `--enable-memory-saver` flag is required to activate the memory saver pathway - The `--enable-memory-saver` flag is required to activate the memory saver pathway
### Shadow Engine Failover (Sleep / Wake) ### Shadow Engine Failover (Sleep / Wake)
...@@ -502,9 +577,14 @@ Both integrations support releasing and reclaiming GPU memory for shadow engine ...@@ -502,9 +577,14 @@ Both integrations support releasing and reclaiming GPU memory for shadow engine
- **vLLM**: `sleep` / `wake_up` (via `/engine/sleep` and `/engine/wake_up` HTTP endpoints) - **vLLM**: `sleep` / `wake_up` (via `/engine/sleep` and `/engine/wake_up` HTTP endpoints)
- **SGLang**: `release_memory_occupation` / `resume_memory_occupation` (via the corresponding HTTP endpoints) - **SGLang**: `release_memory_occupation` / `resume_memory_occupation` (via the corresponding HTTP endpoints)
Under the hood, sleeping calls `unmap_all_vas()` + `disconnect()` to release GPU memory while preserving VA reservations, and waking calls `connect(RO)` + `remap_all_vas()` to re-import weights at the same virtual addresses. Tensor pointers remain valid, so no model re-initialization is needed. Under the hood, sleeping calls `unmap_all_vas()` + `abort()` to release GPU memory while preserving VA reservations. Waking is tag-specific:
- **weights**: `connect(RO)` + `remap_all_vas()`
- **kv_cache**: `connect(RW)` + `reallocate_all_handles("kv_cache")` + `remap_all_vas()`
Tensor pointers remain valid because the original virtual addresses are preserved.
This enables a shadow engine to release its GPU memory, let a primary engine use the GPU, and then reclaim the memory after the primary is killed. This enables a shadow engine to release its GPU memory, let a primary engine use the GPU, and then reclaim the memory after the primary is killed. The mutable KV cache always moves through a fresh RW layout in its own GMS tag before it is reallocated.
### Configuration via `model_loader_extra_config` ### Configuration via `model_loader_extra_config`
......
...@@ -32,7 +32,9 @@ from gpu_memory_service.client.memory_manager import ( ...@@ -32,7 +32,9 @@ from gpu_memory_service.client.memory_manager import (
# PyTorch integration (GMS client memory manager) # PyTorch integration (GMS client memory manager)
from gpu_memory_service.client.torch.allocator import ( from gpu_memory_service.client.torch.allocator import (
get_gms_client_memory_manager, get_gms_client_memory_manager,
get_gms_client_memory_managers,
get_or_create_gms_client_memory_manager, get_or_create_gms_client_memory_manager,
gms_use_mem_pool,
) )
__all__ = [ __all__ = [
...@@ -42,4 +44,6 @@ __all__ = [ ...@@ -42,4 +44,6 @@ __all__ = [
# GMS client memory manager # GMS client memory manager
"get_or_create_gms_client_memory_manager", "get_or_create_gms_client_memory_manager",
"get_gms_client_memory_manager", "get_gms_client_memory_manager",
"get_gms_client_memory_managers",
"gms_use_mem_pool",
] ]
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
import argparse import argparse
import logging import logging
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional
from gpu_memory_service.common.utils import get_socket_path from gpu_memory_service.common.utils import get_socket_path
...@@ -17,7 +18,10 @@ class Config: ...@@ -17,7 +18,10 @@ class Config:
"""Configuration for GPU Memory Service server.""" """Configuration for GPU Memory Service server."""
device: int device: int
tag: str
socket_path: str socket_path: str
alloc_retry_interval: float
alloc_retry_timeout: Optional[float]
verbose: bool verbose: bool
...@@ -33,6 +37,12 @@ def parse_args() -> Config: ...@@ -33,6 +37,12 @@ def parse_args() -> Config:
required=True, required=True,
help="CUDA device ID to manage memory for.", help="CUDA device ID to manage memory for.",
) )
parser.add_argument(
"--tag",
type=str,
default="weights",
help="Logical GMS tag for this server (default: weights).",
)
parser.add_argument( parser.add_argument(
"--socket-path", "--socket-path",
type=str, type=str,
...@@ -45,14 +55,33 @@ def parse_args() -> Config: ...@@ -45,14 +55,33 @@ def parse_args() -> Config:
action="store_true", action="store_true",
help="Enable verbose logging.", help="Enable verbose logging.",
) )
parser.add_argument(
"--alloc-retry-interval",
type=float,
default=0.5,
help="Seconds to sleep between allocation retries on CUDA OOM (default: 0.5).",
)
parser.add_argument(
"--alloc-retry-timeout",
type=float,
default=None,
help="Optional max seconds to wait for allocation retries before failing (default: wait indefinitely).",
)
args = parser.parse_args() args = parser.parse_args()
# Use UUID-based socket path by default (stable across CUDA_VISIBLE_DEVICES) # Use UUID-based socket path by default (stable across CUDA_VISIBLE_DEVICES)
socket_path = args.socket_path or get_socket_path(args.device) socket_path = args.socket_path or get_socket_path(args.device, args.tag)
if args.alloc_retry_interval <= 0:
parser.error("--alloc-retry-interval must be > 0")
if args.alloc_retry_timeout is not None and args.alloc_retry_timeout <= 0:
parser.error("--alloc-retry-timeout must be > 0 when set")
return Config( return Config(
device=args.device, device=args.device,
tag=args.tag,
socket_path=socket_path, socket_path=socket_path,
alloc_retry_interval=args.alloc_retry_interval,
alloc_retry_timeout=args.alloc_retry_timeout,
verbose=args.verbose, verbose=args.verbose,
) )
...@@ -13,7 +13,6 @@ Usage: ...@@ -13,7 +13,6 @@ Usage:
import asyncio import asyncio
import logging import logging
import signal
import uvloop import uvloop
from gpu_memory_service.server import GMSRPCServer from gpu_memory_service.server import GMSRPCServer
...@@ -37,33 +36,28 @@ async def worker() -> None: ...@@ -37,33 +36,28 @@ async def worker() -> None:
logging.getLogger("gpu_memory_service").setLevel(logging.DEBUG) logging.getLogger("gpu_memory_service").setLevel(logging.DEBUG)
logger.info(f"Starting GPU Memory Service Server for device {config.device}") logger.info(f"Starting GPU Memory Service Server for device {config.device}")
logger.info("GMS tag: %s", config.tag)
logger.info(f"Socket path: {config.socket_path}") logger.info(f"Socket path: {config.socket_path}")
logger.info(
server = GMSRPCServer(config.socket_path, device=config.device) "Allocation retry config: interval=%ss timeout=%s",
config.alloc_retry_interval,
# Set up shutdown handling (
shutdown_event = asyncio.Event() f"{config.alloc_retry_timeout}s"
if config.alloc_retry_timeout is not None
def signal_handler(): else "none"
logger.info("Received shutdown signal") ),
shutdown_event.set() )
loop = asyncio.get_running_loop() server = GMSRPCServer(
for sig in (signal.SIGTERM, signal.SIGINT): config.socket_path,
loop.add_signal_handler(sig, signal_handler) device=config.device,
allocation_retry_interval=config.alloc_retry_interval,
await server.start() allocation_retry_timeout=config.alloc_retry_timeout,
)
logger.info("GPU Memory Service Server ready, waiting for connections...") logger.info("GPU Memory Service Server ready, waiting for connections...")
logger.info(f"Clients can connect via socket: {config.socket_path}") logger.info(f"Clients can connect via socket: {config.socket_path}")
await server.serve()
# Wait for shutdown signal
try:
await shutdown_event.wait()
finally:
logger.info("Shutting down GPU Memory Service Server...")
await server.stop()
logger.info("GPU Memory Service Server shutdown complete")
def main() -> None: def main() -> None:
......
...@@ -7,7 +7,6 @@ This module provides the client-side components for interacting with the ...@@ -7,7 +7,6 @@ This module provides the client-side components for interacting with the
GPU Memory Service: GPU Memory Service:
- GMSClientMemoryManager: Manages local VA mappings of remote GPU memory - GMSClientMemoryManager: Manages local VA mappings of remote GPU memory
- GMSRPCClient: Low-level RPC client (pure Python, no PyTorch dependency)
For PyTorch integration (MemPool, tensor utilities), see gpu_memory_service.client.torch. For PyTorch integration (MemPool, tensor utilities), see gpu_memory_service.client.torch.
""" """
...@@ -16,10 +15,8 @@ from gpu_memory_service.client.memory_manager import ( ...@@ -16,10 +15,8 @@ from gpu_memory_service.client.memory_manager import (
GMSClientMemoryManager, GMSClientMemoryManager,
StaleMemoryLayoutError, StaleMemoryLayoutError,
) )
from gpu_memory_service.client.rpc import GMSRPCClient
__all__ = [ __all__ = [
"GMSClientMemoryManager", "GMSClientMemoryManager",
"StaleMemoryLayoutError", "StaleMemoryLayoutError",
"GMSRPCClient",
] ]
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Client-side CUDA VMM utilities.
These functions wrap CUDA driver API calls used by the client memory manager
for importing, mapping, and unmapping GPU memory.
"""
from __future__ import annotations
import os
from cuda.bindings import driver as cuda
from gpu_memory_service.common.cuda_vmm_utils import check_cuda_result
from gpu_memory_service.common.types import GrantedLockType
def import_handle_from_fd(fd: int) -> int:
"""Import a CUDA memory handle from a file descriptor.
Closes the FD after import — the imported handle holds its own reference
to the physical allocation. Leaving the FD open leaks a DMA-buf ref that
prevents cuMemRelease from freeing GPU memory.
Args:
fd: POSIX file descriptor received via SCM_RIGHTS.
Returns:
CUDA memory handle.
"""
try:
result, handle = cuda.cuMemImportFromShareableHandle(
fd,
cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR,
)
check_cuda_result(result, "cuMemImportFromShareableHandle")
return int(handle)
finally:
os.close(fd)
def reserve_va(size: int, granularity: int) -> int:
"""Reserve virtual address space.
Args:
size: Size in bytes (should be aligned to granularity).
granularity: VMM allocation granularity.
Returns:
Reserved virtual address.
"""
result, va = cuda.cuMemAddressReserve(size, granularity, 0, 0)
check_cuda_result(result, "cuMemAddressReserve")
return int(va)
def free_va(va: int, size: int) -> None:
"""Free a virtual address reservation.
Args:
va: Virtual address to free.
size: Size of the reservation.
"""
(result,) = cuda.cuMemAddressFree(va, size)
check_cuda_result(result, "cuMemAddressFree")
def map_to_va(va: int, size: int, handle: int) -> None:
"""Map a CUDA handle to a virtual address.
Args:
va: Virtual address (must be reserved).
size: Size of the mapping.
handle: CUDA memory handle.
"""
(result,) = cuda.cuMemMap(va, size, 0, handle, 0)
check_cuda_result(result, "cuMemMap")
def set_access(va: int, size: int, device: int, access: GrantedLockType) -> None:
"""Set access permissions for a mapped region.
Args:
va: Virtual address.
size: Size of the region.
device: CUDA device index.
access: Access mode - RO for read-only, RW for read-write.
"""
acc = cuda.CUmemAccessDesc()
acc.location.type = cuda.CUmemLocationType.CU_MEM_LOCATION_TYPE_DEVICE
acc.location.id = device
acc.flags = (
cuda.CUmemAccess_flags.CU_MEM_ACCESS_FLAGS_PROT_READ
if access == GrantedLockType.RO
else cuda.CUmemAccess_flags.CU_MEM_ACCESS_FLAGS_PROT_READWRITE
)
(result,) = cuda.cuMemSetAccess(va, size, [acc], 1)
check_cuda_result(result, "cuMemSetAccess")
def unmap(va: int, size: int) -> None:
"""Unmap a virtual address region.
Args:
va: Virtual address to unmap.
size: Size of the mapping.
"""
(result,) = cuda.cuMemUnmap(va, size)
check_cuda_result(result, "cuMemUnmap")
def release_handle(handle: int) -> None:
"""Release a CUDA memory handle.
Args:
handle: CUDA memory handle to release.
"""
(result,) = cuda.cuMemRelease(handle)
check_cuda_result(result, "cuMemRelease")
def validate_pointer(va: int) -> bool:
"""Validate that a mapped VA is accessible.
Returns True if the pointer is valid, False otherwise (logs a warning).
"""
result, _dev_ptr = cuda.cuPointerGetAttribute(
cuda.CUpointer_attribute.CU_POINTER_ATTRIBUTE_DEVICE_POINTER, va
)
if result != cuda.CUresult.CUDA_SUCCESS:
err_result, err_str = cuda.cuGetErrorString(result)
err_msg = ""
if err_result == cuda.CUresult.CUDA_SUCCESS and err_str:
err_msg = err_str.decode() if isinstance(err_str, bytes) else str(err_str)
import logging
logging.getLogger(__name__).warning(
"cuPointerGetAttribute failed for VA 0x%x: %s (%s)",
va,
result,
err_msg,
)
return False
return True
def synchronize() -> None:
"""Synchronize the current CUDA context.
Blocks until all preceding commands in the current context have completed.
"""
(result,) = cuda.cuCtxSynchronize()
check_cuda_result(result, "cuCtxSynchronize")
def set_current_device(device: int) -> None:
"""Set the current CUDA device by activating its primary context.
Args:
device: CUDA device index.
"""
result, ctx = cuda.cuDevicePrimaryCtxRetain(device)
check_cuda_result(result, "cuDevicePrimaryCtxRetain")
(result,) = cuda.cuCtxSetCurrent(ctx)
check_cuda_result(result, "cuCtxSetCurrent")
...@@ -8,7 +8,7 @@ Two-tier API for GPU memory lifecycle management: ...@@ -8,7 +8,7 @@ Two-tier API for GPU memory lifecycle management:
Tier 1 (Atomic Operations): Tier 1 (Atomic Operations):
- Connection: connect(), disconnect() - Connection: connect(), disconnect()
- Handle ops (server-side cuMem allocations): allocate_handle, export_handle, - Handle ops (server-side cuMem allocations): allocate_handle, export_handle,
get_handle_info, free_handle, clear_all_handles, commit, list_handles, get_handle_info, free_handle, commit, list_handles,
get_memory_layout_hash get_memory_layout_hash
- VA ops (local address space): reserve_va, map_va, unmap_va, free_va - VA ops (local address space): reserve_va, map_va, unmap_va, free_va
- Metadata: metadata_put, metadata_get, metadata_list, metadata_delete - Metadata: metadata_put, metadata_get, metadata_list, metadata_delete
...@@ -34,25 +34,22 @@ import logging ...@@ -34,25 +34,22 @@ import logging
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Optional from typing import Dict, List, Optional
from gpu_memory_service.client.cuda_vmm_utils import free_va as _cuda_free_va from gpu_memory_service.client.session import _GMSClientSession
from gpu_memory_service.client.cuda_vmm_utils import ( from gpu_memory_service.common.cuda_utils import (
import_handle_from_fd,
map_to_va,
release_handle,
)
from gpu_memory_service.client.cuda_vmm_utils import reserve_va as _cuda_reserve_va
from gpu_memory_service.client.cuda_vmm_utils import (
set_access,
set_current_device,
synchronize,
unmap,
validate_pointer,
)
from gpu_memory_service.client.rpc import GMSRPCClient
from gpu_memory_service.common.cuda_vmm_utils import (
align_to_granularity, align_to_granularity,
get_allocation_granularity, cuda_set_current_device,
cuda_synchronize,
cuda_validate_pointer,
cumem_address_free,
cumem_address_reserve,
cumem_get_allocation_granularity,
cumem_import_from_shareable_handle_close_fd,
cumem_map,
cumem_release,
cumem_set_access,
cumem_unmap,
) )
from gpu_memory_service.common.protocol.messages import GetAllocationResponse
from gpu_memory_service.common.types import GrantedLockType, RequestedLockType from gpu_memory_service.common.types import GrantedLockType, RequestedLockType
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -97,6 +94,7 @@ class LocalMapping: ...@@ -97,6 +94,7 @@ class LocalMapping:
aligned_size: int aligned_size: int
handle: int # 0 if unmapped but VA reserved handle: int # 0 if unmapped but VA reserved
tag: str tag: str
layout_slot: int
def with_handle(self, handle: int) -> "LocalMapping": def with_handle(self, handle: int) -> "LocalMapping":
return LocalMapping( return LocalMapping(
...@@ -106,9 +104,14 @@ class LocalMapping: ...@@ -106,9 +104,14 @@ class LocalMapping:
self.aligned_size, self.aligned_size,
handle, handle,
self.tag, self.tag,
self.layout_slot,
) )
def with_allocation_id(self, allocation_id: str) -> "LocalMapping": def with_server_identity(
self,
allocation_id: str,
layout_slot: int,
) -> "LocalMapping":
return LocalMapping( return LocalMapping(
allocation_id, allocation_id,
self.va, self.va,
...@@ -116,6 +119,7 @@ class LocalMapping: ...@@ -116,6 +119,7 @@ class LocalMapping:
self.aligned_size, self.aligned_size,
self.handle, self.handle,
self.tag, self.tag,
layout_slot,
) )
...@@ -134,7 +138,7 @@ class GMSClientMemoryManager: ...@@ -134,7 +138,7 @@ class GMSClientMemoryManager:
self.socket_path = socket_path self.socket_path = socket_path
self.device = device self.device = device
self._client: Optional[GMSRPCClient] = None self._client: Optional[_GMSClientSession] = None
self._mappings: Dict[int, LocalMapping] = {} # va -> mapping self._mappings: Dict[int, LocalMapping] = {} # va -> mapping
self._inverse_mapping: Dict[str, int] = {} self._inverse_mapping: Dict[str, int] = {}
...@@ -145,8 +149,8 @@ class GMSClientMemoryManager: ...@@ -145,8 +149,8 @@ class GMSClientMemoryManager:
self._va_preserved = False self._va_preserved = False
self._last_memory_layout_hash: str = "" self._last_memory_layout_hash: str = ""
set_current_device(self.device) cuda_set_current_device(self.device)
self.granularity = get_allocation_granularity(device) self.granularity = cumem_get_allocation_granularity(device)
# ==================== Properties ==================== # ==================== Properties ====================
...@@ -180,40 +184,59 @@ class GMSClientMemoryManager: ...@@ -180,40 +184,59 @@ class GMSClientMemoryManager:
Updates self._granted_lock_type based on granted lock type. Saves memory layout hash Updates self._granted_lock_type based on granted lock type. Saves memory layout hash
for stale detection if server is in committed state. for stale detection if server is in committed state.
""" """
self._client = GMSRPCClient( if self._client is not None:
raise RuntimeError("Memory manager is already connected")
self._client = _GMSClientSession(
self.socket_path, self.socket_path,
lock_type=lock_type, lock_type=lock_type,
timeout_ms=timeout_ms, timeout_ms=timeout_ms,
) )
self._granted_lock_type = self._client.lock_type self._granted_lock_type = self._client.lock_type
# Save layout hash for stale detection on future remap if self._granted_lock_type == GrantedLockType.RW:
if self._client.committed: self._last_memory_layout_hash = ""
return
# Preserve the pre-unmap hash across reconnects so remap_all_vas can
# detect that another writer changed the committed layout while this
# process was disconnected.
if self._client.committed and (
not self._va_preserved or not self._last_memory_layout_hash
):
self._last_memory_layout_hash = self._client.get_memory_layout_hash() self._last_memory_layout_hash = self._client.get_memory_layout_hash()
elif not self._va_preserved:
self._last_memory_layout_hash = ""
def disconnect(self) -> None: def abort(self) -> None:
"""Close connection and release lock.""" """Drop the GMS session.
Clean callers should unmap first. This also supports abrupt session
drop with live mappings still present.
"""
if self._client is not None: if self._client is not None:
try: try:
self._client.close() self._client.close()
except Exception: finally:
pass
self._client = None self._client = None
self._granted_lock_type = None
return
self._granted_lock_type = None
# ==================== Tier 1: Handle Operations (server-side) ==================== # ==================== Tier 1: Handle Operations (server-side) ====================
def allocate_handle(self, size: int, tag: str = "default") -> str: def allocate_handle(self, size: int, tag: str = "default") -> tuple[str, int]:
"""Allocate a cuMem handle on the server. """Allocate a cuMem handle on the server.
Returns allocation_id. Size is aligned to VMM granularity before sending. Returns allocation_id and layout_slot. Size is aligned to VMM granularity
before sending.
""" """
self._require_rw() self._require_rw()
aligned_size = align_to_granularity(size, self.granularity) aligned_size = align_to_granularity(size, self.granularity)
allocation_id, server_aligned = self._client_rpc.allocate(aligned_size, tag) response = self._client_rpc.allocate_info(aligned_size, tag)
if int(server_aligned) != aligned_size: if int(response.aligned_size) != aligned_size:
raise RuntimeError( raise RuntimeError(
f"Alignment mismatch: {aligned_size} vs {server_aligned}" "GMS allocation alignment mismatch: "
f"{aligned_size} vs {response.aligned_size}"
) )
return allocation_id return response.allocation_id, int(response.layout_slot)
def export_handle(self, allocation_id: str) -> int: def export_handle(self, allocation_id: str) -> int:
"""Export allocation as POSIX FD.""" """Export allocation as POSIX FD."""
...@@ -225,34 +248,43 @@ class GMSClientMemoryManager: ...@@ -225,34 +248,43 @@ class GMSClientMemoryManager:
def free_handle(self, allocation_id: str) -> bool: def free_handle(self, allocation_id: str) -> bool:
"""Release a cuMem allocation on the server.""" """Release a cuMem allocation on the server."""
return self._client_rpc.free(allocation_id) ok = self._client_rpc.free(allocation_id)
if not ok:
raise RuntimeError(
f"GMS free_handle failed for allocation_id={allocation_id}"
)
return True
def clear_all_handles(self) -> int: def commit(self) -> bool:
"""Clear all allocations on the server. NO local unmap. """Synchronize, unmap writer mappings, then commit.
Safe at startup (no local mappings) and during failover Commit is a publish barrier. It guarantees all prior GPU writes in the
(preserves local VA reservations). current context are complete before the server transitions state. After
a successful commit, the former writer process no longer has any mapped
access to the published allocations. Any failure after local unmap
raises because the process cannot safely recover its CUDA VMM state.
""" """
self._require_rw() self._require_rw()
return self._client_rpc.clear_all()
def commit(self) -> bool: # Publish barrier: all writer-side GPU work must be visible before commit.
"""Server-only commit: transition to COMMITTED state. cuda_synchronize()
No synchronize(), no CUDA access flip. The caller is responsible for for mapping in list(self._mappings.values()):
synchronizing before calling this. Server closes the RW socket on if mapping.handle != 0:
success, so self._client becomes None. self.unmap_va(mapping.va)
"""
self._require_rw() self._va_preserved = True
ok = self._client_rpc.commit() self._unmapped = True
if ok:
self._client_rpc.commit()
self._client = None self._client = None
return bool(ok) self._granted_lock_type = None
return True
def get_memory_layout_hash(self) -> str: def get_memory_layout_hash(self) -> str:
return self._client_rpc.get_memory_layout_hash() return self._client_rpc.get_memory_layout_hash()
def list_handles(self, tag: Optional[str] = None) -> List[Dict]: def list_handles(self, tag: Optional[str] = None) -> List[GetAllocationResponse]:
return self._client_rpc.list_allocations(tag) return self._client_rpc.list_allocations(tag)
# ==================== Tier 1: Metadata ==================== # ==================== Tier 1: Metadata ====================
...@@ -276,26 +308,26 @@ class GMSClientMemoryManager: ...@@ -276,26 +308,26 @@ class GMSClientMemoryManager:
def reserve_va(self, size: int) -> int: def reserve_va(self, size: int) -> int:
"""Reserve virtual address space (cuMemAddressReserve). No tracking.""" """Reserve virtual address space (cuMemAddressReserve). No tracking."""
aligned_size = align_to_granularity(size, self.granularity) aligned_size = align_to_granularity(size, self.granularity)
return _cuda_reserve_va(aligned_size, self.granularity) return cumem_address_reserve(aligned_size, self.granularity)
def map_va(self, fd: int, va: int, size: int, allocation_id: str, tag: str) -> int: def map_va(
self,
fd: int,
va: int,
size: int,
allocation_id: str,
tag: str,
layout_slot: int,
) -> int:
"""Import FD + cuMemMap + set access + track. """Import FD + cuMemMap + set access + track.
Access is set based on current lock_type. Returns the CUDA handle. Access is set based on current lock_type. Returns the CUDA handle.
""" """
assert self._granted_lock_type is not None assert self._granted_lock_type is not None
aligned_size = align_to_granularity(size, self.granularity) aligned_size = align_to_granularity(size, self.granularity)
handle = import_handle_from_fd(fd) handle = cumem_import_from_shareable_handle_close_fd(fd)
try: cumem_map(va, aligned_size, handle)
map_to_va(va, aligned_size, handle) cumem_set_access(va, aligned_size, self.device, self._granted_lock_type)
set_access(va, aligned_size, self.device, self._granted_lock_type)
except Exception:
try:
unmap(va, aligned_size)
except Exception:
pass
release_handle(handle)
raise
self._track_mapping( self._track_mapping(
LocalMapping( LocalMapping(
allocation_id=allocation_id, allocation_id=allocation_id,
...@@ -304,6 +336,7 @@ class GMSClientMemoryManager: ...@@ -304,6 +336,7 @@ class GMSClientMemoryManager:
aligned_size=aligned_size, aligned_size=aligned_size,
handle=handle, handle=handle,
tag=tag, tag=tag,
layout_slot=layout_slot,
) )
) )
return handle return handle
...@@ -317,8 +350,8 @@ class GMSClientMemoryManager: ...@@ -317,8 +350,8 @@ class GMSClientMemoryManager:
mapping = self._mappings.get(va) mapping = self._mappings.get(va)
if mapping is None or mapping.handle == 0: if mapping is None or mapping.handle == 0:
return return
unmap(va, mapping.aligned_size) cumem_unmap(va, mapping.aligned_size)
release_handle(mapping.handle) cumem_release(mapping.handle)
self._mappings[va] = mapping.with_handle(0) self._mappings[va] = mapping.with_handle(0)
def free_va(self, va: int) -> None: def free_va(self, va: int) -> None:
...@@ -334,7 +367,7 @@ class GMSClientMemoryManager: ...@@ -334,7 +367,7 @@ class GMSClientMemoryManager:
mapping = self._mappings.get(va) mapping = self._mappings.get(va)
if mapping is None: if mapping is None:
return return
_cuda_free_va(va, mapping.aligned_size) cumem_address_free(va, mapping.aligned_size)
self._mappings.pop(va, None) self._mappings.pop(va, None)
self._inverse_mapping.pop(mapping.allocation_id, None) self._inverse_mapping.pop(mapping.allocation_id, None)
...@@ -370,28 +403,21 @@ class GMSClientMemoryManager: ...@@ -370,28 +403,21 @@ class GMSClientMemoryManager:
alloc_size = int(info.size) alloc_size = int(info.size)
aligned_size = int(info.aligned_size) aligned_size = int(info.aligned_size)
alloc_tag = str(getattr(info, "tag", "default")) alloc_tag = str(getattr(info, "tag", "default"))
layout_slot = int(info.layout_slot)
fd = self.export_handle(allocation_id) fd = self.export_handle(allocation_id)
va = self.reserve_va(aligned_size) va = self.reserve_va(aligned_size)
try: self.map_va(fd, va, alloc_size, allocation_id, alloc_tag, layout_slot)
self.map_va(fd, va, alloc_size, allocation_id, alloc_tag)
except Exception:
_cuda_free_va(va, align_to_granularity(aligned_size, self.granularity))
raise
return va return va
# Allocate path # Allocate path
if size <= 0: if size <= 0:
raise ValueError("size must be > 0 when allocation_id is None") raise ValueError("size must be > 0 when allocation_id is None")
alloc_id = self.allocate_handle(size, tag) alloc_id, layout_slot = self.allocate_handle(size, tag)
fd = self.export_handle(alloc_id) fd = self.export_handle(alloc_id)
aligned_size = align_to_granularity(size, self.granularity) aligned_size = align_to_granularity(size, self.granularity)
va = self.reserve_va(aligned_size) va = self.reserve_va(aligned_size)
try: self.map_va(fd, va, size, alloc_id, tag, layout_slot)
self.map_va(fd, va, size, alloc_id, tag)
except Exception:
_cuda_free_va(va, aligned_size)
raise
return va return va
def destroy_mapping(self, va: int) -> None: def destroy_mapping(self, va: int) -> None:
...@@ -402,38 +428,25 @@ class GMSClientMemoryManager: ...@@ -402,38 +428,25 @@ class GMSClientMemoryManager:
alloc_id = mapping.allocation_id alloc_id = mapping.allocation_id
try:
self.unmap_va(va)
except Exception as e:
logger.warning("Error in unmap_va for 0x%x: %s", va, e)
try:
self.free_va(va)
except Exception as e:
logger.warning("Error in free_va for 0x%x: %s", va, e)
# Only free server handle if we're RW and haven't committed # Only free server handle if we're RW and haven't committed
if self._granted_lock_type == GrantedLockType.RW: if self._granted_lock_type == GrantedLockType.RW:
try:
self.free_handle(alloc_id) self.free_handle(alloc_id)
except Exception:
pass self.unmap_va(va)
self.free_va(va)
def unmap_all_vas(self) -> None: def unmap_all_vas(self) -> None:
"""Synchronize + unmap all VAs. Preserves VA reservations for remap.""" """Synchronize + unmap all VAs. Preserves VA reservations for remap."""
synchronize() cuda_synchronize()
unmapped_count = 0 unmapped_count = 0
total_bytes = 0 total_bytes = 0
for va, mapping in list(self._mappings.items()): for va, mapping in list(self._mappings.items()):
if mapping.handle == 0: if mapping.handle == 0:
continue continue
try:
self.unmap_va(va) self.unmap_va(va)
unmapped_count += 1 unmapped_count += 1
total_bytes += mapping.aligned_size total_bytes += mapping.aligned_size
except Exception as e:
logger.warning("Error unmapping VA 0x%x: %s", va, e)
self._va_preserved = True self._va_preserved = True
self._unmapped = True self._unmapped = True
...@@ -451,7 +464,7 @@ class GMSClientMemoryManager: ...@@ -451,7 +464,7 @@ class GMSClientMemoryManager:
Checks layout hash for staleness. Validates each allocation still Checks layout hash for staleness. Validates each allocation still
exists and size matches before remapping. exists and size matches before remapping.
""" """
set_current_device(self.device) cuda_set_current_device(self.device)
# Stale layout check # Stale layout check
current_hash = self.get_memory_layout_hash() current_hash = self.get_memory_layout_hash()
...@@ -465,36 +478,50 @@ class GMSClientMemoryManager: ...@@ -465,36 +478,50 @@ class GMSClientMemoryManager:
assert self._granted_lock_type is not None assert self._granted_lock_type is not None
allocations_by_slot = {
int(info.layout_slot): info for info in self.list_handles()
}
remapped_count = 0 remapped_count = 0
total_bytes = 0 total_bytes = 0
for va, mapping in list(self._mappings.items()): for va, mapping in sorted(
self._mappings.items(), key=lambda item: item[1].layout_slot
):
if mapping.handle != 0: if mapping.handle != 0:
continue # Already mapped continue
# Validate allocation still exists alloc_info = allocations_by_slot.get(mapping.layout_slot)
try: if alloc_info is None:
alloc_info = self.get_handle_info(mapping.allocation_id)
except Exception as e:
raise StaleMemoryLayoutError( raise StaleMemoryLayoutError(
f"Allocation {mapping.allocation_id} no longer exists: {e}" f"Layout slot {mapping.layout_slot} is missing from the committed layout"
) from e )
if int(alloc_info.aligned_size) != mapping.aligned_size: if int(alloc_info.aligned_size) != mapping.aligned_size:
raise StaleMemoryLayoutError( raise StaleMemoryLayoutError(
f"Allocation {mapping.allocation_id} size changed: " f"Layout slot {mapping.layout_slot} size changed: "
f"{mapping.aligned_size} vs {int(alloc_info.aligned_size)}" f"{mapping.aligned_size} vs {int(alloc_info.aligned_size)}"
) )
if str(alloc_info.tag) != mapping.tag:
raise StaleMemoryLayoutError(
f"Layout slot {mapping.layout_slot} tag changed: "
f"{mapping.tag} vs {alloc_info.tag}"
)
# Re-import and map to preserved VA fd = self.export_handle(alloc_info.allocation_id)
fd = self.export_handle(mapping.allocation_id) handle = cumem_import_from_shareable_handle_close_fd(fd)
handle = import_handle_from_fd(fd) cumem_map(va, mapping.aligned_size, handle)
map_to_va(va, mapping.aligned_size, handle) cumem_set_access(
set_access(va, mapping.aligned_size, self.device, self._granted_lock_type) va, mapping.aligned_size, self.device, self._granted_lock_type
)
synchronize() cuda_synchronize()
validate_pointer(va) cuda_validate_pointer(va)
self._mappings[va] = mapping.with_handle(handle) if mapping.allocation_id != alloc_info.allocation_id:
self._inverse_mapping.pop(mapping.allocation_id, None)
self._mappings[va] = mapping.with_server_identity(
alloc_info.allocation_id,
int(alloc_info.layout_slot),
).with_handle(handle)
self._inverse_mapping[alloc_info.allocation_id] = va
remapped_count += 1 remapped_count += 1
total_bytes += mapping.aligned_size total_bytes += mapping.aligned_size
...@@ -523,24 +550,26 @@ class GMSClientMemoryManager: ...@@ -523,24 +550,26 @@ class GMSClientMemoryManager:
) )
reallocated = 0 reallocated = 0
for va, mapping in list(self._mappings.items()): for va, mapping in sorted(
self._mappings.items(), key=lambda item: item[1].layout_slot
):
if mapping.handle != 0: if mapping.handle != 0:
continue continue
# Allocate fresh handle on server (uses raw RPC to avoid re-aligning) response = self._client_rpc.allocate_info(mapping.aligned_size, tag)
allocation_id, server_aligned = self._client_rpc.allocate( if int(response.aligned_size) != mapping.aligned_size:
mapping.aligned_size, tag
)
if int(server_aligned) != mapping.aligned_size:
raise RuntimeError( raise RuntimeError(
f"Alignment mismatch during reallocation: " "GMS reallocation alignment mismatch: "
f"{mapping.aligned_size} vs {server_aligned}" f"{mapping.aligned_size} vs {response.aligned_size}"
) )
allocation_id = response.allocation_id
# Update tracking: new allocation_id, handle stays 0
old_alloc_id = mapping.allocation_id old_alloc_id = mapping.allocation_id
self._inverse_mapping.pop(old_alloc_id, None) self._inverse_mapping.pop(old_alloc_id, None)
self._mappings[va] = mapping.with_allocation_id(allocation_id) self._mappings[va] = mapping.with_server_identity(
allocation_id,
int(response.layout_slot),
)
self._inverse_mapping[allocation_id] = va self._inverse_mapping[allocation_id] = va
reallocated += 1 reallocated += 1
...@@ -551,43 +580,25 @@ class GMSClientMemoryManager: ...@@ -551,43 +580,25 @@ class GMSClientMemoryManager:
# ==================== Lifecycle ==================== # ==================== Lifecycle ====================
def close(self, free: bool = False) -> None: def close(self) -> None:
"""Best-effort cleanup. NOT reliable in crash/signal paths. """Strict cleanup.
synchronize + unmap all + free all VAs + disconnect. synchronize + unmap all + free all VAs + abort.
free=True: also clear_all_handles() on server before disconnect.
VAs are freed by CUDA context teardown on process exit anyway.
""" """
try: cuda_synchronize()
synchronize()
except Exception:
pass
for va in list(self._mappings.keys()): for va in list(self._mappings.keys()):
try:
self.unmap_va(va) self.unmap_va(va)
except Exception as e:
logger.warning("Error unmapping VA 0x%x during close: %s", va, e)
for va in list(self._mappings.keys()):
try:
self.free_va(va) self.free_va(va)
except Exception as e:
logger.warning("Error freeing VA 0x%x during close: %s", va, e)
if (
free
and self._client is not None
and self._granted_lock_type == GrantedLockType.RW
):
try:
self.clear_all_handles()
except Exception as e:
logger.warning("Error clearing handles during close: %s", e)
self.disconnect() self.abort()
self._unmapped = False self._unmapped = False
self._va_preserved = False self._va_preserved = False
from gpu_memory_service.client.torch.allocator import (
evict_gms_client_memory_manager,
)
evict_gms_client_memory_manager(self)
def __enter__(self) -> "GMSClientMemoryManager": def __enter__(self) -> "GMSClientMemoryManager":
return self return self
...@@ -598,7 +609,7 @@ class GMSClientMemoryManager: ...@@ -598,7 +609,7 @@ class GMSClientMemoryManager:
# ==================== Internals ==================== # ==================== Internals ====================
@property @property
def _client_rpc(self) -> GMSRPCClient: def _client_rpc(self) -> _GMSClientSession:
"""Get connected client or raise.""" """Get connected client or raise."""
if self._client is None: if self._client is None:
if self._unmapped: if self._unmapped:
......
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
"""GPU Memory Service RPC Client. """Internal GPU Memory Service transport.
Low-level RPC client stub. The client provides a simple interface for acquiring This module only owns Unix socket transport and typed request/response exchange.
locks and performing allocation operations. The socket connection IS the lock. Session semantics live in `gpu_memory_service.client.session`.
This module has NO PyTorch dependency.
Usage:
# Writer (acquires RW lock in constructor)
with GMSRPCClient(socket_path, lock_type=RequestedLockType.RW) as client:
alloc_id, aligned_size = client.allocate(size=1024*1024)
fd = client.export(alloc_id)
# ... write weights using fd ...
client.commit()
# Lock released on exit
# Reader (acquires RO lock in constructor)
client = GMSRPCClient(socket_path, lock_type=RequestedLockType.RO)
if client.committed: # Check if weights are valid
allocations = client.list_allocations()
for alloc in allocations:
fd = client.export(alloc["allocation_id"])
# ... import and map fd ...
# Keep connection open during inference!
# client.close() only when done with inference
""" """
from __future__ import annotations
import logging import logging
import os
import socket import socket
from typing import Dict, List, Optional, Tuple, Type, TypeVar from typing import Optional, Tuple, Type, TypeVar
from gpu_memory_service.common.protocol.messages import ( from gpu_memory_service.common.protocol.messages import (
AllocateRequest,
AllocateResponse,
ClearAllRequest,
ClearAllResponse,
CommitRequest,
CommitResponse,
ErrorResponse, ErrorResponse,
ExportRequest,
FreeRequest,
FreeResponse,
GetAllocationRequest,
GetAllocationResponse,
GetAllocationStateRequest,
GetAllocationStateResponse,
GetLockStateRequest,
GetLockStateResponse,
GetStateHashRequest,
GetStateHashResponse,
HandshakeRequest, HandshakeRequest,
HandshakeResponse, HandshakeResponse,
ListAllocationsRequest,
ListAllocationsResponse,
MetadataDeleteRequest,
MetadataDeleteResponse,
MetadataGetRequest,
MetadataGetResponse,
MetadataListRequest,
MetadataListResponse,
MetadataPutRequest,
MetadataPutResponse,
) )
from gpu_memory_service.common.protocol.wire import recv_message_sync, send_message_sync from gpu_memory_service.common.protocol.wire import recv_message_sync, send_message_sync
from gpu_memory_service.common.types import ( from gpu_memory_service.common.types import RequestedLockType
RW_REQUIRED,
GrantedLockType,
RequestedLockType,
)
T = TypeVar("T") T = TypeVar("T")
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class GMSRPCClient: class _GMSRPCTransport:
"""GPU Memory Service RPC Client. """Raw GMS Unix socket transport."""
CRITICAL: Socket connection IS the lock.
- Constructor blocks until lock is acquired
- close() releases the lock
- committed property tells readers if weights are valid
For writers (lock_type=RequestedLockType.RW): def __init__(self, socket_path: str):
- Use context manager (with statement) for automatic lock release
- Call commit() after weights are written
- Call clear_all() before loading new model
For readers (lock_type=RequestedLockType.RO):
- Check committed property after construction
- Keep connection open during inference lifetime
- Only call close() when shutting down or allowing weight updates
"""
def __init__(
self,
socket_path: str,
lock_type: RequestedLockType = RequestedLockType.RO,
timeout_ms: Optional[int] = None,
):
"""Connect to Allocation Server and acquire lock.
Args:
socket_path: Path to server's Unix domain socket
lock_type: Requested lock type (RW, RO, or RW_OR_RO)
timeout_ms: Timeout in milliseconds for lock acquisition.
None means wait indefinitely.
Raises:
ConnectionError: If connection fails
TimeoutError: If timeout_ms expires waiting for lock
"""
self.socket_path = socket_path self.socket_path = socket_path
self._requested_lock_type = lock_type
self._socket: Optional[socket.socket] = None self._socket: Optional[socket.socket] = None
self._recv_buffer = bytearray() self._recv_buffer = bytearray()
self._committed = False
self._granted_lock_type: Optional[GrantedLockType] = None
# Connect and acquire lock @property
self._connect(timeout_ms=timeout_ms) def is_connected(self) -> bool:
return self._socket is not None
def _connect(self, timeout_ms: Optional[int]) -> None: def connect(self) -> None:
"""Connect to server and perform handshake (lock acquisition)."""
self._socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) self._socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
try: try:
self._socket.connect(self.socket_path) self._socket.connect(self.socket_path)
except FileNotFoundError: except FileNotFoundError:
self._socket.close() self._socket.close()
self._socket = None self._socket = None
raise ConnectionError(f"Server not running at {self.socket_path}") from None raise ConnectionError(
except Exception as e: f"GMS server not running at {self.socket_path}"
) from None
except Exception as exc:
self._socket.close() self._socket.close()
self._socket = None self._socket = None
raise ConnectionError(f"Failed to connect: {e}") from e raise ConnectionError(f"Failed to connect to GMS: {exc}") from exc
# Handshake I/O — clean up socket on any failure def handshake(
try: self,
request = HandshakeRequest( lock_type: RequestedLockType,
lock_type=self._requested_lock_type, timeout_ms: Optional[int],
timeout_ms=timeout_ms, ) -> HandshakeResponse:
response, _ = self.request_with_fd(
HandshakeRequest(lock_type=lock_type, timeout_ms=timeout_ms),
HandshakeResponse,
error_prefix="GMS handshake",
) )
send_message_sync(self._socket, request) return response
# May block waiting for lock def request(self, request, response_type: Type[T]) -> T:
response, _, self._recv_buffer = recv_message_sync( response, fd = self.request_with_fd(request, response_type)
self._socket, self._recv_buffer if fd >= 0:
os.close(fd)
raise RuntimeError(
f"GMS request {type(request).__name__} returned an unexpected FD"
) )
except Exception: return response
self._socket.close()
self._socket = None
raise
if isinstance(response, ErrorResponse):
self._socket.close()
self._socket = None
raise ConnectionError(f"Handshake error: {response.error}")
if not isinstance(response, HandshakeResponse):
self._socket.close()
self._socket = None
raise ConnectionError(f"Unexpected response: {type(response)}")
if not response.success:
self._socket.close()
self._socket = None
raise TimeoutError("Timeout waiting for lock")
self._committed = response.committed def request_with_fd(
# Store granted lock type (may differ from requested for rw_or_ro mode) self,
if response.granted_lock_type is not None: request,
self._granted_lock_type = response.granted_lock_type response_type: Type[T],
elif self._requested_lock_type == RequestedLockType.RW: *,
self._granted_lock_type = GrantedLockType.RW error_prefix: Optional[str] = None,
else: ) -> Tuple[T, int]:
self._granted_lock_type = GrantedLockType.RO response, fd = self._send_recv(request, error_prefix=error_prefix)
logger.info( if not isinstance(response, response_type):
f"Connected with {self._requested_lock_type.value} lock (granted={self._granted_lock_type.value}), " prefix = error_prefix or f"GMS request {type(request).__name__}"
f"committed={self._committed}" if fd >= 0:
os.close(fd)
raise RuntimeError(
f"{prefix} returned unexpected response type: {type(response)}"
) )
return response, fd
@property def _send_recv(
def committed(self) -> bool: self, request, *, error_prefix: Optional[str] = None
"""Check if weights are committed (valid).""" ) -> Tuple[object, int]:
return self._committed if self._socket is None:
raise RuntimeError("Attempted GMS request on disconnected transport")
@property
def lock_type(self) -> Optional[GrantedLockType]:
"""Get the lock type actually granted by the server.
For rw_or_ro mode, this tells you whether RW or RO was granted.
"""
return self._granted_lock_type
@property
def is_connected(self) -> bool:
"""Check if client is connected."""
return self._socket is not None
def _send_recv(self, request) -> Tuple[object, int]:
"""Send request and receive response. Returns (response, fd)."""
if not self._socket:
raise RuntimeError("Client not connected")
prefix = error_prefix or f"GMS request {type(request).__name__}"
try:
send_message_sync(self._socket, request) send_message_sync(self._socket, request)
response, fd, self._recv_buffer = recv_message_sync( response, fd, self._recv_buffer = recv_message_sync(
self._socket, self._recv_buffer self._socket, self._recv_buffer
) )
except Exception as exc:
try:
self._socket.close()
except Exception:
pass
self._socket = None
raise ConnectionError(f"{prefix} failed: {exc}") from exc
if isinstance(response, ErrorResponse): if isinstance(response, ErrorResponse):
raise RuntimeError(f"Server error: {response.error}") if fd >= 0:
os.close(fd)
raise RuntimeError(f"{prefix} error: {response.error}")
return response, fd return response, fd
def _call(self, request, response_type: Type[T]) -> T:
"""Send request, validate response type, return typed response."""
if type(request) in RW_REQUIRED and self.lock_type != GrantedLockType.RW:
raise RuntimeError("Operation requires RW connection")
response, _ = self._send_recv(request)
if not isinstance(response, response_type):
raise RuntimeError(f"Unexpected response: {type(response)}")
return response
def get_lock_state(self) -> GetLockStateResponse:
return self._call(GetLockStateRequest(), GetLockStateResponse)
def get_allocation_state(self) -> GetAllocationStateResponse:
return self._call(GetAllocationStateRequest(), GetAllocationStateResponse)
def is_ready(self) -> bool:
return self.committed
def commit(self) -> bool:
"""Commit weights and release RW lock. Returns True on success."""
if CommitRequest in RW_REQUIRED and self.lock_type != GrantedLockType.RW:
raise RuntimeError("Operation requires RW connection")
try:
response, _ = self._send_recv(CommitRequest())
ok = isinstance(response, CommitResponse) and response.success
except (ConnectionResetError, BrokenPipeError, OSError) as e:
# Server closes RW socket as part of commit
logger.debug(
f"Commit saw socket error ({type(e).__name__}); verifying via RO connect"
)
self.close()
try:
ro = GMSRPCClient(
self.socket_path, lock_type=RequestedLockType.RO, timeout_ms=1000
)
try:
ok = ro.committed
finally:
ro.close()
except TimeoutError:
ok = False
if ok:
self._committed = True
self.close()
logger.info("Committed weights and released RW connection")
return True
return False
def allocate(self, size: int, tag: str = "default") -> Tuple[str, int]:
"""Returns (allocation_id, aligned_size)."""
r = self._call(AllocateRequest(size=size, tag=tag), AllocateResponse)
return r.allocation_id, r.aligned_size
def export(self, allocation_id: str) -> int:
"""Export allocation as POSIX FD. Caller must close."""
_, fd = self._send_recv(ExportRequest(allocation_id=allocation_id))
if fd < 0:
raise RuntimeError("No FD received from server")
return fd
def get_allocation(self, allocation_id: str) -> GetAllocationResponse:
return self._call(
GetAllocationRequest(allocation_id=allocation_id), GetAllocationResponse
)
def list_allocations(self, tag: Optional[str] = None) -> List[Dict]:
return self._call(
ListAllocationsRequest(tag=tag), ListAllocationsResponse
).allocations
def free(self, allocation_id: str) -> bool:
return self._call(
FreeRequest(allocation_id=allocation_id), FreeResponse
).success
def clear_all(self) -> int:
return self._call(ClearAllRequest(), ClearAllResponse).cleared_count
def metadata_put(
self, key: str, allocation_id: str, offset_bytes: int, value: bytes
) -> bool:
req = MetadataPutRequest(
key=key, allocation_id=allocation_id, offset_bytes=offset_bytes, value=value
)
return self._call(req, MetadataPutResponse).success
def metadata_get(self, key: str) -> Optional[tuple[str, int, bytes]]:
"""Returns (allocation_id, offset_bytes, value) or None if not found."""
r = self._call(MetadataGetRequest(key=key), MetadataGetResponse)
return (r.allocation_id, r.offset_bytes, r.value) if r.found else None
def metadata_delete(self, key: str) -> bool:
return self._call(
MetadataDeleteRequest(key=key), MetadataDeleteResponse
).deleted
def metadata_list(self, prefix: str = "") -> List[str]:
return self._call(MetadataListRequest(prefix=prefix), MetadataListResponse).keys
def get_memory_layout_hash(self) -> str:
"""Get state hash (hash of allocations + metadata). Empty if not committed."""
return self._call(
GetStateHashRequest(), GetStateHashResponse
).memory_layout_hash
def close(self) -> None: def close(self) -> None:
"""Close connection and release lock.""" if self._socket is None:
if self._socket: return
try: try:
self._socket.close() self._socket.close()
except Exception: except Exception as exc:
pass raise ConnectionError(
f"Failed to close GMS transport socket: {exc}"
) from exc
finally:
self._socket = None self._socket = None
lock_str = self.lock_type.value if self.lock_type else "unknown"
logger.info(f"Closed {lock_str} connection")
def __enter__(self) -> "GMSRPCClient": def __enter__(self) -> "_GMSRPCTransport":
"""Context manager entry."""
return self return self
def __exit__(self, exc_type, exc_val, exc_tb) -> None: def __exit__(self, exc_type, exc_val, exc_tb) -> None:
"""Context manager exit."""
self.close() self.close()
def __del__(self): def __del__(self):
"""Destructor: warn if connection not closed.""" if self._socket is not None:
if self._socket: try:
logger.warning("GMSRPCClient not closed properly") self._socket.close()
except Exception:
pass
self._socket = None
logger.warning("_GMSRPCTransport not closed properly")
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Internal GPU Memory Service client session."""
from __future__ import annotations
import logging
from typing import List, Optional, Tuple
from gpu_memory_service.client.rpc import _GMSRPCTransport
from gpu_memory_service.common.protocol.messages import (
AllocateRequest,
AllocateResponse,
CommitRequest,
CommitResponse,
ExportAllocationRequest,
ExportAllocationResponse,
FreeAllocationRequest,
FreeAllocationResponse,
GetAllocationRequest,
GetAllocationResponse,
GetAllocationStateRequest,
GetAllocationStateResponse,
GetLockStateRequest,
GetLockStateResponse,
GetStateHashRequest,
GetStateHashResponse,
HandshakeResponse,
ListAllocationsRequest,
ListAllocationsResponse,
MetadataDeleteRequest,
MetadataDeleteResponse,
MetadataGetRequest,
MetadataGetResponse,
MetadataListRequest,
MetadataListResponse,
MetadataPutRequest,
MetadataPutResponse,
)
from gpu_memory_service.common.types import GrantedLockType, RequestedLockType
logger = logging.getLogger(__name__)
class _GMSClientSession:
"""Connected GMS client session with granted lock state."""
def __init__(
self,
socket_path: str,
lock_type: RequestedLockType,
timeout_ms: Optional[int],
):
self._requested_lock_type = lock_type
self._transport = _GMSRPCTransport(socket_path)
self._transport.connect()
try:
response = self._transport.handshake(lock_type, timeout_ms)
except Exception:
try:
self._transport.close()
except Exception:
pass
raise
self._initialize_from_handshake(response)
def _initialize_from_handshake(self, response: HandshakeResponse) -> None:
if not response.success:
self._transport.close()
raise TimeoutError("Timeout waiting for lock")
self._committed = response.committed
if response.granted_lock_type is None:
self._transport.close()
raise RuntimeError("HandshakeResponse omitted granted_lock_type")
self._granted_lock_type = response.granted_lock_type
logger.info(
"Connected with %s lock (granted=%s), committed=%s",
self._requested_lock_type.value,
self._granted_lock_type.value,
self._committed,
)
@property
def committed(self) -> bool:
return self._committed
@property
def lock_type(self) -> GrantedLockType:
return self._granted_lock_type
@property
def is_connected(self) -> bool:
return self._transport.is_connected
def get_lock_state(self) -> GetLockStateResponse:
return self._transport.request(GetLockStateRequest(), GetLockStateResponse)
def get_allocation_state(self) -> GetAllocationStateResponse:
return self._transport.request(
GetAllocationStateRequest(), GetAllocationStateResponse
)
def is_ready(self) -> bool:
return self.committed
def commit(self) -> bool:
response = self._transport.request(CommitRequest(), CommitResponse)
if not response.success:
raise RuntimeError("GMS commit returned failure")
self._committed = True
try:
self.close()
except ConnectionError as exc:
logger.warning("Commit succeeded but closing transport failed: %s", exc)
logger.info("Committed weights and released RW connection")
return True
def allocate_info(self, size: int, tag: str = "default") -> AllocateResponse:
return self._transport.request(
AllocateRequest(size=size, tag=tag), AllocateResponse
)
def allocate(self, size: int, tag: str = "default") -> Tuple[str, int]:
response = self.allocate_info(size=size, tag=tag)
return response.allocation_id, response.aligned_size
def export(self, allocation_id: str) -> int:
response, fd = self._transport.request_with_fd(
ExportAllocationRequest(allocation_id=allocation_id),
ExportAllocationResponse,
)
if fd < 0:
raise RuntimeError(
f"GMS export returned no FD for allocation_id={allocation_id}"
)
return fd
def get_allocation(self, allocation_id: str) -> GetAllocationResponse:
return self._transport.request(
GetAllocationRequest(allocation_id=allocation_id),
GetAllocationResponse,
)
def list_allocations(
self, tag: Optional[str] = None
) -> List[GetAllocationResponse]:
return self._transport.request(
ListAllocationsRequest(tag=tag),
ListAllocationsResponse,
).allocations
def free(self, allocation_id: str) -> bool:
return self._transport.request(
FreeAllocationRequest(allocation_id=allocation_id),
FreeAllocationResponse,
).success
def metadata_put(
self, key: str, allocation_id: str, offset_bytes: int, value: bytes
) -> bool:
return self._transport.request(
MetadataPutRequest(
key=key,
allocation_id=allocation_id,
offset_bytes=offset_bytes,
value=value,
),
MetadataPutResponse,
).success
def metadata_get(self, key: str) -> Optional[tuple[str, int, bytes]]:
response = self._transport.request(
MetadataGetRequest(key=key), MetadataGetResponse
)
if not response.found:
return None
return response.allocation_id, response.offset_bytes, response.value
def metadata_delete(self, key: str) -> bool:
return self._transport.request(
MetadataDeleteRequest(key=key), MetadataDeleteResponse
).deleted
def metadata_list(self, prefix: str = "") -> List[str]:
return self._transport.request(
MetadataListRequest(prefix=prefix), MetadataListResponse
).keys
def get_memory_layout_hash(self) -> str:
return self._transport.request(
GetStateHashRequest(), GetStateHashResponse
).memory_layout_hash
def close(self) -> None:
self._transport.close()
logger.info("Closed %s connection", self._granted_lock_type.value)
def __enter__(self) -> "_GMSClientSession":
return self
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
self.close()
...@@ -12,7 +12,9 @@ This module provides PyTorch-specific functionality: ...@@ -12,7 +12,9 @@ This module provides PyTorch-specific functionality:
from gpu_memory_service.client.torch.allocator import ( from gpu_memory_service.client.torch.allocator import (
get_gms_client_memory_manager, get_gms_client_memory_manager,
get_gms_client_memory_managers,
get_or_create_gms_client_memory_manager, get_or_create_gms_client_memory_manager,
gms_use_mem_pool,
) )
from gpu_memory_service.client.torch.module import ( from gpu_memory_service.client.torch.module import (
materialize_module_from_gms, materialize_module_from_gms,
...@@ -23,6 +25,8 @@ __all__ = [ ...@@ -23,6 +25,8 @@ __all__ = [
# GMS client memory manager # GMS client memory manager
"get_or_create_gms_client_memory_manager", "get_or_create_gms_client_memory_manager",
"get_gms_client_memory_manager", "get_gms_client_memory_manager",
"get_gms_client_memory_managers",
"gms_use_mem_pool",
# Tensor operations (public API) # Tensor operations (public API)
"register_module_tensors", "register_module_tensors",
"materialize_module_from_gms", "materialize_module_from_gms",
......
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
"""GPU Memory Service allocator management (singleton). """GPU Memory Service allocator registry for PyTorch integration."""
Manages a single weights memory manager and PyTorch MemPool integration.
Only one GMS scope is needed: weights. KV cache is handled by CuMemAllocator.
"""
from __future__ import annotations from __future__ import annotations
import logging import logging
from typing import TYPE_CHECKING, Any, Optional, Tuple from contextlib import contextmanager
from contextvars import ContextVar
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Iterator, Optional
from gpu_memory_service.common.types import GrantedLockType, RequestedLockType from gpu_memory_service.common.types import GrantedLockType, RequestedLockType
if TYPE_CHECKING: if TYPE_CHECKING:
import torch
from gpu_memory_service.client.memory_manager import GMSClientMemoryManager from gpu_memory_service.client.memory_manager import GMSClientMemoryManager
from torch.cuda.memory import MemPool from torch.cuda.memory import MemPool
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Singleton state
_manager: Optional["GMSClientMemoryManager"] = None @dataclass
_mem_pool: Optional["MemPool"] = None class _TagState:
_tag: str = "weights" manager: "GMSClientMemoryManager"
_callbacks_initialized: bool = False mem_pool: "MemPool | None"
_pluggable_alloc: Optional[Any] = None socket_path: str
device: int
_tag_states: dict[str, _TagState] = {}
_active_tag: ContextVar[str | None] = ContextVar(
"gpu_memory_service_active_tag",
default=None,
)
_callbacks_initialized = False
_pluggable_alloc: Any | None = None
def _gms_malloc(size: int, device: int, stream: int) -> int: def _gms_malloc(size: int, device: int, stream: int) -> int:
"""Route malloc to the singleton weights manager.""" tag = _active_tag.get()
if _manager is None: if tag is None:
raise RuntimeError("No GMS manager initialized") raise RuntimeError("No active GMS allocation tag")
va = _manager.create_mapping(size=int(size), tag=_tag)
logger.debug("[GMS] malloc: va=0x%x size=%d", va, size) state = _tag_states.get(tag)
if state is None:
raise RuntimeError(f"Unknown GMS allocation tag: {tag}")
va = state.manager.create_mapping(size=int(size), tag=tag)
logger.debug("[GMS] malloc(tag=%s): va=0x%x size=%d", tag, va, size)
return va return va
def _gms_free(ptr: int, size: int, device: int, stream: int) -> None: def _gms_free(ptr: int, size: int, device: int, stream: int) -> None:
"""Route free to the singleton weights manager.""" va = int(ptr)
if _manager is None: for tag, state in _tag_states.items():
logger.warning("[GMS] free: no manager, ignoring va=0x%x", ptr) if va not in state.manager.mappings:
continue
logger.debug("[GMS] free(tag=%s): va=0x%x size=%d", tag, va, size)
state.manager.destroy_mapping(va)
return return
if int(ptr) in _manager.mappings: logger.warning("[GMS] free: no manager owns va=0x%x, ignoring", va)
logger.debug("[GMS] free: va=0x%x size=%d", ptr, size)
_manager.destroy_mapping(int(ptr))
else:
logger.warning("[GMS] free: manager does not own va=0x%x, ignoring", ptr)
def _ensure_callbacks_initialized() -> "MemPool": def _ensure_callbacks_initialized() -> None:
"""Initialize C-level callbacks exactly once, return a new MemPool."""
global _callbacks_initialized, _pluggable_alloc global _callbacks_initialized, _pluggable_alloc
from gpu_memory_service.client.torch.extensions import _allocator_ext as cumem from gpu_memory_service.client.torch.extensions import _allocator_ext as cumem
from torch.cuda import CUDAPluggableAllocator from torch.cuda import CUDAPluggableAllocator
from torch.cuda.memory import MemPool
if not _callbacks_initialized: if _callbacks_initialized:
_pluggable_alloc = CUDAPluggableAllocator( return
cumem.__file__, "my_malloc", "my_free"
) _pluggable_alloc = CUDAPluggableAllocator(cumem.__file__, "my_malloc", "my_free")
cumem.init_module(_gms_malloc, _gms_free) cumem.init_module(_gms_malloc, _gms_free)
_callbacks_initialized = True _callbacks_initialized = True
def _create_mem_pool() -> "MemPool":
from torch.cuda.memory import MemPool
assert _pluggable_alloc is not None
return MemPool(allocator=_pluggable_alloc.allocator()) return MemPool(allocator=_pluggable_alloc.allocator())
...@@ -74,66 +91,98 @@ def get_or_create_gms_client_memory_manager( ...@@ -74,66 +91,98 @@ def get_or_create_gms_client_memory_manager(
*, *,
tag: str = "weights", tag: str = "weights",
timeout_ms: Optional[int] = None, timeout_ms: Optional[int] = None,
) -> Tuple["GMSClientMemoryManager", Optional["MemPool"]]: ) -> "GMSClientMemoryManager":
"""Get existing memory manager, or create a new one.
Args:
socket_path: Unix socket path for the allocation server.
device: CUDA device index.
mode: RW for cold start, RO for import-only, RW_OR_RO for auto.
tag: Allocation tag for RW mode.
timeout_ms: Lock acquisition timeout (None = wait indefinitely).
Returns:
(gms_client_memory_manager, pool) - pool is None for RO mode.
"""
global _manager, _mem_pool, _tag
from gpu_memory_service.client.memory_manager import GMSClientMemoryManager from gpu_memory_service.client.memory_manager import GMSClientMemoryManager
if _manager is not None: state = _tag_states.get(tag)
return _get_existing(mode) if state is not None:
if state.socket_path != socket_path or state.device != device:
raise RuntimeError(
f"GMS allocator tag={tag} was initialized for "
f"{state.socket_path} on device {state.device}, not {socket_path} "
f"on device {device}"
)
manager = state.manager
if not manager.is_connected:
if manager.mappings or manager.is_unmapped or manager.granted_lock_type:
raise RuntimeError(
f"GMS allocator tag={tag} is disconnected but still owns "
"preserved state; recreate the process instead of reusing it"
)
manager._client = None
manager._granted_lock_type = None
_tag_states.pop(tag, None)
state = None
if state is not None:
current = state.manager.granted_lock_type
if mode == RequestedLockType.RW and current != GrantedLockType.RW:
raise RuntimeError(
f"Cannot get RW allocator for tag {tag}: existing is in {current} mode"
)
if mode == RequestedLockType.RO and current != GrantedLockType.RO:
raise RuntimeError(
f"Cannot get RO allocator for tag {tag}: existing is in {current} mode"
)
return state.manager
manager = GMSClientMemoryManager(socket_path, device=device) manager = GMSClientMemoryManager(socket_path, device=device)
manager.connect(mode, timeout_ms=timeout_ms) manager.connect(mode, timeout_ms=timeout_ms)
mem_pool = None
if manager.granted_lock_type == GrantedLockType.RW: if manager.granted_lock_type == GrantedLockType.RW:
pool = _ensure_callbacks_initialized() _ensure_callbacks_initialized()
# Only set globals after mempool succeeds (avoids partial singleton) mem_pool = _create_mem_pool()
_manager = manager
_tag = tag _tag_states[tag] = _TagState(
_mem_pool = pool manager=manager,
logger.info("[GMS] Created RW allocator (device=%d)", device) mem_pool=mem_pool,
return manager, pool socket_path=socket_path,
else: device=device,
_manager = manager )
_tag = tag logger.info(
logger.info("[GMS] Created RO allocator (device=%d)", device) "[GMS] Created %s allocator for tag=%s (device=%d)",
return manager, None manager.granted_lock_type.value,
tag,
device,
def _get_existing( )
mode: RequestedLockType, return manager
) -> Tuple["GMSClientMemoryManager", Optional["MemPool"]]:
"""Return existing allocator if mode-compatible."""
assert _manager is not None
current = _manager.granted_lock_type
if mode == RequestedLockType.RW:
if current == GrantedLockType.RW:
return _manager, _mem_pool
raise RuntimeError(f"Cannot get RW allocator: existing is in {current} mode")
if mode == RequestedLockType.RO: def get_gms_client_memory_manager(
if current == GrantedLockType.RO: tag: str = "weights",
return _manager, None ) -> "GMSClientMemoryManager | None":
raise RuntimeError(f"Cannot get RO allocator: existing is in {current} mode") state = _tag_states.get(tag)
if state is None:
return None
return state.manager
def get_gms_client_memory_managers() -> tuple["GMSClientMemoryManager", ...]:
return tuple(state.manager for state in _tag_states.values())
def evict_gms_client_memory_manager(manager: "GMSClientMemoryManager") -> None:
for tag, state in list(_tag_states.items()):
if state.manager is manager:
_tag_states.pop(tag, None)
return
# RW_OR_RO: return whatever exists @contextmanager
effective_pool = _mem_pool if current == GrantedLockType.RW else None def gms_use_mem_pool(tag: str, device: "torch.device | int") -> Iterator[None]:
return _manager, effective_pool import torch
state = _tag_states.get(tag)
if state is None:
raise RuntimeError(f"No GMS allocator initialized for tag={tag}")
if state.mem_pool is None:
raise RuntimeError(f"GMS allocator tag={tag} does not have a mempool")
def get_gms_client_memory_manager() -> Optional["GMSClientMemoryManager"]: token = _active_tag.set(tag)
"""Get the active GMS client memory manager, or None.""" try:
return _manager with torch.cuda.use_mem_pool(state.mem_pool, device=device):
yield
finally:
_active_tag.reset(token)
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""CUDA driver helpers shared by the GMS client and server."""
from __future__ import annotations
import atexit
import os
from cuda.bindings import driver as cuda
from gpu_memory_service.common.types import GrantedLockType
from gpu_memory_service.common.utils import fail
_primary_contexts: dict[int, object] = {}
_primary_context_release_registered = False
def cuda_check_result(result: cuda.CUresult, name: str) -> None:
if result != cuda.CUresult.CUDA_SUCCESS:
err_result, err_str = cuda.cuGetErrorString(result)
if err_result == cuda.CUresult.CUDA_SUCCESS and err_str:
err_msg = err_str.decode() if isinstance(err_str, bytes) else str(err_str)
else:
err_msg = str(result)
fail("fatal CUDA VMM error in %s: %s", name, err_msg)
def cuda_ensure_initialized() -> None:
(result,) = cuda.cuInit(0)
cuda_check_result(result, "cuInit")
def cumem_get_allocation_granularity(device: int) -> int:
"""Get VMM allocation granularity for a device.
Args:
device: CUDA device index
Returns:
Allocation granularity in bytes (typically 2 MiB)
"""
prop = cuda.CUmemAllocationProp()
prop.type = cuda.CUmemAllocationType.CU_MEM_ALLOCATION_TYPE_PINNED
prop.location.type = cuda.CUmemLocationType.CU_MEM_LOCATION_TYPE_DEVICE
prop.location.id = device
prop.requestedHandleTypes = (
cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR
)
result, granularity = cuda.cuMemGetAllocationGranularity(
prop, cuda.CUmemAllocationGranularity_flags.CU_MEM_ALLOC_GRANULARITY_MINIMUM
)
cuda_check_result(result, "cuMemGetAllocationGranularity")
return int(granularity)
def cumem_create_tolerate_oom(size: int, device: int) -> tuple[bool, int]:
prop = cuda.CUmemAllocationProp()
prop.type = cuda.CUmemAllocationType.CU_MEM_ALLOCATION_TYPE_PINNED
prop.location.type = cuda.CUmemLocationType.CU_MEM_LOCATION_TYPE_DEVICE
prop.location.id = device
prop.requestedHandleTypes = (
cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR
)
result, handle = cuda.cuMemCreate(size, prop, 0)
if result == cuda.CUresult.CUDA_SUCCESS:
return True, int(handle)
if result == cuda.CUresult.CUDA_ERROR_OUT_OF_MEMORY:
return False, 0
cuda_check_result(result, "cuMemCreate")
return False, 0
def cumem_export_to_shareable_handle(handle: int) -> int:
result, fd = cuda.cuMemExportToShareableHandle(
handle,
cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR,
0,
)
cuda_check_result(result, "cuMemExportToShareableHandle")
return int(fd)
def align_to_granularity(size: int, granularity: int) -> int:
"""Align size up to VMM granularity.
Args:
size: Size in bytes
granularity: Allocation granularity
Returns:
Aligned size
"""
return ((size + granularity - 1) // granularity) * granularity
def cumem_import_from_shareable_handle_close_fd(fd: int) -> int:
try:
result, handle = cuda.cuMemImportFromShareableHandle(
fd,
cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR,
)
cuda_check_result(result, "cuMemImportFromShareableHandle")
return int(handle)
finally:
os.close(fd)
def cumem_address_reserve(size: int, granularity: int) -> int:
result, va = cuda.cuMemAddressReserve(size, granularity, 0, 0)
cuda_check_result(result, "cuMemAddressReserve")
return int(va)
def cumem_address_free(va: int, size: int) -> None:
(result,) = cuda.cuMemAddressFree(va, size)
cuda_check_result(result, "cuMemAddressFree")
def cumem_map(va: int, size: int, handle: int) -> None:
(result,) = cuda.cuMemMap(va, size, 0, handle, 0)
cuda_check_result(result, "cuMemMap")
def cumem_set_access(va: int, size: int, device: int, access: GrantedLockType) -> None:
access_desc = cuda.CUmemAccessDesc()
access_desc.location.type = cuda.CUmemLocationType.CU_MEM_LOCATION_TYPE_DEVICE
access_desc.location.id = device
access_desc.flags = (
cuda.CUmemAccess_flags.CU_MEM_ACCESS_FLAGS_PROT_READ
if access == GrantedLockType.RO
else cuda.CUmemAccess_flags.CU_MEM_ACCESS_FLAGS_PROT_READWRITE
)
(result,) = cuda.cuMemSetAccess(va, size, [access_desc], 1)
cuda_check_result(result, "cuMemSetAccess")
def cumem_unmap(va: int, size: int) -> None:
(result,) = cuda.cuMemUnmap(va, size)
cuda_check_result(result, "cuMemUnmap")
def cumem_release(handle: int) -> None:
(result,) = cuda.cuMemRelease(handle)
cuda_check_result(result, "cuMemRelease")
def cuda_validate_pointer(va: int) -> None:
result, _ = cuda.cuPointerGetAttribute(
cuda.CUpointer_attribute.CU_POINTER_ATTRIBUTE_DEVICE_POINTER, va
)
cuda_check_result(result, "cuPointerGetAttribute")
def cuda_synchronize() -> None:
(result,) = cuda.cuCtxSynchronize()
cuda_check_result(result, "cuCtxSynchronize")
def cuda_set_current_device(device: int) -> None:
global _primary_context_release_registered
ctx = _primary_contexts.get(device)
if ctx is None:
result, ctx = cuda.cuDevicePrimaryCtxRetain(device)
cuda_check_result(result, "cuDevicePrimaryCtxRetain")
_primary_contexts[device] = ctx
if not _primary_context_release_registered:
_primary_context_release_registered = True
atexit.register(_release_primary_contexts)
(result,) = cuda.cuCtxSetCurrent(ctx)
cuda_check_result(result, "cuCtxSetCurrent")
def _release_primary_contexts() -> None:
for device in list(_primary_contexts):
try:
(result,) = cuda.cuDevicePrimaryCtxRelease(device)
except Exception:
continue
if result == cuda.CUresult.CUDA_SUCCESS:
_primary_contexts.pop(device, None)
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""CUDA Virtual Memory Management (VMM) utility functions.
This module provides utility functions for CUDA driver API operations
used by both server (GMSServerMemoryManager) and client (GMSClientMemoryManager).
"""
from cuda.bindings import driver as cuda
def check_cuda_result(result: cuda.CUresult, name: str) -> None:
"""Check CUDA driver API result and raise on error.
Args:
result: CUDA driver API return code (CUresult enum)
name: Operation name for error message
Raises:
RuntimeError: If result is not CUDA_SUCCESS
"""
if result != cuda.CUresult.CUDA_SUCCESS:
err_result, err_str = cuda.cuGetErrorString(result)
if err_result == cuda.CUresult.CUDA_SUCCESS and err_str:
err_msg = err_str.decode() if isinstance(err_str, bytes) else str(err_str)
else:
err_msg = str(result)
raise RuntimeError(f"{name}: {err_msg}")
def ensure_cuda_initialized() -> None:
"""Ensure CUDA driver is initialized.
Raises:
RuntimeError: If cuInit fails
"""
(result,) = cuda.cuInit(0)
check_cuda_result(result, "cuInit")
def get_allocation_granularity(device: int) -> int:
"""Get VMM allocation granularity for a device.
Args:
device: CUDA device index
Returns:
Allocation granularity in bytes (typically 2 MiB)
"""
prop = cuda.CUmemAllocationProp()
prop.type = cuda.CUmemAllocationType.CU_MEM_ALLOCATION_TYPE_PINNED
prop.location.type = cuda.CUmemLocationType.CU_MEM_LOCATION_TYPE_DEVICE
prop.location.id = device
prop.requestedHandleTypes = (
cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR
)
result, granularity = cuda.cuMemGetAllocationGranularity(
prop, cuda.CUmemAllocationGranularity_flags.CU_MEM_ALLOC_GRANULARITY_MINIMUM
)
check_cuda_result(result, "cuMemGetAllocationGranularity")
return int(granularity)
def align_to_granularity(size: int, granularity: int) -> int:
"""Align size up to VMM granularity.
Args:
size: Size in bytes
granularity: Allocation granularity
Returns:
Aligned size
"""
return ((size + granularity - 1) // granularity) * granularity
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
"""Message types for GPU Memory Service RPC protocol.""" """Message types for GPU Memory Service RPC protocol."""
from enum import Enum from enum import Enum
from typing import Any, Dict, List, Optional, Union from typing import List, Optional, Union
import msgspec import msgspec
...@@ -62,7 +62,6 @@ class GetAllocationStateRequest(msgspec.Struct, tag="get_allocation_state_reques ...@@ -62,7 +62,6 @@ class GetAllocationStateRequest(msgspec.Struct, tag="get_allocation_state_reques
class GetAllocationStateResponse(msgspec.Struct, tag="get_allocation_state_response"): class GetAllocationStateResponse(msgspec.Struct, tag="get_allocation_state_response"):
allocation_count: int allocation_count: int
total_bytes: int
class AllocateRequest(msgspec.Struct, tag="allocate_request"): class AllocateRequest(msgspec.Struct, tag="allocate_request"):
...@@ -74,12 +73,21 @@ class AllocateResponse(msgspec.Struct, tag="allocate_response"): ...@@ -74,12 +73,21 @@ class AllocateResponse(msgspec.Struct, tag="allocate_response"):
allocation_id: str allocation_id: str
size: int size: int
aligned_size: int aligned_size: int
layout_slot: int
class ExportRequest(msgspec.Struct, tag="export_request"): class ExportAllocationRequest(msgspec.Struct, tag="export_allocation_request"):
allocation_id: str allocation_id: str
class ExportAllocationResponse(msgspec.Struct, tag="export_allocation_response"):
allocation_id: str
size: int
aligned_size: int
tag: str
layout_slot: int
class GetAllocationRequest(msgspec.Struct, tag="get_allocation_request"): class GetAllocationRequest(msgspec.Struct, tag="get_allocation_request"):
allocation_id: str allocation_id: str
...@@ -89,6 +97,7 @@ class GetAllocationResponse(msgspec.Struct, tag="get_allocation_response"): ...@@ -89,6 +97,7 @@ class GetAllocationResponse(msgspec.Struct, tag="get_allocation_response"):
size: int size: int
aligned_size: int aligned_size: int
tag: str tag: str
layout_slot: int
class ListAllocationsRequest(msgspec.Struct, tag="list_allocations_request"): class ListAllocationsRequest(msgspec.Struct, tag="list_allocations_request"):
...@@ -96,25 +105,17 @@ class ListAllocationsRequest(msgspec.Struct, tag="list_allocations_request"): ...@@ -96,25 +105,17 @@ class ListAllocationsRequest(msgspec.Struct, tag="list_allocations_request"):
class ListAllocationsResponse(msgspec.Struct, tag="list_allocations_response"): class ListAllocationsResponse(msgspec.Struct, tag="list_allocations_response"):
allocations: List[Dict[str, Any]] = [] allocations: List[GetAllocationResponse] = []
class FreeRequest(msgspec.Struct, tag="free_request"): class FreeAllocationRequest(msgspec.Struct, tag="free_allocation_request"):
allocation_id: str allocation_id: str
class FreeResponse(msgspec.Struct, tag="free_response"): class FreeAllocationResponse(msgspec.Struct, tag="free_allocation_response"):
success: bool success: bool
class ClearAllRequest(msgspec.Struct, tag="clear_all_request"):
pass
class ClearAllResponse(msgspec.Struct, tag="clear_all_response"):
cleared_count: int
class ErrorResponse(msgspec.Struct, tag="error_response"): class ErrorResponse(msgspec.Struct, tag="error_response"):
error: str error: str
code: int = 0 code: int = 0
...@@ -166,6 +167,34 @@ class GetStateHashResponse(msgspec.Struct, tag="get_memory_layout_hash_response" ...@@ -166,6 +167,34 @@ class GetStateHashResponse(msgspec.Struct, tag="get_memory_layout_hash_response"
memory_layout_hash: str # Hash of allocations + metadata, empty if not committed memory_layout_hash: str # Hash of allocations + metadata, empty if not committed
class GetRuntimeStateRequest(msgspec.Struct, tag="get_runtime_state_request"):
pass
class GetRuntimeStateResponse(msgspec.Struct, tag="get_runtime_state_response"):
state: str
has_rw_session: bool
ro_session_count: int
waiting_writers: int
committed: bool
is_ready: bool
allocation_count: int = 0
memory_layout_hash: str = ""
class GMSRuntimeEvent(msgspec.Struct):
kind: str
allocation_count: int = 0
class GetEventHistoryRequest(msgspec.Struct, tag="get_event_history_request"):
pass
class GetEventHistoryResponse(msgspec.Struct, tag="get_event_history_response"):
events: List[GMSRuntimeEvent] = []
Message = Union[ Message = Union[
HandshakeRequest, HandshakeRequest,
HandshakeResponse, HandshakeResponse,
...@@ -177,15 +206,14 @@ Message = Union[ ...@@ -177,15 +206,14 @@ Message = Union[
GetAllocationStateResponse, GetAllocationStateResponse,
AllocateRequest, AllocateRequest,
AllocateResponse, AllocateResponse,
ExportRequest, ExportAllocationRequest,
ExportAllocationResponse,
GetAllocationRequest, GetAllocationRequest,
GetAllocationResponse, GetAllocationResponse,
ListAllocationsRequest, ListAllocationsRequest,
ListAllocationsResponse, ListAllocationsResponse,
FreeRequest, FreeAllocationRequest,
FreeResponse, FreeAllocationResponse,
ClearAllRequest,
ClearAllResponse,
ErrorResponse, ErrorResponse,
MetadataPutRequest, MetadataPutRequest,
MetadataPutResponse, MetadataPutResponse,
...@@ -197,6 +225,10 @@ Message = Union[ ...@@ -197,6 +225,10 @@ Message = Union[
MetadataListResponse, MetadataListResponse,
GetStateHashRequest, GetStateHashRequest,
GetStateHashResponse, GetStateHashResponse,
GetRuntimeStateRequest,
GetRuntimeStateResponse,
GetEventHistoryRequest,
GetEventHistoryResponse,
] ]
_encoder = msgspec.msgpack.Encoder() _encoder = msgspec.msgpack.Encoder()
......
...@@ -96,7 +96,11 @@ async def recv_message( ...@@ -96,7 +96,11 @@ async def recv_message(
raw_msg, fds, _flags, _addr = await loop.run_in_executor( raw_msg, fds, _flags, _addr = await loop.run_in_executor(
None, lambda: socket.recv_fds(raw_sock, 65536, 1) None, lambda: socket.recv_fds(raw_sock, 65536, 1)
) )
for extra_fd in fds[1:]:
os.close(extra_fd)
if not raw_msg: if not raw_msg:
if fds:
os.close(fds[0])
raise ConnectionResetError("Connection closed") raise ConnectionResetError("Connection closed")
recv_buffer.extend(raw_msg) recv_buffer.extend(raw_msg)
fd = fds[0] if fds else -1 fd = fds[0] if fds else -1
...@@ -107,6 +111,7 @@ async def recv_message( ...@@ -107,6 +111,7 @@ async def recv_message(
recv_buffer.extend(chunk) recv_buffer.extend(chunk)
# Try to extract message, read more if needed # Try to extract message, read more if needed
try:
msg, remaining, bytes_needed = _try_extract_message(recv_buffer) msg, remaining, bytes_needed = _try_extract_message(recv_buffer)
while msg is None and bytes_needed > 0: while msg is None and bytes_needed > 0:
if raw_sock is not None: if raw_sock is not None:
...@@ -120,8 +125,11 @@ async def recv_message( ...@@ -120,8 +125,11 @@ async def recv_message(
raise ConnectionResetError("Connection closed") raise ConnectionResetError("Connection closed")
remaining.extend(chunk) remaining.extend(chunk)
msg, remaining, bytes_needed = _try_extract_message(remaining) msg, remaining, bytes_needed = _try_extract_message(remaining)
return msg, fd, remaining return msg, fd, remaining
except Exception:
if fd >= 0:
os.close(fd)
raise
# ==================== Sync (for client) ==================== # ==================== Sync (for client) ====================
...@@ -153,12 +161,17 @@ def recv_message_sync( ...@@ -153,12 +161,17 @@ def recv_message_sync(
# Receive more data (with potential FD) # Receive more data (with potential FD)
raw_msg, fds, _flags, _addr = socket.recv_fds(sock, 65536, 1) raw_msg, fds, _flags, _addr = socket.recv_fds(sock, 65536, 1)
for extra_fd in fds[1:]:
os.close(extra_fd)
if not raw_msg: if not raw_msg:
if fds:
os.close(fds[0])
raise ConnectionResetError("Connection closed") raise ConnectionResetError("Connection closed")
recv_buffer.extend(raw_msg) recv_buffer.extend(raw_msg)
fd = fds[0] if fds else -1 fd = fds[0] if fds else -1
# Try to extract message, read more if needed # Try to extract message, read more if needed
try:
msg, remaining, bytes_needed = _try_extract_message(recv_buffer) msg, remaining, bytes_needed = _try_extract_message(recv_buffer)
while msg is None and bytes_needed > 0: while msg is None and bytes_needed > 0:
chunk = sock.recv(bytes_needed) chunk = sock.recv(bytes_needed)
...@@ -166,5 +179,8 @@ def recv_message_sync( ...@@ -166,5 +179,8 @@ def recv_message_sync(
raise ConnectionResetError("Connection closed") raise ConnectionResetError("Connection closed")
remaining.extend(chunk) remaining.extend(chunk)
msg, remaining, bytes_needed = _try_extract_message(remaining) msg, remaining, bytes_needed = _try_extract_message(remaining)
return msg, fd, remaining return msg, fd, remaining
except Exception:
if fd >= 0:
os.close(fd)
raise
...@@ -8,10 +8,9 @@ from enum import Enum, auto ...@@ -8,10 +8,9 @@ from enum import Enum, auto
from gpu_memory_service.common.protocol.messages import ( from gpu_memory_service.common.protocol.messages import (
AllocateRequest, AllocateRequest,
ClearAllRequest,
CommitRequest, CommitRequest,
ExportRequest, ExportAllocationRequest,
FreeRequest, FreeAllocationRequest,
GetAllocationRequest, GetAllocationRequest,
GetAllocationStateRequest, GetAllocationStateRequest,
GetLockStateRequest, GetLockStateRequest,
...@@ -89,8 +88,7 @@ def derive_state(has_rw: bool, ro_count: int, committed: bool) -> ServerState: ...@@ -89,8 +88,7 @@ def derive_state(has_rw: bool, ro_count: int, committed: bool) -> ServerState:
RW_REQUIRED: frozenset[type] = frozenset( RW_REQUIRED: frozenset[type] = frozenset(
{ {
AllocateRequest, AllocateRequest,
FreeRequest, FreeAllocationRequest,
ClearAllRequest,
MetadataPutRequest, MetadataPutRequest,
MetadataDeleteRequest, MetadataDeleteRequest,
CommitRequest, CommitRequest,
...@@ -99,7 +97,7 @@ RW_REQUIRED: frozenset[type] = frozenset( ...@@ -99,7 +97,7 @@ RW_REQUIRED: frozenset[type] = frozenset(
RO_ALLOWED: frozenset[type] = frozenset( RO_ALLOWED: frozenset[type] = frozenset(
{ {
ExportRequest, ExportAllocationRequest,
GetAllocationRequest, GetAllocationRequest,
ListAllocationsRequest, ListAllocationsRequest,
MetadataGetRequest, MetadataGetRequest,
......
...@@ -3,36 +3,39 @@ ...@@ -3,36 +3,39 @@
"""Shared utilities for GPU Memory Service.""" """Shared utilities for GPU Memory Service."""
import logging
import os import os
import tempfile import tempfile
import uuid from typing import NoReturn
from cuda.bindings import driver as cuda logger = logging.getLogger(__name__)
from gpu_memory_service.common.cuda_vmm_utils import (
check_cuda_result,
ensure_cuda_initialized,
)
def get_socket_path(device: int) -> str: def fail(message: str, *args, exc_info=None) -> NoReturn:
"""Get GMS socket path for the given CUDA device. logger.critical(message, *args, exc_info=exc_info)
logging.shutdown()
os._exit(1)
The socket path is based on GPU UUID resolved by CUDA.
CUDA_VISIBLE_DEVICES remapping is handled by CUDA device enumeration. def get_socket_path(device: int, tag: str = "weights") -> str:
"""Get GMS socket path for the given CUDA device and tag.
The socket path is based on GPU UUID, making it stable across different
CUDA_VISIBLE_DEVICES configurations.
Args: Args:
device: CUDA device index. device: CUDA device index.
Returns: Returns:
Socket path (e.g., "<tempdir>/gms_GPU-12345678-1234-1234-1234-123456789abc.sock"). Socket path
(e.g., "<tempdir>/gms_GPU-12345678-1234-1234-1234-123456789abc_weights.sock").
""" """
ensure_cuda_initialized() import pynvml
result, cu_device = cuda.cuDeviceGet(device) pynvml.nvmlInit()
check_cuda_result(result, "cuDeviceGet") try:
handle = pynvml.nvmlDeviceGetHandleByIndex(device)
result, cu_uuid = cuda.cuDeviceGetUuid(cu_device) uuid = pynvml.nvmlDeviceGetUUID(handle)
check_cuda_result(result, "cuDeviceGetUuid") finally:
pynvml.nvmlShutdown()
gpu_uuid = f"GPU-{uuid.UUID(bytes=bytes(cu_uuid.bytes))}" return os.path.join(tempfile.gettempdir(), f"gms_{uuid}_{tag}.sock")
return os.path.join(tempfile.gettempdir(), f"gms_{gpu_uuid}.sock")
...@@ -8,7 +8,7 @@ from __future__ import annotations ...@@ -8,7 +8,7 @@ from __future__ import annotations
import logging import logging
import torch import torch
from gpu_memory_service import get_gms_client_memory_manager from gpu_memory_service.client.torch.allocator import get_gms_client_memory_managers
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -32,11 +32,13 @@ def patch_empty_cache() -> None: ...@@ -32,11 +32,13 @@ def patch_empty_cache() -> None:
_original_empty_cache = torch.cuda.empty_cache _original_empty_cache = torch.cuda.empty_cache
def safe_empty_cache() -> None: def safe_empty_cache() -> None:
manager = get_gms_client_memory_manager() mapping_count = sum(
if manager is not None and len(manager.mappings) > 0: len(manager.mappings) for manager in get_gms_client_memory_managers()
)
if mapping_count > 0:
logger.debug( logger.debug(
"[GMS] Skipping torch.cuda.empty_cache() - %d VMM allocations active", "[GMS] Skipping torch.cuda.empty_cache() - %d VMM allocations active",
len(manager.mappings), mapping_count,
) )
return return
_original_empty_cache() _original_empty_cache()
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
from dataclasses import replace
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
import torch import torch
...@@ -29,6 +30,20 @@ def get_gms_lock_mode(extra_config: dict): ...@@ -29,6 +30,20 @@ def get_gms_lock_mode(extra_config: dict):
return RequestedLockType.RW_OR_RO return RequestedLockType.RW_OR_RO
def strip_gms_model_loader_config(load_config, load_format: str):
"""Copy a loader config with GMS-only keys removed for backend loaders."""
extra_config = getattr(load_config, "model_loader_extra_config", {}) or {}
return replace(
load_config,
load_format=load_format,
model_loader_extra_config={
key: value
for key, value in extra_config.items()
if not key.startswith("gms_")
},
)
def setup_meta_tensor_workaround() -> None: def setup_meta_tensor_workaround() -> None:
"""Enable workaround for meta tensor operations like torch.nonzero().""" """Enable workaround for meta tensor operations like torch.nonzero()."""
try: try:
...@@ -42,9 +57,9 @@ def setup_meta_tensor_workaround() -> None: ...@@ -42,9 +57,9 @@ def setup_meta_tensor_workaround() -> None:
def finalize_gms_write( def finalize_gms_write(
allocator: "GMSClientMemoryManager", model: torch.nn.Module allocator: "GMSClientMemoryManager", model: torch.nn.Module
) -> int: ) -> int:
"""Finalize GMS write mode: register tensors, commit, switch to read. """Finalize GMS write mode: register tensors, commit, reconnect in read mode.
Flow: register tensors -> sync -> commit (server-only) -> disconnect -> connect(RO) Flow: register tensors -> sync -> unmap + commit -> connect(RO) -> remap
Args: Args:
allocator: The GMS client memory manager in write mode. allocator: The GMS client memory manager in write mode.
...@@ -52,9 +67,6 @@ def finalize_gms_write( ...@@ -52,9 +67,6 @@ def finalize_gms_write(
Returns: Returns:
Total bytes committed. Total bytes committed.
Raises:
RuntimeError: If commit fails.
""" """
from gpu_memory_service.client.torch.module import register_module_tensors from gpu_memory_service.client.torch.module import register_module_tensors
from gpu_memory_service.common.types import RequestedLockType from gpu_memory_service.common.types import RequestedLockType
...@@ -65,12 +77,10 @@ def finalize_gms_write( ...@@ -65,12 +77,10 @@ def finalize_gms_write(
# Synchronize before commit — caller's writes must be visible # Synchronize before commit — caller's writes must be visible
torch.cuda.synchronize() torch.cuda.synchronize()
if not allocator.commit(): allocator.commit()
raise RuntimeError("GMS commit failed")
# commit() closed the RW socket; acquire RO for inference
allocator.disconnect() # no-op if commit already cleared _client, but safe
allocator.connect(RequestedLockType.RO) allocator.connect(RequestedLockType.RO)
allocator.remap_all_vas()
logger.info( logger.info(
"[GMS] Committed %.2f GiB, switched to read mode with %d mappings", "[GMS] Committed %.2f GiB, switched to read mode with %d mappings",
......
...@@ -3,13 +3,10 @@ ...@@ -3,13 +3,10 @@
"""Hybrid torch_memory_saver implementation for GPU Memory Service. """Hybrid torch_memory_saver implementation for GPU Memory Service.
This module provides a hybrid implementation that combines: This module uses:
1. GPU Memory Service allocator for "weights" tag (VA-stable unmap/remap, shared) 1. GPU Memory Service for "weights" (shared RO/RW publish flow)
2. Torch mempool mode for other tags like "kv_cache" (CPU backup, per-instance) 2. GPU Memory Service for "kv_cache" (RW-only failover flow)
3. torch_memory_saver for any remaining tags
The impl uses RW_OR_RO mode to connect to GMS:
- First process gets RW lock and loads weights from disk
- Subsequent processes get RO lock and import weights from metadata
""" """
from __future__ import annotations from __future__ import annotations
...@@ -19,10 +16,13 @@ from contextlib import contextmanager ...@@ -19,10 +16,13 @@ from contextlib import contextmanager
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, Optional
import torch import torch
from gpu_memory_service import get_or_create_gms_client_memory_manager
from gpu_memory_service.client.torch.allocator import gms_use_mem_pool
from gpu_memory_service.common.types import GrantedLockType, RequestedLockType
from gpu_memory_service.common.utils import get_socket_path
if TYPE_CHECKING: if TYPE_CHECKING:
from gpu_memory_service.client.memory_manager import GMSClientMemoryManager from gpu_memory_service.client.memory_manager import GMSClientMemoryManager
from torch.cuda.memory import MemPool
from torch_memory_saver.entrypoint import _TorchMemorySaverImpl from torch_memory_saver.entrypoint import _TorchMemorySaverImpl
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -39,56 +39,54 @@ def get_gms_memory_saver_impl() -> Optional["GMSMemorySaverImpl"]: ...@@ -39,56 +39,54 @@ def get_gms_memory_saver_impl() -> Optional["GMSMemorySaverImpl"]:
class GMSMemorySaverImpl: class GMSMemorySaverImpl:
"""Hybrid implementation: GMS for weights, torch mempool for KV cache. """Hybrid implementation: GMS for weights and KV cache."""
Routes operations based on tag:
- "weights" or "model_weights": Handled by GMS allocator (VA-stable)
- Other tags (e.g., "kv_cache"): Delegated to torch mempool mode
"""
def __init__( def __init__(
self, self,
torch_impl: "_TorchMemorySaverImpl", torch_impl: "_TorchMemorySaverImpl",
socket_path: str,
device_index: int, device_index: int,
mode=None, mode=None,
): ):
self._torch_impl = torch_impl self._torch_impl = torch_impl
self._socket_path = socket_path
self._device_index = device_index self._device_index = device_index
self._requested_mode = mode self._requested_mode = mode
self._disabled = False self._disabled = False
self._imported_weights_bytes: int = 0 self._imported_weights_bytes: int = 0
self._allocator: Optional["GMSClientMemoryManager"] self._weights_allocator: Optional["GMSClientMemoryManager"]
self._mem_pool: Optional["MemPool"] self._kv_cache_allocator: "GMSClientMemoryManager"
self._mode: str self._mode: str
self._allocator, self._mem_pool, self._mode = self._init_allocator() (
self._weights_allocator,
self._kv_cache_allocator,
self._mode,
) = self._init_allocators()
logger.info( logger.info(
"[GMS] Initialized: weights=%s mode (device=%d, socket=%s)", "[GMS] Initialized weights=%s mode, kv_cache=RW (device=%d)",
self._mode.upper(), self._mode.upper(),
device_index, device_index,
socket_path,
) )
def _init_allocator( def _init_allocators(
self, self,
) -> tuple[Optional["GMSClientMemoryManager"], Optional["MemPool"], str]: ) -> tuple[Optional["GMSClientMemoryManager"], "GMSClientMemoryManager", str,]:
"""Create allocator with mode from config (default: RW_OR_RO).""" """Create allocator with mode from config (default: RW_OR_RO)."""
from gpu_memory_service import get_or_create_gms_client_memory_manager
from gpu_memory_service.common.types import GrantedLockType, RequestedLockType
mode = self._requested_mode or RequestedLockType.RW_OR_RO mode = self._requested_mode or RequestedLockType.RW_OR_RO
allocator, mem_pool = get_or_create_gms_client_memory_manager( weights_allocator = get_or_create_gms_client_memory_manager(
self._socket_path, get_socket_path(self._device_index, "weights"),
self._device_index, self._device_index,
mode=mode, mode=mode,
tag="weights", tag="weights",
) )
granted_mode = allocator.granted_lock_type kv_cache_allocator = get_or_create_gms_client_memory_manager(
get_socket_path(self._device_index, "kv_cache"),
self._device_index,
mode=RequestedLockType.RW,
tag="kv_cache",
)
granted_mode = weights_allocator.granted_lock_type
if granted_mode == GrantedLockType.RW: if granted_mode == GrantedLockType.RW:
allocator.clear_all_handles()
actual_mode = "write" actual_mode = "write"
else: else:
actual_mode = "read" actual_mode = "read"
...@@ -97,11 +95,7 @@ class GMSMemorySaverImpl: ...@@ -97,11 +95,7 @@ class GMSMemorySaverImpl:
actual_mode.upper(), actual_mode.upper(),
self._device_index, self._device_index,
) )
return ( return weights_allocator, kv_cache_allocator, actual_mode
allocator,
mem_pool if granted_mode == GrantedLockType.RW else None,
actual_mode,
)
def _is_weights_tag(self, tag: Optional[str]) -> bool: def _is_weights_tag(self, tag: Optional[str]) -> bool:
return tag in ("weights", "model_weights") return tag in ("weights", "model_weights")
...@@ -110,25 +104,28 @@ class GMSMemorySaverImpl: ...@@ -110,25 +104,28 @@ class GMSMemorySaverImpl:
return self._mode return self._mode
def get_allocator(self) -> Optional["GMSClientMemoryManager"]: def get_allocator(self) -> Optional["GMSClientMemoryManager"]:
return self._allocator return self._weights_allocator
@contextmanager @contextmanager
def region(self, tag: str, enable_cpu_backup: bool): def region(self, tag: str, enable_cpu_backup: bool):
"""Mark allocation region with tag.""" """Mark allocation region with tag."""
if not self._is_weights_tag(tag): if self._is_weights_tag(tag):
with self._torch_impl.region(tag=tag, enable_cpu_backup=enable_cpu_backup): if self._mode == "read":
yield yield
return return
if self._mode == "read": target_device = torch.device("cuda", self._device_index)
with gms_use_mem_pool("weights", target_device):
yield yield
return return
if self._mem_pool is None: if tag == "kv_cache":
raise RuntimeError("GMS mempool is None in WRITE mode")
target_device = torch.device("cuda", self._device_index) target_device = torch.device("cuda", self._device_index)
with torch.cuda.use_mem_pool(self._mem_pool, device=target_device): with gms_use_mem_pool("kv_cache", target_device):
yield
return
with self._torch_impl.region(tag=tag, enable_cpu_backup=enable_cpu_backup):
yield yield
def pause(self, tag: Optional[str] = None) -> None: def pause(self, tag: Optional[str] = None) -> None:
...@@ -136,7 +133,9 @@ class GMSMemorySaverImpl: ...@@ -136,7 +133,9 @@ class GMSMemorySaverImpl:
return return
if tag is None or self._is_weights_tag(tag): if tag is None or self._is_weights_tag(tag):
self._pause_weights() self._pause_weights()
if tag is None or not self._is_weights_tag(tag): if tag is None or tag == "kv_cache":
self._pause_kv_cache()
if tag is None or (not self._is_weights_tag(tag) and tag != "kv_cache"):
self._torch_impl.pause(tag=tag) self._torch_impl.pause(tag=tag)
def resume(self, tag: Optional[str] = None) -> None: def resume(self, tag: Optional[str] = None) -> None:
...@@ -144,39 +143,56 @@ class GMSMemorySaverImpl: ...@@ -144,39 +143,56 @@ class GMSMemorySaverImpl:
return return
if tag is None or self._is_weights_tag(tag): if tag is None or self._is_weights_tag(tag):
self._resume_weights() self._resume_weights()
if tag is None or not self._is_weights_tag(tag): if tag is None or tag == "kv_cache":
self._resume_kv_cache()
if tag is None or (not self._is_weights_tag(tag) and tag != "kv_cache"):
self._torch_impl.resume(tag=tag) self._torch_impl.resume(tag=tag)
def _pause_weights(self) -> None: def _pause_weights(self) -> None:
if self._allocator is None: if self._weights_allocator is None:
return return
if self._allocator.is_unmapped: if self._weights_allocator.is_unmapped:
return return
logger.info("[GMS] Unmapping weights (VA-stable)") logger.info("[GMS] Unmapping weights (VA-stable)")
self._allocator.unmap_all_vas() self._weights_allocator.unmap_all_vas()
self._allocator.disconnect() self._weights_allocator.abort()
def _resume_weights(self) -> None: def _resume_weights(self) -> None:
if self._allocator is None: if self._weights_allocator is None:
return return
if not self._allocator.is_unmapped: if not self._weights_allocator.is_unmapped:
return return
logger.info("[GMS] Remapping weights (VA-stable)") logger.info("[GMS] Remapping weights (VA-stable)")
from gpu_memory_service.common.types import RequestedLockType self._weights_allocator.connect(RequestedLockType.RO)
self._weights_allocator.remap_all_vas()
self._allocator.connect(RequestedLockType.RO) def _pause_kv_cache(self) -> None:
self._allocator.remap_all_vas() if self._kv_cache_allocator.is_unmapped:
return
logger.info("[GMS] Unmapping KV cache")
self._kv_cache_allocator.unmap_all_vas()
self._kv_cache_allocator.abort()
def _resume_kv_cache(self) -> None:
if not self._kv_cache_allocator.is_unmapped:
return
logger.info("[GMS] Remapping KV cache")
self._kv_cache_allocator.connect(RequestedLockType.RW)
self._kv_cache_allocator.reallocate_all_handles(tag="kv_cache")
self._kv_cache_allocator.remap_all_vas()
def finalize_write_mode(self, model: torch.nn.Module) -> None: def finalize_write_mode(self, model: torch.nn.Module) -> None:
"""Finalize write mode: register tensors, commit, and switch to read.""" """Finalize write mode: register tensors, commit, and switch to read."""
if self._mode != "write": if self._mode != "write":
return return
if self._allocator is None: if self._weights_allocator is None:
raise RuntimeError("Allocator is None in WRITE mode") raise RuntimeError("Allocator is None in WRITE mode")
from gpu_memory_service.integrations.common.utils import finalize_gms_write from gpu_memory_service.integrations.common.utils import finalize_gms_write
self._imported_weights_bytes = finalize_gms_write(self._allocator, model) self._imported_weights_bytes = finalize_gms_write(
self._weights_allocator, model
)
self._mode = "read" self._mode = "read"
def set_imported_weights_bytes(self, bytes_count: int) -> None: def set_imported_weights_bytes(self, bytes_count: int) -> None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment