Unverified Commit f3b181a9 authored by Schwinn Saereesitthipitak's avatar Schwinn Saereesitthipitak Committed by GitHub
Browse files

feat(gms): operator-managed GMS checkpoint/restore support (#8153)

parent 091cdb51
......@@ -32,16 +32,15 @@ def patch_empty_cache() -> None:
_original_empty_cache = torch.cuda.empty_cache
def safe_empty_cache() -> None:
active_mapping_count = sum(
1
# Allow empty_cache when all managers are unmapped (sleep/checkpoint)
# or when there are no active VMM mappings with live handles.
has_live_mappings = any(
any(m.handle != 0 for m in manager.mappings.values())
for manager in get_gms_client_memory_managers()
for mapping in manager.mappings.values()
if mapping.handle != 0
)
if active_mapping_count:
logger.warning(
"[GMS] Skipping torch.cuda.empty_cache() - %d active GMS mappings",
active_mapping_count,
if has_live_mappings:
logger.debug(
"[GMS] Skipping torch.cuda.empty_cache() - live VMM mappings active",
)
return
_original_empty_cache()
......
......@@ -36,6 +36,7 @@ keywords = ["llm", "genai", "inference", "nvidia", "gpu", "memory", "dynamo"]
[project.scripts]
gpu-memory-service = "gpu_memory_service.cli.runner:main"
gms-storage-client = "gpu_memory_service.cli.storage_runner:main"
[project.optional-dependencies]
test = [
......
......@@ -98,6 +98,12 @@ setup(
package_data={
"gpu_memory_service.client.torch.extensions": ["*.cpp"],
},
entry_points={
"console_scripts": [
"gpu-memory-service=gpu_memory_service.cli.runner:main",
"gms-storage-client=gpu_memory_service.cli.storage_runner:main",
]
},
ext_modules=_create_ext_modules(),
cmdclass={"build_ext": BuildExtension},
)
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import base64
import errno
import json
import os
import queue
import threading
from collections import defaultdict
from concurrent.futures import CancelledError, ThreadPoolExecutor
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
if TYPE_CHECKING:
import torch
from gpu_memory_service.snapshot.model import AllocationEntry, SaveManifest
class ShardWriter:
"""Packs allocation bytes sequentially into large binary shard files.
This is a single-threaded utility for streaming writes. The parallel save
path in GMSStorageClient._write_shards assigns allocations to shards via
plan_shard_layout and writes each shard file concurrently, so it does not
use ShardWriter directly. ShardWriter is kept as a public utility for
callers that want a simple sequential writer.
"""
def __init__(self, shards_dir: str, shard_size_bytes: int = 4 * 1024**3) -> None:
self._shards_dir = shards_dir
self._shard_size = shard_size_bytes
self._shard_idx = -1
self._current_offset = 0
self._current_file: Optional[Any] = None
self._current_rel_path: str = ""
os.makedirs(shards_dir, exist_ok=True)
def _roll_shard(self) -> None:
if self._current_file is not None:
self._current_file.close()
self._shard_idx += 1
filename = f"shard_{self._shard_idx:04d}.bin"
abs_path = os.path.join(self._shards_dir, filename)
self._current_file = open(abs_path, "wb")
self._current_rel_path = os.path.join("shards", filename)
self._current_offset = 0
def write(self, tensor: torch.Tensor) -> Tuple[str, int]:
cpu = tensor.cpu() if hasattr(tensor, "is_cuda") and tensor.is_cuda else tensor
if hasattr(cpu, "is_contiguous") and not cpu.is_contiguous():
cpu = cpu.contiguous()
arr = cpu.numpy()
size = arr.nbytes
if self._current_file is None or (
self._current_offset > 0 and self._current_offset + size > self._shard_size
):
self._roll_shard()
offset = self._current_offset
arr.tofile(self._current_file)
self._current_offset += size
return self._current_rel_path, offset
def close(self) -> None:
if self._current_file is not None:
self._current_file.close()
self._current_file = None
def __enter__(self) -> "ShardWriter":
return self
def __exit__(self, *_: Any) -> None:
self.close()
def read_shard_sequential(
abs_path: str,
sorted_entries: List[AllocationEntry],
device: int,
*,
pin_memory: bool = False,
os_module=os,
np_module=None,
torch_module=None,
logger=None,
) -> Dict[str, torch.Tensor]:
"""Read one shard file front-to-back without seeking."""
if np_module is None or torch_module is None:
raise RuntimeError("numpy and torch modules are required to read shards")
result: Dict[str, torch.Tensor] = {}
device_str = f"cuda:{device}" if device >= 0 else "cpu"
if abs_path.endswith(".pt"):
if len(sorted_entries) != 1:
raise RuntimeError(
f"Expected exactly 1 entry for legacy .pt file, got "
f"{len(sorted_entries)}: {abs_path}"
)
entry = sorted_entries[0]
result[entry.allocation_id] = torch_module.load(
abs_path,
weights_only=True,
map_location=device_str,
)
return result
odirect_flag = getattr(os_module, "O_DIRECT", None)
if odirect_flag is not None:
fd: Optional[int] = None
done = 0
try:
total_size = sum(entry.aligned_size for entry in sorted_entries)
# Avoid torch.empty(pin_memory=True): cudaHostAlloc is ~1-3 s/GiB
# and dominates wall time. Plain numpy gives good throughput since
# PCIe H2D bandwidth far exceeds network disk bandwidth.
shard_t = None
arr = np_module.empty(total_size, dtype=np_module.uint8)
fd = os_module.open(abs_path, os_module.O_RDONLY | odirect_flag)
try:
mv = memoryview(arr)
try:
while done < total_size:
read = os_module.readv(fd, [mv[done:]])
if read == 0:
raise RuntimeError(
f"Unexpected EOF in O_DIRECT read from {abs_path}: "
f"got {done} of {total_size} bytes"
)
done += read
finally:
mv.release()
finally:
os_module.close(fd)
offset = 0
for entry in sorted_entries:
size = entry.aligned_size
if shard_t is not None:
tensor = shard_t[offset : offset + size]
else:
tensor = torch_module.from_numpy(arr[offset : offset + size])
if device >= 0:
tensor = tensor.to(device_str)
result[entry.allocation_id] = tensor
offset += size
return result
except OSError as exc:
fallback_errnos = {errno.EINVAL, errno.EOPNOTSUPP}
if fd is not None and exc.errno not in fallback_errnos:
raise
result.clear()
if logger is not None:
if fd is None:
logger.debug(
"O_DIRECT unsupported on %s (errno %s); using buffered reads",
abs_path,
exc.errno,
)
else:
logger.debug(
"O_DIRECT read on %s hit EINVAL after %d/%d bytes; using buffered reads",
abs_path,
done,
total_size,
)
if sorted_entries and sorted_entries[0].tensor_offset != 0:
raise RuntimeError(
f"Buffered shard read requires entries starting at offset 0, "
f"got {sorted_entries[0].tensor_offset} in {abs_path}"
)
with open(abs_path, "rb") as handle:
for entry in sorted_entries:
raw = handle.read(entry.aligned_size)
if len(raw) != entry.aligned_size:
raise RuntimeError(
f"Short read from {abs_path} at offset {entry.tensor_offset}: "
f"expected {entry.aligned_size} bytes, got {len(raw)}"
)
arr = np_module.frombuffer(raw, dtype=np_module.uint8).copy()
tensor = torch_module.from_numpy(arr)
if device >= 0:
tensor = tensor.to(device_str)
result[entry.allocation_id] = tensor
return result
def decode_metadata(raw_meta: Dict[str, Any]) -> Dict[str, Dict[str, Any]]:
return {
key: {
"allocation_id": entry["allocation_id"],
"offset_bytes": int(entry["offset_bytes"]),
"value": base64.b64decode(entry["value"]),
}
for key, entry in raw_meta.items()
}
def group_entries_by_shard(
allocations: List[AllocationEntry],
) -> Dict[str, List[AllocationEntry]]:
groups: Dict[str, List[AllocationEntry]] = defaultdict(list)
for entry in allocations:
groups[entry.tensor_file].append(entry)
for entries in groups.values():
entries.sort(key=lambda entry: entry.tensor_offset)
return dict(groups)
def plan_shard_layout(
allocations_info: List[Dict[str, Any]],
shard_size_bytes: int,
) -> List[Tuple[int, int]]:
result: List[Tuple[int, int]] = []
shard_idx = -1
current_offset = 0
started = False
for alloc in allocations_info:
size = int(alloc["aligned_size"])
if not started or (
current_offset > 0 and current_offset + size > shard_size_bytes
):
shard_idx += 1
current_offset = 0
started = True
result.append((shard_idx, current_offset))
current_offset += size
return result
def _put_entry(
work_q: queue.Queue[Optional[Tuple[AllocationEntry, "torch.Tensor"]]],
entry: AllocationEntry,
tensor: "torch.Tensor",
cancel_event: Optional[threading.Event],
abs_path: str,
) -> None:
"""Put one entry into the work queue, respecting cancellation."""
while True:
if cancel_event is not None and cancel_event.is_set():
raise CancelledError(f"shard read cancelled: {abs_path}")
try:
work_q.put((entry, tensor), timeout=0.1)
return
except queue.Full:
pass
# 64 MiB chunks for parallel preadv — gives high effective iodepth on NFS
# while keeping each syscall large enough to amortize overhead.
_CHUNK_SIZE = 64 * 1024 * 1024
# How many preadv calls to keep in-flight per shard. On Vast NFS each
# outstanding preadv becomes a separate NFS READ RPC, so higher iodepth
# means more network-level parallelism from a single file descriptor.
_IO_DEPTH = 16
def _preadv_chunk(
fd: int,
buf: memoryview,
file_offset: int,
size: int,
os_module,
) -> None:
"""Read exactly *size* bytes from *fd* at *file_offset* into *buf*."""
done = 0
while done < size:
n = os_module.preadv(fd, [buf[done:size]], file_offset + done)
if n == 0:
raise RuntimeError(
f"Unexpected EOF in preadv at offset {file_offset + done}"
)
done += n
def read_shard_streaming_to_queue(
abs_path: str,
sorted_entries: List[AllocationEntry],
work_q: queue.Queue[Optional[Tuple[AllocationEntry, "torch.Tensor"]]],
*,
pin_memory: bool,
cancel_event: Optional[threading.Event] = None,
os_module=os,
np_module=None,
torch_module=None,
logger=None,
) -> int:
"""Read a shard via parallel O_DIRECT preadv calls, streaming entries
to *work_q* as they become readable.
Multiple chunks are read concurrently from different file offsets to
achieve high effective I/O depth on network filesystems (e.g. Vast NFS)
where single-threaded synchronous reads severely under-utilize bandwidth.
"""
if not sorted_entries:
return 0
if np_module is None or torch_module is None:
raise RuntimeError("numpy and torch modules are required")
total_size = sum(e.aligned_size for e in sorted_entries)
# Allocate a buffer for the whole shard. We intentionally avoid
# torch.empty(pin_memory=True) because cudaHostAlloc is extremely
# slow (~1-3 s per GiB) and dominates wall time for large shards.
# A plain numpy buffer still gives good H2D throughput (the copy is
# synchronous but PCIe bandwidth ≫ disk bandwidth).
shard_t = None
shard_arr = np_module.empty(total_size, dtype=np_module.uint8)
odirect_flag = getattr(os_module, "O_DIRECT", None)
preadv_fn = getattr(os_module, "preadv", None)
if odirect_flag is not None and preadv_fn is not None:
fd: Optional[int] = None
io_pool: Optional[ThreadPoolExecutor] = None
try:
fd = os_module.open(abs_path, os_module.O_RDONLY | odirect_flag)
mv = memoryview(shard_arr)
# Build aligned chunk list covering the full shard.
chunk_size = _CHUNK_SIZE
chunks: List[Tuple[int, int]] = [] # (offset, size)
off = 0
while off < total_size:
sz = min(chunk_size, total_size - off)
chunks.append((off, sz))
off += sz
# chunks_done[i] is set when chunk i finishes (success or error).
chunks_done = [threading.Event() for _ in chunks]
chunk_errors: List[BaseException] = []
def _read_chunk(idx: int) -> None:
try:
c_off, c_sz = chunks[idx]
_preadv_chunk(fd, mv[c_off : c_off + c_sz], c_off, c_sz, os_module)
except BaseException as exc:
chunk_errors.append(exc)
finally:
chunks_done[idx].set()
# Submit chunk reads with bounded concurrency.
io_pool = ThreadPoolExecutor(max_workers=min(_IO_DEPTH, len(chunks)))
for i in range(len(chunks)):
io_pool.submit(_read_chunk, i)
# Stream entries to the work queue as their data arrives.
def _chunk_for_byte(byte_off: int) -> int:
return byte_off // chunk_size
for entry_idx in range(len(sorted_entries)):
if cancel_event is not None and cancel_event.is_set():
raise CancelledError(f"shard read cancelled: {abs_path}")
entry = sorted_entries[entry_idx]
start_chunk = _chunk_for_byte(entry.tensor_offset)
end_chunk = _chunk_for_byte(
entry.tensor_offset + entry.aligned_size - 1
)
for ci in range(start_chunk, end_chunk + 1):
chunks_done[ci].wait()
if chunk_errors:
raise chunk_errors[0]
eoff = entry.tensor_offset
if shard_t is not None:
tensor = shard_t[eoff : eoff + entry.aligned_size]
else:
tensor = torch_module.from_numpy(
shard_arr[eoff : eoff + entry.aligned_size]
)
_put_entry(work_q, entry, tensor, cancel_event, abs_path)
if chunk_errors:
raise chunk_errors[0]
return len(sorted_entries)
except OSError as exc:
fallback_errnos = {errno.EINVAL, errno.EOPNOTSUPP}
if exc.errno not in fallback_errnos:
raise
if logger is not None:
logger.debug(
"O_DIRECT preadv failed on %s (errno %s); "
"falling back to buffered read",
abs_path,
exc.errno,
)
finally:
if io_pool is not None:
io_pool.shutdown(wait=False)
io_pool = None
if fd is not None:
os_module.close(fd)
fd = None
# Fallback: buffered full-shard read, then queue all entries.
with open(abs_path, "rb") as handle:
raw = handle.read()
arr = np_module.frombuffer(raw, dtype=np_module.uint8).copy()
for entry in sorted_entries:
off = entry.tensor_offset
tensor = torch_module.from_numpy(arr[off : off + entry.aligned_size])
_put_entry(work_q, entry, tensor, cancel_event, abs_path)
return len(sorted_entries)
def read_shard_to_queue(
abs_path: str,
sorted_entries: List[AllocationEntry],
work_q: queue.Queue[Optional[Tuple[AllocationEntry, torch.Tensor]]],
*,
pin_memory: bool,
read_shard,
cancel_event: Optional[threading.Event] = None,
) -> int:
shard_result = read_shard(
abs_path,
sorted_entries,
-1,
pin_memory=pin_memory,
)
for entry in sorted_entries:
_put_entry(
work_q, entry, shard_result[entry.allocation_id], cancel_event, abs_path
)
return len(sorted_entries)
def load_manifest_and_metadata(
input_dir: str,
) -> Tuple[SaveManifest, Dict[str, Dict[str, Any]]]:
manifest_path = os.path.join(input_dir, "manifest.json")
with open(manifest_path, encoding="utf-8") as handle:
manifest = SaveManifest.from_dict(json.load(handle))
metadata_path = os.path.join(input_dir, "gms_metadata.json")
raw_meta: Dict[str, Any] = {}
if os.path.exists(metadata_path):
with open(metadata_path, encoding="utf-8") as handle:
raw_meta = json.load(handle)
return manifest, decode_metadata(raw_meta)
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
from dataclasses import asdict, dataclass, field
from typing import Any, Dict, List
CURRENT_VERSION = "1.0"
@dataclass(frozen=True)
class AllocationEntry:
"""Immutable record of one dumped allocation."""
allocation_id: str
size: int
aligned_size: int
tag: str
tensor_file: str
tensor_offset: int = 0
@dataclass
class SaveManifest:
"""Manifest for a GMS dump directory."""
version: str
timestamp: float
layout_hash: str
device: int
allocations: List[AllocationEntry] = field(default_factory=list)
def to_dict(self) -> Dict[str, Any]:
return {
"version": self.version,
"timestamp": self.timestamp,
"layout_hash": self.layout_hash,
"device": self.device,
"allocations": [asdict(a) for a in self.allocations],
}
@classmethod
def from_dict(cls, payload: Dict[str, Any]) -> "SaveManifest":
version = payload["version"]
if version != CURRENT_VERSION:
raise ValueError(
f"Unsupported manifest version {version!r} "
f"(expected {CURRENT_VERSION!r})"
)
allocations = [
AllocationEntry(
allocation_id=entry["allocation_id"],
size=entry["size"],
aligned_size=entry["aligned_size"],
tag=entry["tag"],
tensor_file=entry["tensor_file"],
tensor_offset=entry.get("tensor_offset", 0),
)
for entry in payload.get("allocations", [])
]
return cls(
version=payload["version"],
timestamp=payload["timestamp"],
layout_hash=payload["layout_hash"],
device=payload["device"],
allocations=allocations,
)
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import queue
import threading
from concurrent.futures import Future, ThreadPoolExecutor
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
if TYPE_CHECKING:
import torch
from gpu_memory_service.snapshot.model import AllocationEntry
WORK_QUEUE_DEPTH_MULTIPLIER = 4
@dataclass
class RestorePipelineContext:
"""Mutable state shared across disk, copy, and Phase A restore stages."""
worker_count: int
use_streams: bool
device: int
work_q: queue.Queue[Optional[Tuple[AllocationEntry, torch.Tensor]]]
va_events: Dict[str, threading.Event]
streams: List[torch.cuda.Stream]
cancel_event: threading.Event = field(default_factory=threading.Event)
vas: Dict[str, int] = field(default_factory=dict)
staged_srcs: List[torch.Tensor] = field(default_factory=list)
copy_errors: List[BaseException] = field(default_factory=list)
lock: threading.Lock = field(default_factory=threading.Lock)
@classmethod
def build(
cls,
allocations: List[AllocationEntry],
worker_count: int,
*,
device: int,
use_streams: bool,
torch_module,
) -> "RestorePipelineContext":
streams = (
[torch_module.cuda.Stream(device=device) for _ in range(worker_count)]
if use_streams
else []
)
return cls(
worker_count=worker_count,
use_streams=use_streams,
device=device,
work_q=queue.Queue(maxsize=worker_count * WORK_QUEUE_DEPTH_MULTIPLIER),
va_events={entry.allocation_id: threading.Event() for entry in allocations},
streams=streams,
)
@dataclass
class RestorePipelineResources:
"""Live restore pipeline resources that must be torn down together."""
ctx: RestorePipelineContext
disk_pool: ThreadPoolExecutor
disk_futures: Dict[Future[int], str]
copy_threads: List[threading.Thread]
active: bool = True
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""GMS storage client: save GMS state to disk and load it back."""
from __future__ import annotations
import base64
import json
import logging
import os
import queue
import threading
import time
from collections import defaultdict
from concurrent.futures import CancelledError, Future, ThreadPoolExecutor, as_completed
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
from gpu_memory_service.snapshot.disk import ( # noqa: F401 re-exported for external callers
ShardWriter as _ShardWriter,
)
from gpu_memory_service.snapshot.disk import decode_metadata as _decode_metadata_impl
from gpu_memory_service.snapshot.disk import (
group_entries_by_shard as _group_entries_by_shard_impl,
)
from gpu_memory_service.snapshot.disk import (
load_manifest_and_metadata as _load_manifest_and_metadata_impl,
)
from gpu_memory_service.snapshot.disk import (
plan_shard_layout as _plan_shard_layout_impl,
)
from gpu_memory_service.snapshot.disk import (
read_shard_sequential as _read_shard_sequential_impl,
)
from gpu_memory_service.snapshot.disk import (
read_shard_to_queue as _read_shard_to_queue_impl,
)
from gpu_memory_service.snapshot.model import CURRENT_VERSION as _CURRENT_VERSION
from gpu_memory_service.snapshot.model import AllocationEntry, SaveManifest
from gpu_memory_service.snapshot.restore import (
RestorePipelineContext as _RestorePipelineContext,
)
from gpu_memory_service.snapshot.restore import (
RestorePipelineResources as _RestorePipelineResources,
)
logger = logging.getLogger(__name__)
try:
from gpu_memory_service.client.memory_manager import GMSClientMemoryManager
from gpu_memory_service.client.torch.tensor import _tensor_from_pointer
from gpu_memory_service.common.locks import RequestedLockType
_GMS_IMPORTS_AVAILABLE = True
except ImportError:
_GMS_IMPORTS_AVAILABLE = False
GMSClientMemoryManager = None # type: ignore[assignment,misc]
_tensor_from_pointer = None # type: ignore[assignment]
RequestedLockType = None # type: ignore[assignment]
try:
import torch
_TORCH_AVAILABLE = True
except ImportError:
_TORCH_AVAILABLE = False
torch = None # type: ignore[assignment]
def _read_shard_sequential(
abs_path: str,
sorted_entries: List[AllocationEntry],
device: int,
pin_memory: bool = False,
) -> Dict[str, "torch.Tensor"]:
"""Facade wrapper kept for test patchability and backwards compatibility."""
return _read_shard_sequential_impl(
abs_path,
sorted_entries,
device,
pin_memory=pin_memory,
os_module=os,
np_module=np,
torch_module=torch,
logger=logger,
)
def _decode_metadata(raw_meta: Dict[str, Any]) -> Dict[str, Dict[str, Any]]:
# Re-exported for external callers (e.g. multi_ssd_bench.py).
return _decode_metadata_impl(raw_meta)
def _group_entries_by_shard(
allocations: List[AllocationEntry],
) -> Dict[str, List[AllocationEntry]]:
return _group_entries_by_shard_impl(allocations)
def _allocation_record(alloc: Any) -> Dict[str, Any]:
if isinstance(alloc, dict):
return alloc
return {
"allocation_id": str(alloc.allocation_id),
"size": int(alloc.size),
"aligned_size": int(alloc.aligned_size),
"tag": str(alloc.tag),
"layout_slot": int(alloc.layout_slot),
}
def _plan_shard_layout(
allocations_info: List[Dict[str, Any]],
shard_size_bytes: int,
) -> List[Tuple[int, int]]:
return _plan_shard_layout_impl(allocations_info, shard_size_bytes)
def _read_shard_to_queue(
abs_path: str,
sorted_entries: List[AllocationEntry],
work_q: "queue.Queue[Optional[Tuple[AllocationEntry, 'torch.Tensor']]]",
*,
pin_memory: bool,
cancel_event: Optional[threading.Event] = None,
) -> int:
return _read_shard_to_queue_impl(
abs_path,
sorted_entries,
work_q,
pin_memory=pin_memory,
read_shard=_read_shard_sequential,
cancel_event=cancel_event,
)
def _load_manifest_and_metadata(
input_dir: str,
) -> Tuple[SaveManifest, Dict[str, Dict[str, Any]]]:
return _load_manifest_and_metadata_impl(input_dir)
class GMSStorageClient:
"""Dump and restore GMS state to/from disk."""
def __init__(
self,
output_dir: Optional[str] = None,
socket_path: Optional[str] = None,
device: int = 0,
*,
timeout_ms: Optional[int] = None,
shard_size_bytes: int = 4 * 1024**3,
) -> None:
self.output_dir = output_dir
self.device = device
self._timeout_ms = timeout_ms
self._shard_size = shard_size_bytes
if socket_path is None:
from gpu_memory_service.common.utils import get_socket_path
socket_path = get_socket_path(device)
self._socket_path = socket_path
def save(self, max_workers: int = 4) -> SaveManifest:
"""Connect to GMS in RO mode and save all allocations + metadata to disk."""
self._validate_save_request()
output_dir, shards_dir = self._prepare_output_dir()
mm = GMSClientMemoryManager(self._socket_path, device=self.device)
try:
mm.connect(RequestedLockType.RO, timeout_ms=self._timeout_ms)
layout_hash = mm.get_memory_layout_hash()
if not layout_hash:
raise RuntimeError(
"GMS server has no committed weights; nothing to dump"
)
allocations_info = [
_allocation_record(alloc) for alloc in mm.list_handles()
]
va_list = self._import_source_mappings(mm, allocations_info)
entries = self._write_shards(
shards_dir,
allocations_info,
va_list,
max_workers=max_workers,
)
metadata = self._save_metadata(mm)
except Exception:
mm.close(best_effort=True)
raise
self._write_json(os.path.join(output_dir, "gms_metadata.json"), metadata)
manifest = SaveManifest(
version=_CURRENT_VERSION,
timestamp=time.time(),
layout_hash=layout_hash,
device=self.device,
allocations=entries,
)
self._write_json(os.path.join(output_dir, "manifest.json"), manifest.to_dict())
logger.info("Wrote manifest with %d allocations", len(entries))
# Best-effort cleanup; CUDA context may be invalid after
# checkpoint (cuda-checkpoint tears down device state).
mm.close(best_effort=True)
return manifest
def _validate_save_request(self) -> None:
if not _GMS_IMPORTS_AVAILABLE:
raise RuntimeError(
"GMS client imports unavailable (missing cuda-python or torch)"
)
if self.output_dir is None:
raise ValueError(
"output_dir must be set to call save(); pass it to GMSStorageClient()"
)
def _prepare_output_dir(self) -> Tuple[str, str]:
assert self.output_dir is not None
os.makedirs(self.output_dir, exist_ok=True)
shards_dir = os.path.join(self.output_dir, "shards")
os.makedirs(shards_dir, exist_ok=True)
for name in os.listdir(shards_dir):
if name.startswith("shard_") and name.endswith(".bin"):
os.unlink(os.path.join(shards_dir, name))
return self.output_dir, shards_dir
def _import_source_mappings(
self,
mm: Any,
allocations_info: List[Dict[str, Any]],
) -> List[int]:
va_list = [
mm.create_mapping(allocation_id=alloc["allocation_id"])
for alloc in allocations_info
]
logger.info("Phase A complete: imported %d allocation VAs", len(va_list))
return va_list
def _write_shards(
self,
shards_dir: str,
allocations_info: List[Dict[str, Any]],
va_list: List[int],
*,
max_workers: int,
) -> List[AllocationEntry]:
layout = _plan_shard_layout(allocations_info, self._shard_size)
shard_groups: Dict[int, List[Tuple[int, int]]] = defaultdict(list)
for index, (shard_idx, byte_offset) in enumerate(layout):
shard_groups[shard_idx].append((index, byte_offset))
entries: List[Optional[AllocationEntry]] = [None] * len(allocations_info)
def _write_one_shard(
shard_idx: int, alloc_pairs: List[Tuple[int, int]]
) -> None:
filename = f"shard_{shard_idx:04d}.bin"
abs_path = os.path.join(shards_dir, filename)
rel_path = os.path.join("shards", filename)
with open(abs_path, "wb") as handle:
for index, byte_offset in alloc_pairs:
alloc = allocations_info[index]
aligned_size = int(alloc["aligned_size"])
tensor = _tensor_from_pointer(
va_list[index],
[aligned_size],
[1],
torch.uint8,
self.device,
)
tensor.cpu().numpy().tofile(handle)
entries[index] = AllocationEntry(
allocation_id=alloc["allocation_id"],
size=int(alloc["size"]),
aligned_size=aligned_size,
tag=str(alloc.get("tag", "default")),
tensor_file=rel_path,
tensor_offset=byte_offset,
)
with ThreadPoolExecutor(max_workers=max_workers) as pool:
futures = {
pool.submit(_write_one_shard, shard_idx, alloc_pairs): shard_idx
for shard_idx, alloc_pairs in shard_groups.items()
}
for future in as_completed(futures):
future.result()
missing = sum(1 for entry in entries if entry is None)
if missing:
raise RuntimeError(
f"BUG: {missing} allocation(s) missing after shard writers completed"
)
logger.info("Phase B complete: wrote %d shards", len(shard_groups))
return [entry for entry in entries if entry is not None]
def _write_json(self, path: str, payload: Dict[str, Any]) -> None:
with open(path, "w", encoding="utf-8") as handle:
json.dump(payload, handle, indent=2)
def _run_restore_copy_worker(
self,
ctx: _RestorePipelineContext,
stream_idx: int,
) -> None:
while True:
try:
item = ctx.work_q.get(timeout=0.1)
except queue.Empty:
if ctx.cancel_event.is_set():
return
continue
if item is None:
return
entry, src = item
try:
while not ctx.va_events[entry.allocation_id].wait(timeout=0.1):
if ctx.cancel_event.is_set():
return
dst = _tensor_from_pointer(
ctx.vas[entry.allocation_id],
[entry.aligned_size],
[1],
torch.uint8,
self.device,
)
if ctx.streams:
with torch.cuda.stream(ctx.streams[stream_idx]):
dst.copy_(src, non_blocking=src.is_pinned())
else:
dst.copy_(src)
if ctx.use_streams and src.is_pinned():
with ctx.lock:
ctx.staged_srcs.append(src)
except Exception as exc: # noqa: BLE001
with ctx.lock:
ctx.copy_errors.append(exc)
def _start_restore_copy_threads(
self,
ctx: _RestorePipelineContext,
) -> List[threading.Thread]:
threads = [
threading.Thread(
target=self._run_restore_copy_worker,
args=(ctx, index),
daemon=True,
)
for index in range(ctx.worker_count)
]
for thread in threads:
thread.start()
return threads
def _prepare_restore_pipeline(
self,
manifest: SaveManifest,
groups: Dict[str, List[AllocationEntry]],
worker_count: int,
input_dir: str,
) -> _RestorePipelineResources:
ctx = _RestorePipelineContext.build(
manifest.allocations,
worker_count,
device=self.device,
use_streams=_TORCH_AVAILABLE and torch.cuda.is_available(),
torch_module=torch,
)
copy_threads = self._start_restore_copy_threads(ctx)
disk_pool = ThreadPoolExecutor(max_workers=worker_count)
disk_futures = {
disk_pool.submit(
_read_shard_to_queue,
os.path.join(input_dir, rel_path),
sorted_entries,
ctx.work_q,
pin_memory=ctx.use_streams,
cancel_event=ctx.cancel_event,
): rel_path
for rel_path, sorted_entries in groups.items()
}
return _RestorePipelineResources(
ctx=ctx,
disk_pool=disk_pool,
disk_futures=disk_futures,
copy_threads=copy_threads,
)
def _allocate_restore_mappings(
self,
mm: Any,
manifest: SaveManifest,
ctx: _RestorePipelineContext,
) -> Dict[str, str]:
id_map: Dict[str, str] = {}
for entry in manifest.allocations:
old_id = entry.allocation_id
va = mm.create_mapping(size=entry.size, tag=entry.tag)
id_map[old_id] = mm.mappings[va].allocation_id
ctx.vas[old_id] = va
ctx.va_events[old_id].set()
logger.info(
"Phase A complete: allocated %d GMS VAs; waiting for disk/copy pipeline",
len(ctx.vas),
)
return id_map
def _await_disk_reads(self, disk_futures: Dict[Future[int], str]) -> None:
for future in as_completed(disk_futures):
rel_path = disk_futures[future]
try:
future.result()
except CancelledError:
pass
except Exception as exc:
raise RuntimeError(f"Failed to load shard {rel_path}: {exc}") from exc
def _stop_restore_copy_threads(
self,
ctx: _RestorePipelineContext,
threads: List[threading.Thread],
*,
drain_queue: bool = False,
) -> None:
if drain_queue:
self._drain_restore_queue(ctx)
for _ in threads:
if drain_queue:
# Cancel path: workers may have exited, so drain to make room.
while True:
try:
ctx.work_q.put(None, timeout=0.1)
break
except queue.Full:
self._drain_restore_queue(ctx)
else:
# Normal path: disk reads are done and workers are alive; block
# until a slot opens rather than spinning with a timeout.
ctx.work_q.put(None)
for thread in threads:
thread.join()
def _drain_restore_queue(self, ctx: _RestorePipelineContext) -> None:
while True:
try:
ctx.work_q.get_nowait()
except queue.Empty:
return
def _cancel_restore_pipeline(self, ctx: _RestorePipelineContext) -> None:
ctx.cancel_event.set()
for event in ctx.va_events.values():
event.set()
self._drain_restore_queue(ctx)
def _finalize_restore_pipeline(self, ctx: _RestorePipelineContext) -> None:
if ctx.use_streams:
torch.cuda.synchronize(device=self.device)
ctx.staged_srcs.clear()
if ctx.copy_errors:
raise RuntimeError(
f"Failed to copy restored data to GMS: {ctx.copy_errors[0]}"
)
def _drain_restore_pipeline(self, resources: _RestorePipelineResources) -> None:
disk_error: Optional[BaseException] = None
finalize_error: Optional[BaseException] = None
drain_queue = False
try:
self._await_disk_reads(resources.disk_futures)
except Exception as exc:
disk_error = exc
self._cancel_restore_pipeline(resources.ctx)
drain_queue = True
resources.disk_pool.shutdown(wait=True, cancel_futures=True)
else:
resources.disk_pool.shutdown(wait=True)
try:
self._stop_restore_copy_threads(
resources.ctx,
resources.copy_threads,
drain_queue=drain_queue,
)
finally:
resources.active = False
try:
self._finalize_restore_pipeline(resources.ctx)
except Exception as exc: # noqa: BLE001
finalize_error = exc
if disk_error is not None:
raise disk_error
if finalize_error is not None:
raise finalize_error
def _shutdown_restore_pipeline(
self,
resources: _RestorePipelineResources,
) -> None:
if not resources.active:
return
self._cancel_restore_pipeline(resources.ctx)
resources.disk_pool.shutdown(wait=True, cancel_futures=True)
self._stop_restore_copy_threads(
resources.ctx,
resources.copy_threads,
drain_queue=True,
)
resources.active = False
# Synchronize async copies to prevent use-after-free of staged pinned
# buffers, but suppress copy errors — the caller already has an error
# to propagate and we must not mask it.
try:
self._finalize_restore_pipeline(resources.ctx)
except Exception: # noqa: BLE001
self._logger.warning(
"cleanup failed during restore error handling",
exc_info=True,
)
def load_to_gms(
self,
input_dir: str,
*,
max_workers: int = 4,
clear_existing: bool = True,
) -> Dict[str, str]:
if not _GMS_IMPORTS_AVAILABLE:
raise RuntimeError(
"GMS client imports unavailable (missing cuda-python or torch)"
)
manifest, saved_metadata = _load_manifest_and_metadata(input_dir)
groups = _group_entries_by_shard(manifest.allocations)
worker_count = max(1, min(max_workers, len(groups) or 1))
with GMSClientMemoryManager(self._socket_path, device=self.device) as mm:
mm.connect(RequestedLockType.RW, timeout_ms=self._timeout_ms)
if clear_existing:
logger.info("RW connect cleared any previously committed GMS state")
resources = self._prepare_restore_pipeline(
manifest,
groups,
worker_count,
input_dir,
)
try:
id_map = self._allocate_restore_mappings(mm, manifest, resources.ctx)
self._drain_restore_pipeline(resources)
except Exception:
self._shutdown_restore_pipeline(resources)
raise
logger.info(
"Phase B complete: streamed %d allocations to GMS memory",
len(manifest.allocations),
)
self._restore_metadata(mm, saved_metadata, id_map)
if not mm.commit():
raise RuntimeError("GMS commit failed after restore")
logger.info(
"load_to_gms complete: %d allocations, %d metadata keys",
len(id_map),
len(saved_metadata),
)
return id_map
def _restore_metadata(
self,
mm: Any,
saved_metadata: Dict[str, Dict[str, Any]],
id_map: Dict[str, str],
) -> None:
for key, meta in saved_metadata.items():
old_alloc_id = meta["allocation_id"]
new_alloc_id = id_map.get(old_alloc_id, old_alloc_id)
ok = mm.metadata_put(key, new_alloc_id, meta["offset_bytes"], meta["value"])
if not ok:
raise RuntimeError(f"Failed to write metadata key={key!r}")
logger.debug("Restored metadata key=%s -> alloc=%s", key, new_alloc_id)
logger.info("Restored %d metadata keys; committing", len(saved_metadata))
@staticmethod
def load_tensors(
input_dir: str,
device: int = 0,
*,
max_workers: int = 4,
) -> Tuple[Dict[str, "torch.Tensor"], Dict[str, Dict[str, Any]]]:
if not _TORCH_AVAILABLE:
raise RuntimeError("PyTorch is required for load_tensors()")
manifest, metadata = _load_manifest_and_metadata(input_dir)
groups = _group_entries_by_shard(manifest.allocations)
tensors: Dict[str, "torch.Tensor"] = {}
with ThreadPoolExecutor(max_workers=max_workers) as pool:
futures = {
pool.submit(
_read_shard_sequential,
os.path.join(input_dir, rel_path),
sorted_entries,
device,
): rel_path
for rel_path, sorted_entries in groups.items()
}
for future in as_completed(futures):
rel_path = futures[future]
try:
tensors.update(future.result())
except Exception as exc:
raise RuntimeError(
f"Failed to load shard {rel_path}: {exc}"
) from exc
logger.info("Loaded %d allocations from %s", len(tensors), input_dir)
return tensors, metadata
def _save_metadata(self, mm: Any) -> Dict[str, Any]:
result: Dict[str, Any] = {}
for key in mm.metadata_list():
got = mm.metadata_get(key)
if got is None:
logger.warning("Metadata key disappeared during dump: %s", key)
continue
allocation_id, offset_bytes, value = got
result[key] = {
"allocation_id": str(allocation_id),
"offset_bytes": int(offset_bytes),
"value": base64.b64encode(value).decode("ascii"),
}
return result
......@@ -245,9 +245,6 @@ def running_gms(monkeypatch, tmp_path):
server_allocations, "cumem_export_to_shareable_handle", export_fd
)
monkeypatch.setattr(
client_memory_manager, "cuda_set_current_device", lambda device: None
)
monkeypatch.setattr(
client_memory_manager,
"cumem_get_allocation_granularity",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment