Unverified Commit f3b181a9 authored by Schwinn Saereesitthipitak's avatar Schwinn Saereesitthipitak Committed by GitHub
Browse files

feat(gms): operator-managed GMS checkpoint/restore support (#8153)

parent 091cdb51
...@@ -32,16 +32,15 @@ def patch_empty_cache() -> None: ...@@ -32,16 +32,15 @@ def patch_empty_cache() -> None:
_original_empty_cache = torch.cuda.empty_cache _original_empty_cache = torch.cuda.empty_cache
def safe_empty_cache() -> None: def safe_empty_cache() -> None:
active_mapping_count = sum( # Allow empty_cache when all managers are unmapped (sleep/checkpoint)
1 # or when there are no active VMM mappings with live handles.
has_live_mappings = any(
any(m.handle != 0 for m in manager.mappings.values())
for manager in get_gms_client_memory_managers() for manager in get_gms_client_memory_managers()
for mapping in manager.mappings.values()
if mapping.handle != 0
) )
if active_mapping_count: if has_live_mappings:
logger.warning( logger.debug(
"[GMS] Skipping torch.cuda.empty_cache() - %d active GMS mappings", "[GMS] Skipping torch.cuda.empty_cache() - live VMM mappings active",
active_mapping_count,
) )
return return
_original_empty_cache() _original_empty_cache()
......
...@@ -36,6 +36,7 @@ keywords = ["llm", "genai", "inference", "nvidia", "gpu", "memory", "dynamo"] ...@@ -36,6 +36,7 @@ keywords = ["llm", "genai", "inference", "nvidia", "gpu", "memory", "dynamo"]
[project.scripts] [project.scripts]
gpu-memory-service = "gpu_memory_service.cli.runner:main" gpu-memory-service = "gpu_memory_service.cli.runner:main"
gms-storage-client = "gpu_memory_service.cli.storage_runner:main"
[project.optional-dependencies] [project.optional-dependencies]
test = [ test = [
......
...@@ -98,6 +98,12 @@ setup( ...@@ -98,6 +98,12 @@ setup(
package_data={ package_data={
"gpu_memory_service.client.torch.extensions": ["*.cpp"], "gpu_memory_service.client.torch.extensions": ["*.cpp"],
}, },
entry_points={
"console_scripts": [
"gpu-memory-service=gpu_memory_service.cli.runner:main",
"gms-storage-client=gpu_memory_service.cli.storage_runner:main",
]
},
ext_modules=_create_ext_modules(), ext_modules=_create_ext_modules(),
cmdclass={"build_ext": BuildExtension}, cmdclass={"build_ext": BuildExtension},
) )
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import base64
import errno
import json
import os
import queue
import threading
from collections import defaultdict
from concurrent.futures import CancelledError, ThreadPoolExecutor
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
if TYPE_CHECKING:
import torch
from gpu_memory_service.snapshot.model import AllocationEntry, SaveManifest
class ShardWriter:
"""Packs allocation bytes sequentially into large binary shard files.
This is a single-threaded utility for streaming writes. The parallel save
path in GMSStorageClient._write_shards assigns allocations to shards via
plan_shard_layout and writes each shard file concurrently, so it does not
use ShardWriter directly. ShardWriter is kept as a public utility for
callers that want a simple sequential writer.
"""
def __init__(self, shards_dir: str, shard_size_bytes: int = 4 * 1024**3) -> None:
self._shards_dir = shards_dir
self._shard_size = shard_size_bytes
self._shard_idx = -1
self._current_offset = 0
self._current_file: Optional[Any] = None
self._current_rel_path: str = ""
os.makedirs(shards_dir, exist_ok=True)
def _roll_shard(self) -> None:
if self._current_file is not None:
self._current_file.close()
self._shard_idx += 1
filename = f"shard_{self._shard_idx:04d}.bin"
abs_path = os.path.join(self._shards_dir, filename)
self._current_file = open(abs_path, "wb")
self._current_rel_path = os.path.join("shards", filename)
self._current_offset = 0
def write(self, tensor: torch.Tensor) -> Tuple[str, int]:
cpu = tensor.cpu() if hasattr(tensor, "is_cuda") and tensor.is_cuda else tensor
if hasattr(cpu, "is_contiguous") and not cpu.is_contiguous():
cpu = cpu.contiguous()
arr = cpu.numpy()
size = arr.nbytes
if self._current_file is None or (
self._current_offset > 0 and self._current_offset + size > self._shard_size
):
self._roll_shard()
offset = self._current_offset
arr.tofile(self._current_file)
self._current_offset += size
return self._current_rel_path, offset
def close(self) -> None:
if self._current_file is not None:
self._current_file.close()
self._current_file = None
def __enter__(self) -> "ShardWriter":
return self
def __exit__(self, *_: Any) -> None:
self.close()
def read_shard_sequential(
abs_path: str,
sorted_entries: List[AllocationEntry],
device: int,
*,
pin_memory: bool = False,
os_module=os,
np_module=None,
torch_module=None,
logger=None,
) -> Dict[str, torch.Tensor]:
"""Read one shard file front-to-back without seeking."""
if np_module is None or torch_module is None:
raise RuntimeError("numpy and torch modules are required to read shards")
result: Dict[str, torch.Tensor] = {}
device_str = f"cuda:{device}" if device >= 0 else "cpu"
if abs_path.endswith(".pt"):
if len(sorted_entries) != 1:
raise RuntimeError(
f"Expected exactly 1 entry for legacy .pt file, got "
f"{len(sorted_entries)}: {abs_path}"
)
entry = sorted_entries[0]
result[entry.allocation_id] = torch_module.load(
abs_path,
weights_only=True,
map_location=device_str,
)
return result
odirect_flag = getattr(os_module, "O_DIRECT", None)
if odirect_flag is not None:
fd: Optional[int] = None
done = 0
try:
total_size = sum(entry.aligned_size for entry in sorted_entries)
# Avoid torch.empty(pin_memory=True): cudaHostAlloc is ~1-3 s/GiB
# and dominates wall time. Plain numpy gives good throughput since
# PCIe H2D bandwidth far exceeds network disk bandwidth.
shard_t = None
arr = np_module.empty(total_size, dtype=np_module.uint8)
fd = os_module.open(abs_path, os_module.O_RDONLY | odirect_flag)
try:
mv = memoryview(arr)
try:
while done < total_size:
read = os_module.readv(fd, [mv[done:]])
if read == 0:
raise RuntimeError(
f"Unexpected EOF in O_DIRECT read from {abs_path}: "
f"got {done} of {total_size} bytes"
)
done += read
finally:
mv.release()
finally:
os_module.close(fd)
offset = 0
for entry in sorted_entries:
size = entry.aligned_size
if shard_t is not None:
tensor = shard_t[offset : offset + size]
else:
tensor = torch_module.from_numpy(arr[offset : offset + size])
if device >= 0:
tensor = tensor.to(device_str)
result[entry.allocation_id] = tensor
offset += size
return result
except OSError as exc:
fallback_errnos = {errno.EINVAL, errno.EOPNOTSUPP}
if fd is not None and exc.errno not in fallback_errnos:
raise
result.clear()
if logger is not None:
if fd is None:
logger.debug(
"O_DIRECT unsupported on %s (errno %s); using buffered reads",
abs_path,
exc.errno,
)
else:
logger.debug(
"O_DIRECT read on %s hit EINVAL after %d/%d bytes; using buffered reads",
abs_path,
done,
total_size,
)
if sorted_entries and sorted_entries[0].tensor_offset != 0:
raise RuntimeError(
f"Buffered shard read requires entries starting at offset 0, "
f"got {sorted_entries[0].tensor_offset} in {abs_path}"
)
with open(abs_path, "rb") as handle:
for entry in sorted_entries:
raw = handle.read(entry.aligned_size)
if len(raw) != entry.aligned_size:
raise RuntimeError(
f"Short read from {abs_path} at offset {entry.tensor_offset}: "
f"expected {entry.aligned_size} bytes, got {len(raw)}"
)
arr = np_module.frombuffer(raw, dtype=np_module.uint8).copy()
tensor = torch_module.from_numpy(arr)
if device >= 0:
tensor = tensor.to(device_str)
result[entry.allocation_id] = tensor
return result
def decode_metadata(raw_meta: Dict[str, Any]) -> Dict[str, Dict[str, Any]]:
return {
key: {
"allocation_id": entry["allocation_id"],
"offset_bytes": int(entry["offset_bytes"]),
"value": base64.b64decode(entry["value"]),
}
for key, entry in raw_meta.items()
}
def group_entries_by_shard(
allocations: List[AllocationEntry],
) -> Dict[str, List[AllocationEntry]]:
groups: Dict[str, List[AllocationEntry]] = defaultdict(list)
for entry in allocations:
groups[entry.tensor_file].append(entry)
for entries in groups.values():
entries.sort(key=lambda entry: entry.tensor_offset)
return dict(groups)
def plan_shard_layout(
allocations_info: List[Dict[str, Any]],
shard_size_bytes: int,
) -> List[Tuple[int, int]]:
result: List[Tuple[int, int]] = []
shard_idx = -1
current_offset = 0
started = False
for alloc in allocations_info:
size = int(alloc["aligned_size"])
if not started or (
current_offset > 0 and current_offset + size > shard_size_bytes
):
shard_idx += 1
current_offset = 0
started = True
result.append((shard_idx, current_offset))
current_offset += size
return result
def _put_entry(
work_q: queue.Queue[Optional[Tuple[AllocationEntry, "torch.Tensor"]]],
entry: AllocationEntry,
tensor: "torch.Tensor",
cancel_event: Optional[threading.Event],
abs_path: str,
) -> None:
"""Put one entry into the work queue, respecting cancellation."""
while True:
if cancel_event is not None and cancel_event.is_set():
raise CancelledError(f"shard read cancelled: {abs_path}")
try:
work_q.put((entry, tensor), timeout=0.1)
return
except queue.Full:
pass
# 64 MiB chunks for parallel preadv — gives high effective iodepth on NFS
# while keeping each syscall large enough to amortize overhead.
_CHUNK_SIZE = 64 * 1024 * 1024
# How many preadv calls to keep in-flight per shard. On Vast NFS each
# outstanding preadv becomes a separate NFS READ RPC, so higher iodepth
# means more network-level parallelism from a single file descriptor.
_IO_DEPTH = 16
def _preadv_chunk(
fd: int,
buf: memoryview,
file_offset: int,
size: int,
os_module,
) -> None:
"""Read exactly *size* bytes from *fd* at *file_offset* into *buf*."""
done = 0
while done < size:
n = os_module.preadv(fd, [buf[done:size]], file_offset + done)
if n == 0:
raise RuntimeError(
f"Unexpected EOF in preadv at offset {file_offset + done}"
)
done += n
def read_shard_streaming_to_queue(
abs_path: str,
sorted_entries: List[AllocationEntry],
work_q: queue.Queue[Optional[Tuple[AllocationEntry, "torch.Tensor"]]],
*,
pin_memory: bool,
cancel_event: Optional[threading.Event] = None,
os_module=os,
np_module=None,
torch_module=None,
logger=None,
) -> int:
"""Read a shard via parallel O_DIRECT preadv calls, streaming entries
to *work_q* as they become readable.
Multiple chunks are read concurrently from different file offsets to
achieve high effective I/O depth on network filesystems (e.g. Vast NFS)
where single-threaded synchronous reads severely under-utilize bandwidth.
"""
if not sorted_entries:
return 0
if np_module is None or torch_module is None:
raise RuntimeError("numpy and torch modules are required")
total_size = sum(e.aligned_size for e in sorted_entries)
# Allocate a buffer for the whole shard. We intentionally avoid
# torch.empty(pin_memory=True) because cudaHostAlloc is extremely
# slow (~1-3 s per GiB) and dominates wall time for large shards.
# A plain numpy buffer still gives good H2D throughput (the copy is
# synchronous but PCIe bandwidth ≫ disk bandwidth).
shard_t = None
shard_arr = np_module.empty(total_size, dtype=np_module.uint8)
odirect_flag = getattr(os_module, "O_DIRECT", None)
preadv_fn = getattr(os_module, "preadv", None)
if odirect_flag is not None and preadv_fn is not None:
fd: Optional[int] = None
io_pool: Optional[ThreadPoolExecutor] = None
try:
fd = os_module.open(abs_path, os_module.O_RDONLY | odirect_flag)
mv = memoryview(shard_arr)
# Build aligned chunk list covering the full shard.
chunk_size = _CHUNK_SIZE
chunks: List[Tuple[int, int]] = [] # (offset, size)
off = 0
while off < total_size:
sz = min(chunk_size, total_size - off)
chunks.append((off, sz))
off += sz
# chunks_done[i] is set when chunk i finishes (success or error).
chunks_done = [threading.Event() for _ in chunks]
chunk_errors: List[BaseException] = []
def _read_chunk(idx: int) -> None:
try:
c_off, c_sz = chunks[idx]
_preadv_chunk(fd, mv[c_off : c_off + c_sz], c_off, c_sz, os_module)
except BaseException as exc:
chunk_errors.append(exc)
finally:
chunks_done[idx].set()
# Submit chunk reads with bounded concurrency.
io_pool = ThreadPoolExecutor(max_workers=min(_IO_DEPTH, len(chunks)))
for i in range(len(chunks)):
io_pool.submit(_read_chunk, i)
# Stream entries to the work queue as their data arrives.
def _chunk_for_byte(byte_off: int) -> int:
return byte_off // chunk_size
for entry_idx in range(len(sorted_entries)):
if cancel_event is not None and cancel_event.is_set():
raise CancelledError(f"shard read cancelled: {abs_path}")
entry = sorted_entries[entry_idx]
start_chunk = _chunk_for_byte(entry.tensor_offset)
end_chunk = _chunk_for_byte(
entry.tensor_offset + entry.aligned_size - 1
)
for ci in range(start_chunk, end_chunk + 1):
chunks_done[ci].wait()
if chunk_errors:
raise chunk_errors[0]
eoff = entry.tensor_offset
if shard_t is not None:
tensor = shard_t[eoff : eoff + entry.aligned_size]
else:
tensor = torch_module.from_numpy(
shard_arr[eoff : eoff + entry.aligned_size]
)
_put_entry(work_q, entry, tensor, cancel_event, abs_path)
if chunk_errors:
raise chunk_errors[0]
return len(sorted_entries)
except OSError as exc:
fallback_errnos = {errno.EINVAL, errno.EOPNOTSUPP}
if exc.errno not in fallback_errnos:
raise
if logger is not None:
logger.debug(
"O_DIRECT preadv failed on %s (errno %s); "
"falling back to buffered read",
abs_path,
exc.errno,
)
finally:
if io_pool is not None:
io_pool.shutdown(wait=False)
io_pool = None
if fd is not None:
os_module.close(fd)
fd = None
# Fallback: buffered full-shard read, then queue all entries.
with open(abs_path, "rb") as handle:
raw = handle.read()
arr = np_module.frombuffer(raw, dtype=np_module.uint8).copy()
for entry in sorted_entries:
off = entry.tensor_offset
tensor = torch_module.from_numpy(arr[off : off + entry.aligned_size])
_put_entry(work_q, entry, tensor, cancel_event, abs_path)
return len(sorted_entries)
def read_shard_to_queue(
abs_path: str,
sorted_entries: List[AllocationEntry],
work_q: queue.Queue[Optional[Tuple[AllocationEntry, torch.Tensor]]],
*,
pin_memory: bool,
read_shard,
cancel_event: Optional[threading.Event] = None,
) -> int:
shard_result = read_shard(
abs_path,
sorted_entries,
-1,
pin_memory=pin_memory,
)
for entry in sorted_entries:
_put_entry(
work_q, entry, shard_result[entry.allocation_id], cancel_event, abs_path
)
return len(sorted_entries)
def load_manifest_and_metadata(
input_dir: str,
) -> Tuple[SaveManifest, Dict[str, Dict[str, Any]]]:
manifest_path = os.path.join(input_dir, "manifest.json")
with open(manifest_path, encoding="utf-8") as handle:
manifest = SaveManifest.from_dict(json.load(handle))
metadata_path = os.path.join(input_dir, "gms_metadata.json")
raw_meta: Dict[str, Any] = {}
if os.path.exists(metadata_path):
with open(metadata_path, encoding="utf-8") as handle:
raw_meta = json.load(handle)
return manifest, decode_metadata(raw_meta)
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
from dataclasses import asdict, dataclass, field
from typing import Any, Dict, List
CURRENT_VERSION = "1.0"
@dataclass(frozen=True)
class AllocationEntry:
"""Immutable record of one dumped allocation."""
allocation_id: str
size: int
aligned_size: int
tag: str
tensor_file: str
tensor_offset: int = 0
@dataclass
class SaveManifest:
"""Manifest for a GMS dump directory."""
version: str
timestamp: float
layout_hash: str
device: int
allocations: List[AllocationEntry] = field(default_factory=list)
def to_dict(self) -> Dict[str, Any]:
return {
"version": self.version,
"timestamp": self.timestamp,
"layout_hash": self.layout_hash,
"device": self.device,
"allocations": [asdict(a) for a in self.allocations],
}
@classmethod
def from_dict(cls, payload: Dict[str, Any]) -> "SaveManifest":
version = payload["version"]
if version != CURRENT_VERSION:
raise ValueError(
f"Unsupported manifest version {version!r} "
f"(expected {CURRENT_VERSION!r})"
)
allocations = [
AllocationEntry(
allocation_id=entry["allocation_id"],
size=entry["size"],
aligned_size=entry["aligned_size"],
tag=entry["tag"],
tensor_file=entry["tensor_file"],
tensor_offset=entry.get("tensor_offset", 0),
)
for entry in payload.get("allocations", [])
]
return cls(
version=payload["version"],
timestamp=payload["timestamp"],
layout_hash=payload["layout_hash"],
device=payload["device"],
allocations=allocations,
)
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import queue
import threading
from concurrent.futures import Future, ThreadPoolExecutor
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
if TYPE_CHECKING:
import torch
from gpu_memory_service.snapshot.model import AllocationEntry
WORK_QUEUE_DEPTH_MULTIPLIER = 4
@dataclass
class RestorePipelineContext:
"""Mutable state shared across disk, copy, and Phase A restore stages."""
worker_count: int
use_streams: bool
device: int
work_q: queue.Queue[Optional[Tuple[AllocationEntry, torch.Tensor]]]
va_events: Dict[str, threading.Event]
streams: List[torch.cuda.Stream]
cancel_event: threading.Event = field(default_factory=threading.Event)
vas: Dict[str, int] = field(default_factory=dict)
staged_srcs: List[torch.Tensor] = field(default_factory=list)
copy_errors: List[BaseException] = field(default_factory=list)
lock: threading.Lock = field(default_factory=threading.Lock)
@classmethod
def build(
cls,
allocations: List[AllocationEntry],
worker_count: int,
*,
device: int,
use_streams: bool,
torch_module,
) -> "RestorePipelineContext":
streams = (
[torch_module.cuda.Stream(device=device) for _ in range(worker_count)]
if use_streams
else []
)
return cls(
worker_count=worker_count,
use_streams=use_streams,
device=device,
work_q=queue.Queue(maxsize=worker_count * WORK_QUEUE_DEPTH_MULTIPLIER),
va_events={entry.allocation_id: threading.Event() for entry in allocations},
streams=streams,
)
@dataclass
class RestorePipelineResources:
"""Live restore pipeline resources that must be torn down together."""
ctx: RestorePipelineContext
disk_pool: ThreadPoolExecutor
disk_futures: Dict[Future[int], str]
copy_threads: List[threading.Thread]
active: bool = True
This diff is collapsed.
...@@ -245,9 +245,6 @@ def running_gms(monkeypatch, tmp_path): ...@@ -245,9 +245,6 @@ def running_gms(monkeypatch, tmp_path):
server_allocations, "cumem_export_to_shareable_handle", export_fd server_allocations, "cumem_export_to_shareable_handle", export_fd
) )
monkeypatch.setattr(
client_memory_manager, "cuda_set_current_device", lambda device: None
)
monkeypatch.setattr( monkeypatch.setattr(
client_memory_manager, client_memory_manager,
"cumem_get_allocation_granularity", "cumem_get_allocation_granularity",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment