utils.py 1.8 KB
Newer Older
1
2
3
4
5
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

"""Shared utilities for GPU Memory Service."""

6
import logging
7
8
import os
import tempfile
9
from typing import NoReturn
10

11
logger = logging.getLogger(__name__)
12
13


14
15
16
17
def fail(message: str, *args, exc_info=None) -> NoReturn:
    logger.critical(message, *args, exc_info=exc_info)
    logging.shutdown()
    os._exit(1)
18

19

20
21
22
23
24
25
26
27
_uuid_cache: dict[int, str] = {}


def invalidate_uuid_cache() -> None:
    """Clear cached GPU UUIDs. Call after CRIU restore when GPU assignment may change."""
    _uuid_cache.clear()


28
29
30
31
def get_socket_path(device: int, tag: str = "weights") -> str:
    """Get GMS socket path for the given CUDA device and tag.

    The socket path is based on GPU UUID, making it stable across different
32
    CUDA_VISIBLE_DEVICES configurations. UUIDs are cached per device index.
33
34
35
36
37

    Args:
        device: CUDA device index.

    Returns:
38
39
        Socket path
        (e.g., "<tempdir>/gms_GPU-12345678-1234-1234-1234-123456789abc_weights.sock").
40
    """
41
42
43
44
45
46
47
48
49
50
51
52
53
    uuid = _uuid_cache.get(device)
    if uuid is None:
        import pynvml  # deferred: not available in all environments

        pynvml.nvmlInit()
        try:
            handle = pynvml.nvmlDeviceGetHandleByIndex(device)
            uuid = pynvml.nvmlDeviceGetUUID(handle)
        finally:
            pynvml.nvmlShutdown()
        _uuid_cache[device] = uuid
    socket_dir = os.environ.get("GMS_SOCKET_DIR") or tempfile.gettempdir()
    return os.path.join(socket_dir, f"gms_{uuid}_{tag}.sock")
54
55
56
57
58
59
60
61
62


def wait_for_weights_socket(device: int) -> None:
    """Block until the GMS weights socket for the given device exists."""
    import time

    path = get_socket_path(device, "weights")
    while not os.path.exists(path):
        time.sleep(0.1)