"docs/pages/components/router/router-examples.md" did not exist on "80e7bafd37a0bc5970bea955a63e746ba5adac5a"
saver.py 4.87 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

"""GMS checkpoint saver entry point.

Waits for the checkpoint pod to reach Ready=True, then saves GMS state from
each device in parallel. Writes a stop file to signal the GMS server to shut
down after save completes.
"""

from __future__ import annotations

import json
import logging
import os
import ssl
import time
import urllib.error
import urllib.request
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
from typing import Any

from gpu_memory_service.common.utils import get_socket_path
from gpu_memory_service.snapshot.storage_client import GMSStorageClient

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)

_WEIGHTS_TAG = "weights"
_SERVICE_ACCOUNT_TOKEN = Path("/var/run/secrets/kubernetes.io/serviceaccount/token")
_SERVICE_ACCOUNT_CA = "/var/run/secrets/kubernetes.io/serviceaccount/ca.crt"


def _list_devices() -> list[int]:
    import pynvml

    pynvml.nvmlInit()
    try:
        count = pynvml.nvmlDeviceGetCount()
    finally:
        pynvml.nvmlShutdown()

    if count == 0:
        raise SystemExit("no nvidia devices found")
    return list(range(count))


def _wait_for_weights_socket(device: int) -> None:
    socket_path = get_socket_path(device, _WEIGHTS_TAG)
    while not os.path.exists(socket_path):
        time.sleep(1)


def _checkpoint_pod_ready(pod: dict[str, Any]) -> bool:
    status = pod.get("status") or {}
    if str(status.get("phase", "")).strip() != "Running":
        return False
    for condition in status.get("conditions") or []:
        if (
            condition.get("type") == "Ready"
            and str(condition.get("status", "")).strip() == "True"
        ):
            return True
    return False


def _main_terminated(pod: dict[str, Any]) -> bool:
    status = pod.get("status") or {}
    for container in status.get("containerStatuses") or []:
        if container.get("name") != "main":
            continue
        return bool((container.get("state") or {}).get("terminated"))
    return False


def main() -> None:
    service_token = _SERVICE_ACCOUNT_TOKEN.read_text(encoding="utf-8").strip()
    ssl_context = ssl.create_default_context(cafile=_SERVICE_ACCOUNT_CA)
    pod_api_url = (
        "https://"
        + os.environ["KUBERNETES_SERVICE_HOST"]
        + ":"
        + os.environ.get("KUBERNETES_SERVICE_PORT_HTTPS", "443")
        + f"/api/v1/namespaces/{os.environ['POD_NAMESPACE']}/pods/{os.environ['POD_NAME']}"
    )
    checkpoint_dir = os.environ["GMS_CHECKPOINT_DIR"]

    def checkpoint_pod() -> dict[str, Any]:
        request = urllib.request.Request(
            pod_api_url,
            headers={"Authorization": f"Bearer {service_token}"},
        )
        with urllib.request.urlopen(
            request,
            context=ssl_context,
            timeout=5,
        ) as response:
            return json.load(response)

    logger.info("Waiting for checkpoint pod Ready=True before GMS save")
    while True:
        try:
            pod = checkpoint_pod()
        except (urllib.error.URLError, TimeoutError, ssl.SSLError, OSError):
            time.sleep(1)
            continue

        if _checkpoint_pod_ready(pod):
            break
        if _main_terminated(pod):
            raise SystemExit("main container terminated before GMS save could start")
        time.sleep(1)

    def _save_device(device: int, max_workers: int) -> None:
        _wait_for_weights_socket(device)
        output_dir = os.path.join(checkpoint_dir, f"device-{device}")
        logger.info(
            "Saving GMS checkpoint: device=%d output_dir=%s",
            device,
            output_dir,
        )
        t0 = time.monotonic()
        client = GMSStorageClient(
            output_dir,
            socket_path=get_socket_path(device),
            device=device,
        )
        client.save(max_workers=max_workers)
        elapsed = time.monotonic() - t0
        logger.info("GMS checkpoint saved: device=%d elapsed=%.2fs", device, elapsed)

    max_workers = int(os.environ.get("GMS_SAVE_WORKERS", "8"))
    logger.info("Checkpoint pod is Ready; starting GMS save")
    try:
        devices = _list_devices()
        t0 = time.monotonic()
        with ThreadPoolExecutor(max_workers=len(devices)) as pool:
            futures = {
                pool.submit(_save_device, dev, max_workers): dev for dev in devices
            }
            for future in as_completed(futures):
                future.result()
        elapsed = time.monotonic() - t0
        logger.info("All %d devices saved in %.2fs", len(devices), elapsed)
    finally:
        (Path(os.environ["GMS_CONTROL_DIR"]) / "checkpoint-done").write_text(
            "done",
            encoding="utf-8",
        )


if __name__ == "__main__":
    main()