loader.py 2.64 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

"""GMS checkpoint loader entry point.

Waits for the GMS server UDS socket on each device, then loads saved GMS
state from a checkpoint directory into the running GMS servers. Devices
are loaded in parallel to saturate PVC bandwidth.
"""

from __future__ import annotations

import logging
import os
import time
from concurrent.futures import ThreadPoolExecutor, as_completed

from gpu_memory_service.common.utils import get_socket_path
from gpu_memory_service.snapshot.storage_client import GMSStorageClient

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)

_WEIGHTS_TAG = "weights"
_DEFAULT_MAX_WORKERS = 8


def _list_devices() -> list[int]:
    import pynvml

    pynvml.nvmlInit()
    try:
        count = pynvml.nvmlDeviceGetCount()
    finally:
        pynvml.nvmlShutdown()

    if count == 0:
        raise SystemExit("no nvidia devices found")
    return list(range(count))


def _wait_for_weights_socket(device: int) -> None:
    socket_path = get_socket_path(device, _WEIGHTS_TAG)
    while not os.path.exists(socket_path):
        time.sleep(1)


def _load_device(checkpoint_dir: str, device: int, max_workers: int) -> None:
    _wait_for_weights_socket(device)
    input_dir = os.path.join(checkpoint_dir, f"device-{device}")
    logger.info("Loading GMS checkpoint: device=%d input_dir=%s", device, input_dir)
    t0 = time.monotonic()
    client = GMSStorageClient(
        socket_path=get_socket_path(device),
        device=device,
    )
    client.load_to_gms(
        input_dir,
        max_workers=max_workers,
        clear_existing=True,
    )
    elapsed = time.monotonic() - t0
    logger.info("GMS checkpoint loaded: device=%d elapsed=%.2fs", device, elapsed)


def main() -> None:
    checkpoint_dir = os.environ["GMS_CHECKPOINT_DIR"]
    max_workers = int(os.environ.get("GMS_LOAD_WORKERS", str(_DEFAULT_MAX_WORKERS)))
    devices = _list_devices()

    t0 = time.monotonic()
    with ThreadPoolExecutor(max_workers=len(devices)) as pool:
        futures = {
            pool.submit(_load_device, checkpoint_dir, dev, max_workers): dev
            for dev in devices
        }
        for future in as_completed(futures):
            dev = futures[future]
            future.result()
            logger.info("Device %d load complete", dev)
    elapsed = time.monotonic() - t0
    logger.info("All %d devices loaded in %.2fs", len(devices), elapsed)

    while True:
        time.sleep(3600)


if __name__ == "__main__":
    main()