saver.py 2.25 KB
Newer Older
1
2
3
4
5
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

"""GMS checkpoint saver entry point.

6
7
8
Waits for committed GMS weights on each device, then saves GPU memory state
to the checkpoint directory. Runs as an init sidecar — sleeps after saving
until the pod terminates.
9
10
11
12
13
14
"""

from __future__ import annotations

import logging
import os
15
16
import signal
import sys
17
18
19
import time
from concurrent.futures import ThreadPoolExecutor, as_completed

20
21
from gpu_memory_service.common.cuda_utils import list_devices
from gpu_memory_service.common.utils import get_socket_path, wait_for_weights_socket
22
23
24
25
26
27
28
29
30
from gpu_memory_service.snapshot.storage_client import GMSStorageClient

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)


31
32
33
34
35
36
37
38
39
40
41
42
def _save_device(checkpoint_dir: str, device: int, max_workers: int) -> None:
    wait_for_weights_socket(device)
    output_dir = os.path.join(checkpoint_dir, f"device-{device}")
    logger.info("Saving GMS checkpoint: device=%d output_dir=%s", device, output_dir)
    t0 = time.monotonic()
    GMSStorageClient(
        output_dir,
        socket_path=get_socket_path(device),
        device=device,
    ).save(max_workers=max_workers)
    elapsed = time.monotonic() - t0
    logger.info("GMS checkpoint saved: device=%d elapsed=%.2fs", device, elapsed)
43
44
45
46


def main() -> None:
    checkpoint_dir = os.environ["GMS_CHECKPOINT_DIR"]
47
    max_workers = int(os.environ.get("GMS_SAVE_WORKERS", "8"))
48

49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
    devices = list_devices()
    logger.info("Starting GMS save for %d devices", len(devices))
    t0 = time.monotonic()
    with ThreadPoolExecutor(max_workers=len(devices)) as pool:
        futures = {
            pool.submit(_save_device, checkpoint_dir, dev, max_workers): dev
            for dev in devices
        }
        for future in as_completed(futures):
            future.result()
    elapsed = time.monotonic() - t0
    logger.info("All %d devices saved in %.2fs", len(devices), elapsed)

    logger.info("Save complete; sleeping until pod terminates")
    signal.signal(signal.SIGTERM, lambda *_: sys.exit(0))
64
    while True:
65
        time.sleep(3600)
66
67
68
69


if __name__ == "__main__":
    main()