Unverified Commit 91a8c6af authored by Schwinn Saereesitthipitak's avatar Schwinn Saereesitthipitak Committed by GitHub
Browse files

fix: simplify GMS layout state and harden GPU-backed flows (#7006)


Signed-off-by: default avatarmohammedabdulwahhab <furkhan324@berkeley.edu>
Signed-off-by: default avatarHannah Zhang <hannahz@nvidia.com>
Co-authored-by: default avatarmohammedabdulwahhab <furkhan324@berkeley.edu>
Co-authored-by: default avatarhhzhang16 <54051230+hhzhang16@users.noreply.github.com>
parent dd7ceb4a
This diff is collapsed.
......@@ -32,7 +32,9 @@ from gpu_memory_service.client.memory_manager import (
# PyTorch integration (GMS client memory manager)
from gpu_memory_service.client.torch.allocator import (
get_gms_client_memory_manager,
get_gms_client_memory_managers,
get_or_create_gms_client_memory_manager,
gms_use_mem_pool,
)
__all__ = [
......@@ -42,4 +44,6 @@ __all__ = [
# GMS client memory manager
"get_or_create_gms_client_memory_manager",
"get_gms_client_memory_manager",
"get_gms_client_memory_managers",
"gms_use_mem_pool",
]
......@@ -6,6 +6,7 @@
import argparse
import logging
from dataclasses import dataclass
from typing import Optional
from gpu_memory_service.common.utils import get_socket_path
......@@ -17,7 +18,10 @@ class Config:
"""Configuration for GPU Memory Service server."""
device: int
tag: str
socket_path: str
alloc_retry_interval: float
alloc_retry_timeout: Optional[float]
verbose: bool
......@@ -33,6 +37,12 @@ def parse_args() -> Config:
required=True,
help="CUDA device ID to manage memory for.",
)
parser.add_argument(
"--tag",
type=str,
default="weights",
help="Logical GMS tag for this server (default: weights).",
)
parser.add_argument(
"--socket-path",
type=str,
......@@ -45,14 +55,33 @@ def parse_args() -> Config:
action="store_true",
help="Enable verbose logging.",
)
parser.add_argument(
"--alloc-retry-interval",
type=float,
default=0.5,
help="Seconds to sleep between allocation retries on CUDA OOM (default: 0.5).",
)
parser.add_argument(
"--alloc-retry-timeout",
type=float,
default=None,
help="Optional max seconds to wait for allocation retries before failing (default: wait indefinitely).",
)
args = parser.parse_args()
# Use UUID-based socket path by default (stable across CUDA_VISIBLE_DEVICES)
socket_path = args.socket_path or get_socket_path(args.device)
socket_path = args.socket_path or get_socket_path(args.device, args.tag)
if args.alloc_retry_interval <= 0:
parser.error("--alloc-retry-interval must be > 0")
if args.alloc_retry_timeout is not None and args.alloc_retry_timeout <= 0:
parser.error("--alloc-retry-timeout must be > 0 when set")
return Config(
device=args.device,
tag=args.tag,
socket_path=socket_path,
alloc_retry_interval=args.alloc_retry_interval,
alloc_retry_timeout=args.alloc_retry_timeout,
verbose=args.verbose,
)
......@@ -13,7 +13,6 @@ Usage:
import asyncio
import logging
import signal
import uvloop
from gpu_memory_service.server import GMSRPCServer
......@@ -37,33 +36,28 @@ async def worker() -> None:
logging.getLogger("gpu_memory_service").setLevel(logging.DEBUG)
logger.info(f"Starting GPU Memory Service Server for device {config.device}")
logger.info("GMS tag: %s", config.tag)
logger.info(f"Socket path: {config.socket_path}")
server = GMSRPCServer(config.socket_path, device=config.device)
# Set up shutdown handling
shutdown_event = asyncio.Event()
def signal_handler():
logger.info("Received shutdown signal")
shutdown_event.set()
loop = asyncio.get_running_loop()
for sig in (signal.SIGTERM, signal.SIGINT):
loop.add_signal_handler(sig, signal_handler)
await server.start()
logger.info(
"Allocation retry config: interval=%ss timeout=%s",
config.alloc_retry_interval,
(
f"{config.alloc_retry_timeout}s"
if config.alloc_retry_timeout is not None
else "none"
),
)
server = GMSRPCServer(
config.socket_path,
device=config.device,
allocation_retry_interval=config.alloc_retry_interval,
allocation_retry_timeout=config.alloc_retry_timeout,
)
logger.info("GPU Memory Service Server ready, waiting for connections...")
logger.info(f"Clients can connect via socket: {config.socket_path}")
# Wait for shutdown signal
try:
await shutdown_event.wait()
finally:
logger.info("Shutting down GPU Memory Service Server...")
await server.stop()
logger.info("GPU Memory Service Server shutdown complete")
await server.serve()
def main() -> None:
......
......@@ -7,7 +7,6 @@ This module provides the client-side components for interacting with the
GPU Memory Service:
- GMSClientMemoryManager: Manages local VA mappings of remote GPU memory
- GMSRPCClient: Low-level RPC client (pure Python, no PyTorch dependency)
For PyTorch integration (MemPool, tensor utilities), see gpu_memory_service.client.torch.
"""
......@@ -16,10 +15,8 @@ from gpu_memory_service.client.memory_manager import (
GMSClientMemoryManager,
StaleMemoryLayoutError,
)
from gpu_memory_service.client.rpc import GMSRPCClient
__all__ = [
"GMSClientMemoryManager",
"StaleMemoryLayoutError",
"GMSRPCClient",
]
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Client-side CUDA VMM utilities.
These functions wrap CUDA driver API calls used by the client memory manager
for importing, mapping, and unmapping GPU memory.
"""
from __future__ import annotations
import os
from cuda.bindings import driver as cuda
from gpu_memory_service.common.cuda_vmm_utils import check_cuda_result
from gpu_memory_service.common.types import GrantedLockType
def import_handle_from_fd(fd: int) -> int:
"""Import a CUDA memory handle from a file descriptor.
Closes the FD after import — the imported handle holds its own reference
to the physical allocation. Leaving the FD open leaks a DMA-buf ref that
prevents cuMemRelease from freeing GPU memory.
Args:
fd: POSIX file descriptor received via SCM_RIGHTS.
Returns:
CUDA memory handle.
"""
try:
result, handle = cuda.cuMemImportFromShareableHandle(
fd,
cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR,
)
check_cuda_result(result, "cuMemImportFromShareableHandle")
return int(handle)
finally:
os.close(fd)
def reserve_va(size: int, granularity: int) -> int:
"""Reserve virtual address space.
Args:
size: Size in bytes (should be aligned to granularity).
granularity: VMM allocation granularity.
Returns:
Reserved virtual address.
"""
result, va = cuda.cuMemAddressReserve(size, granularity, 0, 0)
check_cuda_result(result, "cuMemAddressReserve")
return int(va)
def free_va(va: int, size: int) -> None:
"""Free a virtual address reservation.
Args:
va: Virtual address to free.
size: Size of the reservation.
"""
(result,) = cuda.cuMemAddressFree(va, size)
check_cuda_result(result, "cuMemAddressFree")
def map_to_va(va: int, size: int, handle: int) -> None:
"""Map a CUDA handle to a virtual address.
Args:
va: Virtual address (must be reserved).
size: Size of the mapping.
handle: CUDA memory handle.
"""
(result,) = cuda.cuMemMap(va, size, 0, handle, 0)
check_cuda_result(result, "cuMemMap")
def set_access(va: int, size: int, device: int, access: GrantedLockType) -> None:
"""Set access permissions for a mapped region.
Args:
va: Virtual address.
size: Size of the region.
device: CUDA device index.
access: Access mode - RO for read-only, RW for read-write.
"""
acc = cuda.CUmemAccessDesc()
acc.location.type = cuda.CUmemLocationType.CU_MEM_LOCATION_TYPE_DEVICE
acc.location.id = device
acc.flags = (
cuda.CUmemAccess_flags.CU_MEM_ACCESS_FLAGS_PROT_READ
if access == GrantedLockType.RO
else cuda.CUmemAccess_flags.CU_MEM_ACCESS_FLAGS_PROT_READWRITE
)
(result,) = cuda.cuMemSetAccess(va, size, [acc], 1)
check_cuda_result(result, "cuMemSetAccess")
def unmap(va: int, size: int) -> None:
"""Unmap a virtual address region.
Args:
va: Virtual address to unmap.
size: Size of the mapping.
"""
(result,) = cuda.cuMemUnmap(va, size)
check_cuda_result(result, "cuMemUnmap")
def release_handle(handle: int) -> None:
"""Release a CUDA memory handle.
Args:
handle: CUDA memory handle to release.
"""
(result,) = cuda.cuMemRelease(handle)
check_cuda_result(result, "cuMemRelease")
def validate_pointer(va: int) -> bool:
"""Validate that a mapped VA is accessible.
Returns True if the pointer is valid, False otherwise (logs a warning).
"""
result, _dev_ptr = cuda.cuPointerGetAttribute(
cuda.CUpointer_attribute.CU_POINTER_ATTRIBUTE_DEVICE_POINTER, va
)
if result != cuda.CUresult.CUDA_SUCCESS:
err_result, err_str = cuda.cuGetErrorString(result)
err_msg = ""
if err_result == cuda.CUresult.CUDA_SUCCESS and err_str:
err_msg = err_str.decode() if isinstance(err_str, bytes) else str(err_str)
import logging
logging.getLogger(__name__).warning(
"cuPointerGetAttribute failed for VA 0x%x: %s (%s)",
va,
result,
err_msg,
)
return False
return True
def synchronize() -> None:
"""Synchronize the current CUDA context.
Blocks until all preceding commands in the current context have completed.
"""
(result,) = cuda.cuCtxSynchronize()
check_cuda_result(result, "cuCtxSynchronize")
def set_current_device(device: int) -> None:
"""Set the current CUDA device by activating its primary context.
Args:
device: CUDA device index.
"""
result, ctx = cuda.cuDevicePrimaryCtxRetain(device)
check_cuda_result(result, "cuDevicePrimaryCtxRetain")
(result,) = cuda.cuCtxSetCurrent(ctx)
check_cuda_result(result, "cuCtxSetCurrent")
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""GPU Memory Service RPC Client.
"""Internal GPU Memory Service transport.
Low-level RPC client stub. The client provides a simple interface for acquiring
locks and performing allocation operations. The socket connection IS the lock.
This module has NO PyTorch dependency.
Usage:
# Writer (acquires RW lock in constructor)
with GMSRPCClient(socket_path, lock_type=RequestedLockType.RW) as client:
alloc_id, aligned_size = client.allocate(size=1024*1024)
fd = client.export(alloc_id)
# ... write weights using fd ...
client.commit()
# Lock released on exit
# Reader (acquires RO lock in constructor)
client = GMSRPCClient(socket_path, lock_type=RequestedLockType.RO)
if client.committed: # Check if weights are valid
allocations = client.list_allocations()
for alloc in allocations:
fd = client.export(alloc["allocation_id"])
# ... import and map fd ...
# Keep connection open during inference!
# client.close() only when done with inference
This module only owns Unix socket transport and typed request/response exchange.
Session semantics live in `gpu_memory_service.client.session`.
"""
from __future__ import annotations
import logging
import os
import socket
from typing import Dict, List, Optional, Tuple, Type, TypeVar
from typing import Optional, Tuple, Type, TypeVar
from gpu_memory_service.common.protocol.messages import (
AllocateRequest,
AllocateResponse,
ClearAllRequest,
ClearAllResponse,
CommitRequest,
CommitResponse,
ErrorResponse,
ExportRequest,
FreeRequest,
FreeResponse,
GetAllocationRequest,
GetAllocationResponse,
GetAllocationStateRequest,
GetAllocationStateResponse,
GetLockStateRequest,
GetLockStateResponse,
GetStateHashRequest,
GetStateHashResponse,
HandshakeRequest,
HandshakeResponse,
ListAllocationsRequest,
ListAllocationsResponse,
MetadataDeleteRequest,
MetadataDeleteResponse,
MetadataGetRequest,
MetadataGetResponse,
MetadataListRequest,
MetadataListResponse,
MetadataPutRequest,
MetadataPutResponse,
)
from gpu_memory_service.common.protocol.wire import recv_message_sync, send_message_sync
from gpu_memory_service.common.types import (
RW_REQUIRED,
GrantedLockType,
RequestedLockType,
)
from gpu_memory_service.common.types import RequestedLockType
T = TypeVar("T")
logger = logging.getLogger(__name__)
class GMSRPCClient:
"""GPU Memory Service RPC Client.
CRITICAL: Socket connection IS the lock.
- Constructor blocks until lock is acquired
- close() releases the lock
- committed property tells readers if weights are valid
class _GMSRPCTransport:
"""Raw GMS Unix socket transport."""
For writers (lock_type=RequestedLockType.RW):
- Use context manager (with statement) for automatic lock release
- Call commit() after weights are written
- Call clear_all() before loading new model
For readers (lock_type=RequestedLockType.RO):
- Check committed property after construction
- Keep connection open during inference lifetime
- Only call close() when shutting down or allowing weight updates
"""
def __init__(
self,
socket_path: str,
lock_type: RequestedLockType = RequestedLockType.RO,
timeout_ms: Optional[int] = None,
):
"""Connect to Allocation Server and acquire lock.
Args:
socket_path: Path to server's Unix domain socket
lock_type: Requested lock type (RW, RO, or RW_OR_RO)
timeout_ms: Timeout in milliseconds for lock acquisition.
None means wait indefinitely.
Raises:
ConnectionError: If connection fails
TimeoutError: If timeout_ms expires waiting for lock
"""
def __init__(self, socket_path: str):
self.socket_path = socket_path
self._requested_lock_type = lock_type
self._socket: Optional[socket.socket] = None
self._recv_buffer = bytearray()
self._committed = False
self._granted_lock_type: Optional[GrantedLockType] = None
# Connect and acquire lock
self._connect(timeout_ms=timeout_ms)
@property
def is_connected(self) -> bool:
return self._socket is not None
def _connect(self, timeout_ms: Optional[int]) -> None:
"""Connect to server and perform handshake (lock acquisition)."""
def connect(self) -> None:
self._socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
try:
self._socket.connect(self.socket_path)
except FileNotFoundError:
self._socket.close()
self._socket = None
raise ConnectionError(f"Server not running at {self.socket_path}") from None
except Exception as e:
self._socket.close()
self._socket = None
raise ConnectionError(f"Failed to connect: {e}") from e
# Handshake I/O — clean up socket on any failure
try:
request = HandshakeRequest(
lock_type=self._requested_lock_type,
timeout_ms=timeout_ms,
)
send_message_sync(self._socket, request)
# May block waiting for lock
response, _, self._recv_buffer = recv_message_sync(
self._socket, self._recv_buffer
)
except Exception:
self._socket.close()
self._socket = None
raise
if isinstance(response, ErrorResponse):
self._socket.close()
self._socket = None
raise ConnectionError(f"Handshake error: {response.error}")
if not isinstance(response, HandshakeResponse):
self._socket.close()
self._socket = None
raise ConnectionError(f"Unexpected response: {type(response)}")
if not response.success:
raise ConnectionError(
f"GMS server not running at {self.socket_path}"
) from None
except Exception as exc:
self._socket.close()
self._socket = None
raise TimeoutError("Timeout waiting for lock")
self._committed = response.committed
# Store granted lock type (may differ from requested for rw_or_ro mode)
if response.granted_lock_type is not None:
self._granted_lock_type = response.granted_lock_type
elif self._requested_lock_type == RequestedLockType.RW:
self._granted_lock_type = GrantedLockType.RW
else:
self._granted_lock_type = GrantedLockType.RO
logger.info(
f"Connected with {self._requested_lock_type.value} lock (granted={self._granted_lock_type.value}), "
f"committed={self._committed}"
)
raise ConnectionError(f"Failed to connect to GMS: {exc}") from exc
@property
def committed(self) -> bool:
"""Check if weights are committed (valid)."""
return self._committed
@property
def lock_type(self) -> Optional[GrantedLockType]:
"""Get the lock type actually granted by the server.
For rw_or_ro mode, this tells you whether RW or RO was granted.
"""
return self._granted_lock_type
@property
def is_connected(self) -> bool:
"""Check if client is connected."""
return self._socket is not None
def _send_recv(self, request) -> Tuple[object, int]:
"""Send request and receive response. Returns (response, fd)."""
if not self._socket:
raise RuntimeError("Client not connected")
send_message_sync(self._socket, request)
response, fd, self._recv_buffer = recv_message_sync(
self._socket, self._recv_buffer
def handshake(
self,
lock_type: RequestedLockType,
timeout_ms: Optional[int],
) -> HandshakeResponse:
response, _ = self.request_with_fd(
HandshakeRequest(lock_type=lock_type, timeout_ms=timeout_ms),
HandshakeResponse,
error_prefix="GMS handshake",
)
if isinstance(response, ErrorResponse):
raise RuntimeError(f"Server error: {response.error}")
return response, fd
def _call(self, request, response_type: Type[T]) -> T:
"""Send request, validate response type, return typed response."""
if type(request) in RW_REQUIRED and self.lock_type != GrantedLockType.RW:
raise RuntimeError("Operation requires RW connection")
response, _ = self._send_recv(request)
if not isinstance(response, response_type):
raise RuntimeError(f"Unexpected response: {type(response)}")
return response
def get_lock_state(self) -> GetLockStateResponse:
return self._call(GetLockStateRequest(), GetLockStateResponse)
def get_allocation_state(self) -> GetAllocationStateResponse:
return self._call(GetAllocationStateRequest(), GetAllocationStateResponse)
def request(self, request, response_type: Type[T]) -> T:
response, fd = self.request_with_fd(request, response_type)
if fd >= 0:
os.close(fd)
raise RuntimeError(
f"GMS request {type(request).__name__} returned an unexpected FD"
)
return response
def is_ready(self) -> bool:
return self.committed
def request_with_fd(
self,
request,
response_type: Type[T],
*,
error_prefix: Optional[str] = None,
) -> Tuple[T, int]:
response, fd = self._send_recv(request, error_prefix=error_prefix)
if not isinstance(response, response_type):
prefix = error_prefix or f"GMS request {type(request).__name__}"
if fd >= 0:
os.close(fd)
raise RuntimeError(
f"{prefix} returned unexpected response type: {type(response)}"
)
return response, fd
def commit(self) -> bool:
"""Commit weights and release RW lock. Returns True on success."""
if CommitRequest in RW_REQUIRED and self.lock_type != GrantedLockType.RW:
raise RuntimeError("Operation requires RW connection")
def _send_recv(
self, request, *, error_prefix: Optional[str] = None
) -> Tuple[object, int]:
if self._socket is None:
raise RuntimeError("Attempted GMS request on disconnected transport")
prefix = error_prefix or f"GMS request {type(request).__name__}"
try:
response, _ = self._send_recv(CommitRequest())
ok = isinstance(response, CommitResponse) and response.success
except (ConnectionResetError, BrokenPipeError, OSError) as e:
# Server closes RW socket as part of commit
logger.debug(
f"Commit saw socket error ({type(e).__name__}); verifying via RO connect"
send_message_sync(self._socket, request)
response, fd, self._recv_buffer = recv_message_sync(
self._socket, self._recv_buffer
)
self.close()
try:
ro = GMSRPCClient(
self.socket_path, lock_type=RequestedLockType.RO, timeout_ms=1000
)
try:
ok = ro.committed
finally:
ro.close()
except TimeoutError:
ok = False
if ok:
self._committed = True
self.close()
logger.info("Committed weights and released RW connection")
return True
return False
def allocate(self, size: int, tag: str = "default") -> Tuple[str, int]:
"""Returns (allocation_id, aligned_size)."""
r = self._call(AllocateRequest(size=size, tag=tag), AllocateResponse)
return r.allocation_id, r.aligned_size
def export(self, allocation_id: str) -> int:
"""Export allocation as POSIX FD. Caller must close."""
_, fd = self._send_recv(ExportRequest(allocation_id=allocation_id))
if fd < 0:
raise RuntimeError("No FD received from server")
return fd
def get_allocation(self, allocation_id: str) -> GetAllocationResponse:
return self._call(
GetAllocationRequest(allocation_id=allocation_id), GetAllocationResponse
)
def list_allocations(self, tag: Optional[str] = None) -> List[Dict]:
return self._call(
ListAllocationsRequest(tag=tag), ListAllocationsResponse
).allocations
def free(self, allocation_id: str) -> bool:
return self._call(
FreeRequest(allocation_id=allocation_id), FreeResponse
).success
def clear_all(self) -> int:
return self._call(ClearAllRequest(), ClearAllResponse).cleared_count
def metadata_put(
self, key: str, allocation_id: str, offset_bytes: int, value: bytes
) -> bool:
req = MetadataPutRequest(
key=key, allocation_id=allocation_id, offset_bytes=offset_bytes, value=value
)
return self._call(req, MetadataPutResponse).success
def metadata_get(self, key: str) -> Optional[tuple[str, int, bytes]]:
"""Returns (allocation_id, offset_bytes, value) or None if not found."""
r = self._call(MetadataGetRequest(key=key), MetadataGetResponse)
return (r.allocation_id, r.offset_bytes, r.value) if r.found else None
def metadata_delete(self, key: str) -> bool:
return self._call(
MetadataDeleteRequest(key=key), MetadataDeleteResponse
).deleted
def metadata_list(self, prefix: str = "") -> List[str]:
return self._call(MetadataListRequest(prefix=prefix), MetadataListResponse).keys
def get_memory_layout_hash(self) -> str:
"""Get state hash (hash of allocations + metadata). Empty if not committed."""
return self._call(
GetStateHashRequest(), GetStateHashResponse
).memory_layout_hash
def close(self) -> None:
"""Close connection and release lock."""
if self._socket:
except Exception as exc:
try:
self._socket.close()
except Exception:
pass
self._socket = None
lock_str = self.lock_type.value if self.lock_type else "unknown"
logger.info(f"Closed {lock_str} connection")
raise ConnectionError(f"{prefix} failed: {exc}") from exc
def __enter__(self) -> "GMSRPCClient":
"""Context manager entry."""
if isinstance(response, ErrorResponse):
if fd >= 0:
os.close(fd)
raise RuntimeError(f"{prefix} error: {response.error}")
return response, fd
def close(self) -> None:
if self._socket is None:
return
try:
self._socket.close()
except Exception as exc:
raise ConnectionError(
f"Failed to close GMS transport socket: {exc}"
) from exc
finally:
self._socket = None
def __enter__(self) -> "_GMSRPCTransport":
return self
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
"""Context manager exit."""
self.close()
def __del__(self):
"""Destructor: warn if connection not closed."""
if self._socket:
logger.warning("GMSRPCClient not closed properly")
if self._socket is not None:
try:
self._socket.close()
except Exception:
pass
self._socket = None
logger.warning("_GMSRPCTransport not closed properly")
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Internal GPU Memory Service client session."""
from __future__ import annotations
import logging
from typing import List, Optional, Tuple
from gpu_memory_service.client.rpc import _GMSRPCTransport
from gpu_memory_service.common.protocol.messages import (
AllocateRequest,
AllocateResponse,
CommitRequest,
CommitResponse,
ExportAllocationRequest,
ExportAllocationResponse,
FreeAllocationRequest,
FreeAllocationResponse,
GetAllocationRequest,
GetAllocationResponse,
GetAllocationStateRequest,
GetAllocationStateResponse,
GetLockStateRequest,
GetLockStateResponse,
GetStateHashRequest,
GetStateHashResponse,
HandshakeResponse,
ListAllocationsRequest,
ListAllocationsResponse,
MetadataDeleteRequest,
MetadataDeleteResponse,
MetadataGetRequest,
MetadataGetResponse,
MetadataListRequest,
MetadataListResponse,
MetadataPutRequest,
MetadataPutResponse,
)
from gpu_memory_service.common.types import GrantedLockType, RequestedLockType
logger = logging.getLogger(__name__)
class _GMSClientSession:
"""Connected GMS client session with granted lock state."""
def __init__(
self,
socket_path: str,
lock_type: RequestedLockType,
timeout_ms: Optional[int],
):
self._requested_lock_type = lock_type
self._transport = _GMSRPCTransport(socket_path)
self._transport.connect()
try:
response = self._transport.handshake(lock_type, timeout_ms)
except Exception:
try:
self._transport.close()
except Exception:
pass
raise
self._initialize_from_handshake(response)
def _initialize_from_handshake(self, response: HandshakeResponse) -> None:
if not response.success:
self._transport.close()
raise TimeoutError("Timeout waiting for lock")
self._committed = response.committed
if response.granted_lock_type is None:
self._transport.close()
raise RuntimeError("HandshakeResponse omitted granted_lock_type")
self._granted_lock_type = response.granted_lock_type
logger.info(
"Connected with %s lock (granted=%s), committed=%s",
self._requested_lock_type.value,
self._granted_lock_type.value,
self._committed,
)
@property
def committed(self) -> bool:
return self._committed
@property
def lock_type(self) -> GrantedLockType:
return self._granted_lock_type
@property
def is_connected(self) -> bool:
return self._transport.is_connected
def get_lock_state(self) -> GetLockStateResponse:
return self._transport.request(GetLockStateRequest(), GetLockStateResponse)
def get_allocation_state(self) -> GetAllocationStateResponse:
return self._transport.request(
GetAllocationStateRequest(), GetAllocationStateResponse
)
def is_ready(self) -> bool:
return self.committed
def commit(self) -> bool:
response = self._transport.request(CommitRequest(), CommitResponse)
if not response.success:
raise RuntimeError("GMS commit returned failure")
self._committed = True
try:
self.close()
except ConnectionError as exc:
logger.warning("Commit succeeded but closing transport failed: %s", exc)
logger.info("Committed weights and released RW connection")
return True
def allocate_info(self, size: int, tag: str = "default") -> AllocateResponse:
return self._transport.request(
AllocateRequest(size=size, tag=tag), AllocateResponse
)
def allocate(self, size: int, tag: str = "default") -> Tuple[str, int]:
response = self.allocate_info(size=size, tag=tag)
return response.allocation_id, response.aligned_size
def export(self, allocation_id: str) -> int:
response, fd = self._transport.request_with_fd(
ExportAllocationRequest(allocation_id=allocation_id),
ExportAllocationResponse,
)
if fd < 0:
raise RuntimeError(
f"GMS export returned no FD for allocation_id={allocation_id}"
)
return fd
def get_allocation(self, allocation_id: str) -> GetAllocationResponse:
return self._transport.request(
GetAllocationRequest(allocation_id=allocation_id),
GetAllocationResponse,
)
def list_allocations(
self, tag: Optional[str] = None
) -> List[GetAllocationResponse]:
return self._transport.request(
ListAllocationsRequest(tag=tag),
ListAllocationsResponse,
).allocations
def free(self, allocation_id: str) -> bool:
return self._transport.request(
FreeAllocationRequest(allocation_id=allocation_id),
FreeAllocationResponse,
).success
def metadata_put(
self, key: str, allocation_id: str, offset_bytes: int, value: bytes
) -> bool:
return self._transport.request(
MetadataPutRequest(
key=key,
allocation_id=allocation_id,
offset_bytes=offset_bytes,
value=value,
),
MetadataPutResponse,
).success
def metadata_get(self, key: str) -> Optional[tuple[str, int, bytes]]:
response = self._transport.request(
MetadataGetRequest(key=key), MetadataGetResponse
)
if not response.found:
return None
return response.allocation_id, response.offset_bytes, response.value
def metadata_delete(self, key: str) -> bool:
return self._transport.request(
MetadataDeleteRequest(key=key), MetadataDeleteResponse
).deleted
def metadata_list(self, prefix: str = "") -> List[str]:
return self._transport.request(
MetadataListRequest(prefix=prefix), MetadataListResponse
).keys
def get_memory_layout_hash(self) -> str:
return self._transport.request(
GetStateHashRequest(), GetStateHashResponse
).memory_layout_hash
def close(self) -> None:
self._transport.close()
logger.info("Closed %s connection", self._granted_lock_type.value)
def __enter__(self) -> "_GMSClientSession":
return self
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
self.close()
......@@ -12,7 +12,9 @@ This module provides PyTorch-specific functionality:
from gpu_memory_service.client.torch.allocator import (
get_gms_client_memory_manager,
get_gms_client_memory_managers,
get_or_create_gms_client_memory_manager,
gms_use_mem_pool,
)
from gpu_memory_service.client.torch.module import (
materialize_module_from_gms,
......@@ -23,6 +25,8 @@ __all__ = [
# GMS client memory manager
"get_or_create_gms_client_memory_manager",
"get_gms_client_memory_manager",
"get_gms_client_memory_managers",
"gms_use_mem_pool",
# Tensor operations (public API)
"register_module_tensors",
"materialize_module_from_gms",
......
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""GPU Memory Service allocator management (singleton).
Manages a single weights memory manager and PyTorch MemPool integration.
Only one GMS scope is needed: weights. KV cache is handled by CuMemAllocator.
"""
"""GPU Memory Service allocator registry for PyTorch integration."""
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Any, Optional, Tuple
from contextlib import contextmanager
from contextvars import ContextVar
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Iterator, Optional
from gpu_memory_service.common.types import GrantedLockType, RequestedLockType
if TYPE_CHECKING:
import torch
from gpu_memory_service.client.memory_manager import GMSClientMemoryManager
from torch.cuda.memory import MemPool
logger = logging.getLogger(__name__)
# Singleton state
_manager: Optional["GMSClientMemoryManager"] = None
_mem_pool: Optional["MemPool"] = None
_tag: str = "weights"
_callbacks_initialized: bool = False
_pluggable_alloc: Optional[Any] = None
@dataclass
class _TagState:
manager: "GMSClientMemoryManager"
mem_pool: "MemPool | None"
socket_path: str
device: int
_tag_states: dict[str, _TagState] = {}
_active_tag: ContextVar[str | None] = ContextVar(
"gpu_memory_service_active_tag",
default=None,
)
_callbacks_initialized = False
_pluggable_alloc: Any | None = None
def _gms_malloc(size: int, device: int, stream: int) -> int:
"""Route malloc to the singleton weights manager."""
if _manager is None:
raise RuntimeError("No GMS manager initialized")
va = _manager.create_mapping(size=int(size), tag=_tag)
logger.debug("[GMS] malloc: va=0x%x size=%d", va, size)
tag = _active_tag.get()
if tag is None:
raise RuntimeError("No active GMS allocation tag")
state = _tag_states.get(tag)
if state is None:
raise RuntimeError(f"Unknown GMS allocation tag: {tag}")
va = state.manager.create_mapping(size=int(size), tag=tag)
logger.debug("[GMS] malloc(tag=%s): va=0x%x size=%d", tag, va, size)
return va
def _gms_free(ptr: int, size: int, device: int, stream: int) -> None:
"""Route free to the singleton weights manager."""
if _manager is None:
logger.warning("[GMS] free: no manager, ignoring va=0x%x", ptr)
va = int(ptr)
for tag, state in _tag_states.items():
if va not in state.manager.mappings:
continue
logger.debug("[GMS] free(tag=%s): va=0x%x size=%d", tag, va, size)
state.manager.destroy_mapping(va)
return
if int(ptr) in _manager.mappings:
logger.debug("[GMS] free: va=0x%x size=%d", ptr, size)
_manager.destroy_mapping(int(ptr))
else:
logger.warning("[GMS] free: manager does not own va=0x%x, ignoring", ptr)
logger.warning("[GMS] free: no manager owns va=0x%x, ignoring", va)
def _ensure_callbacks_initialized() -> "MemPool":
"""Initialize C-level callbacks exactly once, return a new MemPool."""
def _ensure_callbacks_initialized() -> None:
global _callbacks_initialized, _pluggable_alloc
from gpu_memory_service.client.torch.extensions import _allocator_ext as cumem
from torch.cuda import CUDAPluggableAllocator
from torch.cuda.memory import MemPool
if not _callbacks_initialized:
_pluggable_alloc = CUDAPluggableAllocator(
cumem.__file__, "my_malloc", "my_free"
)
cumem.init_module(_gms_malloc, _gms_free)
_callbacks_initialized = True
if _callbacks_initialized:
return
_pluggable_alloc = CUDAPluggableAllocator(cumem.__file__, "my_malloc", "my_free")
cumem.init_module(_gms_malloc, _gms_free)
_callbacks_initialized = True
def _create_mem_pool() -> "MemPool":
from torch.cuda.memory import MemPool
assert _pluggable_alloc is not None
return MemPool(allocator=_pluggable_alloc.allocator())
......@@ -74,66 +91,98 @@ def get_or_create_gms_client_memory_manager(
*,
tag: str = "weights",
timeout_ms: Optional[int] = None,
) -> Tuple["GMSClientMemoryManager", Optional["MemPool"]]:
"""Get existing memory manager, or create a new one.
Args:
socket_path: Unix socket path for the allocation server.
device: CUDA device index.
mode: RW for cold start, RO for import-only, RW_OR_RO for auto.
tag: Allocation tag for RW mode.
timeout_ms: Lock acquisition timeout (None = wait indefinitely).
Returns:
(gms_client_memory_manager, pool) - pool is None for RO mode.
"""
global _manager, _mem_pool, _tag
) -> "GMSClientMemoryManager":
from gpu_memory_service.client.memory_manager import GMSClientMemoryManager
if _manager is not None:
return _get_existing(mode)
state = _tag_states.get(tag)
if state is not None:
if state.socket_path != socket_path or state.device != device:
raise RuntimeError(
f"GMS allocator tag={tag} was initialized for "
f"{state.socket_path} on device {state.device}, not {socket_path} "
f"on device {device}"
)
manager = state.manager
if not manager.is_connected:
if manager.mappings or manager.is_unmapped or manager.granted_lock_type:
raise RuntimeError(
f"GMS allocator tag={tag} is disconnected but still owns "
"preserved state; recreate the process instead of reusing it"
)
manager._client = None
manager._granted_lock_type = None
_tag_states.pop(tag, None)
state = None
if state is not None:
current = state.manager.granted_lock_type
if mode == RequestedLockType.RW and current != GrantedLockType.RW:
raise RuntimeError(
f"Cannot get RW allocator for tag {tag}: existing is in {current} mode"
)
if mode == RequestedLockType.RO and current != GrantedLockType.RO:
raise RuntimeError(
f"Cannot get RO allocator for tag {tag}: existing is in {current} mode"
)
return state.manager
manager = GMSClientMemoryManager(socket_path, device=device)
manager.connect(mode, timeout_ms=timeout_ms)
mem_pool = None
if manager.granted_lock_type == GrantedLockType.RW:
pool = _ensure_callbacks_initialized()
# Only set globals after mempool succeeds (avoids partial singleton)
_manager = manager
_tag = tag
_mem_pool = pool
logger.info("[GMS] Created RW allocator (device=%d)", device)
return manager, pool
else:
_manager = manager
_tag = tag
logger.info("[GMS] Created RO allocator (device=%d)", device)
return manager, None
def _get_existing(
mode: RequestedLockType,
) -> Tuple["GMSClientMemoryManager", Optional["MemPool"]]:
"""Return existing allocator if mode-compatible."""
assert _manager is not None
current = _manager.granted_lock_type
_ensure_callbacks_initialized()
mem_pool = _create_mem_pool()
_tag_states[tag] = _TagState(
manager=manager,
mem_pool=mem_pool,
socket_path=socket_path,
device=device,
)
logger.info(
"[GMS] Created %s allocator for tag=%s (device=%d)",
manager.granted_lock_type.value,
tag,
device,
)
return manager
def get_gms_client_memory_manager(
tag: str = "weights",
) -> "GMSClientMemoryManager | None":
state = _tag_states.get(tag)
if state is None:
return None
return state.manager
def get_gms_client_memory_managers() -> tuple["GMSClientMemoryManager", ...]:
return tuple(state.manager for state in _tag_states.values())
if mode == RequestedLockType.RW:
if current == GrantedLockType.RW:
return _manager, _mem_pool
raise RuntimeError(f"Cannot get RW allocator: existing is in {current} mode")
def evict_gms_client_memory_manager(manager: "GMSClientMemoryManager") -> None:
for tag, state in list(_tag_states.items()):
if state.manager is manager:
_tag_states.pop(tag, None)
return
if mode == RequestedLockType.RO:
if current == GrantedLockType.RO:
return _manager, None
raise RuntimeError(f"Cannot get RO allocator: existing is in {current} mode")
# RW_OR_RO: return whatever exists
effective_pool = _mem_pool if current == GrantedLockType.RW else None
return _manager, effective_pool
@contextmanager
def gms_use_mem_pool(tag: str, device: "torch.device | int") -> Iterator[None]:
import torch
state = _tag_states.get(tag)
if state is None:
raise RuntimeError(f"No GMS allocator initialized for tag={tag}")
if state.mem_pool is None:
raise RuntimeError(f"GMS allocator tag={tag} does not have a mempool")
def get_gms_client_memory_manager() -> Optional["GMSClientMemoryManager"]:
"""Get the active GMS client memory manager, or None."""
return _manager
token = _active_tag.set(tag)
try:
with torch.cuda.use_mem_pool(state.mem_pool, device=device):
yield
finally:
_active_tag.reset(token)
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""CUDA driver helpers shared by the GMS client and server."""
from __future__ import annotations
import atexit
import os
from cuda.bindings import driver as cuda
from gpu_memory_service.common.types import GrantedLockType
from gpu_memory_service.common.utils import fail
_primary_contexts: dict[int, object] = {}
_primary_context_release_registered = False
def cuda_check_result(result: cuda.CUresult, name: str) -> None:
if result != cuda.CUresult.CUDA_SUCCESS:
err_result, err_str = cuda.cuGetErrorString(result)
if err_result == cuda.CUresult.CUDA_SUCCESS and err_str:
err_msg = err_str.decode() if isinstance(err_str, bytes) else str(err_str)
else:
err_msg = str(result)
fail("fatal CUDA VMM error in %s: %s", name, err_msg)
def cuda_ensure_initialized() -> None:
(result,) = cuda.cuInit(0)
cuda_check_result(result, "cuInit")
def cumem_get_allocation_granularity(device: int) -> int:
"""Get VMM allocation granularity for a device.
Args:
device: CUDA device index
Returns:
Allocation granularity in bytes (typically 2 MiB)
"""
prop = cuda.CUmemAllocationProp()
prop.type = cuda.CUmemAllocationType.CU_MEM_ALLOCATION_TYPE_PINNED
prop.location.type = cuda.CUmemLocationType.CU_MEM_LOCATION_TYPE_DEVICE
prop.location.id = device
prop.requestedHandleTypes = (
cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR
)
result, granularity = cuda.cuMemGetAllocationGranularity(
prop, cuda.CUmemAllocationGranularity_flags.CU_MEM_ALLOC_GRANULARITY_MINIMUM
)
cuda_check_result(result, "cuMemGetAllocationGranularity")
return int(granularity)
def cumem_create_tolerate_oom(size: int, device: int) -> tuple[bool, int]:
prop = cuda.CUmemAllocationProp()
prop.type = cuda.CUmemAllocationType.CU_MEM_ALLOCATION_TYPE_PINNED
prop.location.type = cuda.CUmemLocationType.CU_MEM_LOCATION_TYPE_DEVICE
prop.location.id = device
prop.requestedHandleTypes = (
cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR
)
result, handle = cuda.cuMemCreate(size, prop, 0)
if result == cuda.CUresult.CUDA_SUCCESS:
return True, int(handle)
if result == cuda.CUresult.CUDA_ERROR_OUT_OF_MEMORY:
return False, 0
cuda_check_result(result, "cuMemCreate")
return False, 0
def cumem_export_to_shareable_handle(handle: int) -> int:
result, fd = cuda.cuMemExportToShareableHandle(
handle,
cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR,
0,
)
cuda_check_result(result, "cuMemExportToShareableHandle")
return int(fd)
def align_to_granularity(size: int, granularity: int) -> int:
"""Align size up to VMM granularity.
Args:
size: Size in bytes
granularity: Allocation granularity
Returns:
Aligned size
"""
return ((size + granularity - 1) // granularity) * granularity
def cumem_import_from_shareable_handle_close_fd(fd: int) -> int:
try:
result, handle = cuda.cuMemImportFromShareableHandle(
fd,
cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR,
)
cuda_check_result(result, "cuMemImportFromShareableHandle")
return int(handle)
finally:
os.close(fd)
def cumem_address_reserve(size: int, granularity: int) -> int:
result, va = cuda.cuMemAddressReserve(size, granularity, 0, 0)
cuda_check_result(result, "cuMemAddressReserve")
return int(va)
def cumem_address_free(va: int, size: int) -> None:
(result,) = cuda.cuMemAddressFree(va, size)
cuda_check_result(result, "cuMemAddressFree")
def cumem_map(va: int, size: int, handle: int) -> None:
(result,) = cuda.cuMemMap(va, size, 0, handle, 0)
cuda_check_result(result, "cuMemMap")
def cumem_set_access(va: int, size: int, device: int, access: GrantedLockType) -> None:
access_desc = cuda.CUmemAccessDesc()
access_desc.location.type = cuda.CUmemLocationType.CU_MEM_LOCATION_TYPE_DEVICE
access_desc.location.id = device
access_desc.flags = (
cuda.CUmemAccess_flags.CU_MEM_ACCESS_FLAGS_PROT_READ
if access == GrantedLockType.RO
else cuda.CUmemAccess_flags.CU_MEM_ACCESS_FLAGS_PROT_READWRITE
)
(result,) = cuda.cuMemSetAccess(va, size, [access_desc], 1)
cuda_check_result(result, "cuMemSetAccess")
def cumem_unmap(va: int, size: int) -> None:
(result,) = cuda.cuMemUnmap(va, size)
cuda_check_result(result, "cuMemUnmap")
def cumem_release(handle: int) -> None:
(result,) = cuda.cuMemRelease(handle)
cuda_check_result(result, "cuMemRelease")
def cuda_validate_pointer(va: int) -> None:
result, _ = cuda.cuPointerGetAttribute(
cuda.CUpointer_attribute.CU_POINTER_ATTRIBUTE_DEVICE_POINTER, va
)
cuda_check_result(result, "cuPointerGetAttribute")
def cuda_synchronize() -> None:
(result,) = cuda.cuCtxSynchronize()
cuda_check_result(result, "cuCtxSynchronize")
def cuda_set_current_device(device: int) -> None:
global _primary_context_release_registered
ctx = _primary_contexts.get(device)
if ctx is None:
result, ctx = cuda.cuDevicePrimaryCtxRetain(device)
cuda_check_result(result, "cuDevicePrimaryCtxRetain")
_primary_contexts[device] = ctx
if not _primary_context_release_registered:
_primary_context_release_registered = True
atexit.register(_release_primary_contexts)
(result,) = cuda.cuCtxSetCurrent(ctx)
cuda_check_result(result, "cuCtxSetCurrent")
def _release_primary_contexts() -> None:
for device in list(_primary_contexts):
try:
(result,) = cuda.cuDevicePrimaryCtxRelease(device)
except Exception:
continue
if result == cuda.CUresult.CUDA_SUCCESS:
_primary_contexts.pop(device, None)
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""CUDA Virtual Memory Management (VMM) utility functions.
This module provides utility functions for CUDA driver API operations
used by both server (GMSServerMemoryManager) and client (GMSClientMemoryManager).
"""
from cuda.bindings import driver as cuda
def check_cuda_result(result: cuda.CUresult, name: str) -> None:
"""Check CUDA driver API result and raise on error.
Args:
result: CUDA driver API return code (CUresult enum)
name: Operation name for error message
Raises:
RuntimeError: If result is not CUDA_SUCCESS
"""
if result != cuda.CUresult.CUDA_SUCCESS:
err_result, err_str = cuda.cuGetErrorString(result)
if err_result == cuda.CUresult.CUDA_SUCCESS and err_str:
err_msg = err_str.decode() if isinstance(err_str, bytes) else str(err_str)
else:
err_msg = str(result)
raise RuntimeError(f"{name}: {err_msg}")
def ensure_cuda_initialized() -> None:
"""Ensure CUDA driver is initialized.
Raises:
RuntimeError: If cuInit fails
"""
(result,) = cuda.cuInit(0)
check_cuda_result(result, "cuInit")
def get_allocation_granularity(device: int) -> int:
"""Get VMM allocation granularity for a device.
Args:
device: CUDA device index
Returns:
Allocation granularity in bytes (typically 2 MiB)
"""
prop = cuda.CUmemAllocationProp()
prop.type = cuda.CUmemAllocationType.CU_MEM_ALLOCATION_TYPE_PINNED
prop.location.type = cuda.CUmemLocationType.CU_MEM_LOCATION_TYPE_DEVICE
prop.location.id = device
prop.requestedHandleTypes = (
cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR
)
result, granularity = cuda.cuMemGetAllocationGranularity(
prop, cuda.CUmemAllocationGranularity_flags.CU_MEM_ALLOC_GRANULARITY_MINIMUM
)
check_cuda_result(result, "cuMemGetAllocationGranularity")
return int(granularity)
def align_to_granularity(size: int, granularity: int) -> int:
"""Align size up to VMM granularity.
Args:
size: Size in bytes
granularity: Allocation granularity
Returns:
Aligned size
"""
return ((size + granularity - 1) // granularity) * granularity
......@@ -4,7 +4,7 @@
"""Message types for GPU Memory Service RPC protocol."""
from enum import Enum
from typing import Any, Dict, List, Optional, Union
from typing import List, Optional, Union
import msgspec
......@@ -62,7 +62,6 @@ class GetAllocationStateRequest(msgspec.Struct, tag="get_allocation_state_reques
class GetAllocationStateResponse(msgspec.Struct, tag="get_allocation_state_response"):
allocation_count: int
total_bytes: int
class AllocateRequest(msgspec.Struct, tag="allocate_request"):
......@@ -74,12 +73,21 @@ class AllocateResponse(msgspec.Struct, tag="allocate_response"):
allocation_id: str
size: int
aligned_size: int
layout_slot: int
class ExportRequest(msgspec.Struct, tag="export_request"):
class ExportAllocationRequest(msgspec.Struct, tag="export_allocation_request"):
allocation_id: str
class ExportAllocationResponse(msgspec.Struct, tag="export_allocation_response"):
allocation_id: str
size: int
aligned_size: int
tag: str
layout_slot: int
class GetAllocationRequest(msgspec.Struct, tag="get_allocation_request"):
allocation_id: str
......@@ -89,6 +97,7 @@ class GetAllocationResponse(msgspec.Struct, tag="get_allocation_response"):
size: int
aligned_size: int
tag: str
layout_slot: int
class ListAllocationsRequest(msgspec.Struct, tag="list_allocations_request"):
......@@ -96,25 +105,17 @@ class ListAllocationsRequest(msgspec.Struct, tag="list_allocations_request"):
class ListAllocationsResponse(msgspec.Struct, tag="list_allocations_response"):
allocations: List[Dict[str, Any]] = []
allocations: List[GetAllocationResponse] = []
class FreeRequest(msgspec.Struct, tag="free_request"):
class FreeAllocationRequest(msgspec.Struct, tag="free_allocation_request"):
allocation_id: str
class FreeResponse(msgspec.Struct, tag="free_response"):
class FreeAllocationResponse(msgspec.Struct, tag="free_allocation_response"):
success: bool
class ClearAllRequest(msgspec.Struct, tag="clear_all_request"):
pass
class ClearAllResponse(msgspec.Struct, tag="clear_all_response"):
cleared_count: int
class ErrorResponse(msgspec.Struct, tag="error_response"):
error: str
code: int = 0
......@@ -166,6 +167,34 @@ class GetStateHashResponse(msgspec.Struct, tag="get_memory_layout_hash_response"
memory_layout_hash: str # Hash of allocations + metadata, empty if not committed
class GetRuntimeStateRequest(msgspec.Struct, tag="get_runtime_state_request"):
pass
class GetRuntimeStateResponse(msgspec.Struct, tag="get_runtime_state_response"):
state: str
has_rw_session: bool
ro_session_count: int
waiting_writers: int
committed: bool
is_ready: bool
allocation_count: int = 0
memory_layout_hash: str = ""
class GMSRuntimeEvent(msgspec.Struct):
kind: str
allocation_count: int = 0
class GetEventHistoryRequest(msgspec.Struct, tag="get_event_history_request"):
pass
class GetEventHistoryResponse(msgspec.Struct, tag="get_event_history_response"):
events: List[GMSRuntimeEvent] = []
Message = Union[
HandshakeRequest,
HandshakeResponse,
......@@ -177,15 +206,14 @@ Message = Union[
GetAllocationStateResponse,
AllocateRequest,
AllocateResponse,
ExportRequest,
ExportAllocationRequest,
ExportAllocationResponse,
GetAllocationRequest,
GetAllocationResponse,
ListAllocationsRequest,
ListAllocationsResponse,
FreeRequest,
FreeResponse,
ClearAllRequest,
ClearAllResponse,
FreeAllocationRequest,
FreeAllocationResponse,
ErrorResponse,
MetadataPutRequest,
MetadataPutResponse,
......@@ -197,6 +225,10 @@ Message = Union[
MetadataListResponse,
GetStateHashRequest,
GetStateHashResponse,
GetRuntimeStateRequest,
GetRuntimeStateResponse,
GetEventHistoryRequest,
GetEventHistoryResponse,
]
_encoder = msgspec.msgpack.Encoder()
......
......@@ -96,7 +96,11 @@ async def recv_message(
raw_msg, fds, _flags, _addr = await loop.run_in_executor(
None, lambda: socket.recv_fds(raw_sock, 65536, 1)
)
for extra_fd in fds[1:]:
os.close(extra_fd)
if not raw_msg:
if fds:
os.close(fds[0])
raise ConnectionResetError("Connection closed")
recv_buffer.extend(raw_msg)
fd = fds[0] if fds else -1
......@@ -107,21 +111,25 @@ async def recv_message(
recv_buffer.extend(chunk)
# Try to extract message, read more if needed
msg, remaining, bytes_needed = _try_extract_message(recv_buffer)
while msg is None and bytes_needed > 0:
if raw_sock is not None:
# Continue reading from raw socket to avoid buffer inconsistency
chunk = await loop.run_in_executor(
None, lambda n=bytes_needed: raw_sock.recv(n)
)
else:
chunk = await reader.read(bytes_needed)
if not chunk:
raise ConnectionResetError("Connection closed")
remaining.extend(chunk)
msg, remaining, bytes_needed = _try_extract_message(remaining)
return msg, fd, remaining
try:
msg, remaining, bytes_needed = _try_extract_message(recv_buffer)
while msg is None and bytes_needed > 0:
if raw_sock is not None:
# Continue reading from raw socket to avoid buffer inconsistency
chunk = await loop.run_in_executor(
None, lambda n=bytes_needed: raw_sock.recv(n)
)
else:
chunk = await reader.read(bytes_needed)
if not chunk:
raise ConnectionResetError("Connection closed")
remaining.extend(chunk)
msg, remaining, bytes_needed = _try_extract_message(remaining)
return msg, fd, remaining
except Exception:
if fd >= 0:
os.close(fd)
raise
# ==================== Sync (for client) ====================
......@@ -153,18 +161,26 @@ def recv_message_sync(
# Receive more data (with potential FD)
raw_msg, fds, _flags, _addr = socket.recv_fds(sock, 65536, 1)
for extra_fd in fds[1:]:
os.close(extra_fd)
if not raw_msg:
if fds:
os.close(fds[0])
raise ConnectionResetError("Connection closed")
recv_buffer.extend(raw_msg)
fd = fds[0] if fds else -1
# Try to extract message, read more if needed
msg, remaining, bytes_needed = _try_extract_message(recv_buffer)
while msg is None and bytes_needed > 0:
chunk = sock.recv(bytes_needed)
if not chunk:
raise ConnectionResetError("Connection closed")
remaining.extend(chunk)
msg, remaining, bytes_needed = _try_extract_message(remaining)
return msg, fd, remaining
try:
msg, remaining, bytes_needed = _try_extract_message(recv_buffer)
while msg is None and bytes_needed > 0:
chunk = sock.recv(bytes_needed)
if not chunk:
raise ConnectionResetError("Connection closed")
remaining.extend(chunk)
msg, remaining, bytes_needed = _try_extract_message(remaining)
return msg, fd, remaining
except Exception:
if fd >= 0:
os.close(fd)
raise
......@@ -8,10 +8,9 @@ from enum import Enum, auto
from gpu_memory_service.common.protocol.messages import (
AllocateRequest,
ClearAllRequest,
CommitRequest,
ExportRequest,
FreeRequest,
ExportAllocationRequest,
FreeAllocationRequest,
GetAllocationRequest,
GetAllocationStateRequest,
GetLockStateRequest,
......@@ -89,8 +88,7 @@ def derive_state(has_rw: bool, ro_count: int, committed: bool) -> ServerState:
RW_REQUIRED: frozenset[type] = frozenset(
{
AllocateRequest,
FreeRequest,
ClearAllRequest,
FreeAllocationRequest,
MetadataPutRequest,
MetadataDeleteRequest,
CommitRequest,
......@@ -99,7 +97,7 @@ RW_REQUIRED: frozenset[type] = frozenset(
RO_ALLOWED: frozenset[type] = frozenset(
{
ExportRequest,
ExportAllocationRequest,
GetAllocationRequest,
ListAllocationsRequest,
MetadataGetRequest,
......
......@@ -3,36 +3,39 @@
"""Shared utilities for GPU Memory Service."""
import logging
import os
import tempfile
import uuid
from typing import NoReturn
from cuda.bindings import driver as cuda
from gpu_memory_service.common.cuda_vmm_utils import (
check_cuda_result,
ensure_cuda_initialized,
)
logger = logging.getLogger(__name__)
def get_socket_path(device: int) -> str:
"""Get GMS socket path for the given CUDA device.
def fail(message: str, *args, exc_info=None) -> NoReturn:
logger.critical(message, *args, exc_info=exc_info)
logging.shutdown()
os._exit(1)
The socket path is based on GPU UUID resolved by CUDA.
CUDA_VISIBLE_DEVICES remapping is handled by CUDA device enumeration.
def get_socket_path(device: int, tag: str = "weights") -> str:
"""Get GMS socket path for the given CUDA device and tag.
The socket path is based on GPU UUID, making it stable across different
CUDA_VISIBLE_DEVICES configurations.
Args:
device: CUDA device index.
Returns:
Socket path (e.g., "<tempdir>/gms_GPU-12345678-1234-1234-1234-123456789abc.sock").
Socket path
(e.g., "<tempdir>/gms_GPU-12345678-1234-1234-1234-123456789abc_weights.sock").
"""
ensure_cuda_initialized()
result, cu_device = cuda.cuDeviceGet(device)
check_cuda_result(result, "cuDeviceGet")
result, cu_uuid = cuda.cuDeviceGetUuid(cu_device)
check_cuda_result(result, "cuDeviceGetUuid")
gpu_uuid = f"GPU-{uuid.UUID(bytes=bytes(cu_uuid.bytes))}"
return os.path.join(tempfile.gettempdir(), f"gms_{gpu_uuid}.sock")
import pynvml
pynvml.nvmlInit()
try:
handle = pynvml.nvmlDeviceGetHandleByIndex(device)
uuid = pynvml.nvmlDeviceGetUUID(handle)
finally:
pynvml.nvmlShutdown()
return os.path.join(tempfile.gettempdir(), f"gms_{uuid}_{tag}.sock")
......@@ -8,7 +8,7 @@ from __future__ import annotations
import logging
import torch
from gpu_memory_service import get_gms_client_memory_manager
from gpu_memory_service.client.torch.allocator import get_gms_client_memory_managers
logger = logging.getLogger(__name__)
......@@ -32,11 +32,13 @@ def patch_empty_cache() -> None:
_original_empty_cache = torch.cuda.empty_cache
def safe_empty_cache() -> None:
manager = get_gms_client_memory_manager()
if manager is not None and len(manager.mappings) > 0:
mapping_count = sum(
len(manager.mappings) for manager in get_gms_client_memory_managers()
)
if mapping_count > 0:
logger.debug(
"[GMS] Skipping torch.cuda.empty_cache() - %d VMM allocations active",
len(manager.mappings),
mapping_count,
)
return
_original_empty_cache()
......
......@@ -6,6 +6,7 @@
from __future__ import annotations
import logging
from dataclasses import replace
from typing import TYPE_CHECKING
import torch
......@@ -29,6 +30,20 @@ def get_gms_lock_mode(extra_config: dict):
return RequestedLockType.RW_OR_RO
def strip_gms_model_loader_config(load_config, load_format: str):
"""Copy a loader config with GMS-only keys removed for backend loaders."""
extra_config = getattr(load_config, "model_loader_extra_config", {}) or {}
return replace(
load_config,
load_format=load_format,
model_loader_extra_config={
key: value
for key, value in extra_config.items()
if not key.startswith("gms_")
},
)
def setup_meta_tensor_workaround() -> None:
"""Enable workaround for meta tensor operations like torch.nonzero()."""
try:
......@@ -42,9 +57,9 @@ def setup_meta_tensor_workaround() -> None:
def finalize_gms_write(
allocator: "GMSClientMemoryManager", model: torch.nn.Module
) -> int:
"""Finalize GMS write mode: register tensors, commit, switch to read.
"""Finalize GMS write mode: register tensors, commit, reconnect in read mode.
Flow: register tensors -> sync -> commit (server-only) -> disconnect -> connect(RO)
Flow: register tensors -> sync -> unmap + commit -> connect(RO) -> remap
Args:
allocator: The GMS client memory manager in write mode.
......@@ -52,9 +67,6 @@ def finalize_gms_write(
Returns:
Total bytes committed.
Raises:
RuntimeError: If commit fails.
"""
from gpu_memory_service.client.torch.module import register_module_tensors
from gpu_memory_service.common.types import RequestedLockType
......@@ -65,12 +77,10 @@ def finalize_gms_write(
# Synchronize before commit — caller's writes must be visible
torch.cuda.synchronize()
if not allocator.commit():
raise RuntimeError("GMS commit failed")
allocator.commit()
# commit() closed the RW socket; acquire RO for inference
allocator.disconnect() # no-op if commit already cleared _client, but safe
allocator.connect(RequestedLockType.RO)
allocator.remap_all_vas()
logger.info(
"[GMS] Committed %.2f GiB, switched to read mode with %d mappings",
......
......@@ -3,13 +3,10 @@
"""Hybrid torch_memory_saver implementation for GPU Memory Service.
This module provides a hybrid implementation that combines:
1. GPU Memory Service allocator for "weights" tag (VA-stable unmap/remap, shared)
2. Torch mempool mode for other tags like "kv_cache" (CPU backup, per-instance)
The impl uses RW_OR_RO mode to connect to GMS:
- First process gets RW lock and loads weights from disk
- Subsequent processes get RO lock and import weights from metadata
This module uses:
1. GPU Memory Service for "weights" (shared RO/RW publish flow)
2. GPU Memory Service for "kv_cache" (RW-only failover flow)
3. torch_memory_saver for any remaining tags
"""
from __future__ import annotations
......@@ -19,10 +16,13 @@ from contextlib import contextmanager
from typing import TYPE_CHECKING, Optional
import torch
from gpu_memory_service import get_or_create_gms_client_memory_manager
from gpu_memory_service.client.torch.allocator import gms_use_mem_pool
from gpu_memory_service.common.types import GrantedLockType, RequestedLockType
from gpu_memory_service.common.utils import get_socket_path
if TYPE_CHECKING:
from gpu_memory_service.client.memory_manager import GMSClientMemoryManager
from torch.cuda.memory import MemPool
from torch_memory_saver.entrypoint import _TorchMemorySaverImpl
logger = logging.getLogger(__name__)
......@@ -39,56 +39,54 @@ def get_gms_memory_saver_impl() -> Optional["GMSMemorySaverImpl"]:
class GMSMemorySaverImpl:
"""Hybrid implementation: GMS for weights, torch mempool for KV cache.
Routes operations based on tag:
- "weights" or "model_weights": Handled by GMS allocator (VA-stable)
- Other tags (e.g., "kv_cache"): Delegated to torch mempool mode
"""
"""Hybrid implementation: GMS for weights and KV cache."""
def __init__(
self,
torch_impl: "_TorchMemorySaverImpl",
socket_path: str,
device_index: int,
mode=None,
):
self._torch_impl = torch_impl
self._socket_path = socket_path
self._device_index = device_index
self._requested_mode = mode
self._disabled = False
self._imported_weights_bytes: int = 0
self._allocator: Optional["GMSClientMemoryManager"]
self._mem_pool: Optional["MemPool"]
self._weights_allocator: Optional["GMSClientMemoryManager"]
self._kv_cache_allocator: "GMSClientMemoryManager"
self._mode: str
self._allocator, self._mem_pool, self._mode = self._init_allocator()
(
self._weights_allocator,
self._kv_cache_allocator,
self._mode,
) = self._init_allocators()
logger.info(
"[GMS] Initialized: weights=%s mode (device=%d, socket=%s)",
"[GMS] Initialized weights=%s mode, kv_cache=RW (device=%d)",
self._mode.upper(),
device_index,
socket_path,
)
def _init_allocator(
def _init_allocators(
self,
) -> tuple[Optional["GMSClientMemoryManager"], Optional["MemPool"], str]:
) -> tuple[Optional["GMSClientMemoryManager"], "GMSClientMemoryManager", str,]:
"""Create allocator with mode from config (default: RW_OR_RO)."""
from gpu_memory_service import get_or_create_gms_client_memory_manager
from gpu_memory_service.common.types import GrantedLockType, RequestedLockType
mode = self._requested_mode or RequestedLockType.RW_OR_RO
allocator, mem_pool = get_or_create_gms_client_memory_manager(
self._socket_path,
weights_allocator = get_or_create_gms_client_memory_manager(
get_socket_path(self._device_index, "weights"),
self._device_index,
mode=mode,
tag="weights",
)
granted_mode = allocator.granted_lock_type
kv_cache_allocator = get_or_create_gms_client_memory_manager(
get_socket_path(self._device_index, "kv_cache"),
self._device_index,
mode=RequestedLockType.RW,
tag="kv_cache",
)
granted_mode = weights_allocator.granted_lock_type
if granted_mode == GrantedLockType.RW:
allocator.clear_all_handles()
actual_mode = "write"
else:
actual_mode = "read"
......@@ -97,11 +95,7 @@ class GMSMemorySaverImpl:
actual_mode.upper(),
self._device_index,
)
return (
allocator,
mem_pool if granted_mode == GrantedLockType.RW else None,
actual_mode,
)
return weights_allocator, kv_cache_allocator, actual_mode
def _is_weights_tag(self, tag: Optional[str]) -> bool:
return tag in ("weights", "model_weights")
......@@ -110,25 +104,28 @@ class GMSMemorySaverImpl:
return self._mode
def get_allocator(self) -> Optional["GMSClientMemoryManager"]:
return self._allocator
return self._weights_allocator
@contextmanager
def region(self, tag: str, enable_cpu_backup: bool):
"""Mark allocation region with tag."""
if not self._is_weights_tag(tag):
with self._torch_impl.region(tag=tag, enable_cpu_backup=enable_cpu_backup):
if self._is_weights_tag(tag):
if self._mode == "read":
yield
return
return
if self._mode == "read":
yield
target_device = torch.device("cuda", self._device_index)
with gms_use_mem_pool("weights", target_device):
yield
return
if self._mem_pool is None:
raise RuntimeError("GMS mempool is None in WRITE mode")
if tag == "kv_cache":
target_device = torch.device("cuda", self._device_index)
with gms_use_mem_pool("kv_cache", target_device):
yield
return
target_device = torch.device("cuda", self._device_index)
with torch.cuda.use_mem_pool(self._mem_pool, device=target_device):
with self._torch_impl.region(tag=tag, enable_cpu_backup=enable_cpu_backup):
yield
def pause(self, tag: Optional[str] = None) -> None:
......@@ -136,7 +133,9 @@ class GMSMemorySaverImpl:
return
if tag is None or self._is_weights_tag(tag):
self._pause_weights()
if tag is None or not self._is_weights_tag(tag):
if tag is None or tag == "kv_cache":
self._pause_kv_cache()
if tag is None or (not self._is_weights_tag(tag) and tag != "kv_cache"):
self._torch_impl.pause(tag=tag)
def resume(self, tag: Optional[str] = None) -> None:
......@@ -144,39 +143,56 @@ class GMSMemorySaverImpl:
return
if tag is None or self._is_weights_tag(tag):
self._resume_weights()
if tag is None or not self._is_weights_tag(tag):
if tag is None or tag == "kv_cache":
self._resume_kv_cache()
if tag is None or (not self._is_weights_tag(tag) and tag != "kv_cache"):
self._torch_impl.resume(tag=tag)
def _pause_weights(self) -> None:
if self._allocator is None:
if self._weights_allocator is None:
return
if self._allocator.is_unmapped:
if self._weights_allocator.is_unmapped:
return
logger.info("[GMS] Unmapping weights (VA-stable)")
self._allocator.unmap_all_vas()
self._allocator.disconnect()
self._weights_allocator.unmap_all_vas()
self._weights_allocator.abort()
def _resume_weights(self) -> None:
if self._allocator is None:
if self._weights_allocator is None:
return
if not self._allocator.is_unmapped:
if not self._weights_allocator.is_unmapped:
return
logger.info("[GMS] Remapping weights (VA-stable)")
from gpu_memory_service.common.types import RequestedLockType
self._weights_allocator.connect(RequestedLockType.RO)
self._weights_allocator.remap_all_vas()
self._allocator.connect(RequestedLockType.RO)
self._allocator.remap_all_vas()
def _pause_kv_cache(self) -> None:
if self._kv_cache_allocator.is_unmapped:
return
logger.info("[GMS] Unmapping KV cache")
self._kv_cache_allocator.unmap_all_vas()
self._kv_cache_allocator.abort()
def _resume_kv_cache(self) -> None:
if not self._kv_cache_allocator.is_unmapped:
return
logger.info("[GMS] Remapping KV cache")
self._kv_cache_allocator.connect(RequestedLockType.RW)
self._kv_cache_allocator.reallocate_all_handles(tag="kv_cache")
self._kv_cache_allocator.remap_all_vas()
def finalize_write_mode(self, model: torch.nn.Module) -> None:
"""Finalize write mode: register tensors, commit, and switch to read."""
if self._mode != "write":
return
if self._allocator is None:
if self._weights_allocator is None:
raise RuntimeError("Allocator is None in WRITE mode")
from gpu_memory_service.integrations.common.utils import finalize_gms_write
self._imported_weights_bytes = finalize_gms_write(self._allocator, model)
self._imported_weights_bytes = finalize_gms_write(
self._weights_allocator, model
)
self._mode = "read"
def set_imported_weights_bytes(self, bytes_count: int) -> None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment