loader.py 2.29 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

"""GMS checkpoint loader entry point.

Waits for the GMS server UDS socket on each device, then loads saved GMS
state from a checkpoint directory into the running GMS servers. Devices
are loaded in parallel to saturate PVC bandwidth.
"""

from __future__ import annotations

import logging
import os
15
16
import signal
import sys
17
18
19
import time
from concurrent.futures import ThreadPoolExecutor, as_completed

20
21
from gpu_memory_service.common.cuda_utils import list_devices
from gpu_memory_service.common.utils import get_socket_path, wait_for_weights_socket
22
23
24
25
26
27
28
29
30
31
from gpu_memory_service.snapshot.storage_client import GMSStorageClient

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)


def _load_device(checkpoint_dir: str, device: int, max_workers: int) -> None:
32
    wait_for_weights_socket(device)
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
    input_dir = os.path.join(checkpoint_dir, f"device-{device}")
    logger.info("Loading GMS checkpoint: device=%d input_dir=%s", device, input_dir)
    t0 = time.monotonic()
    client = GMSStorageClient(
        socket_path=get_socket_path(device),
        device=device,
    )
    client.load_to_gms(
        input_dir,
        max_workers=max_workers,
        clear_existing=True,
    )
    elapsed = time.monotonic() - t0
    logger.info("GMS checkpoint loaded: device=%d elapsed=%.2fs", device, elapsed)


def main() -> None:
    checkpoint_dir = os.environ["GMS_CHECKPOINT_DIR"]
51
52
    max_workers = int(os.environ.get("GMS_LOAD_WORKERS", "8"))
    devices = list_devices()
53
54
55
56
57
58
59
60
61
62
63
64
65

    t0 = time.monotonic()
    with ThreadPoolExecutor(max_workers=len(devices)) as pool:
        futures = {
            pool.submit(_load_device, checkpoint_dir, dev, max_workers): dev
            for dev in devices
        }
        for future in as_completed(futures):
            dev = futures[future]
            future.result()
            logger.info("Device %d load complete", dev)
    elapsed = time.monotonic() - t0
    logger.info("All %d devices loaded in %.2fs", len(devices), elapsed)
66
    signal.signal(signal.SIGTERM, lambda *_: sys.exit(0))
67
68
69
70
71
72
73

    while True:
        time.sleep(3600)


if __name__ == "__main__":
    main()