Unverified Commit 91a8c6af authored by Schwinn Saereesitthipitak's avatar Schwinn Saereesitthipitak Committed by GitHub
Browse files

fix: simplify GMS layout state and harden GPU-backed flows (#7006)


Signed-off-by: default avatarmohammedabdulwahhab <furkhan324@berkeley.edu>
Signed-off-by: default avatarHannah Zhang <hannahz@nvidia.com>
Co-authored-by: default avatarmohammedabdulwahhab <furkhan324@berkeley.edu>
Co-authored-by: default avatarhhzhang16 <54051230+hhzhang16@users.noreply.github.com>
parent dd7ceb4a
This diff is collapsed.
...@@ -32,7 +32,9 @@ from gpu_memory_service.client.memory_manager import ( ...@@ -32,7 +32,9 @@ from gpu_memory_service.client.memory_manager import (
# PyTorch integration (GMS client memory manager) # PyTorch integration (GMS client memory manager)
from gpu_memory_service.client.torch.allocator import ( from gpu_memory_service.client.torch.allocator import (
get_gms_client_memory_manager, get_gms_client_memory_manager,
get_gms_client_memory_managers,
get_or_create_gms_client_memory_manager, get_or_create_gms_client_memory_manager,
gms_use_mem_pool,
) )
__all__ = [ __all__ = [
...@@ -42,4 +44,6 @@ __all__ = [ ...@@ -42,4 +44,6 @@ __all__ = [
# GMS client memory manager # GMS client memory manager
"get_or_create_gms_client_memory_manager", "get_or_create_gms_client_memory_manager",
"get_gms_client_memory_manager", "get_gms_client_memory_manager",
"get_gms_client_memory_managers",
"gms_use_mem_pool",
] ]
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
import argparse import argparse
import logging import logging
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional
from gpu_memory_service.common.utils import get_socket_path from gpu_memory_service.common.utils import get_socket_path
...@@ -17,7 +18,10 @@ class Config: ...@@ -17,7 +18,10 @@ class Config:
"""Configuration for GPU Memory Service server.""" """Configuration for GPU Memory Service server."""
device: int device: int
tag: str
socket_path: str socket_path: str
alloc_retry_interval: float
alloc_retry_timeout: Optional[float]
verbose: bool verbose: bool
...@@ -33,6 +37,12 @@ def parse_args() -> Config: ...@@ -33,6 +37,12 @@ def parse_args() -> Config:
required=True, required=True,
help="CUDA device ID to manage memory for.", help="CUDA device ID to manage memory for.",
) )
parser.add_argument(
"--tag",
type=str,
default="weights",
help="Logical GMS tag for this server (default: weights).",
)
parser.add_argument( parser.add_argument(
"--socket-path", "--socket-path",
type=str, type=str,
...@@ -45,14 +55,33 @@ def parse_args() -> Config: ...@@ -45,14 +55,33 @@ def parse_args() -> Config:
action="store_true", action="store_true",
help="Enable verbose logging.", help="Enable verbose logging.",
) )
parser.add_argument(
"--alloc-retry-interval",
type=float,
default=0.5,
help="Seconds to sleep between allocation retries on CUDA OOM (default: 0.5).",
)
parser.add_argument(
"--alloc-retry-timeout",
type=float,
default=None,
help="Optional max seconds to wait for allocation retries before failing (default: wait indefinitely).",
)
args = parser.parse_args() args = parser.parse_args()
# Use UUID-based socket path by default (stable across CUDA_VISIBLE_DEVICES) # Use UUID-based socket path by default (stable across CUDA_VISIBLE_DEVICES)
socket_path = args.socket_path or get_socket_path(args.device) socket_path = args.socket_path or get_socket_path(args.device, args.tag)
if args.alloc_retry_interval <= 0:
parser.error("--alloc-retry-interval must be > 0")
if args.alloc_retry_timeout is not None and args.alloc_retry_timeout <= 0:
parser.error("--alloc-retry-timeout must be > 0 when set")
return Config( return Config(
device=args.device, device=args.device,
tag=args.tag,
socket_path=socket_path, socket_path=socket_path,
alloc_retry_interval=args.alloc_retry_interval,
alloc_retry_timeout=args.alloc_retry_timeout,
verbose=args.verbose, verbose=args.verbose,
) )
...@@ -13,7 +13,6 @@ Usage: ...@@ -13,7 +13,6 @@ Usage:
import asyncio import asyncio
import logging import logging
import signal
import uvloop import uvloop
from gpu_memory_service.server import GMSRPCServer from gpu_memory_service.server import GMSRPCServer
...@@ -37,33 +36,28 @@ async def worker() -> None: ...@@ -37,33 +36,28 @@ async def worker() -> None:
logging.getLogger("gpu_memory_service").setLevel(logging.DEBUG) logging.getLogger("gpu_memory_service").setLevel(logging.DEBUG)
logger.info(f"Starting GPU Memory Service Server for device {config.device}") logger.info(f"Starting GPU Memory Service Server for device {config.device}")
logger.info("GMS tag: %s", config.tag)
logger.info(f"Socket path: {config.socket_path}") logger.info(f"Socket path: {config.socket_path}")
logger.info(
server = GMSRPCServer(config.socket_path, device=config.device) "Allocation retry config: interval=%ss timeout=%s",
config.alloc_retry_interval,
# Set up shutdown handling (
shutdown_event = asyncio.Event() f"{config.alloc_retry_timeout}s"
if config.alloc_retry_timeout is not None
def signal_handler(): else "none"
logger.info("Received shutdown signal") ),
shutdown_event.set() )
loop = asyncio.get_running_loop() server = GMSRPCServer(
for sig in (signal.SIGTERM, signal.SIGINT): config.socket_path,
loop.add_signal_handler(sig, signal_handler) device=config.device,
allocation_retry_interval=config.alloc_retry_interval,
await server.start() allocation_retry_timeout=config.alloc_retry_timeout,
)
logger.info("GPU Memory Service Server ready, waiting for connections...") logger.info("GPU Memory Service Server ready, waiting for connections...")
logger.info(f"Clients can connect via socket: {config.socket_path}") logger.info(f"Clients can connect via socket: {config.socket_path}")
await server.serve()
# Wait for shutdown signal
try:
await shutdown_event.wait()
finally:
logger.info("Shutting down GPU Memory Service Server...")
await server.stop()
logger.info("GPU Memory Service Server shutdown complete")
def main() -> None: def main() -> None:
......
...@@ -7,7 +7,6 @@ This module provides the client-side components for interacting with the ...@@ -7,7 +7,6 @@ This module provides the client-side components for interacting with the
GPU Memory Service: GPU Memory Service:
- GMSClientMemoryManager: Manages local VA mappings of remote GPU memory - GMSClientMemoryManager: Manages local VA mappings of remote GPU memory
- GMSRPCClient: Low-level RPC client (pure Python, no PyTorch dependency)
For PyTorch integration (MemPool, tensor utilities), see gpu_memory_service.client.torch. For PyTorch integration (MemPool, tensor utilities), see gpu_memory_service.client.torch.
""" """
...@@ -16,10 +15,8 @@ from gpu_memory_service.client.memory_manager import ( ...@@ -16,10 +15,8 @@ from gpu_memory_service.client.memory_manager import (
GMSClientMemoryManager, GMSClientMemoryManager,
StaleMemoryLayoutError, StaleMemoryLayoutError,
) )
from gpu_memory_service.client.rpc import GMSRPCClient
__all__ = [ __all__ = [
"GMSClientMemoryManager", "GMSClientMemoryManager",
"StaleMemoryLayoutError", "StaleMemoryLayoutError",
"GMSRPCClient",
] ]
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Client-side CUDA VMM utilities.
These functions wrap CUDA driver API calls used by the client memory manager
for importing, mapping, and unmapping GPU memory.
"""
from __future__ import annotations
import os
from cuda.bindings import driver as cuda
from gpu_memory_service.common.cuda_vmm_utils import check_cuda_result
from gpu_memory_service.common.types import GrantedLockType
def import_handle_from_fd(fd: int) -> int:
"""Import a CUDA memory handle from a file descriptor.
Closes the FD after import — the imported handle holds its own reference
to the physical allocation. Leaving the FD open leaks a DMA-buf ref that
prevents cuMemRelease from freeing GPU memory.
Args:
fd: POSIX file descriptor received via SCM_RIGHTS.
Returns:
CUDA memory handle.
"""
try:
result, handle = cuda.cuMemImportFromShareableHandle(
fd,
cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR,
)
check_cuda_result(result, "cuMemImportFromShareableHandle")
return int(handle)
finally:
os.close(fd)
def reserve_va(size: int, granularity: int) -> int:
"""Reserve virtual address space.
Args:
size: Size in bytes (should be aligned to granularity).
granularity: VMM allocation granularity.
Returns:
Reserved virtual address.
"""
result, va = cuda.cuMemAddressReserve(size, granularity, 0, 0)
check_cuda_result(result, "cuMemAddressReserve")
return int(va)
def free_va(va: int, size: int) -> None:
"""Free a virtual address reservation.
Args:
va: Virtual address to free.
size: Size of the reservation.
"""
(result,) = cuda.cuMemAddressFree(va, size)
check_cuda_result(result, "cuMemAddressFree")
def map_to_va(va: int, size: int, handle: int) -> None:
"""Map a CUDA handle to a virtual address.
Args:
va: Virtual address (must be reserved).
size: Size of the mapping.
handle: CUDA memory handle.
"""
(result,) = cuda.cuMemMap(va, size, 0, handle, 0)
check_cuda_result(result, "cuMemMap")
def set_access(va: int, size: int, device: int, access: GrantedLockType) -> None:
"""Set access permissions for a mapped region.
Args:
va: Virtual address.
size: Size of the region.
device: CUDA device index.
access: Access mode - RO for read-only, RW for read-write.
"""
acc = cuda.CUmemAccessDesc()
acc.location.type = cuda.CUmemLocationType.CU_MEM_LOCATION_TYPE_DEVICE
acc.location.id = device
acc.flags = (
cuda.CUmemAccess_flags.CU_MEM_ACCESS_FLAGS_PROT_READ
if access == GrantedLockType.RO
else cuda.CUmemAccess_flags.CU_MEM_ACCESS_FLAGS_PROT_READWRITE
)
(result,) = cuda.cuMemSetAccess(va, size, [acc], 1)
check_cuda_result(result, "cuMemSetAccess")
def unmap(va: int, size: int) -> None:
"""Unmap a virtual address region.
Args:
va: Virtual address to unmap.
size: Size of the mapping.
"""
(result,) = cuda.cuMemUnmap(va, size)
check_cuda_result(result, "cuMemUnmap")
def release_handle(handle: int) -> None:
"""Release a CUDA memory handle.
Args:
handle: CUDA memory handle to release.
"""
(result,) = cuda.cuMemRelease(handle)
check_cuda_result(result, "cuMemRelease")
def validate_pointer(va: int) -> bool:
"""Validate that a mapped VA is accessible.
Returns True if the pointer is valid, False otherwise (logs a warning).
"""
result, _dev_ptr = cuda.cuPointerGetAttribute(
cuda.CUpointer_attribute.CU_POINTER_ATTRIBUTE_DEVICE_POINTER, va
)
if result != cuda.CUresult.CUDA_SUCCESS:
err_result, err_str = cuda.cuGetErrorString(result)
err_msg = ""
if err_result == cuda.CUresult.CUDA_SUCCESS and err_str:
err_msg = err_str.decode() if isinstance(err_str, bytes) else str(err_str)
import logging
logging.getLogger(__name__).warning(
"cuPointerGetAttribute failed for VA 0x%x: %s (%s)",
va,
result,
err_msg,
)
return False
return True
def synchronize() -> None:
"""Synchronize the current CUDA context.
Blocks until all preceding commands in the current context have completed.
"""
(result,) = cuda.cuCtxSynchronize()
check_cuda_result(result, "cuCtxSynchronize")
def set_current_device(device: int) -> None:
"""Set the current CUDA device by activating its primary context.
Args:
device: CUDA device index.
"""
result, ctx = cuda.cuDevicePrimaryCtxRetain(device)
check_cuda_result(result, "cuDevicePrimaryCtxRetain")
(result,) = cuda.cuCtxSetCurrent(ctx)
check_cuda_result(result, "cuCtxSetCurrent")
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
"""GPU Memory Service RPC Client. """Internal GPU Memory Service transport.
Low-level RPC client stub. The client provides a simple interface for acquiring This module only owns Unix socket transport and typed request/response exchange.
locks and performing allocation operations. The socket connection IS the lock. Session semantics live in `gpu_memory_service.client.session`.
This module has NO PyTorch dependency.
Usage:
# Writer (acquires RW lock in constructor)
with GMSRPCClient(socket_path, lock_type=RequestedLockType.RW) as client:
alloc_id, aligned_size = client.allocate(size=1024*1024)
fd = client.export(alloc_id)
# ... write weights using fd ...
client.commit()
# Lock released on exit
# Reader (acquires RO lock in constructor)
client = GMSRPCClient(socket_path, lock_type=RequestedLockType.RO)
if client.committed: # Check if weights are valid
allocations = client.list_allocations()
for alloc in allocations:
fd = client.export(alloc["allocation_id"])
# ... import and map fd ...
# Keep connection open during inference!
# client.close() only when done with inference
""" """
from __future__ import annotations
import logging import logging
import os
import socket import socket
from typing import Dict, List, Optional, Tuple, Type, TypeVar from typing import Optional, Tuple, Type, TypeVar
from gpu_memory_service.common.protocol.messages import ( from gpu_memory_service.common.protocol.messages import (
AllocateRequest,
AllocateResponse,
ClearAllRequest,
ClearAllResponse,
CommitRequest,
CommitResponse,
ErrorResponse, ErrorResponse,
ExportRequest,
FreeRequest,
FreeResponse,
GetAllocationRequest,
GetAllocationResponse,
GetAllocationStateRequest,
GetAllocationStateResponse,
GetLockStateRequest,
GetLockStateResponse,
GetStateHashRequest,
GetStateHashResponse,
HandshakeRequest, HandshakeRequest,
HandshakeResponse, HandshakeResponse,
ListAllocationsRequest,
ListAllocationsResponse,
MetadataDeleteRequest,
MetadataDeleteResponse,
MetadataGetRequest,
MetadataGetResponse,
MetadataListRequest,
MetadataListResponse,
MetadataPutRequest,
MetadataPutResponse,
) )
from gpu_memory_service.common.protocol.wire import recv_message_sync, send_message_sync from gpu_memory_service.common.protocol.wire import recv_message_sync, send_message_sync
from gpu_memory_service.common.types import ( from gpu_memory_service.common.types import RequestedLockType
RW_REQUIRED,
GrantedLockType,
RequestedLockType,
)
T = TypeVar("T") T = TypeVar("T")
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class GMSRPCClient: class _GMSRPCTransport:
"""GPU Memory Service RPC Client. """Raw GMS Unix socket transport."""
CRITICAL: Socket connection IS the lock.
- Constructor blocks until lock is acquired
- close() releases the lock
- committed property tells readers if weights are valid
For writers (lock_type=RequestedLockType.RW): def __init__(self, socket_path: str):
- Use context manager (with statement) for automatic lock release
- Call commit() after weights are written
- Call clear_all() before loading new model
For readers (lock_type=RequestedLockType.RO):
- Check committed property after construction
- Keep connection open during inference lifetime
- Only call close() when shutting down or allowing weight updates
"""
def __init__(
self,
socket_path: str,
lock_type: RequestedLockType = RequestedLockType.RO,
timeout_ms: Optional[int] = None,
):
"""Connect to Allocation Server and acquire lock.
Args:
socket_path: Path to server's Unix domain socket
lock_type: Requested lock type (RW, RO, or RW_OR_RO)
timeout_ms: Timeout in milliseconds for lock acquisition.
None means wait indefinitely.
Raises:
ConnectionError: If connection fails
TimeoutError: If timeout_ms expires waiting for lock
"""
self.socket_path = socket_path self.socket_path = socket_path
self._requested_lock_type = lock_type
self._socket: Optional[socket.socket] = None self._socket: Optional[socket.socket] = None
self._recv_buffer = bytearray() self._recv_buffer = bytearray()
self._committed = False
self._granted_lock_type: Optional[GrantedLockType] = None
# Connect and acquire lock @property
self._connect(timeout_ms=timeout_ms) def is_connected(self) -> bool:
return self._socket is not None
def _connect(self, timeout_ms: Optional[int]) -> None: def connect(self) -> None:
"""Connect to server and perform handshake (lock acquisition)."""
self._socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) self._socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
try: try:
self._socket.connect(self.socket_path) self._socket.connect(self.socket_path)
except FileNotFoundError: except FileNotFoundError:
self._socket.close() self._socket.close()
self._socket = None self._socket = None
raise ConnectionError(f"Server not running at {self.socket_path}") from None raise ConnectionError(
except Exception as e: f"GMS server not running at {self.socket_path}"
self._socket.close() ) from None
self._socket = None except Exception as exc:
raise ConnectionError(f"Failed to connect: {e}") from e
# Handshake I/O — clean up socket on any failure
try:
request = HandshakeRequest(
lock_type=self._requested_lock_type,
timeout_ms=timeout_ms,
)
send_message_sync(self._socket, request)
# May block waiting for lock
response, _, self._recv_buffer = recv_message_sync(
self._socket, self._recv_buffer
)
except Exception:
self._socket.close()
self._socket = None
raise
if isinstance(response, ErrorResponse):
self._socket.close()
self._socket = None
raise ConnectionError(f"Handshake error: {response.error}")
if not isinstance(response, HandshakeResponse):
self._socket.close()
self._socket = None
raise ConnectionError(f"Unexpected response: {type(response)}")
if not response.success:
self._socket.close() self._socket.close()
self._socket = None self._socket = None
raise TimeoutError("Timeout waiting for lock") raise ConnectionError(f"Failed to connect to GMS: {exc}") from exc
self._committed = response.committed
# Store granted lock type (may differ from requested for rw_or_ro mode)
if response.granted_lock_type is not None:
self._granted_lock_type = response.granted_lock_type
elif self._requested_lock_type == RequestedLockType.RW:
self._granted_lock_type = GrantedLockType.RW
else:
self._granted_lock_type = GrantedLockType.RO
logger.info(
f"Connected with {self._requested_lock_type.value} lock (granted={self._granted_lock_type.value}), "
f"committed={self._committed}"
)
@property def handshake(
def committed(self) -> bool: self,
"""Check if weights are committed (valid).""" lock_type: RequestedLockType,
return self._committed timeout_ms: Optional[int],
) -> HandshakeResponse:
@property response, _ = self.request_with_fd(
def lock_type(self) -> Optional[GrantedLockType]: HandshakeRequest(lock_type=lock_type, timeout_ms=timeout_ms),
"""Get the lock type actually granted by the server. HandshakeResponse,
error_prefix="GMS handshake",
For rw_or_ro mode, this tells you whether RW or RO was granted.
"""
return self._granted_lock_type
@property
def is_connected(self) -> bool:
"""Check if client is connected."""
return self._socket is not None
def _send_recv(self, request) -> Tuple[object, int]:
"""Send request and receive response. Returns (response, fd)."""
if not self._socket:
raise RuntimeError("Client not connected")
send_message_sync(self._socket, request)
response, fd, self._recv_buffer = recv_message_sync(
self._socket, self._recv_buffer
) )
if isinstance(response, ErrorResponse):
raise RuntimeError(f"Server error: {response.error}")
return response, fd
def _call(self, request, response_type: Type[T]) -> T:
"""Send request, validate response type, return typed response."""
if type(request) in RW_REQUIRED and self.lock_type != GrantedLockType.RW:
raise RuntimeError("Operation requires RW connection")
response, _ = self._send_recv(request)
if not isinstance(response, response_type):
raise RuntimeError(f"Unexpected response: {type(response)}")
return response return response
def get_lock_state(self) -> GetLockStateResponse: def request(self, request, response_type: Type[T]) -> T:
return self._call(GetLockStateRequest(), GetLockStateResponse) response, fd = self.request_with_fd(request, response_type)
if fd >= 0:
def get_allocation_state(self) -> GetAllocationStateResponse: os.close(fd)
return self._call(GetAllocationStateRequest(), GetAllocationStateResponse) raise RuntimeError(
f"GMS request {type(request).__name__} returned an unexpected FD"
)
return response
def is_ready(self) -> bool: def request_with_fd(
return self.committed self,
request,
response_type: Type[T],
*,
error_prefix: Optional[str] = None,
) -> Tuple[T, int]:
response, fd = self._send_recv(request, error_prefix=error_prefix)
if not isinstance(response, response_type):
prefix = error_prefix or f"GMS request {type(request).__name__}"
if fd >= 0:
os.close(fd)
raise RuntimeError(
f"{prefix} returned unexpected response type: {type(response)}"
)
return response, fd
def commit(self) -> bool: def _send_recv(
"""Commit weights and release RW lock. Returns True on success.""" self, request, *, error_prefix: Optional[str] = None
if CommitRequest in RW_REQUIRED and self.lock_type != GrantedLockType.RW: ) -> Tuple[object, int]:
raise RuntimeError("Operation requires RW connection") if self._socket is None:
raise RuntimeError("Attempted GMS request on disconnected transport")
prefix = error_prefix or f"GMS request {type(request).__name__}"
try: try:
response, _ = self._send_recv(CommitRequest()) send_message_sync(self._socket, request)
ok = isinstance(response, CommitResponse) and response.success response, fd, self._recv_buffer = recv_message_sync(
except (ConnectionResetError, BrokenPipeError, OSError) as e: self._socket, self._recv_buffer
# Server closes RW socket as part of commit
logger.debug(
f"Commit saw socket error ({type(e).__name__}); verifying via RO connect"
) )
self.close() except Exception as exc:
try:
ro = GMSRPCClient(
self.socket_path, lock_type=RequestedLockType.RO, timeout_ms=1000
)
try:
ok = ro.committed
finally:
ro.close()
except TimeoutError:
ok = False
if ok:
self._committed = True
self.close()
logger.info("Committed weights and released RW connection")
return True
return False
def allocate(self, size: int, tag: str = "default") -> Tuple[str, int]:
"""Returns (allocation_id, aligned_size)."""
r = self._call(AllocateRequest(size=size, tag=tag), AllocateResponse)
return r.allocation_id, r.aligned_size
def export(self, allocation_id: str) -> int:
"""Export allocation as POSIX FD. Caller must close."""
_, fd = self._send_recv(ExportRequest(allocation_id=allocation_id))
if fd < 0:
raise RuntimeError("No FD received from server")
return fd
def get_allocation(self, allocation_id: str) -> GetAllocationResponse:
return self._call(
GetAllocationRequest(allocation_id=allocation_id), GetAllocationResponse
)
def list_allocations(self, tag: Optional[str] = None) -> List[Dict]:
return self._call(
ListAllocationsRequest(tag=tag), ListAllocationsResponse
).allocations
def free(self, allocation_id: str) -> bool:
return self._call(
FreeRequest(allocation_id=allocation_id), FreeResponse
).success
def clear_all(self) -> int:
return self._call(ClearAllRequest(), ClearAllResponse).cleared_count
def metadata_put(
self, key: str, allocation_id: str, offset_bytes: int, value: bytes
) -> bool:
req = MetadataPutRequest(
key=key, allocation_id=allocation_id, offset_bytes=offset_bytes, value=value
)
return self._call(req, MetadataPutResponse).success
def metadata_get(self, key: str) -> Optional[tuple[str, int, bytes]]:
"""Returns (allocation_id, offset_bytes, value) or None if not found."""
r = self._call(MetadataGetRequest(key=key), MetadataGetResponse)
return (r.allocation_id, r.offset_bytes, r.value) if r.found else None
def metadata_delete(self, key: str) -> bool:
return self._call(
MetadataDeleteRequest(key=key), MetadataDeleteResponse
).deleted
def metadata_list(self, prefix: str = "") -> List[str]:
return self._call(MetadataListRequest(prefix=prefix), MetadataListResponse).keys
def get_memory_layout_hash(self) -> str:
"""Get state hash (hash of allocations + metadata). Empty if not committed."""
return self._call(
GetStateHashRequest(), GetStateHashResponse
).memory_layout_hash
def close(self) -> None:
"""Close connection and release lock."""
if self._socket:
try: try:
self._socket.close() self._socket.close()
except Exception: except Exception:
pass pass
self._socket = None self._socket = None
lock_str = self.lock_type.value if self.lock_type else "unknown" raise ConnectionError(f"{prefix} failed: {exc}") from exc
logger.info(f"Closed {lock_str} connection")
def __enter__(self) -> "GMSRPCClient": if isinstance(response, ErrorResponse):
"""Context manager entry.""" if fd >= 0:
os.close(fd)
raise RuntimeError(f"{prefix} error: {response.error}")
return response, fd
def close(self) -> None:
if self._socket is None:
return
try:
self._socket.close()
except Exception as exc:
raise ConnectionError(
f"Failed to close GMS transport socket: {exc}"
) from exc
finally:
self._socket = None
def __enter__(self) -> "_GMSRPCTransport":
return self return self
def __exit__(self, exc_type, exc_val, exc_tb) -> None: def __exit__(self, exc_type, exc_val, exc_tb) -> None:
"""Context manager exit."""
self.close() self.close()
def __del__(self): def __del__(self):
"""Destructor: warn if connection not closed.""" if self._socket is not None:
if self._socket: try:
logger.warning("GMSRPCClient not closed properly") self._socket.close()
except Exception:
pass
self._socket = None
logger.warning("_GMSRPCTransport not closed properly")
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Internal GPU Memory Service client session."""
from __future__ import annotations
import logging
from typing import List, Optional, Tuple
from gpu_memory_service.client.rpc import _GMSRPCTransport
from gpu_memory_service.common.protocol.messages import (
AllocateRequest,
AllocateResponse,
CommitRequest,
CommitResponse,
ExportAllocationRequest,
ExportAllocationResponse,
FreeAllocationRequest,
FreeAllocationResponse,
GetAllocationRequest,
GetAllocationResponse,
GetAllocationStateRequest,
GetAllocationStateResponse,
GetLockStateRequest,
GetLockStateResponse,
GetStateHashRequest,
GetStateHashResponse,
HandshakeResponse,
ListAllocationsRequest,
ListAllocationsResponse,
MetadataDeleteRequest,
MetadataDeleteResponse,
MetadataGetRequest,
MetadataGetResponse,
MetadataListRequest,
MetadataListResponse,
MetadataPutRequest,
MetadataPutResponse,
)
from gpu_memory_service.common.types import GrantedLockType, RequestedLockType
logger = logging.getLogger(__name__)
class _GMSClientSession:
"""Connected GMS client session with granted lock state."""
def __init__(
self,
socket_path: str,
lock_type: RequestedLockType,
timeout_ms: Optional[int],
):
self._requested_lock_type = lock_type
self._transport = _GMSRPCTransport(socket_path)
self._transport.connect()
try:
response = self._transport.handshake(lock_type, timeout_ms)
except Exception:
try:
self._transport.close()
except Exception:
pass
raise
self._initialize_from_handshake(response)
def _initialize_from_handshake(self, response: HandshakeResponse) -> None:
if not response.success:
self._transport.close()
raise TimeoutError("Timeout waiting for lock")
self._committed = response.committed
if response.granted_lock_type is None:
self._transport.close()
raise RuntimeError("HandshakeResponse omitted granted_lock_type")
self._granted_lock_type = response.granted_lock_type
logger.info(
"Connected with %s lock (granted=%s), committed=%s",
self._requested_lock_type.value,
self._granted_lock_type.value,
self._committed,
)
@property
def committed(self) -> bool:
return self._committed
@property
def lock_type(self) -> GrantedLockType:
return self._granted_lock_type
@property
def is_connected(self) -> bool:
return self._transport.is_connected
def get_lock_state(self) -> GetLockStateResponse:
return self._transport.request(GetLockStateRequest(), GetLockStateResponse)
def get_allocation_state(self) -> GetAllocationStateResponse:
return self._transport.request(
GetAllocationStateRequest(), GetAllocationStateResponse
)
def is_ready(self) -> bool:
return self.committed
def commit(self) -> bool:
response = self._transport.request(CommitRequest(), CommitResponse)
if not response.success:
raise RuntimeError("GMS commit returned failure")
self._committed = True
try:
self.close()
except ConnectionError as exc:
logger.warning("Commit succeeded but closing transport failed: %s", exc)
logger.info("Committed weights and released RW connection")
return True
def allocate_info(self, size: int, tag: str = "default") -> AllocateResponse:
return self._transport.request(
AllocateRequest(size=size, tag=tag), AllocateResponse
)
def allocate(self, size: int, tag: str = "default") -> Tuple[str, int]:
response = self.allocate_info(size=size, tag=tag)
return response.allocation_id, response.aligned_size
def export(self, allocation_id: str) -> int:
response, fd = self._transport.request_with_fd(
ExportAllocationRequest(allocation_id=allocation_id),
ExportAllocationResponse,
)
if fd < 0:
raise RuntimeError(
f"GMS export returned no FD for allocation_id={allocation_id}"
)
return fd
def get_allocation(self, allocation_id: str) -> GetAllocationResponse:
return self._transport.request(
GetAllocationRequest(allocation_id=allocation_id),
GetAllocationResponse,
)
def list_allocations(
self, tag: Optional[str] = None
) -> List[GetAllocationResponse]:
return self._transport.request(
ListAllocationsRequest(tag=tag),
ListAllocationsResponse,
).allocations
def free(self, allocation_id: str) -> bool:
return self._transport.request(
FreeAllocationRequest(allocation_id=allocation_id),
FreeAllocationResponse,
).success
def metadata_put(
self, key: str, allocation_id: str, offset_bytes: int, value: bytes
) -> bool:
return self._transport.request(
MetadataPutRequest(
key=key,
allocation_id=allocation_id,
offset_bytes=offset_bytes,
value=value,
),
MetadataPutResponse,
).success
def metadata_get(self, key: str) -> Optional[tuple[str, int, bytes]]:
response = self._transport.request(
MetadataGetRequest(key=key), MetadataGetResponse
)
if not response.found:
return None
return response.allocation_id, response.offset_bytes, response.value
def metadata_delete(self, key: str) -> bool:
return self._transport.request(
MetadataDeleteRequest(key=key), MetadataDeleteResponse
).deleted
def metadata_list(self, prefix: str = "") -> List[str]:
return self._transport.request(
MetadataListRequest(prefix=prefix), MetadataListResponse
).keys
def get_memory_layout_hash(self) -> str:
return self._transport.request(
GetStateHashRequest(), GetStateHashResponse
).memory_layout_hash
def close(self) -> None:
self._transport.close()
logger.info("Closed %s connection", self._granted_lock_type.value)
def __enter__(self) -> "_GMSClientSession":
return self
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
self.close()
...@@ -12,7 +12,9 @@ This module provides PyTorch-specific functionality: ...@@ -12,7 +12,9 @@ This module provides PyTorch-specific functionality:
from gpu_memory_service.client.torch.allocator import ( from gpu_memory_service.client.torch.allocator import (
get_gms_client_memory_manager, get_gms_client_memory_manager,
get_gms_client_memory_managers,
get_or_create_gms_client_memory_manager, get_or_create_gms_client_memory_manager,
gms_use_mem_pool,
) )
from gpu_memory_service.client.torch.module import ( from gpu_memory_service.client.torch.module import (
materialize_module_from_gms, materialize_module_from_gms,
...@@ -23,6 +25,8 @@ __all__ = [ ...@@ -23,6 +25,8 @@ __all__ = [
# GMS client memory manager # GMS client memory manager
"get_or_create_gms_client_memory_manager", "get_or_create_gms_client_memory_manager",
"get_gms_client_memory_manager", "get_gms_client_memory_manager",
"get_gms_client_memory_managers",
"gms_use_mem_pool",
# Tensor operations (public API) # Tensor operations (public API)
"register_module_tensors", "register_module_tensors",
"materialize_module_from_gms", "materialize_module_from_gms",
......
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
"""GPU Memory Service allocator management (singleton). """GPU Memory Service allocator registry for PyTorch integration."""
Manages a single weights memory manager and PyTorch MemPool integration.
Only one GMS scope is needed: weights. KV cache is handled by CuMemAllocator.
"""
from __future__ import annotations from __future__ import annotations
import logging import logging
from typing import TYPE_CHECKING, Any, Optional, Tuple from contextlib import contextmanager
from contextvars import ContextVar
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Iterator, Optional
from gpu_memory_service.common.types import GrantedLockType, RequestedLockType from gpu_memory_service.common.types import GrantedLockType, RequestedLockType
if TYPE_CHECKING: if TYPE_CHECKING:
import torch
from gpu_memory_service.client.memory_manager import GMSClientMemoryManager from gpu_memory_service.client.memory_manager import GMSClientMemoryManager
from torch.cuda.memory import MemPool from torch.cuda.memory import MemPool
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Singleton state
_manager: Optional["GMSClientMemoryManager"] = None @dataclass
_mem_pool: Optional["MemPool"] = None class _TagState:
_tag: str = "weights" manager: "GMSClientMemoryManager"
_callbacks_initialized: bool = False mem_pool: "MemPool | None"
_pluggable_alloc: Optional[Any] = None socket_path: str
device: int
_tag_states: dict[str, _TagState] = {}
_active_tag: ContextVar[str | None] = ContextVar(
"gpu_memory_service_active_tag",
default=None,
)
_callbacks_initialized = False
_pluggable_alloc: Any | None = None
def _gms_malloc(size: int, device: int, stream: int) -> int: def _gms_malloc(size: int, device: int, stream: int) -> int:
"""Route malloc to the singleton weights manager.""" tag = _active_tag.get()
if _manager is None: if tag is None:
raise RuntimeError("No GMS manager initialized") raise RuntimeError("No active GMS allocation tag")
va = _manager.create_mapping(size=int(size), tag=_tag)
logger.debug("[GMS] malloc: va=0x%x size=%d", va, size) state = _tag_states.get(tag)
if state is None:
raise RuntimeError(f"Unknown GMS allocation tag: {tag}")
va = state.manager.create_mapping(size=int(size), tag=tag)
logger.debug("[GMS] malloc(tag=%s): va=0x%x size=%d", tag, va, size)
return va return va
def _gms_free(ptr: int, size: int, device: int, stream: int) -> None: def _gms_free(ptr: int, size: int, device: int, stream: int) -> None:
"""Route free to the singleton weights manager.""" va = int(ptr)
if _manager is None: for tag, state in _tag_states.items():
logger.warning("[GMS] free: no manager, ignoring va=0x%x", ptr) if va not in state.manager.mappings:
continue
logger.debug("[GMS] free(tag=%s): va=0x%x size=%d", tag, va, size)
state.manager.destroy_mapping(va)
return return
if int(ptr) in _manager.mappings: logger.warning("[GMS] free: no manager owns va=0x%x, ignoring", va)
logger.debug("[GMS] free: va=0x%x size=%d", ptr, size)
_manager.destroy_mapping(int(ptr))
else:
logger.warning("[GMS] free: manager does not own va=0x%x, ignoring", ptr)
def _ensure_callbacks_initialized() -> "MemPool": def _ensure_callbacks_initialized() -> None:
"""Initialize C-level callbacks exactly once, return a new MemPool."""
global _callbacks_initialized, _pluggable_alloc global _callbacks_initialized, _pluggable_alloc
from gpu_memory_service.client.torch.extensions import _allocator_ext as cumem from gpu_memory_service.client.torch.extensions import _allocator_ext as cumem
from torch.cuda import CUDAPluggableAllocator from torch.cuda import CUDAPluggableAllocator
from torch.cuda.memory import MemPool
if not _callbacks_initialized: if _callbacks_initialized:
_pluggable_alloc = CUDAPluggableAllocator( return
cumem.__file__, "my_malloc", "my_free"
) _pluggable_alloc = CUDAPluggableAllocator(cumem.__file__, "my_malloc", "my_free")
cumem.init_module(_gms_malloc, _gms_free) cumem.init_module(_gms_malloc, _gms_free)
_callbacks_initialized = True _callbacks_initialized = True
def _create_mem_pool() -> "MemPool":
from torch.cuda.memory import MemPool
assert _pluggable_alloc is not None
return MemPool(allocator=_pluggable_alloc.allocator()) return MemPool(allocator=_pluggable_alloc.allocator())
...@@ -74,66 +91,98 @@ def get_or_create_gms_client_memory_manager( ...@@ -74,66 +91,98 @@ def get_or_create_gms_client_memory_manager(
*, *,
tag: str = "weights", tag: str = "weights",
timeout_ms: Optional[int] = None, timeout_ms: Optional[int] = None,
) -> Tuple["GMSClientMemoryManager", Optional["MemPool"]]: ) -> "GMSClientMemoryManager":
"""Get existing memory manager, or create a new one.
Args:
socket_path: Unix socket path for the allocation server.
device: CUDA device index.
mode: RW for cold start, RO for import-only, RW_OR_RO for auto.
tag: Allocation tag for RW mode.
timeout_ms: Lock acquisition timeout (None = wait indefinitely).
Returns:
(gms_client_memory_manager, pool) - pool is None for RO mode.
"""
global _manager, _mem_pool, _tag
from gpu_memory_service.client.memory_manager import GMSClientMemoryManager from gpu_memory_service.client.memory_manager import GMSClientMemoryManager
if _manager is not None: state = _tag_states.get(tag)
return _get_existing(mode) if state is not None:
if state.socket_path != socket_path or state.device != device:
raise RuntimeError(
f"GMS allocator tag={tag} was initialized for "
f"{state.socket_path} on device {state.device}, not {socket_path} "
f"on device {device}"
)
manager = state.manager
if not manager.is_connected:
if manager.mappings or manager.is_unmapped or manager.granted_lock_type:
raise RuntimeError(
f"GMS allocator tag={tag} is disconnected but still owns "
"preserved state; recreate the process instead of reusing it"
)
manager._client = None
manager._granted_lock_type = None
_tag_states.pop(tag, None)
state = None
if state is not None:
current = state.manager.granted_lock_type
if mode == RequestedLockType.RW and current != GrantedLockType.RW:
raise RuntimeError(
f"Cannot get RW allocator for tag {tag}: existing is in {current} mode"
)
if mode == RequestedLockType.RO and current != GrantedLockType.RO:
raise RuntimeError(
f"Cannot get RO allocator for tag {tag}: existing is in {current} mode"
)
return state.manager
manager = GMSClientMemoryManager(socket_path, device=device) manager = GMSClientMemoryManager(socket_path, device=device)
manager.connect(mode, timeout_ms=timeout_ms) manager.connect(mode, timeout_ms=timeout_ms)
mem_pool = None
if manager.granted_lock_type == GrantedLockType.RW: if manager.granted_lock_type == GrantedLockType.RW:
pool = _ensure_callbacks_initialized() _ensure_callbacks_initialized()
# Only set globals after mempool succeeds (avoids partial singleton) mem_pool = _create_mem_pool()
_manager = manager
_tag = tag _tag_states[tag] = _TagState(
_mem_pool = pool manager=manager,
logger.info("[GMS] Created RW allocator (device=%d)", device) mem_pool=mem_pool,
return manager, pool socket_path=socket_path,
else: device=device,
_manager = manager )
_tag = tag logger.info(
logger.info("[GMS] Created RO allocator (device=%d)", device) "[GMS] Created %s allocator for tag=%s (device=%d)",
return manager, None manager.granted_lock_type.value,
tag,
device,
def _get_existing( )
mode: RequestedLockType, return manager
) -> Tuple["GMSClientMemoryManager", Optional["MemPool"]]:
"""Return existing allocator if mode-compatible."""
assert _manager is not None def get_gms_client_memory_manager(
current = _manager.granted_lock_type tag: str = "weights",
) -> "GMSClientMemoryManager | None":
state = _tag_states.get(tag)
if state is None:
return None
return state.manager
def get_gms_client_memory_managers() -> tuple["GMSClientMemoryManager", ...]:
return tuple(state.manager for state in _tag_states.values())
if mode == RequestedLockType.RW: def evict_gms_client_memory_manager(manager: "GMSClientMemoryManager") -> None:
if current == GrantedLockType.RW: for tag, state in list(_tag_states.items()):
return _manager, _mem_pool if state.manager is manager:
raise RuntimeError(f"Cannot get RW allocator: existing is in {current} mode") _tag_states.pop(tag, None)
return
if mode == RequestedLockType.RO:
if current == GrantedLockType.RO:
return _manager, None
raise RuntimeError(f"Cannot get RO allocator: existing is in {current} mode")
# RW_OR_RO: return whatever exists @contextmanager
effective_pool = _mem_pool if current == GrantedLockType.RW else None def gms_use_mem_pool(tag: str, device: "torch.device | int") -> Iterator[None]:
return _manager, effective_pool import torch
state = _tag_states.get(tag)
if state is None:
raise RuntimeError(f"No GMS allocator initialized for tag={tag}")
if state.mem_pool is None:
raise RuntimeError(f"GMS allocator tag={tag} does not have a mempool")
def get_gms_client_memory_manager() -> Optional["GMSClientMemoryManager"]: token = _active_tag.set(tag)
"""Get the active GMS client memory manager, or None.""" try:
return _manager with torch.cuda.use_mem_pool(state.mem_pool, device=device):
yield
finally:
_active_tag.reset(token)
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""CUDA driver helpers shared by the GMS client and server."""
from __future__ import annotations
import atexit
import os
from cuda.bindings import driver as cuda
from gpu_memory_service.common.types import GrantedLockType
from gpu_memory_service.common.utils import fail
_primary_contexts: dict[int, object] = {}
_primary_context_release_registered = False
def cuda_check_result(result: cuda.CUresult, name: str) -> None:
if result != cuda.CUresult.CUDA_SUCCESS:
err_result, err_str = cuda.cuGetErrorString(result)
if err_result == cuda.CUresult.CUDA_SUCCESS and err_str:
err_msg = err_str.decode() if isinstance(err_str, bytes) else str(err_str)
else:
err_msg = str(result)
fail("fatal CUDA VMM error in %s: %s", name, err_msg)
def cuda_ensure_initialized() -> None:
(result,) = cuda.cuInit(0)
cuda_check_result(result, "cuInit")
def cumem_get_allocation_granularity(device: int) -> int:
"""Get VMM allocation granularity for a device.
Args:
device: CUDA device index
Returns:
Allocation granularity in bytes (typically 2 MiB)
"""
prop = cuda.CUmemAllocationProp()
prop.type = cuda.CUmemAllocationType.CU_MEM_ALLOCATION_TYPE_PINNED
prop.location.type = cuda.CUmemLocationType.CU_MEM_LOCATION_TYPE_DEVICE
prop.location.id = device
prop.requestedHandleTypes = (
cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR
)
result, granularity = cuda.cuMemGetAllocationGranularity(
prop, cuda.CUmemAllocationGranularity_flags.CU_MEM_ALLOC_GRANULARITY_MINIMUM
)
cuda_check_result(result, "cuMemGetAllocationGranularity")
return int(granularity)
def cumem_create_tolerate_oom(size: int, device: int) -> tuple[bool, int]:
prop = cuda.CUmemAllocationProp()
prop.type = cuda.CUmemAllocationType.CU_MEM_ALLOCATION_TYPE_PINNED
prop.location.type = cuda.CUmemLocationType.CU_MEM_LOCATION_TYPE_DEVICE
prop.location.id = device
prop.requestedHandleTypes = (
cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR
)
result, handle = cuda.cuMemCreate(size, prop, 0)
if result == cuda.CUresult.CUDA_SUCCESS:
return True, int(handle)
if result == cuda.CUresult.CUDA_ERROR_OUT_OF_MEMORY:
return False, 0
cuda_check_result(result, "cuMemCreate")
return False, 0
def cumem_export_to_shareable_handle(handle: int) -> int:
result, fd = cuda.cuMemExportToShareableHandle(
handle,
cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR,
0,
)
cuda_check_result(result, "cuMemExportToShareableHandle")
return int(fd)
def align_to_granularity(size: int, granularity: int) -> int:
"""Align size up to VMM granularity.
Args:
size: Size in bytes
granularity: Allocation granularity
Returns:
Aligned size
"""
return ((size + granularity - 1) // granularity) * granularity
def cumem_import_from_shareable_handle_close_fd(fd: int) -> int:
try:
result, handle = cuda.cuMemImportFromShareableHandle(
fd,
cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR,
)
cuda_check_result(result, "cuMemImportFromShareableHandle")
return int(handle)
finally:
os.close(fd)
def cumem_address_reserve(size: int, granularity: int) -> int:
result, va = cuda.cuMemAddressReserve(size, granularity, 0, 0)
cuda_check_result(result, "cuMemAddressReserve")
return int(va)
def cumem_address_free(va: int, size: int) -> None:
(result,) = cuda.cuMemAddressFree(va, size)
cuda_check_result(result, "cuMemAddressFree")
def cumem_map(va: int, size: int, handle: int) -> None:
(result,) = cuda.cuMemMap(va, size, 0, handle, 0)
cuda_check_result(result, "cuMemMap")
def cumem_set_access(va: int, size: int, device: int, access: GrantedLockType) -> None:
access_desc = cuda.CUmemAccessDesc()
access_desc.location.type = cuda.CUmemLocationType.CU_MEM_LOCATION_TYPE_DEVICE
access_desc.location.id = device
access_desc.flags = (
cuda.CUmemAccess_flags.CU_MEM_ACCESS_FLAGS_PROT_READ
if access == GrantedLockType.RO
else cuda.CUmemAccess_flags.CU_MEM_ACCESS_FLAGS_PROT_READWRITE
)
(result,) = cuda.cuMemSetAccess(va, size, [access_desc], 1)
cuda_check_result(result, "cuMemSetAccess")
def cumem_unmap(va: int, size: int) -> None:
(result,) = cuda.cuMemUnmap(va, size)
cuda_check_result(result, "cuMemUnmap")
def cumem_release(handle: int) -> None:
(result,) = cuda.cuMemRelease(handle)
cuda_check_result(result, "cuMemRelease")
def cuda_validate_pointer(va: int) -> None:
result, _ = cuda.cuPointerGetAttribute(
cuda.CUpointer_attribute.CU_POINTER_ATTRIBUTE_DEVICE_POINTER, va
)
cuda_check_result(result, "cuPointerGetAttribute")
def cuda_synchronize() -> None:
(result,) = cuda.cuCtxSynchronize()
cuda_check_result(result, "cuCtxSynchronize")
def cuda_set_current_device(device: int) -> None:
global _primary_context_release_registered
ctx = _primary_contexts.get(device)
if ctx is None:
result, ctx = cuda.cuDevicePrimaryCtxRetain(device)
cuda_check_result(result, "cuDevicePrimaryCtxRetain")
_primary_contexts[device] = ctx
if not _primary_context_release_registered:
_primary_context_release_registered = True
atexit.register(_release_primary_contexts)
(result,) = cuda.cuCtxSetCurrent(ctx)
cuda_check_result(result, "cuCtxSetCurrent")
def _release_primary_contexts() -> None:
for device in list(_primary_contexts):
try:
(result,) = cuda.cuDevicePrimaryCtxRelease(device)
except Exception:
continue
if result == cuda.CUresult.CUDA_SUCCESS:
_primary_contexts.pop(device, None)
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""CUDA Virtual Memory Management (VMM) utility functions.
This module provides utility functions for CUDA driver API operations
used by both server (GMSServerMemoryManager) and client (GMSClientMemoryManager).
"""
from cuda.bindings import driver as cuda
def check_cuda_result(result: cuda.CUresult, name: str) -> None:
"""Check CUDA driver API result and raise on error.
Args:
result: CUDA driver API return code (CUresult enum)
name: Operation name for error message
Raises:
RuntimeError: If result is not CUDA_SUCCESS
"""
if result != cuda.CUresult.CUDA_SUCCESS:
err_result, err_str = cuda.cuGetErrorString(result)
if err_result == cuda.CUresult.CUDA_SUCCESS and err_str:
err_msg = err_str.decode() if isinstance(err_str, bytes) else str(err_str)
else:
err_msg = str(result)
raise RuntimeError(f"{name}: {err_msg}")
def ensure_cuda_initialized() -> None:
"""Ensure CUDA driver is initialized.
Raises:
RuntimeError: If cuInit fails
"""
(result,) = cuda.cuInit(0)
check_cuda_result(result, "cuInit")
def get_allocation_granularity(device: int) -> int:
"""Get VMM allocation granularity for a device.
Args:
device: CUDA device index
Returns:
Allocation granularity in bytes (typically 2 MiB)
"""
prop = cuda.CUmemAllocationProp()
prop.type = cuda.CUmemAllocationType.CU_MEM_ALLOCATION_TYPE_PINNED
prop.location.type = cuda.CUmemLocationType.CU_MEM_LOCATION_TYPE_DEVICE
prop.location.id = device
prop.requestedHandleTypes = (
cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR
)
result, granularity = cuda.cuMemGetAllocationGranularity(
prop, cuda.CUmemAllocationGranularity_flags.CU_MEM_ALLOC_GRANULARITY_MINIMUM
)
check_cuda_result(result, "cuMemGetAllocationGranularity")
return int(granularity)
def align_to_granularity(size: int, granularity: int) -> int:
"""Align size up to VMM granularity.
Args:
size: Size in bytes
granularity: Allocation granularity
Returns:
Aligned size
"""
return ((size + granularity - 1) // granularity) * granularity
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
"""Message types for GPU Memory Service RPC protocol.""" """Message types for GPU Memory Service RPC protocol."""
from enum import Enum from enum import Enum
from typing import Any, Dict, List, Optional, Union from typing import List, Optional, Union
import msgspec import msgspec
...@@ -62,7 +62,6 @@ class GetAllocationStateRequest(msgspec.Struct, tag="get_allocation_state_reques ...@@ -62,7 +62,6 @@ class GetAllocationStateRequest(msgspec.Struct, tag="get_allocation_state_reques
class GetAllocationStateResponse(msgspec.Struct, tag="get_allocation_state_response"): class GetAllocationStateResponse(msgspec.Struct, tag="get_allocation_state_response"):
allocation_count: int allocation_count: int
total_bytes: int
class AllocateRequest(msgspec.Struct, tag="allocate_request"): class AllocateRequest(msgspec.Struct, tag="allocate_request"):
...@@ -74,12 +73,21 @@ class AllocateResponse(msgspec.Struct, tag="allocate_response"): ...@@ -74,12 +73,21 @@ class AllocateResponse(msgspec.Struct, tag="allocate_response"):
allocation_id: str allocation_id: str
size: int size: int
aligned_size: int aligned_size: int
layout_slot: int
class ExportRequest(msgspec.Struct, tag="export_request"): class ExportAllocationRequest(msgspec.Struct, tag="export_allocation_request"):
allocation_id: str allocation_id: str
class ExportAllocationResponse(msgspec.Struct, tag="export_allocation_response"):
allocation_id: str
size: int
aligned_size: int
tag: str
layout_slot: int
class GetAllocationRequest(msgspec.Struct, tag="get_allocation_request"): class GetAllocationRequest(msgspec.Struct, tag="get_allocation_request"):
allocation_id: str allocation_id: str
...@@ -89,6 +97,7 @@ class GetAllocationResponse(msgspec.Struct, tag="get_allocation_response"): ...@@ -89,6 +97,7 @@ class GetAllocationResponse(msgspec.Struct, tag="get_allocation_response"):
size: int size: int
aligned_size: int aligned_size: int
tag: str tag: str
layout_slot: int
class ListAllocationsRequest(msgspec.Struct, tag="list_allocations_request"): class ListAllocationsRequest(msgspec.Struct, tag="list_allocations_request"):
...@@ -96,25 +105,17 @@ class ListAllocationsRequest(msgspec.Struct, tag="list_allocations_request"): ...@@ -96,25 +105,17 @@ class ListAllocationsRequest(msgspec.Struct, tag="list_allocations_request"):
class ListAllocationsResponse(msgspec.Struct, tag="list_allocations_response"): class ListAllocationsResponse(msgspec.Struct, tag="list_allocations_response"):
allocations: List[Dict[str, Any]] = [] allocations: List[GetAllocationResponse] = []
class FreeRequest(msgspec.Struct, tag="free_request"): class FreeAllocationRequest(msgspec.Struct, tag="free_allocation_request"):
allocation_id: str allocation_id: str
class FreeResponse(msgspec.Struct, tag="free_response"): class FreeAllocationResponse(msgspec.Struct, tag="free_allocation_response"):
success: bool success: bool
class ClearAllRequest(msgspec.Struct, tag="clear_all_request"):
pass
class ClearAllResponse(msgspec.Struct, tag="clear_all_response"):
cleared_count: int
class ErrorResponse(msgspec.Struct, tag="error_response"): class ErrorResponse(msgspec.Struct, tag="error_response"):
error: str error: str
code: int = 0 code: int = 0
...@@ -166,6 +167,34 @@ class GetStateHashResponse(msgspec.Struct, tag="get_memory_layout_hash_response" ...@@ -166,6 +167,34 @@ class GetStateHashResponse(msgspec.Struct, tag="get_memory_layout_hash_response"
memory_layout_hash: str # Hash of allocations + metadata, empty if not committed memory_layout_hash: str # Hash of allocations + metadata, empty if not committed
class GetRuntimeStateRequest(msgspec.Struct, tag="get_runtime_state_request"):
pass
class GetRuntimeStateResponse(msgspec.Struct, tag="get_runtime_state_response"):
state: str
has_rw_session: bool
ro_session_count: int
waiting_writers: int
committed: bool
is_ready: bool
allocation_count: int = 0
memory_layout_hash: str = ""
class GMSRuntimeEvent(msgspec.Struct):
kind: str
allocation_count: int = 0
class GetEventHistoryRequest(msgspec.Struct, tag="get_event_history_request"):
pass
class GetEventHistoryResponse(msgspec.Struct, tag="get_event_history_response"):
events: List[GMSRuntimeEvent] = []
Message = Union[ Message = Union[
HandshakeRequest, HandshakeRequest,
HandshakeResponse, HandshakeResponse,
...@@ -177,15 +206,14 @@ Message = Union[ ...@@ -177,15 +206,14 @@ Message = Union[
GetAllocationStateResponse, GetAllocationStateResponse,
AllocateRequest, AllocateRequest,
AllocateResponse, AllocateResponse,
ExportRequest, ExportAllocationRequest,
ExportAllocationResponse,
GetAllocationRequest, GetAllocationRequest,
GetAllocationResponse, GetAllocationResponse,
ListAllocationsRequest, ListAllocationsRequest,
ListAllocationsResponse, ListAllocationsResponse,
FreeRequest, FreeAllocationRequest,
FreeResponse, FreeAllocationResponse,
ClearAllRequest,
ClearAllResponse,
ErrorResponse, ErrorResponse,
MetadataPutRequest, MetadataPutRequest,
MetadataPutResponse, MetadataPutResponse,
...@@ -197,6 +225,10 @@ Message = Union[ ...@@ -197,6 +225,10 @@ Message = Union[
MetadataListResponse, MetadataListResponse,
GetStateHashRequest, GetStateHashRequest,
GetStateHashResponse, GetStateHashResponse,
GetRuntimeStateRequest,
GetRuntimeStateResponse,
GetEventHistoryRequest,
GetEventHistoryResponse,
] ]
_encoder = msgspec.msgpack.Encoder() _encoder = msgspec.msgpack.Encoder()
......
...@@ -96,7 +96,11 @@ async def recv_message( ...@@ -96,7 +96,11 @@ async def recv_message(
raw_msg, fds, _flags, _addr = await loop.run_in_executor( raw_msg, fds, _flags, _addr = await loop.run_in_executor(
None, lambda: socket.recv_fds(raw_sock, 65536, 1) None, lambda: socket.recv_fds(raw_sock, 65536, 1)
) )
for extra_fd in fds[1:]:
os.close(extra_fd)
if not raw_msg: if not raw_msg:
if fds:
os.close(fds[0])
raise ConnectionResetError("Connection closed") raise ConnectionResetError("Connection closed")
recv_buffer.extend(raw_msg) recv_buffer.extend(raw_msg)
fd = fds[0] if fds else -1 fd = fds[0] if fds else -1
...@@ -107,21 +111,25 @@ async def recv_message( ...@@ -107,21 +111,25 @@ async def recv_message(
recv_buffer.extend(chunk) recv_buffer.extend(chunk)
# Try to extract message, read more if needed # Try to extract message, read more if needed
msg, remaining, bytes_needed = _try_extract_message(recv_buffer) try:
while msg is None and bytes_needed > 0: msg, remaining, bytes_needed = _try_extract_message(recv_buffer)
if raw_sock is not None: while msg is None and bytes_needed > 0:
# Continue reading from raw socket to avoid buffer inconsistency if raw_sock is not None:
chunk = await loop.run_in_executor( # Continue reading from raw socket to avoid buffer inconsistency
None, lambda n=bytes_needed: raw_sock.recv(n) chunk = await loop.run_in_executor(
) None, lambda n=bytes_needed: raw_sock.recv(n)
else: )
chunk = await reader.read(bytes_needed) else:
if not chunk: chunk = await reader.read(bytes_needed)
raise ConnectionResetError("Connection closed") if not chunk:
remaining.extend(chunk) raise ConnectionResetError("Connection closed")
msg, remaining, bytes_needed = _try_extract_message(remaining) remaining.extend(chunk)
msg, remaining, bytes_needed = _try_extract_message(remaining)
return msg, fd, remaining return msg, fd, remaining
except Exception:
if fd >= 0:
os.close(fd)
raise
# ==================== Sync (for client) ==================== # ==================== Sync (for client) ====================
...@@ -153,18 +161,26 @@ def recv_message_sync( ...@@ -153,18 +161,26 @@ def recv_message_sync(
# Receive more data (with potential FD) # Receive more data (with potential FD)
raw_msg, fds, _flags, _addr = socket.recv_fds(sock, 65536, 1) raw_msg, fds, _flags, _addr = socket.recv_fds(sock, 65536, 1)
for extra_fd in fds[1:]:
os.close(extra_fd)
if not raw_msg: if not raw_msg:
if fds:
os.close(fds[0])
raise ConnectionResetError("Connection closed") raise ConnectionResetError("Connection closed")
recv_buffer.extend(raw_msg) recv_buffer.extend(raw_msg)
fd = fds[0] if fds else -1 fd = fds[0] if fds else -1
# Try to extract message, read more if needed # Try to extract message, read more if needed
msg, remaining, bytes_needed = _try_extract_message(recv_buffer) try:
while msg is None and bytes_needed > 0: msg, remaining, bytes_needed = _try_extract_message(recv_buffer)
chunk = sock.recv(bytes_needed) while msg is None and bytes_needed > 0:
if not chunk: chunk = sock.recv(bytes_needed)
raise ConnectionResetError("Connection closed") if not chunk:
remaining.extend(chunk) raise ConnectionResetError("Connection closed")
msg, remaining, bytes_needed = _try_extract_message(remaining) remaining.extend(chunk)
msg, remaining, bytes_needed = _try_extract_message(remaining)
return msg, fd, remaining return msg, fd, remaining
except Exception:
if fd >= 0:
os.close(fd)
raise
...@@ -8,10 +8,9 @@ from enum import Enum, auto ...@@ -8,10 +8,9 @@ from enum import Enum, auto
from gpu_memory_service.common.protocol.messages import ( from gpu_memory_service.common.protocol.messages import (
AllocateRequest, AllocateRequest,
ClearAllRequest,
CommitRequest, CommitRequest,
ExportRequest, ExportAllocationRequest,
FreeRequest, FreeAllocationRequest,
GetAllocationRequest, GetAllocationRequest,
GetAllocationStateRequest, GetAllocationStateRequest,
GetLockStateRequest, GetLockStateRequest,
...@@ -89,8 +88,7 @@ def derive_state(has_rw: bool, ro_count: int, committed: bool) -> ServerState: ...@@ -89,8 +88,7 @@ def derive_state(has_rw: bool, ro_count: int, committed: bool) -> ServerState:
RW_REQUIRED: frozenset[type] = frozenset( RW_REQUIRED: frozenset[type] = frozenset(
{ {
AllocateRequest, AllocateRequest,
FreeRequest, FreeAllocationRequest,
ClearAllRequest,
MetadataPutRequest, MetadataPutRequest,
MetadataDeleteRequest, MetadataDeleteRequest,
CommitRequest, CommitRequest,
...@@ -99,7 +97,7 @@ RW_REQUIRED: frozenset[type] = frozenset( ...@@ -99,7 +97,7 @@ RW_REQUIRED: frozenset[type] = frozenset(
RO_ALLOWED: frozenset[type] = frozenset( RO_ALLOWED: frozenset[type] = frozenset(
{ {
ExportRequest, ExportAllocationRequest,
GetAllocationRequest, GetAllocationRequest,
ListAllocationsRequest, ListAllocationsRequest,
MetadataGetRequest, MetadataGetRequest,
......
...@@ -3,36 +3,39 @@ ...@@ -3,36 +3,39 @@
"""Shared utilities for GPU Memory Service.""" """Shared utilities for GPU Memory Service."""
import logging
import os import os
import tempfile import tempfile
import uuid from typing import NoReturn
from cuda.bindings import driver as cuda logger = logging.getLogger(__name__)
from gpu_memory_service.common.cuda_vmm_utils import (
check_cuda_result,
ensure_cuda_initialized,
)
def get_socket_path(device: int) -> str: def fail(message: str, *args, exc_info=None) -> NoReturn:
"""Get GMS socket path for the given CUDA device. logger.critical(message, *args, exc_info=exc_info)
logging.shutdown()
os._exit(1)
The socket path is based on GPU UUID resolved by CUDA.
CUDA_VISIBLE_DEVICES remapping is handled by CUDA device enumeration. def get_socket_path(device: int, tag: str = "weights") -> str:
"""Get GMS socket path for the given CUDA device and tag.
The socket path is based on GPU UUID, making it stable across different
CUDA_VISIBLE_DEVICES configurations.
Args: Args:
device: CUDA device index. device: CUDA device index.
Returns: Returns:
Socket path (e.g., "<tempdir>/gms_GPU-12345678-1234-1234-1234-123456789abc.sock"). Socket path
(e.g., "<tempdir>/gms_GPU-12345678-1234-1234-1234-123456789abc_weights.sock").
""" """
ensure_cuda_initialized() import pynvml
result, cu_device = cuda.cuDeviceGet(device) pynvml.nvmlInit()
check_cuda_result(result, "cuDeviceGet") try:
handle = pynvml.nvmlDeviceGetHandleByIndex(device)
result, cu_uuid = cuda.cuDeviceGetUuid(cu_device) uuid = pynvml.nvmlDeviceGetUUID(handle)
check_cuda_result(result, "cuDeviceGetUuid") finally:
pynvml.nvmlShutdown()
gpu_uuid = f"GPU-{uuid.UUID(bytes=bytes(cu_uuid.bytes))}" return os.path.join(tempfile.gettempdir(), f"gms_{uuid}_{tag}.sock")
return os.path.join(tempfile.gettempdir(), f"gms_{gpu_uuid}.sock")
...@@ -8,7 +8,7 @@ from __future__ import annotations ...@@ -8,7 +8,7 @@ from __future__ import annotations
import logging import logging
import torch import torch
from gpu_memory_service import get_gms_client_memory_manager from gpu_memory_service.client.torch.allocator import get_gms_client_memory_managers
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -32,11 +32,13 @@ def patch_empty_cache() -> None: ...@@ -32,11 +32,13 @@ def patch_empty_cache() -> None:
_original_empty_cache = torch.cuda.empty_cache _original_empty_cache = torch.cuda.empty_cache
def safe_empty_cache() -> None: def safe_empty_cache() -> None:
manager = get_gms_client_memory_manager() mapping_count = sum(
if manager is not None and len(manager.mappings) > 0: len(manager.mappings) for manager in get_gms_client_memory_managers()
)
if mapping_count > 0:
logger.debug( logger.debug(
"[GMS] Skipping torch.cuda.empty_cache() - %d VMM allocations active", "[GMS] Skipping torch.cuda.empty_cache() - %d VMM allocations active",
len(manager.mappings), mapping_count,
) )
return return
_original_empty_cache() _original_empty_cache()
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
from dataclasses import replace
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
import torch import torch
...@@ -29,6 +30,20 @@ def get_gms_lock_mode(extra_config: dict): ...@@ -29,6 +30,20 @@ def get_gms_lock_mode(extra_config: dict):
return RequestedLockType.RW_OR_RO return RequestedLockType.RW_OR_RO
def strip_gms_model_loader_config(load_config, load_format: str):
"""Copy a loader config with GMS-only keys removed for backend loaders."""
extra_config = getattr(load_config, "model_loader_extra_config", {}) or {}
return replace(
load_config,
load_format=load_format,
model_loader_extra_config={
key: value
for key, value in extra_config.items()
if not key.startswith("gms_")
},
)
def setup_meta_tensor_workaround() -> None: def setup_meta_tensor_workaround() -> None:
"""Enable workaround for meta tensor operations like torch.nonzero().""" """Enable workaround for meta tensor operations like torch.nonzero()."""
try: try:
...@@ -42,9 +57,9 @@ def setup_meta_tensor_workaround() -> None: ...@@ -42,9 +57,9 @@ def setup_meta_tensor_workaround() -> None:
def finalize_gms_write( def finalize_gms_write(
allocator: "GMSClientMemoryManager", model: torch.nn.Module allocator: "GMSClientMemoryManager", model: torch.nn.Module
) -> int: ) -> int:
"""Finalize GMS write mode: register tensors, commit, switch to read. """Finalize GMS write mode: register tensors, commit, reconnect in read mode.
Flow: register tensors -> sync -> commit (server-only) -> disconnect -> connect(RO) Flow: register tensors -> sync -> unmap + commit -> connect(RO) -> remap
Args: Args:
allocator: The GMS client memory manager in write mode. allocator: The GMS client memory manager in write mode.
...@@ -52,9 +67,6 @@ def finalize_gms_write( ...@@ -52,9 +67,6 @@ def finalize_gms_write(
Returns: Returns:
Total bytes committed. Total bytes committed.
Raises:
RuntimeError: If commit fails.
""" """
from gpu_memory_service.client.torch.module import register_module_tensors from gpu_memory_service.client.torch.module import register_module_tensors
from gpu_memory_service.common.types import RequestedLockType from gpu_memory_service.common.types import RequestedLockType
...@@ -65,12 +77,10 @@ def finalize_gms_write( ...@@ -65,12 +77,10 @@ def finalize_gms_write(
# Synchronize before commit — caller's writes must be visible # Synchronize before commit — caller's writes must be visible
torch.cuda.synchronize() torch.cuda.synchronize()
if not allocator.commit(): allocator.commit()
raise RuntimeError("GMS commit failed")
# commit() closed the RW socket; acquire RO for inference
allocator.disconnect() # no-op if commit already cleared _client, but safe
allocator.connect(RequestedLockType.RO) allocator.connect(RequestedLockType.RO)
allocator.remap_all_vas()
logger.info( logger.info(
"[GMS] Committed %.2f GiB, switched to read mode with %d mappings", "[GMS] Committed %.2f GiB, switched to read mode with %d mappings",
......
...@@ -3,13 +3,10 @@ ...@@ -3,13 +3,10 @@
"""Hybrid torch_memory_saver implementation for GPU Memory Service. """Hybrid torch_memory_saver implementation for GPU Memory Service.
This module provides a hybrid implementation that combines: This module uses:
1. GPU Memory Service allocator for "weights" tag (VA-stable unmap/remap, shared) 1. GPU Memory Service for "weights" (shared RO/RW publish flow)
2. Torch mempool mode for other tags like "kv_cache" (CPU backup, per-instance) 2. GPU Memory Service for "kv_cache" (RW-only failover flow)
3. torch_memory_saver for any remaining tags
The impl uses RW_OR_RO mode to connect to GMS:
- First process gets RW lock and loads weights from disk
- Subsequent processes get RO lock and import weights from metadata
""" """
from __future__ import annotations from __future__ import annotations
...@@ -19,10 +16,13 @@ from contextlib import contextmanager ...@@ -19,10 +16,13 @@ from contextlib import contextmanager
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, Optional
import torch import torch
from gpu_memory_service import get_or_create_gms_client_memory_manager
from gpu_memory_service.client.torch.allocator import gms_use_mem_pool
from gpu_memory_service.common.types import GrantedLockType, RequestedLockType
from gpu_memory_service.common.utils import get_socket_path
if TYPE_CHECKING: if TYPE_CHECKING:
from gpu_memory_service.client.memory_manager import GMSClientMemoryManager from gpu_memory_service.client.memory_manager import GMSClientMemoryManager
from torch.cuda.memory import MemPool
from torch_memory_saver.entrypoint import _TorchMemorySaverImpl from torch_memory_saver.entrypoint import _TorchMemorySaverImpl
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -39,56 +39,54 @@ def get_gms_memory_saver_impl() -> Optional["GMSMemorySaverImpl"]: ...@@ -39,56 +39,54 @@ def get_gms_memory_saver_impl() -> Optional["GMSMemorySaverImpl"]:
class GMSMemorySaverImpl: class GMSMemorySaverImpl:
"""Hybrid implementation: GMS for weights, torch mempool for KV cache. """Hybrid implementation: GMS for weights and KV cache."""
Routes operations based on tag:
- "weights" or "model_weights": Handled by GMS allocator (VA-stable)
- Other tags (e.g., "kv_cache"): Delegated to torch mempool mode
"""
def __init__( def __init__(
self, self,
torch_impl: "_TorchMemorySaverImpl", torch_impl: "_TorchMemorySaverImpl",
socket_path: str,
device_index: int, device_index: int,
mode=None, mode=None,
): ):
self._torch_impl = torch_impl self._torch_impl = torch_impl
self._socket_path = socket_path
self._device_index = device_index self._device_index = device_index
self._requested_mode = mode self._requested_mode = mode
self._disabled = False self._disabled = False
self._imported_weights_bytes: int = 0 self._imported_weights_bytes: int = 0
self._allocator: Optional["GMSClientMemoryManager"] self._weights_allocator: Optional["GMSClientMemoryManager"]
self._mem_pool: Optional["MemPool"] self._kv_cache_allocator: "GMSClientMemoryManager"
self._mode: str self._mode: str
self._allocator, self._mem_pool, self._mode = self._init_allocator() (
self._weights_allocator,
self._kv_cache_allocator,
self._mode,
) = self._init_allocators()
logger.info( logger.info(
"[GMS] Initialized: weights=%s mode (device=%d, socket=%s)", "[GMS] Initialized weights=%s mode, kv_cache=RW (device=%d)",
self._mode.upper(), self._mode.upper(),
device_index, device_index,
socket_path,
) )
def _init_allocator( def _init_allocators(
self, self,
) -> tuple[Optional["GMSClientMemoryManager"], Optional["MemPool"], str]: ) -> tuple[Optional["GMSClientMemoryManager"], "GMSClientMemoryManager", str,]:
"""Create allocator with mode from config (default: RW_OR_RO).""" """Create allocator with mode from config (default: RW_OR_RO)."""
from gpu_memory_service import get_or_create_gms_client_memory_manager
from gpu_memory_service.common.types import GrantedLockType, RequestedLockType
mode = self._requested_mode or RequestedLockType.RW_OR_RO mode = self._requested_mode or RequestedLockType.RW_OR_RO
allocator, mem_pool = get_or_create_gms_client_memory_manager( weights_allocator = get_or_create_gms_client_memory_manager(
self._socket_path, get_socket_path(self._device_index, "weights"),
self._device_index, self._device_index,
mode=mode, mode=mode,
tag="weights", tag="weights",
) )
granted_mode = allocator.granted_lock_type kv_cache_allocator = get_or_create_gms_client_memory_manager(
get_socket_path(self._device_index, "kv_cache"),
self._device_index,
mode=RequestedLockType.RW,
tag="kv_cache",
)
granted_mode = weights_allocator.granted_lock_type
if granted_mode == GrantedLockType.RW: if granted_mode == GrantedLockType.RW:
allocator.clear_all_handles()
actual_mode = "write" actual_mode = "write"
else: else:
actual_mode = "read" actual_mode = "read"
...@@ -97,11 +95,7 @@ class GMSMemorySaverImpl: ...@@ -97,11 +95,7 @@ class GMSMemorySaverImpl:
actual_mode.upper(), actual_mode.upper(),
self._device_index, self._device_index,
) )
return ( return weights_allocator, kv_cache_allocator, actual_mode
allocator,
mem_pool if granted_mode == GrantedLockType.RW else None,
actual_mode,
)
def _is_weights_tag(self, tag: Optional[str]) -> bool: def _is_weights_tag(self, tag: Optional[str]) -> bool:
return tag in ("weights", "model_weights") return tag in ("weights", "model_weights")
...@@ -110,25 +104,28 @@ class GMSMemorySaverImpl: ...@@ -110,25 +104,28 @@ class GMSMemorySaverImpl:
return self._mode return self._mode
def get_allocator(self) -> Optional["GMSClientMemoryManager"]: def get_allocator(self) -> Optional["GMSClientMemoryManager"]:
return self._allocator return self._weights_allocator
@contextmanager @contextmanager
def region(self, tag: str, enable_cpu_backup: bool): def region(self, tag: str, enable_cpu_backup: bool):
"""Mark allocation region with tag.""" """Mark allocation region with tag."""
if not self._is_weights_tag(tag): if self._is_weights_tag(tag):
with self._torch_impl.region(tag=tag, enable_cpu_backup=enable_cpu_backup): if self._mode == "read":
yield yield
return return
if self._mode == "read": target_device = torch.device("cuda", self._device_index)
yield with gms_use_mem_pool("weights", target_device):
yield
return return
if self._mem_pool is None: if tag == "kv_cache":
raise RuntimeError("GMS mempool is None in WRITE mode") target_device = torch.device("cuda", self._device_index)
with gms_use_mem_pool("kv_cache", target_device):
yield
return
target_device = torch.device("cuda", self._device_index) with self._torch_impl.region(tag=tag, enable_cpu_backup=enable_cpu_backup):
with torch.cuda.use_mem_pool(self._mem_pool, device=target_device):
yield yield
def pause(self, tag: Optional[str] = None) -> None: def pause(self, tag: Optional[str] = None) -> None:
...@@ -136,7 +133,9 @@ class GMSMemorySaverImpl: ...@@ -136,7 +133,9 @@ class GMSMemorySaverImpl:
return return
if tag is None or self._is_weights_tag(tag): if tag is None or self._is_weights_tag(tag):
self._pause_weights() self._pause_weights()
if tag is None or not self._is_weights_tag(tag): if tag is None or tag == "kv_cache":
self._pause_kv_cache()
if tag is None or (not self._is_weights_tag(tag) and tag != "kv_cache"):
self._torch_impl.pause(tag=tag) self._torch_impl.pause(tag=tag)
def resume(self, tag: Optional[str] = None) -> None: def resume(self, tag: Optional[str] = None) -> None:
...@@ -144,39 +143,56 @@ class GMSMemorySaverImpl: ...@@ -144,39 +143,56 @@ class GMSMemorySaverImpl:
return return
if tag is None or self._is_weights_tag(tag): if tag is None or self._is_weights_tag(tag):
self._resume_weights() self._resume_weights()
if tag is None or not self._is_weights_tag(tag): if tag is None or tag == "kv_cache":
self._resume_kv_cache()
if tag is None or (not self._is_weights_tag(tag) and tag != "kv_cache"):
self._torch_impl.resume(tag=tag) self._torch_impl.resume(tag=tag)
def _pause_weights(self) -> None: def _pause_weights(self) -> None:
if self._allocator is None: if self._weights_allocator is None:
return return
if self._allocator.is_unmapped: if self._weights_allocator.is_unmapped:
return return
logger.info("[GMS] Unmapping weights (VA-stable)") logger.info("[GMS] Unmapping weights (VA-stable)")
self._allocator.unmap_all_vas() self._weights_allocator.unmap_all_vas()
self._allocator.disconnect() self._weights_allocator.abort()
def _resume_weights(self) -> None: def _resume_weights(self) -> None:
if self._allocator is None: if self._weights_allocator is None:
return return
if not self._allocator.is_unmapped: if not self._weights_allocator.is_unmapped:
return return
logger.info("[GMS] Remapping weights (VA-stable)") logger.info("[GMS] Remapping weights (VA-stable)")
from gpu_memory_service.common.types import RequestedLockType self._weights_allocator.connect(RequestedLockType.RO)
self._weights_allocator.remap_all_vas()
self._allocator.connect(RequestedLockType.RO) def _pause_kv_cache(self) -> None:
self._allocator.remap_all_vas() if self._kv_cache_allocator.is_unmapped:
return
logger.info("[GMS] Unmapping KV cache")
self._kv_cache_allocator.unmap_all_vas()
self._kv_cache_allocator.abort()
def _resume_kv_cache(self) -> None:
if not self._kv_cache_allocator.is_unmapped:
return
logger.info("[GMS] Remapping KV cache")
self._kv_cache_allocator.connect(RequestedLockType.RW)
self._kv_cache_allocator.reallocate_all_handles(tag="kv_cache")
self._kv_cache_allocator.remap_all_vas()
def finalize_write_mode(self, model: torch.nn.Module) -> None: def finalize_write_mode(self, model: torch.nn.Module) -> None:
"""Finalize write mode: register tensors, commit, and switch to read.""" """Finalize write mode: register tensors, commit, and switch to read."""
if self._mode != "write": if self._mode != "write":
return return
if self._allocator is None: if self._weights_allocator is None:
raise RuntimeError("Allocator is None in WRITE mode") raise RuntimeError("Allocator is None in WRITE mode")
from gpu_memory_service.integrations.common.utils import finalize_gms_write from gpu_memory_service.integrations.common.utils import finalize_gms_write
self._imported_weights_bytes = finalize_gms_write(self._allocator, model) self._imported_weights_bytes = finalize_gms_write(
self._weights_allocator, model
)
self._mode = "read" self._mode = "read"
def set_imported_weights_bytes(self, bytes_count: int) -> None: def set_imported_weights_bytes(self, bytes_count: int) -> None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment