".buildkite/vscode:/vscode.git/clone" did not exist on "36db0a35e45f32f7c37f6f1967dc8d6ff301d882"
Unverified Commit 30c6228b authored by Schwinn Saereesitthipitak's avatar Schwinn Saereesitthipitak Committed by GitHub
Browse files

feat: GPU Memory Service (#5286)


Signed-off-by: default avatarSchwinn Saereesitthipitak <17022745+galletas1712@users.noreply.github.com>
parent cde3b2a5
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""GPU Memory Service allocator singleton management.
Manages the singleton memory manager and PyTorch MemPool integration.
"""
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Any, Optional, Tuple
from gpu_memory_service.common.types import GrantedLockType, RequestedLockType
if TYPE_CHECKING:
from gpu_memory_service.client.memory_manager import GMSClientMemoryManager
from torch.cuda.memory import MemPool
logger = logging.getLogger(__name__)
# Global singleton state
_gms_client_memory_manager: Optional["GMSClientMemoryManager"] = None
_mem_pool: Optional["MemPool"] = None
_pluggable_alloc: Optional[Any] = None
def get_or_create_gms_client_memory_manager(
socket_path: str,
device: int,
mode: RequestedLockType,
*,
tag: str = "weights",
timeout_ms: Optional[int] = None,
) -> Tuple["GMSClientMemoryManager", Optional["MemPool"]]:
"""Get existing memory manager or create a new one.
Args:
socket_path: Unix socket path for the allocation server.
device: CUDA device index.
mode: RW for cold start, RO for import-only, RW_OR_RO for auto.
tag: Allocation tag for RW mode.
timeout_ms: Lock acquisition timeout (None = wait indefinitely).
Returns:
(gms_client_memory_manager, pool) - pool is None for RO mode.
"""
global _gms_client_memory_manager, _mem_pool
from gpu_memory_service.client.memory_manager import GMSClientMemoryManager
if _gms_client_memory_manager is not None:
return _get_existing(mode)
# Create new manager
gms_client_memory_manager = GMSClientMemoryManager(
socket_path, mode=mode, device=device, timeout_ms=timeout_ms
)
_gms_client_memory_manager = gms_client_memory_manager
if gms_client_memory_manager.mode == GrantedLockType.RW:
_mem_pool = _setup_mempool(gms_client_memory_manager, tag)
logger.info("[GMS] Created RW allocator (device=%d)", device)
return gms_client_memory_manager, _mem_pool
else:
logger.info("[GMS] Created RO allocator (device=%d)", device)
return gms_client_memory_manager, None
def _get_existing(
mode: RequestedLockType,
) -> Tuple["GMSClientMemoryManager", Optional["MemPool"]]:
"""Return existing allocator if mode-compatible."""
current = _gms_client_memory_manager.mode
if mode == RequestedLockType.RW:
if current == GrantedLockType.RW:
return _gms_client_memory_manager, _mem_pool
raise RuntimeError(f"Cannot get RW allocator: existing is in {current} mode")
if mode == RequestedLockType.RO:
if current == GrantedLockType.RO:
return _gms_client_memory_manager, None
raise RuntimeError(
f"Cannot get RO allocator: existing is in {current} mode. "
"Call manager.switch_to_read() first."
)
# RW_OR_RO: return whatever exists
pool = _mem_pool if current == GrantedLockType.RW else None
return _gms_client_memory_manager, pool
def _setup_mempool(
gms_client_memory_manager: "GMSClientMemoryManager",
tag: str,
) -> "MemPool":
"""Set up PyTorch CUDAPluggableAllocator and MemPool."""
global _pluggable_alloc
from gpu_memory_service.client.torch.extensions import _allocator_ext as cumem
from torch.cuda import CUDAPluggableAllocator
from torch.cuda.memory import MemPool
pluggable_alloc = CUDAPluggableAllocator(cumem.__file__, "my_malloc", "my_free")
pool = MemPool(allocator=pluggable_alloc.allocator())
_pluggable_alloc = pluggable_alloc
def malloc_cb(size: int, device: int, stream: int) -> int:
va = gms_client_memory_manager.allocate_and_map(int(size), tag=tag)
logger.debug("[GMS] malloc: va=0x%x size=%d", va, size)
return va
def free_cb(ptr: int, size: int, device: int, stream: int) -> None:
logger.debug("[GMS] free: va=0x%x size=%d", ptr, size)
gms_client_memory_manager.free_mapping(int(ptr))
cumem.init_module(malloc_cb, free_cb)
return pool
def get_gms_client_memory_manager() -> Optional["GMSClientMemoryManager"]:
"""Get the active GMS client memory manager, or None if not initialized."""
return _gms_client_memory_manager
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""GPU Memory Service C++ extensions for PyTorch integration.
These extensions are built at install time using setuptools.
- _allocator_ext: CUDAPluggableAllocator backend (my_malloc/my_free)
"""
# Built by setup.py build_ext --inplace
# Import will fail until extensions are built
try:
from gpu_memory_service.client.torch.extensions import _allocator_ext # noqa: F401
from gpu_memory_service.client.torch.extensions._allocator_ext import * # noqa: F401, F403
except ImportError:
_allocator_ext = None # type: ignore
// SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
// Minimal CUDAPluggableAllocator shim for GPU Memory Service.
//
// This extension provides the my_malloc/my_free function pointers required by
// PyTorch's CUDAPluggableAllocator. All actual CUDA VMM operations are delegated
// to Python callbacks which use cuda.bindings.
//
// Note: The stream parameter is unused because CUDA VMM operations (cuMemMap,
// cuMemUnmap) are synchronous and globally visible - they don't have per-stream
// semantics like cudaMallocAsync. We keep the parameter to match PyTorch's
// CUDAPluggableAllocator interface signature.
#define PY_SSIZE_T_CLEAN
#include <Python.h>
#include <cstdint>
static PyObject* g_malloc_callback = nullptr;
static PyObject* g_free_callback = nullptr;
extern "C" {
void*
my_malloc(ssize_t size, int device, void* stream)
{
if (!g_malloc_callback) {
return nullptr;
}
PyGILState_STATE gstate = PyGILState_Ensure();
PyObject* args = Py_BuildValue("(niK)", size, device, (unsigned long long)stream);
PyObject* result = PyObject_CallObject(g_malloc_callback, args);
Py_DECREF(args);
void* ptr = nullptr;
if (result && PyLong_Check(result)) {
ptr = (void*)PyLong_AsUnsignedLongLong(result);
}
Py_XDECREF(result);
if (PyErr_Occurred()) {
PyErr_Print();
}
PyGILState_Release(gstate);
return ptr;
}
void
my_free(void* ptr, ssize_t size, int device, void* stream)
{
if (!g_free_callback) {
return;
}
PyGILState_STATE gstate = PyGILState_Ensure();
PyObject* args = Py_BuildValue("(KniK)", (unsigned long long)ptr, size, device, (unsigned long long)stream);
PyObject* result = PyObject_CallObject(g_free_callback, args);
Py_DECREF(args);
Py_XDECREF(result);
if (PyErr_Occurred()) {
PyErr_Print();
}
PyGILState_Release(gstate);
}
static PyObject*
py_init_module(PyObject* self, PyObject* args)
{
PyObject* malloc_cb = nullptr;
PyObject* free_cb = nullptr;
if (!PyArg_ParseTuple(args, "OO", &malloc_cb, &free_cb)) {
return nullptr;
}
if (!PyCallable_Check(malloc_cb) || !PyCallable_Check(free_cb)) {
PyErr_SetString(PyExc_TypeError, "Both arguments must be callables");
return nullptr;
}
Py_XINCREF(malloc_cb);
Py_XINCREF(free_cb);
Py_XDECREF(g_malloc_callback);
Py_XDECREF(g_free_callback);
g_malloc_callback = malloc_cb;
g_free_callback = free_cb;
Py_RETURN_NONE;
}
static PyMethodDef module_methods[] = {
{"init_module", py_init_module, METH_VARARGS, "Set malloc/free callbacks"}, {nullptr, nullptr, 0, nullptr}};
static struct PyModuleDef allocator_module = {
PyModuleDef_HEAD_INIT, "_allocator_ext", "CUDAPluggableAllocator shim for GPU Memory Service", -1, module_methods};
PyMODINIT_FUNC
PyInit__allocator_ext(void)
{
return PyModule_Create(&allocator_module);
}
} // extern "C"
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Module tensor operations for GPU Memory Service.
This module provides module-level tensor operations:
- Module tensor iteration
- Tensor registration (write path)
- Tensor materialization (read path)
"""
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Iterator, Tuple
import torch
from gpu_memory_service.client.torch.tensor import GMSTensorSpec, TensorMetadata
if TYPE_CHECKING:
from gpu_memory_service.client.memory_manager import GMSClientMemoryManager
logger = logging.getLogger(__name__)
# =============================================================================
# Module Tensor Iteration
# =============================================================================
def _iter_module_tensors(
module: torch.nn.Module,
prefix: str = "",
) -> Iterator[Tuple[str, torch.Tensor, str]]:
"""Iterate over all CUDA tensors in a module tree.
Yields (qualified_name, tensor, tensor_type) for:
- Parameters (tensor_type="parameter")
- Buffers (tensor_type="buffer")
- Other tensor attributes like _k_scale (tensor_type="tensor_attr")
Args:
module: The nn.Module to iterate.
prefix: Prefix for qualified names (used in recursion).
Yields:
(name, tensor, tensor_type) tuples for each CUDA tensor.
"""
# Parameters
for name, param in module._parameters.items():
if param is not None and param.is_cuda:
qualified = f"{prefix}{name}" if prefix else name
yield (qualified, param, "parameter")
# Buffers
for name, buf in module._buffers.items():
if buf is not None and buf.is_cuda:
qualified = f"{prefix}{name}" if prefix else name
yield (qualified, buf, "buffer")
# Other tensor attributes (not params/buffers/submodules)
skip = (
set(module._parameters.keys())
| set(module._buffers.keys())
| set(module._modules.keys())
)
for attr_name in dir(module):
if attr_name in skip or attr_name.startswith("__"):
continue
try:
attr_val = getattr(module, attr_name, None)
except Exception:
continue
if torch.is_tensor(attr_val) and attr_val.is_cuda:
qualified = f"{prefix}{attr_name}" if prefix else attr_name
yield (qualified, attr_val, "tensor_attr")
elif isinstance(attr_val, (list, tuple)) and attr_val:
if all(torch.is_tensor(x) and x.is_cuda for x in attr_val):
for i, x in enumerate(attr_val):
qualified = (
f"{prefix}{attr_name}.{i}" if prefix else f"{attr_name}.{i}"
)
yield (qualified, x, "tensor_attr")
# Recurse into submodules
for name, submodule in module._modules.items():
if submodule is not None:
subprefix = f"{prefix}{name}." if prefix else f"{name}."
yield from _iter_module_tensors(submodule, subprefix)
def _resolve_module_attr(
root: torch.nn.Module, qualified_name: str
) -> Tuple[torch.nn.Module, str]:
"""Resolve a dotted name to (parent_module, leaf_attr).
Handles ModuleList/Sequential (numeric indices) and ModuleDict (key access).
"""
parts = qualified_name.split(".")
mod = root
for p in parts[:-1]:
if hasattr(mod, p):
mod = getattr(mod, p)
elif hasattr(mod, "__getitem__"):
try:
mod = mod[int(p)] if p.isdigit() else mod[p]
except Exception:
raise AttributeError(f"Cannot resolve {p!r} in {qualified_name!r}")
else:
raise AttributeError(f"Cannot resolve {p!r} in {qualified_name!r}")
return mod, parts[-1]
# =============================================================================
# Public API - Registration and Materialization
# =============================================================================
def register_module_tensors(
gms_client_memory_manager: "GMSClientMemoryManager",
model: torch.nn.Module,
) -> None:
"""Register all model tensors into the GMS metadata store.
Args:
gms_client_memory_manager: GMS client memory manager in write mode.
model: PyTorch model to register.
"""
for name, tensor, tensor_type in _iter_module_tensors(model):
ptr = int(tensor.data_ptr())
# Find allocation containing this tensor
for va, mapping in gms_client_memory_manager.mappings.items():
if va <= ptr < va + mapping.aligned_size:
offset = ptr - va
meta = TensorMetadata.from_tensor(tensor, tensor_type)
gms_client_memory_manager.metadata_put(
key=name,
allocation_id=mapping.allocation_id,
offset_bytes=offset,
value=meta.to_bytes(),
)
break
else:
# No mapping matched - tensor pointer not in any GMS allocation
if tensor_type == "parameter":
# Parameters are model weights - must be in GMS allocations
raise RuntimeError(f"Tensor {name!r} not found in any GMS allocation")
# Buffers and tensor_attrs may be dynamically allocated (e.g., KV cache)
logger.debug(
"[GMS] Skipping %s %r - not in GMS allocations", tensor_type, name
)
def materialize_module_from_gms(
gms_client_memory_manager: "GMSClientMemoryManager",
model: torch.nn.Module,
*,
device_index: int,
) -> None:
"""Materialize model tensors from GMS.
Args:
gms_client_memory_manager: GMS client memory manager in read mode.
model: Model to populate with tensors.
device_index: CUDA device index.
"""
specs = GMSTensorSpec.load_all(gms_client_memory_manager)
for name, spec in specs.items():
tensor = spec.materialize(gms_client_memory_manager, device_index)
mod, attr = _resolve_module_attr(model, name)
tensor_type = spec.meta.tensor_type
# Tensor attrs and buffers: clone since they may be mutated
if tensor_type in ("tensor_attr", "buffer"):
if (
tensor_type == "buffer"
and hasattr(mod, "_buffers")
and attr in mod._buffers
):
mod._buffers[attr] = tensor.detach().clone()
else:
setattr(mod, attr, tensor.detach().clone())
continue
# Parameters: in-place update or replace meta tensors
if hasattr(mod, "_parameters") and attr in mod._parameters:
param = mod._parameters[attr]
if param is not None:
if param.shape != tensor.shape or param.dtype != tensor.dtype:
raise RuntimeError(
f"Shape/dtype mismatch for {name}: "
f"param={tuple(param.shape)}/{param.dtype}, "
f"gms={tuple(tensor.shape)}/{tensor.dtype}"
)
if param.is_meta or param.device != tensor.device:
mod._parameters[attr] = torch.nn.Parameter(
tensor, requires_grad=param.requires_grad
)
else:
param.data = tensor
continue
# Fallback: set as attribute
setattr(mod, attr, tensor)
# Check for meta tensors and warn
meta_tensors = [n for n, p in model.named_parameters() if p.is_meta]
meta_tensors += [n for n, b in model.named_buffers() if b.is_meta]
if meta_tensors:
logger.warning(
"[GMS] %d meta tensors not in metadata: %s",
len(meta_tensors),
meta_tensors[:10],
)
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Tensor utilities for GPU Memory Service.
This module provides low-level tensor functionality:
- Tensor creation from CUDA pointers
- Tensor metadata serialization/deserialization
- GMS tensor spec for metadata store entries
"""
from __future__ import annotations
import json
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Tuple
import torch
if TYPE_CHECKING:
from gpu_memory_service.client.memory_manager import GMSClientMemoryManager
# =============================================================================
# Tensor Creation from CUDA Pointer
# =============================================================================
def _tensor_from_pointer(
data_ptr: int,
shape: List[int],
stride: List[int],
dtype: torch.dtype,
device_index: int,
) -> torch.Tensor:
"""Create a torch.Tensor from a raw CUDA pointer without copying data.
Uses PyTorch's internal APIs to create a tensor that aliases existing
GPU memory. The tensor does NOT own the memory - the caller must ensure
the memory remains valid for the tensor's lifetime.
Args:
data_ptr: CUDA device pointer (virtual address) to the tensor data.
shape: Tensor dimensions.
stride: Tensor strides (in elements, not bytes).
dtype: Tensor data type.
device_index: CUDA device index where the memory resides.
Returns:
A tensor aliasing the specified GPU memory.
"""
device = torch.device("cuda", device_index)
# Calculate storage size in bytes based on stride (handles non-contiguous tensors)
# For non-contiguous tensors, the memory footprint is larger than numel * element_size
element_size = torch.tensor([], dtype=dtype).element_size()
if shape and stride:
if len(shape) != len(stride):
raise ValueError(
f"Shape and stride length mismatch: {len(shape)} vs {len(stride)}"
)
# Maximum offset = sum of stride[i] * (shape[i] - 1) for all dimensions
max_offset = sum(
s * (d - 1) for s, d in zip(stride, shape, strict=True) if d > 0
)
required_elements = max_offset + 1
else:
# Scalar tensor or empty tensor
required_elements = 1
storage_size_bytes = required_elements * element_size
# Create storage from raw pointer (does not take ownership)
storage = torch._C._construct_storage_from_data_pointer(
data_ptr, device, storage_size_bytes
)
# Create tensor from storage with metadata
metadata = {
"size": torch.Size(shape),
"stride": stride,
"storage_offset": 0,
"dtype": dtype,
}
return torch._C._construct_CUDA_Tensor_From_Storage_And_Metadata(metadata, storage)
# =============================================================================
# Tensor Metadata - serialization format for metadata store
# =============================================================================
def _parse_dtype(dtype_str: str) -> torch.dtype:
"""Parse dtype string (e.g., 'torch.float16') to torch.dtype."""
s = str(dtype_str)
if s.startswith("torch."):
s = s.split(".", 1)[1]
dt = getattr(torch, s, None)
if not isinstance(dt, torch.dtype):
raise ValueError(f"Unknown dtype: {dtype_str!r}")
return dt
@dataclass(frozen=True)
class TensorMetadata:
"""Metadata for a tensor stored in the GMS metadata store."""
shape: Tuple[int, ...]
dtype: torch.dtype
stride: Tuple[int, ...]
tensor_type: str = "parameter" # "parameter", "buffer", or "tensor_attr"
@classmethod
def from_tensor(
cls, tensor: torch.Tensor, tensor_type: str = "parameter"
) -> "TensorMetadata":
"""Create TensorMetadata from an existing tensor."""
return cls(
shape=tuple(tensor.shape),
dtype=tensor.dtype,
stride=tuple(int(s) for s in tensor.stride()),
tensor_type=tensor_type,
)
@classmethod
def from_bytes(cls, value: bytes) -> "TensorMetadata":
"""Parse metadata from JSON bytes."""
obj = json.loads(value.decode("utf-8"))
shape = tuple(int(x) for x in obj["shape"])
dtype = _parse_dtype(obj["dtype"])
if "stride" in obj and obj["stride"] is not None:
stride = tuple(int(x) for x in obj["stride"])
else:
# Legacy format: compute contiguous stride
stride = []
acc = 1
for d in reversed(shape):
stride.append(acc)
acc *= d
stride = tuple(reversed(stride)) if stride else ()
return cls(
shape=shape,
dtype=dtype,
stride=stride,
tensor_type=obj.get("tensor_type", "parameter"),
)
def to_bytes(self) -> bytes:
"""Serialize to JSON bytes for metadata store."""
return json.dumps(
{
"shape": list(self.shape),
"dtype": str(self.dtype),
"stride": list(self.stride),
"tensor_type": self.tensor_type,
},
sort_keys=True,
).encode("utf-8")
# =============================================================================
# GMS Tensor Spec - metadata entry from store
# =============================================================================
@dataclass(frozen=True)
class GMSTensorSpec:
"""A tensor entry from the GMS metadata store."""
key: str
name: str
allocation_id: str
offset_bytes: int
meta: TensorMetadata
@classmethod
def load_all(
cls, gms_client_memory_manager: "GMSClientMemoryManager"
) -> Dict[str, "GMSTensorSpec"]:
"""Load all metadata entries.
Returns:
Mapping of tensor name -> GMSTensorSpec.
"""
specs: Dict[str, GMSTensorSpec] = {}
for key in gms_client_memory_manager.metadata_list():
got = gms_client_memory_manager.metadata_get(key)
if got is None:
raise RuntimeError(f"Metadata key disappeared: {key}")
allocation_id, offset_bytes, value = got
if key in specs:
raise RuntimeError(f"Duplicate tensor name: {key}")
specs[key] = cls(
key=key,
name=key,
allocation_id=str(allocation_id),
offset_bytes=int(offset_bytes),
meta=TensorMetadata.from_bytes(value),
)
return specs
def materialize(
self,
gms_client_memory_manager: "GMSClientMemoryManager",
device_index: int,
) -> torch.Tensor:
"""Create a tensor aliasing mapped CUDA memory."""
base_va = gms_client_memory_manager.import_allocation(self.allocation_id)
ptr = int(base_va) + int(self.offset_bytes)
return _tensor_from_pointer(
ptr,
list(self.meta.shape),
list(self.meta.stride),
self.meta.dtype,
device_index,
)
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""CUDA Virtual Memory Management (VMM) utility functions.
This module provides utility functions for CUDA driver API operations
used by both server (GMSServerMemoryManager) and client (GMSClientMemoryManager).
"""
from cuda.bindings import driver as cuda
def check_cuda_result(result: cuda.CUresult, name: str) -> None:
"""Check CUDA driver API result and raise on error.
Args:
result: CUDA driver API return code (CUresult enum)
name: Operation name for error message
Raises:
RuntimeError: If result is not CUDA_SUCCESS
"""
if result != cuda.CUresult.CUDA_SUCCESS:
err_result, err_str = cuda.cuGetErrorString(result)
if err_result == cuda.CUresult.CUDA_SUCCESS and err_str:
err_msg = err_str.decode() if isinstance(err_str, bytes) else str(err_str)
else:
err_msg = str(result)
raise RuntimeError(f"{name}: {err_msg}")
def ensure_cuda_initialized() -> None:
"""Ensure CUDA driver is initialized.
Raises:
RuntimeError: If cuInit fails
"""
(result,) = cuda.cuInit(0)
check_cuda_result(result, "cuInit")
def get_allocation_granularity(device: int) -> int:
"""Get VMM allocation granularity for a device.
Args:
device: CUDA device index
Returns:
Allocation granularity in bytes (typically 2 MiB)
"""
prop = cuda.CUmemAllocationProp()
prop.type = cuda.CUmemAllocationType.CU_MEM_ALLOCATION_TYPE_PINNED
prop.location.type = cuda.CUmemLocationType.CU_MEM_LOCATION_TYPE_DEVICE
prop.location.id = device
prop.requestedHandleTypes = (
cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR
)
result, granularity = cuda.cuMemGetAllocationGranularity(
prop, cuda.CUmemAllocationGranularity_flags.CU_MEM_ALLOC_GRANULARITY_MINIMUM
)
check_cuda_result(result, "cuMemGetAllocationGranularity")
return int(granularity)
def align_to_granularity(size: int, granularity: int) -> int:
"""Align size up to VMM granularity.
Args:
size: Size in bytes
granularity: Allocation granularity
Returns:
Aligned size
"""
return ((size + granularity - 1) // granularity) * granularity
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Message types for GPU Memory Service RPC protocol."""
from enum import Enum
from typing import Any, Dict, List, Optional, Union
import msgspec
class RequestedLockType(str, Enum):
"""Lock type requested by client."""
RW = "rw"
RO = "ro"
RW_OR_RO = "rw_or_ro"
class GrantedLockType(str, Enum):
"""Lock type actually granted by server."""
RW = "rw"
RO = "ro"
class HandshakeRequest(msgspec.Struct, tag="handshake_request"):
lock_type: RequestedLockType
timeout_ms: Optional[int] = None
class HandshakeResponse(msgspec.Struct, tag="handshake_response"):
success: bool
committed: bool
granted_lock_type: Optional[GrantedLockType] = None
class CommitRequest(msgspec.Struct, tag="commit_request"):
pass
class CommitResponse(msgspec.Struct, tag="commit_response"):
success: bool
class GetLockStateRequest(msgspec.Struct, tag="get_lock_state_request"):
pass
class GetLockStateResponse(msgspec.Struct, tag="get_lock_state_response"):
state: str # "EMPTY", "RW", "COMMITTED", "RO"
has_rw_session: bool
ro_session_count: int
waiting_writers: int
committed: bool
is_ready: bool
class GetAllocationStateRequest(msgspec.Struct, tag="get_allocation_state_request"):
pass
class GetAllocationStateResponse(msgspec.Struct, tag="get_allocation_state_response"):
allocation_count: int
total_bytes: int
class AllocateRequest(msgspec.Struct, tag="allocate_request"):
size: int
tag: str = "default"
class AllocateResponse(msgspec.Struct, tag="allocate_response"):
allocation_id: str
size: int
aligned_size: int
class ExportRequest(msgspec.Struct, tag="export_request"):
allocation_id: str
class GetAllocationRequest(msgspec.Struct, tag="get_allocation_request"):
allocation_id: str
class GetAllocationResponse(msgspec.Struct, tag="get_allocation_response"):
allocation_id: str
size: int
aligned_size: int
tag: str
class ListAllocationsRequest(msgspec.Struct, tag="list_allocations_request"):
tag: Optional[str] = None
class ListAllocationsResponse(msgspec.Struct, tag="list_allocations_response"):
allocations: List[Dict[str, Any]] = []
class FreeRequest(msgspec.Struct, tag="free_request"):
allocation_id: str
class FreeResponse(msgspec.Struct, tag="free_response"):
success: bool
class ClearAllRequest(msgspec.Struct, tag="clear_all_request"):
pass
class ClearAllResponse(msgspec.Struct, tag="clear_all_response"):
cleared_count: int
class ErrorResponse(msgspec.Struct, tag="error_response"):
error: str
code: int = 0
class MetadataPutRequest(msgspec.Struct, tag="metadata_put_request"):
key: str
allocation_id: str
offset_bytes: int
value: bytes
class MetadataPutResponse(msgspec.Struct, tag="metadata_put_response"):
success: bool
class MetadataGetRequest(msgspec.Struct, tag="metadata_get_request"):
key: str
class MetadataGetResponse(msgspec.Struct, tag="metadata_get_response"):
found: bool
allocation_id: Optional[str] = None
offset_bytes: Optional[int] = None
value: Optional[bytes] = None
class MetadataDeleteRequest(msgspec.Struct, tag="metadata_delete_request"):
key: str
class MetadataDeleteResponse(msgspec.Struct, tag="metadata_delete_response"):
deleted: bool
class MetadataListRequest(msgspec.Struct, tag="metadata_list_request"):
prefix: str = ""
class MetadataListResponse(msgspec.Struct, tag="metadata_list_response"):
keys: List[str] = []
class GetStateHashRequest(msgspec.Struct, tag="get_memory_layout_hash_request"):
pass
class GetStateHashResponse(msgspec.Struct, tag="get_memory_layout_hash_response"):
memory_layout_hash: str # Hash of allocations + metadata, empty if not committed
Message = Union[
HandshakeRequest,
HandshakeResponse,
CommitRequest,
CommitResponse,
GetLockStateRequest,
GetLockStateResponse,
GetAllocationStateRequest,
GetAllocationStateResponse,
AllocateRequest,
AllocateResponse,
ExportRequest,
GetAllocationRequest,
GetAllocationResponse,
ListAllocationsRequest,
ListAllocationsResponse,
FreeRequest,
FreeResponse,
ClearAllRequest,
ClearAllResponse,
ErrorResponse,
MetadataPutRequest,
MetadataPutResponse,
MetadataGetRequest,
MetadataGetResponse,
MetadataDeleteRequest,
MetadataDeleteResponse,
MetadataListRequest,
MetadataListResponse,
GetStateHashRequest,
GetStateHashResponse,
]
_encoder = msgspec.msgpack.Encoder()
_decoder = msgspec.msgpack.Decoder(Message)
def encode_message(msg: Message) -> bytes:
return _encoder.encode(msg)
def decode_message(data: bytes) -> Message:
return _decoder.decode(data)
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Wire protocol for length-prefixed messages with optional FD passing."""
import asyncio
import os
import socket
import struct
from typing import Optional, Tuple
from .messages import Message, decode_message, encode_message
HEADER_SIZE = 4 # 4-byte big-endian length prefix
def _frame_message(msg: Message) -> bytes:
"""Encode and frame a message with length prefix."""
data = encode_message(msg)
return struct.pack("!I", len(data)) + data
def _try_extract_message(
recv_buffer: bytearray,
) -> Tuple[Optional[Message], bytearray, int]:
"""Try to extract a complete message from buffer.
Returns (message, remaining_buffer, bytes_needed).
"""
if len(recv_buffer) < HEADER_SIZE:
return None, recv_buffer, HEADER_SIZE - len(recv_buffer)
length = struct.unpack("!I", bytes(recv_buffer[:HEADER_SIZE]))[0]
total_needed = HEADER_SIZE + length
if len(recv_buffer) < total_needed:
return None, recv_buffer, total_needed - len(recv_buffer)
msg_data = bytes(recv_buffer[HEADER_SIZE:total_needed])
remaining = bytearray(recv_buffer[total_needed:])
return decode_message(msg_data), remaining, 0
# ==================== Async (for server) ====================
async def send_message(writer, msg: Message, fd: int = -1) -> None:
"""Send a length-prefixed message with optional FD via SCM_RIGHTS."""
frame = _frame_message(msg)
if fd >= 0:
transport_sock = writer.get_extra_info("socket")
if transport_sock is None:
raise RuntimeError("Cannot get socket from transport for FD passing")
def do_send_fd():
raw_fd = transport_sock.fileno()
dup_fd = os.dup(raw_fd)
try:
sock = socket.socket(fileno=dup_fd)
try:
sock.setblocking(True)
socket.send_fds(sock, [frame], [fd])
finally:
sock.detach()
except Exception:
os.close(dup_fd)
raise
await asyncio.get_running_loop().run_in_executor(None, do_send_fd)
else:
writer.write(frame)
await writer.drain()
async def recv_message(
reader, recv_buffer: Optional[bytearray] = None, raw_sock=None
) -> Tuple[Optional[Message], int, bytearray]:
"""Receive a length-prefixed message with optional FD.
Returns (message, fd, remaining_buffer). fd is -1 if none sent.
"""
if recv_buffer is None:
recv_buffer = bytearray()
# Check if complete message already in buffer
msg, remaining, _ = _try_extract_message(recv_buffer)
if msg is not None:
return msg, -1, remaining
loop = asyncio.get_running_loop()
fd = -1
# Receive more data
if raw_sock is not None:
raw_msg, fds, _flags, _addr = await loop.run_in_executor(
None, lambda: socket.recv_fds(raw_sock, 65536, 1)
)
if not raw_msg:
raise ConnectionResetError("Connection closed")
recv_buffer.extend(raw_msg)
fd = fds[0] if fds else -1
else:
chunk = await reader.read(65536)
if not chunk:
raise ConnectionResetError("Connection closed")
recv_buffer.extend(chunk)
# Try to extract message, read more if needed
msg, remaining, bytes_needed = _try_extract_message(recv_buffer)
while msg is None and bytes_needed > 0:
if raw_sock is not None:
# Continue reading from raw socket to avoid buffer inconsistency
chunk = await loop.run_in_executor(
None, lambda n=bytes_needed: raw_sock.recv(n)
)
else:
chunk = await reader.read(bytes_needed)
if not chunk:
raise ConnectionResetError("Connection closed")
remaining.extend(chunk)
msg, remaining, bytes_needed = _try_extract_message(remaining)
return msg, fd, remaining
# ==================== Sync (for client) ====================
def send_message_sync(sock, msg: Message, fd: int = -1) -> None:
"""Send a length-prefixed message with optional FD via SCM_RIGHTS."""
frame = _frame_message(msg)
if fd >= 0:
socket.send_fds(sock, [frame], [fd])
else:
sock.sendall(frame)
def recv_message_sync(
sock, recv_buffer: Optional[bytearray] = None
) -> Tuple[Optional[Message], int, bytearray]:
"""Receive a length-prefixed message with optional FD.
Returns (message, fd, remaining_buffer). fd is -1 if none sent.
"""
if recv_buffer is None:
recv_buffer = bytearray()
# Check if complete message already in buffer
msg, remaining, _ = _try_extract_message(recv_buffer)
if msg is not None:
return msg, -1, remaining
# Receive more data (with potential FD)
raw_msg, fds, _flags, _addr = socket.recv_fds(sock, 65536, 1)
if not raw_msg:
raise ConnectionResetError("Connection closed")
recv_buffer.extend(raw_msg)
fd = fds[0] if fds else -1
# Try to extract message, read more if needed
msg, remaining, bytes_needed = _try_extract_message(recv_buffer)
while msg is None and bytes_needed > 0:
chunk = sock.recv(bytes_needed)
if not chunk:
raise ConnectionResetError("Connection closed")
remaining.extend(chunk)
msg, remaining, bytes_needed = _try_extract_message(remaining)
return msg, fd, remaining
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Shared types for GPU Memory Service."""
from dataclasses import dataclass
from enum import Enum, auto
from gpu_memory_service.common.protocol.messages import (
AllocateRequest,
ClearAllRequest,
CommitRequest,
ExportRequest,
FreeRequest,
GetAllocationRequest,
GetAllocationStateRequest,
GetLockStateRequest,
GetStateHashRequest,
GrantedLockType,
ListAllocationsRequest,
MetadataDeleteRequest,
MetadataGetRequest,
MetadataListRequest,
MetadataPutRequest,
RequestedLockType,
)
# Re-export lock types for convenience
__all__ = [
"GrantedLockType",
"RequestedLockType",
"ServerState",
"StateEvent",
"StateSnapshot",
"derive_state",
"RW_REQUIRED",
"RO_ALLOWED",
"RW_ALLOWED",
]
class ServerState(str, Enum):
"""Server state - derived from actual connections."""
EMPTY = "EMPTY"
RW = "RW"
COMMITTED = "COMMITTED"
RO = "RO"
class StateEvent(Enum):
"""Events that trigger state transitions."""
RW_CONNECT = auto()
RW_COMMIT = auto()
RW_ABORT = auto()
RO_CONNECT = auto()
RO_DISCONNECT = auto()
@dataclass
class StateSnapshot:
"""Current server state snapshot."""
state: ServerState
has_rw: bool
ro_count: int
waiting_writers: int
committed: bool
@property
def is_ready(self) -> bool:
"""Ready = committed and no RW connection."""
return self.committed and not self.has_rw
def derive_state(has_rw: bool, ro_count: int, committed: bool) -> ServerState:
"""Derive server state from connection info."""
if has_rw:
return ServerState.RW
if ro_count > 0:
return ServerState.RO
if committed:
return ServerState.COMMITTED
return ServerState.EMPTY
# Permission sets: which message types require which connection mode
RW_REQUIRED: frozenset[type] = frozenset(
{
AllocateRequest,
FreeRequest,
ClearAllRequest,
MetadataPutRequest,
MetadataDeleteRequest,
CommitRequest,
}
)
RO_ALLOWED: frozenset[type] = frozenset(
{
ExportRequest,
GetAllocationRequest,
ListAllocationsRequest,
MetadataGetRequest,
MetadataListRequest,
GetLockStateRequest,
GetAllocationStateRequest,
GetStateHashRequest,
}
)
RW_ALLOWED: frozenset[type] = RW_REQUIRED | RO_ALLOWED
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
[build-system]
requires = ["setuptools>=61.0", "wheel"]
build-backend = "setuptools.build_meta"
[project]
name = "gpu-memory-service"
version = "0.8.0"
description = "GPU Memory Service for Dynamo - CUDA VMM-based GPU memory allocation and sharing"
readme = "README.md"
authors = [
{ name = "NVIDIA Inc.", email = "sw-dl-dynamo@nvidia.com" },
]
license = { text = "Apache-2.0" }
requires-python = ">=3.10"
dependencies = [
"msgspec>=0.18.0",
"uvloop>=0.21.0",
]
classifiers = [
"Development Status :: 4 - Beta",
"Intended Audience :: Developers",
"Intended Audience :: Science/Research",
"Intended Audience :: Information Technology",
"License :: OSI Approved :: Apache Software License",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Topic :: Scientific/Engineering",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Operating System :: POSIX :: Linux",
]
keywords = ["llm", "genai", "inference", "nvidia", "gpu", "memory", "dynamo"]
[project.optional-dependencies]
test = [
"pytest>=8.3.4",
"pytest-asyncio",
]
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""GPU Memory Service server components."""
from gpu_memory_service.common.types import (
GrantedLockType,
RequestedLockType,
ServerState,
StateSnapshot,
)
from gpu_memory_service.server.handler import MetadataEntry, RequestHandler
from gpu_memory_service.server.locking import Connection, GlobalLockFSM
from gpu_memory_service.server.memory_manager import (
AllocationInfo,
AllocationNotFoundError,
GMSServerMemoryManager,
)
from gpu_memory_service.server.rpc import GMSRPCServer
__all__ = [
"GMSRPCServer",
"GMSServerMemoryManager",
"AllocationInfo",
"AllocationNotFoundError",
"MetadataEntry",
"Connection",
"GrantedLockType",
"RequestedLockType",
"RequestHandler",
"ServerState",
"GlobalLockFSM",
"StateSnapshot",
]
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Request handlers for GPU Memory Service."""
import hashlib
import logging
from dataclasses import dataclass
from gpu_memory_service.common.protocol.messages import (
AllocateRequest,
AllocateResponse,
ClearAllResponse,
FreeRequest,
FreeResponse,
GetAllocationRequest,
GetAllocationResponse,
GetAllocationStateResponse,
GetLockStateResponse,
GetStateHashResponse,
ListAllocationsRequest,
ListAllocationsResponse,
MetadataDeleteRequest,
MetadataDeleteResponse,
MetadataGetRequest,
MetadataGetResponse,
MetadataListRequest,
MetadataListResponse,
MetadataPutRequest,
MetadataPutResponse,
)
from gpu_memory_service.common.types import derive_state
from .memory_manager import AllocationNotFoundError, GMSServerMemoryManager
logger = logging.getLogger(__name__)
@dataclass(frozen=True)
class MetadataEntry:
allocation_id: str
offset_bytes: int
value: bytes
class RequestHandler:
"""Handles allocation and metadata requests."""
def __init__(self, device: int = 0):
self._memory_manager = GMSServerMemoryManager(device)
self._metadata: dict[str, MetadataEntry] = {}
self._memory_layout_hash: str = (
"" # Hash of allocations + metadata, computed on commit
)
logger.info(f"RequestHandler initialized: device={device}")
@property
def granularity(self) -> int:
return self._memory_manager.granularity
def on_rw_abort(self) -> None:
"""Called when RW connection closes without commit."""
logger.warning("RW aborted; clearing allocations and metadata")
self._memory_manager.clear_all()
self._metadata.clear()
self._memory_layout_hash = ""
def on_commit(self) -> None:
"""Called when RW connection commits. Computes state hash."""
self._memory_layout_hash = self._compute_memory_layout_hash()
logger.info(f"Committed with state hash: {self._memory_layout_hash[:16]}...")
def _compute_memory_layout_hash(self) -> str:
"""Compute hash of current allocations + metadata."""
h = hashlib.sha256()
# Hash allocations (sorted by ID for determinism)
for info in sorted(
self._memory_manager.list_allocations(), key=lambda x: x.allocation_id
):
h.update(
f"{info.allocation_id}:{info.size}:{info.aligned_size}:{info.tag}".encode()
)
# Hash metadata (sorted by key for determinism)
for key in sorted(self._metadata.keys()):
entry = self._metadata[key]
h.update(f"{key}:{entry.allocation_id}:{entry.offset_bytes}:".encode())
h.update(entry.value)
return h.hexdigest()
def on_shutdown(self) -> None:
"""Called on server shutdown."""
if self._memory_manager.allocation_count > 0:
count = self._memory_manager.clear_all()
self._metadata.clear()
logger.info(f"Released {count} GPU allocations during shutdown")
# ==================== State Queries ====================
def handle_get_lock_state(
self,
has_rw: bool,
ro_count: int,
waiting_writers: int,
committed: bool,
) -> GetLockStateResponse:
"""Get lock/session state."""
state = derive_state(has_rw, ro_count, committed)
return GetLockStateResponse(
state=state.value,
has_rw_session=has_rw,
ro_session_count=ro_count,
waiting_writers=waiting_writers,
committed=committed,
is_ready=committed and not has_rw,
)
def handle_get_allocation_state(self) -> GetAllocationStateResponse:
"""Get allocation state."""
return GetAllocationStateResponse(
allocation_count=self._memory_manager.allocation_count,
total_bytes=self._memory_manager.total_bytes,
)
# ==================== Allocation Operations ====================
def handle_allocate(self, req: AllocateRequest) -> AllocateResponse:
"""Create physical memory allocation.
Requires RW connection (enforced by server).
"""
info = self._memory_manager.allocate(req.size, req.tag)
return AllocateResponse(
allocation_id=info.allocation_id,
size=info.size,
aligned_size=info.aligned_size,
)
def handle_export(self, allocation_id: str) -> tuple[GetAllocationResponse, int]:
"""Export allocation as POSIX FD.
Returns (response, fd). Caller must close fd after sending.
"""
fd = self._memory_manager.export_fd(allocation_id)
info = self._memory_manager.get_allocation(allocation_id)
response = GetAllocationResponse(
allocation_id=info.allocation_id,
size=info.size,
aligned_size=info.aligned_size,
tag=info.tag,
)
return response, fd
def handle_get_allocation(self, req: GetAllocationRequest) -> GetAllocationResponse:
"""Get allocation info."""
try:
info = self._memory_manager.get_allocation(req.allocation_id)
return GetAllocationResponse(
allocation_id=info.allocation_id,
size=info.size,
aligned_size=info.aligned_size,
tag=info.tag,
)
except AllocationNotFoundError:
raise ValueError(f"Unknown allocation: {req.allocation_id}") from None
def handle_list_allocations(
self, req: ListAllocationsRequest
) -> ListAllocationsResponse:
"""List all allocations."""
allocations = self._memory_manager.list_allocations(req.tag)
result = [
{
"allocation_id": info.allocation_id,
"size": info.size,
"aligned_size": info.aligned_size,
"tag": info.tag,
}
for info in allocations
]
return ListAllocationsResponse(allocations=result)
def handle_free(self, req: FreeRequest) -> FreeResponse:
"""Free single allocation.
Requires RW connection (enforced by server).
"""
success = self._memory_manager.free(req.allocation_id)
return FreeResponse(success=success)
def handle_clear_all(self) -> ClearAllResponse:
"""Clear all allocations and metadata.
Requires RW connection (enforced by server).
"""
count = self._memory_manager.clear_all()
self._metadata.clear()
return ClearAllResponse(cleared_count=count)
# ==================== Metadata Operations ====================
def handle_metadata_put(self, req: MetadataPutRequest) -> MetadataPutResponse:
self._metadata[req.key] = MetadataEntry(
req.allocation_id, req.offset_bytes, req.value
)
return MetadataPutResponse(success=True)
def handle_metadata_get(self, req: MetadataGetRequest) -> MetadataGetResponse:
entry = self._metadata.get(req.key)
if entry is None:
return MetadataGetResponse(found=False)
return MetadataGetResponse(
found=True,
allocation_id=entry.allocation_id,
offset_bytes=entry.offset_bytes,
value=entry.value,
)
def handle_metadata_delete(
self, req: MetadataDeleteRequest
) -> MetadataDeleteResponse:
return MetadataDeleteResponse(
deleted=self._metadata.pop(req.key, None) is not None
)
def handle_metadata_list(self, req: MetadataListRequest) -> MetadataListResponse:
keys = (
[k for k in self._metadata if k.startswith(req.prefix)]
if req.prefix
else list(self._metadata)
)
return MetadataListResponse(keys=sorted(keys))
def handle_get_memory_layout_hash(self) -> GetStateHashResponse:
return GetStateHashResponse(memory_layout_hash=self._memory_layout_hash)
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Connection and state machine for GPU Memory Service.
This module handles:
- Connection: Represents an active client connection
- GlobalLockFSM: Explicit state transitions with validated permissions
State Diagram:
EMPTY ──RW_CONNECT──► RW ──RW_COMMIT──► COMMITTED
▲ │ │
│ │ │
└───RW_ABORT─────────┘ │
COMMITTED ◄──RO_DISCONNECT (last)── RO ◄──RO_CONNECT
│ ▲
│ │
└──RO_CONNECT──────┘
└──RO_DISCONNECT───┘ (not last)
"""
from __future__ import annotations
import asyncio
import logging
import socket
from dataclasses import dataclass, field
from typing import Callable, Optional, Set
from gpu_memory_service.common.types import (
RO_ALLOWED,
RW_ALLOWED,
RW_REQUIRED,
GrantedLockType,
ServerState,
StateEvent,
)
logger = logging.getLogger(__name__)
# =============================================================================
# Connection
# =============================================================================
@dataclass(eq=False)
class Connection:
"""Represents an active connection.
The existence of Connection objects IS the state - we don't track
sessions separately. When a Connection is removed, the lock is released.
Note: eq=False disables auto-generated __eq__ so we can use default
object identity for equality and add __hash__ for use in sets.
"""
reader: asyncio.StreamReader
writer: asyncio.StreamWriter
mode: GrantedLockType
session_id: str
recv_buffer: bytearray = field(default_factory=bytearray)
def __hash__(self) -> int:
"""Hash based on session_id (immutable identifier)."""
return hash(self.session_id)
@property
def raw_socket(self) -> socket.socket:
"""Get underlying socket for FD passing."""
return self.writer.get_extra_info("socket")
async def close(self) -> None:
"""Close the connection."""
self.writer.close()
try:
await self.writer.wait_closed()
except Exception:
pass
# =============================================================================
# State Machine
# =============================================================================
class InvalidTransition(Exception):
"""Raised when an invalid state transition is attempted."""
pass
class OperationNotAllowed(Exception):
"""Raised when an operation is not allowed in the current state/mode."""
pass
@dataclass(frozen=True)
class Transition:
"""A valid state transition.
Attributes:
from_states: Set of states this transition can originate from
event: The event that triggers this transition
to_state: The resulting state (or None if conditional)
condition: Optional condition function for conditional transitions
"""
from_states: frozenset[ServerState]
event: StateEvent
to_state: Optional[ServerState]
condition: Optional[str] = None # Name of condition method
# Transition table - the single source of truth for valid state transitions
TRANSITIONS: list[Transition] = [
# From EMPTY or COMMITTED: RW can connect
# Writer acquires exclusive lock
Transition(
from_states=frozenset({ServerState.EMPTY, ServerState.COMMITTED}),
event=StateEvent.RW_CONNECT,
to_state=ServerState.RW,
),
# From RW: commit publishes and transitions to COMMITTED
# Writer publishes and releases lock
Transition(
from_states=frozenset({ServerState.RW}),
event=StateEvent.RW_COMMIT,
to_state=ServerState.COMMITTED,
),
# From RW: abort (disconnect without commit) transitions to EMPTY
# Writer aborts, state invalidated
Transition(
from_states=frozenset({ServerState.RW}),
event=StateEvent.RW_ABORT,
to_state=ServerState.EMPTY,
),
# From COMMITTED or RO: RO can connect
# Reader acquires shared lock
Transition(
from_states=frozenset({ServerState.COMMITTED, ServerState.RO}),
event=StateEvent.RO_CONNECT,
to_state=ServerState.RO,
),
# From RO: reader disconnect (not last) stays in RO
# Reader leaves, others remain
Transition(
from_states=frozenset({ServerState.RO}),
event=StateEvent.RO_DISCONNECT,
to_state=ServerState.RO,
condition="has_remaining_readers",
),
# From RO: last reader disconnect transitions to COMMITTED
# Last reader leaves
Transition(
from_states=frozenset({ServerState.RO}),
event=StateEvent.RO_DISCONNECT,
to_state=ServerState.COMMITTED,
condition="is_last_reader",
),
]
@dataclass
class TransitionRecord:
"""Record of a state transition for debugging/auditing."""
from_state: ServerState
event: StateEvent
to_state: ServerState
session_id: Optional[str] = None
class GlobalLockFSM:
"""Explicit state machine for GPU Memory Service.
State is DERIVED from actual connection objects:
- _rw_conn: The active RW connection (or None)
- _ro_conns: Set of active RO connections
- _committed: Whether allocations have been committed
All state mutations happen through explicit transitions.
"""
def __init__(self, on_rw_abort: Optional[Callable[[], None]] = None):
"""Initialize the state machine.
Args:
on_rw_abort: Callback invoked when RW aborts (for cleanup)
"""
# Connection state - THIS IS THE SOURCE OF TRUTH
self._rw_conn: Optional[Connection] = None
self._ro_conns: Set[Connection] = set()
self._committed: bool = False
# Callback for RW abort cleanup
self._on_rw_abort = on_rw_abort
# Transition history for debugging
self._transition_log: list[TransitionRecord] = []
# ==================== State Properties ====================
@property
def state(self) -> ServerState:
"""Derive current state from connection objects."""
if self._rw_conn is not None:
return ServerState.RW
if len(self._ro_conns) > 0:
return ServerState.RO
if self._committed:
return ServerState.COMMITTED
return ServerState.EMPTY
@property
def rw_conn(self) -> Optional[Connection]:
"""The active RW connection, if any."""
return self._rw_conn
@property
def ro_conns(self) -> Set[Connection]:
"""Set of active RO connections."""
return self._ro_conns
@property
def ro_count(self) -> int:
"""Number of active RO connections."""
return len(self._ro_conns)
@property
def committed(self) -> bool:
"""Whether allocations have been committed."""
return self._committed
@property
def transition_log(self) -> list[TransitionRecord]:
"""History of state transitions."""
return self._transition_log
# ==================== Transition Conditions ====================
def _has_remaining_readers(self, conn: Connection) -> bool:
"""Check if there are readers remaining after removing conn."""
return len(self._ro_conns) > 1 or conn not in self._ro_conns
def _is_last_reader(self, conn: Connection) -> bool:
"""Check if conn is the last reader."""
return len(self._ro_conns) == 1 and conn in self._ro_conns
def _check_condition(self, condition: Optional[str], conn: Connection) -> bool:
"""Evaluate a named condition."""
if condition is None:
return True
if condition == "has_remaining_readers":
return self._has_remaining_readers(conn)
if condition == "is_last_reader":
return self._is_last_reader(conn)
raise ValueError(f"Unknown condition: {condition}")
# ==================== State Transitions ====================
def _find_transition(
self, from_state: ServerState, event: StateEvent, conn: Connection
) -> Optional[Transition]:
"""Find the applicable transition for the given event."""
for t in TRANSITIONS:
if from_state not in t.from_states:
continue
if t.event != event:
continue
if not self._check_condition(t.condition, conn):
continue
return t
return None
def _apply_event(self, event: StateEvent, conn: Connection) -> None:
"""Mutate internal state based on event."""
match event:
case StateEvent.RW_CONNECT:
self._rw_conn = conn
self._committed = False # Invalidate on RW connect
case StateEvent.RW_COMMIT:
self._committed = True
self._rw_conn = None
case StateEvent.RW_ABORT:
self._rw_conn = None
if self._on_rw_abort:
self._on_rw_abort()
case StateEvent.RO_CONNECT:
self._ro_conns.add(conn)
case StateEvent.RO_DISCONNECT:
self._ro_conns.discard(conn)
def transition(self, event: StateEvent, conn: Connection) -> ServerState:
"""Execute a state transition.
Args:
event: The triggering event
conn: The connection involved in the transition
Returns:
The new state after the transition
Raises:
InvalidTransition: If the transition is not valid from current state
"""
from_state = self.state
session_id = conn.session_id if conn else None
# Find valid transition
trans = self._find_transition(from_state, event, conn)
if trans is None:
raise InvalidTransition(
f"No transition for {event.name} from state {from_state.name} "
f"(session={session_id})"
)
# Apply the transition
self._apply_event(event, conn)
to_state = self.state
# Validate we ended up in expected state
if trans.to_state is not None and to_state != trans.to_state:
raise InvalidTransition(
f"Transition mismatch: expected {trans.to_state.name}, "
f"got {to_state.name}"
)
# Record transition
record = TransitionRecord(from_state, event, to_state, session_id)
self._transition_log.append(record)
logger.info(
f"State transition: {from_state.name} --{event.name}--> {to_state.name} "
f"(session={session_id})"
)
return to_state
# ==================== Operation Permissions ====================
def check_operation(self, msg_type: type, conn: Connection) -> None:
"""Check if a request type is allowed for the given connection.
Args:
msg_type: The request message type (e.g., AllocateRequest)
conn: The connection attempting the operation
Raises:
OperationNotAllowed: If the operation is not permitted
"""
current_state = self.state
# Determine allowed operations based on state
if current_state == ServerState.RW:
allowed = RW_ALLOWED
elif current_state == ServerState.RO:
allowed = RO_ALLOWED
else:
allowed = frozenset() # EMPTY and COMMITTED have no connections
if msg_type not in allowed:
raise OperationNotAllowed(
f"{msg_type.__name__} not allowed in state {current_state.name}"
)
# Check connection mode
if msg_type in RW_REQUIRED and conn.mode != GrantedLockType.RW:
raise OperationNotAllowed(
f"{msg_type.__name__} requires RW connection, "
f"but connection is {conn.mode.value}"
)
# ==================== Lock Acquisition Predicates ====================
def can_acquire_rw(self) -> bool:
"""Check if RW lock can be acquired now.
RW can only be acquired if:
- No current RW holder
- No RO holders
Note: This allows RW from COMMITTED state (for explicit reload).
For rw_or_ro mode, callers should also check `committed` to prefer RO.
"""
return self._rw_conn is None and len(self._ro_conns) == 0
def can_acquire_ro(self, waiting_writers: int) -> bool:
"""Check if RO lock can be acquired now.
Args:
waiting_writers: Number of writers waiting for the lock
"""
return self._rw_conn is None and waiting_writers == 0 and self._committed
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""CUDA VMM allocation manager - pure business logic, no threading/transport.
This module contains the GMSServerMemoryManager class which handles physical GPU memory
allocations via CUDA Virtual Memory Management (VMM) API. It creates shareable
memory without mapping it locally (no CUDA context needed on the server).
The GMSServerMemoryManager is NOT thread-safe. Callers must provide external
synchronization (e.g., via LockManager ensuring single-writer access).
"""
import logging
import time
from dataclasses import dataclass
from typing import Dict, List, Optional
from uuid import uuid4
from cuda.bindings import driver as cuda
from gpu_memory_service.common.cuda_vmm_utils import (
align_to_granularity,
check_cuda_result,
ensure_cuda_initialized,
get_allocation_granularity,
)
logger = logging.getLogger(__name__)
@dataclass
class AllocationInfo:
"""Information about a single GPU memory allocation.
Attributes:
allocation_id: Unique identifier for this allocation
size: Requested size in bytes
aligned_size: Actual size after alignment to VMM granularity
handle: CUmemGenericAllocationHandle value
tag: User-provided tag for grouping allocations
created_at: Timestamp when allocation was created
"""
allocation_id: str
size: int
aligned_size: int
handle: int
tag: str
created_at: float
class AllocationNotFoundError(Exception):
"""Raised when an allocation_id doesn't exist."""
pass
class GMSServerMemoryManager:
"""GPU Memory Service server-side memory manager.
Manages CUDA VMM physical memory allocations. This class handles the core
memory operations using CUDA Virtual Memory Management (VMM) API. It creates
physical allocations that can be exported as POSIX file descriptors for
sharing with other processes.
Key design points:
- No VA mapping: The memory manager never maps memory to virtual addresses,
so it doesn't create a CUDA context. This allows it to survive GPU
driver failures.
- NOT thread-safe: Callers must provide external synchronization.
The GlobalLockFSM's RW/RO semantics ensure single-writer access.
"""
def __init__(self, device: int = 0):
self._device = device
self._allocations: Dict[str, AllocationInfo] = {}
ensure_cuda_initialized()
self._granularity = get_allocation_granularity(device)
logger.info(
f"GMSServerMemoryManager initialized: device={device}, granularity={self._granularity}"
)
@property
def device(self) -> int:
return self._device
@property
def granularity(self) -> int:
return self._granularity
@property
def allocation_count(self) -> int:
return len(self._allocations)
@property
def total_bytes(self) -> int:
return sum(info.aligned_size for info in self._allocations.values())
def _get(self, allocation_id: str) -> AllocationInfo:
info = self._allocations.get(allocation_id)
if info is None:
raise AllocationNotFoundError(f"Unknown allocation: {allocation_id}")
return info
def _release(self, info: AllocationInfo) -> None:
(result,) = cuda.cuMemRelease(info.handle)
if result != cuda.CUresult.CUDA_SUCCESS:
logger.warning(f"cuMemRelease failed for {info.allocation_id}: {result}")
def allocate(self, size: int, tag: str = "default") -> AllocationInfo:
"""Create a physical memory allocation (no VA mapping).
Uses cuMemCreate to allocate physical GPU memory that can be exported
as a file descriptor for sharing with other processes.
Args:
size: Requested size in bytes (will be aligned up to granularity)
tag: Tag for grouping allocations (e.g., "weights", "kv_cache")
Returns:
AllocationInfo with allocation_id, aligned_size, handle
Raises:
RuntimeError: If CUDA allocation fails
"""
aligned_size = align_to_granularity(size, self._granularity)
prop = cuda.CUmemAllocationProp()
prop.type = cuda.CUmemAllocationType.CU_MEM_ALLOCATION_TYPE_PINNED
prop.location.type = cuda.CUmemLocationType.CU_MEM_LOCATION_TYPE_DEVICE
prop.location.id = self._device
prop.requestedHandleTypes = (
cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR
)
result, handle = cuda.cuMemCreate(aligned_size, prop, 0)
check_cuda_result(result, "cuMemCreate")
info = AllocationInfo(
allocation_id=str(uuid4()),
size=size,
aligned_size=aligned_size,
handle=int(handle),
tag=tag,
created_at=time.time(),
)
self._allocations[info.allocation_id] = info
logger.debug(
f"Allocated {info.allocation_id}: size={size}, aligned={aligned_size}, tag={tag}"
)
return info
def export_fd(self, allocation_id: str) -> int:
"""Export allocation as POSIX FD for SCM_RIGHTS transfer.
The returned file descriptor can be sent to another process via
Unix domain socket SCM_RIGHTS. The receiving process can then
import it using cuMemImportFromShareableHandle.
IMPORTANT: The caller MUST close the returned FD after sendmsg()
to avoid file descriptor leaks.
Args:
allocation_id: ID of allocation to export
Returns:
File descriptor (caller owns, must close after sending)
Raises:
AllocationNotFoundError: If allocation_id doesn't exist
RuntimeError: If CUDA export fails
"""
info = self._get(allocation_id)
result, fd = cuda.cuMemExportToShareableHandle(
info.handle,
cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR,
0,
)
check_cuda_result(result, "cuMemExportToShareableHandle")
return int(fd)
def free(self, allocation_id: str) -> bool:
"""Release physical memory for a single allocation.
Args:
allocation_id: ID of allocation to free
Returns:
True if allocation existed and was freed, False otherwise
"""
info = self._allocations.pop(allocation_id, None)
if info is None:
return False
self._release(info)
logger.debug(f"Freed allocation: {allocation_id}")
return True
def clear_all(self) -> int:
"""Release ALL allocations.
Used by loaders before loading a new model, or during cleanup
when a writer aborts without committing.
Returns:
Number of allocations cleared
"""
count = len(self._allocations)
for info in self._allocations.values():
self._release(info)
self._allocations.clear()
logger.info(f"Cleared {count} allocations")
return count
def get_allocation(self, allocation_id: str) -> AllocationInfo:
"""Get allocation info. Raises AllocationNotFoundError if not found."""
return self._get(allocation_id)
def list_allocations(self, tag: Optional[str] = None) -> List[AllocationInfo]:
"""List all allocations, optionally filtered by tag."""
if tag is None:
return list(self._allocations.values())
return [info for info in self._allocations.values() if info.tag == tag]
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Async Allocation RPC Server - Single-threaded event loop with explicit state machine.
State transitions are explicit and validated by the GlobalLockFSM class.
Operations are checked against state/mode permissions before execution.
State Machine (see locking.py for full diagram):
EMPTY: No connections, not committed
RW: Writer connected (exclusive)
COMMITTED: No connections, committed (weights valid)
RO: Reader(s) connected (shared)
"""
from __future__ import annotations
import asyncio
import logging
import os
from typing import ClassVar, Optional
from gpu_memory_service.common.protocol.messages import (
AllocateRequest,
ClearAllRequest,
CommitRequest,
CommitResponse,
ErrorResponse,
ExportRequest,
FreeRequest,
GetAllocationRequest,
GetAllocationStateRequest,
GetLockStateRequest,
GetStateHashRequest,
HandshakeRequest,
HandshakeResponse,
ListAllocationsRequest,
MetadataDeleteRequest,
MetadataGetRequest,
MetadataListRequest,
MetadataPutRequest,
)
from gpu_memory_service.common.protocol.wire import recv_message, send_message
from gpu_memory_service.common.types import (
GrantedLockType,
RequestedLockType,
ServerState,
StateEvent,
)
from .handler import RequestHandler
from .locking import Connection, GlobalLockFSM
logger = logging.getLogger(__name__)
class GMSRPCServer:
"""GPU Memory Service RPC Server.
Async single-threaded server using GlobalLockFSM for explicit state transitions
and operation validation. All state mutations happen through the state machine's
transition() method.
"""
def __init__(self, socket_path: str, device: int = 0):
self.socket_path = socket_path
self.device = device
# Request handler (business logic)
self._handler = RequestHandler(device)
# State machine - handles all state transitions and permission checks
self._sm = GlobalLockFSM(on_rw_abort=self._handler.on_rw_abort)
self._waiting_writers: int = 0
# Async waiting for lock acquisition
self._condition = asyncio.Condition()
self._shutdown = False
# Session ID generation
self._next_session_id: int = 0
# Server state
self._server: Optional[asyncio.Server] = None
self._running: bool = False
logger.info(f"GMSRPCServer initialized: device={device}")
# ==================== State Properties ====================
@property
def state(self) -> ServerState:
"""Current server state (delegated to state machine)."""
return self._sm.state
@property
def granularity(self) -> int:
return self._handler.granularity
def is_ready(self) -> bool:
"""Ready = committed and no RW connection."""
return self._sm.committed and self._sm.rw_conn is None
@property
def running(self) -> bool:
"""Whether the server is running."""
return self._running
def _generate_session_id(self) -> str:
self._next_session_id += 1
return f"session_{self._next_session_id}"
# ==================== Connection Lifecycle ====================
async def _handle_connection(
self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter
) -> None:
"""Handle a connection from accept to close."""
session_id = self._generate_session_id()
conn: Optional[Connection] = None
try:
conn = await self._do_handshake(reader, writer, session_id)
if conn is None:
return
await self._request_loop(conn)
except ConnectionResetError:
logger.debug(f"Connection reset: {session_id}")
except asyncio.CancelledError:
raise
except Exception:
logger.exception(f"Connection error: {session_id}")
finally:
await self._cleanup_connection(conn)
async def _do_handshake(
self,
reader: asyncio.StreamReader,
writer: asyncio.StreamWriter,
session_id: str,
) -> Optional[Connection]:
"""Perform handshake and acquire lock via state machine transition."""
try:
# Server never receives FDs from clients, so no need for raw_sock
msg, _, recv_buffer = await recv_message(reader, bytearray())
except Exception:
logger.exception("Handshake recv error")
return None
if not isinstance(msg, HandshakeRequest):
await send_message(writer, ErrorResponse(error="Expected HandshakeRequest"))
writer.close()
return None
# Acquire lock (blocks until available or timeout)
# Returns the actual granted mode (may differ from requested for rw_or_ro)
granted_mode = await self._acquire_lock(msg.lock_type, msg.timeout_ms)
if granted_mode is None:
await send_message(
writer, HandshakeResponse(success=False, committed=self._sm.committed)
)
writer.close()
return None
conn = Connection(reader, writer, granted_mode, session_id, recv_buffer)
# State transition: connect
event = (
StateEvent.RW_CONNECT
if granted_mode == GrantedLockType.RW
else StateEvent.RO_CONNECT
)
self._sm.transition(event, conn)
await send_message(
writer,
HandshakeResponse(
success=True,
committed=self._sm.committed,
granted_lock_type=granted_mode,
),
)
return conn
async def _acquire_lock(
self, mode: RequestedLockType, timeout_ms: Optional[int]
) -> Optional[GrantedLockType]:
"""Wait until lock can be acquired (uses state machine predicates).
Returns the granted lock type, or None if failed/timeout.
For rw_or_ro mode, returns RW if available immediately, else waits for RO.
"""
timeout = timeout_ms / 1000 if timeout_ms is not None else None
if mode == RequestedLockType.RW:
self._waiting_writers += 1
try:
async with self._condition:
try:
await asyncio.wait_for(
self._condition.wait_for(
lambda: self._shutdown or self._sm.can_acquire_rw()
),
timeout=timeout,
)
return None if self._shutdown else GrantedLockType.RW
except asyncio.TimeoutError:
return None
finally:
self._waiting_writers -= 1
elif mode == RequestedLockType.RO:
async with self._condition:
try:
await asyncio.wait_for(
self._condition.wait_for(
lambda: self._shutdown
or self._sm.can_acquire_ro(self._waiting_writers)
),
timeout=timeout,
)
return None if self._shutdown else GrantedLockType.RO
except asyncio.TimeoutError:
return None
elif mode == RequestedLockType.RW_OR_RO:
# Auto mode: try RW if available immediately AND no committed weights,
# otherwise wait for RO (to import existing weights)
async with self._condition:
# Check if RW is available AND no committed weights exist
# If weights are already committed, prefer RO to import them
if self._sm.can_acquire_rw() and not self._sm.committed:
return GrantedLockType.RW
# Either RW not available OR weights already committed - wait for RO
if self._sm.committed:
logger.info(
"RW_OR_RO: Weights already committed, preferring RO to import"
)
else:
logger.info(
"RW_OR_RO: RW not available (another writer active), "
"falling back to RO"
)
try:
await asyncio.wait_for(
self._condition.wait_for(
lambda: self._shutdown
or self._sm.can_acquire_ro(self._waiting_writers)
),
timeout=timeout,
)
return None if self._shutdown else GrantedLockType.RO
except asyncio.TimeoutError:
return None
return None
async def _cleanup_connection(self, conn: Optional[Connection]) -> None:
"""Clean up after connection closes via state machine transition."""
if conn is None:
return
# State transition: disconnect
if conn.mode == GrantedLockType.RW:
if self._sm.rw_conn is conn and not self._sm.committed:
# RW abort - state machine callback handles cleanup
self._sm.transition(StateEvent.RW_ABORT, conn)
elif self._sm.rw_conn is conn:
# Already committed, no transition needed (commit already did it)
pass
else:
if conn in self._sm.ro_conns:
self._sm.transition(StateEvent.RO_DISCONNECT, conn)
await conn.close()
async with self._condition:
self._condition.notify_all()
# ==================== Request Handling ====================
async def _request_loop(self, conn: Connection) -> None:
"""Process requests until close or commit."""
while self._running:
try:
# Server never receives FDs from clients, so no need for raw_socket
msg, _, conn.recv_buffer = await recv_message(
conn.reader, conn.recv_buffer
)
except ConnectionResetError:
return
except asyncio.CancelledError:
raise
except Exception:
logger.exception("Recv error")
return
if msg is None:
continue
try:
response, fd, should_close = await self._dispatch(conn, msg)
if response is not None:
try:
await send_message(conn.writer, response, fd)
finally:
if fd >= 0:
os.close(fd)
if should_close:
return
except Exception as e:
logger.exception("Request error")
await send_message(conn.writer, ErrorResponse(error=str(e)))
# Dispatch table: message type -> handler method name
# Handlers take (msg) and return response. Special cases handled separately.
_HANDLERS: ClassVar[dict[type, str]] = {
AllocateRequest: "handle_allocate",
GetAllocationRequest: "handle_get_allocation",
ListAllocationsRequest: "handle_list_allocations",
FreeRequest: "handle_free",
MetadataPutRequest: "handle_metadata_put",
MetadataGetRequest: "handle_metadata_get",
MetadataDeleteRequest: "handle_metadata_delete",
MetadataListRequest: "handle_metadata_list",
}
async def _dispatch(self, conn: Connection, msg) -> tuple[object, int, bool]:
"""Dispatch request to handler. Returns (response, fd, should_close)."""
msg_type = type(msg)
self._sm.check_operation(msg_type, conn)
# Special cases
if msg_type is CommitRequest:
return await self._handle_commit(conn)
if msg_type is GetLockStateRequest:
return (
self._handler.handle_get_lock_state(
self._sm.rw_conn is not None,
self._sm.ro_count,
self._waiting_writers,
self._sm.committed,
),
-1,
False,
)
if msg_type is GetAllocationStateRequest:
return self._handler.handle_get_allocation_state(), -1, False
if msg_type is ExportRequest:
response, fd = self._handler.handle_export(msg.allocation_id)
return response, fd, False
if msg_type is ClearAllRequest:
return self._handler.handle_clear_all(), -1, False
if msg_type is GetStateHashRequest:
return self._handler.handle_get_memory_layout_hash(), -1, False
# Standard dispatch: handler takes msg, returns response
handler_name = self._HANDLERS.get(msg_type)
if handler_name:
handler = getattr(self._handler, handler_name)
return handler(msg), -1, False
raise ValueError(f"Unknown request: {msg_type.__name__}")
async def _handle_commit(self, conn: Connection) -> tuple[object, int, bool]:
"""Handle commit via state machine transition - atomic with disconnect."""
# Compute state hash before transitioning
self._handler.on_commit()
# State transition: commit
self._sm.transition(StateEvent.RW_COMMIT, conn)
await send_message(conn.writer, CommitResponse(success=True))
await conn.close()
async with self._condition:
self._condition.notify_all()
return None, -1, True
# ==================== Server Lifecycle ====================
async def start(self) -> None:
if os.path.exists(self.socket_path):
os.unlink(self.socket_path)
self._server = await asyncio.start_unix_server(
self._handle_connection, path=self.socket_path
)
self._running = True
logger.info(f"Server started: {self.socket_path}")
async def stop(self) -> None:
self._running = False
self._shutdown = True
async with self._condition:
self._condition.notify_all()
if self._server:
self._server.close()
await self._server.wait_closed()
self._server = None
# Close connections (bypassing state machine - this is shutdown)
if self._sm.rw_conn:
await self._sm.rw_conn.close()
for conn in list(self._sm.ro_conns):
await conn.close()
self._handler.on_shutdown()
if os.path.exists(self.socket_path):
os.unlink(self.socket_path)
logger.info("Server stopped")
async def serve_forever(self) -> None:
await self.start()
try:
while self._running:
await asyncio.sleep(1)
finally:
await self.stop()
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Build script for GPU Memory Service with C++ extensions.
This setup.py builds the C++ extensions as part of pip install.
The _allocator_ext extension only requires Python headers (no CUDA or PyTorch needed).
Following the torch_memory_saver pattern of using pure setuptools for extension building.
"""
from setuptools import Extension, setup
from setuptools.command.build_ext import build_ext
class BuildExtension(build_ext):
"""Custom build extension for C++ modules."""
def build_extensions(self):
import os
# Use CXX environment variable if set, otherwise default to g++
cxx = os.environ.get("CXX", "g++")
self.compiler.set_executable("compiler_so", cxx)
self.compiler.set_executable("compiler_cxx", cxx)
self.compiler.set_executable("linker_so", f"{cxx} -shared")
build_ext.build_extensions(self)
def _create_ext_modules():
"""Create extension modules for gpu_memory_service."""
# Common compile arguments
extra_compile_args = ["-std=c++17", "-O3", "-fPIC"]
# _allocator_ext: CUDAPluggableAllocator shim using only Python C API
# No CUDA or PyTorch dependency - just provides my_malloc/my_free that call Python callbacks
return [
Extension(
name="gpu_memory_service.client.torch.extensions._allocator_ext",
sources=["client/torch/extensions/allocator.cpp"],
extra_compile_args=extra_compile_args,
)
]
setup(
name="gpu-memory-service",
version="0.8.0",
description="GPU Memory Service for Dynamo - CUDA VMM-based GPU memory allocation and sharing",
author="NVIDIA Inc.",
author_email="sw-dl-dynamo@nvidia.com",
license="Apache-2.0",
python_requires=">=3.10",
install_requires=[
"msgpack>=1.0",
"uvloop>=0.21.0",
],
extras_require={
"test": [
"pytest>=8.3.4",
"pytest-asyncio",
],
},
# Package directory mapping: the current directory IS the gpu_memory_service package
packages=[
"gpu_memory_service",
"gpu_memory_service.common",
"gpu_memory_service.common.protocol",
"gpu_memory_service.server",
"gpu_memory_service.client",
"gpu_memory_service.client.torch",
"gpu_memory_service.client.torch.extensions",
],
package_dir={
"gpu_memory_service": ".",
"gpu_memory_service.common": "common",
"gpu_memory_service.common.protocol": "common/protocol",
"gpu_memory_service.server": "server",
"gpu_memory_service.client": "client",
"gpu_memory_service.client.torch": "client/torch",
"gpu_memory_service.client.torch.extensions": "client/torch/extensions",
},
package_data={
"gpu_memory_service.client.torch.extensions": ["*.cpp"],
},
ext_modules=_create_ext_modules(),
cmdclass={"build_ext": BuildExtension},
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment