utils.py 1.11 KB
Newer Older
1
2
3
4
5
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

"""Shared utilities for GPU Memory Service."""

6
import logging
7
8
import os
import tempfile
9
from typing import NoReturn
10

11
logger = logging.getLogger(__name__)
12
13


14
15
16
17
def fail(message: str, *args, exc_info=None) -> NoReturn:
    logger.critical(message, *args, exc_info=exc_info)
    logging.shutdown()
    os._exit(1)
18

19
20
21
22
23
24

def get_socket_path(device: int, tag: str = "weights") -> str:
    """Get GMS socket path for the given CUDA device and tag.

    The socket path is based on GPU UUID, making it stable across different
    CUDA_VISIBLE_DEVICES configurations.
25
26
27
28
29

    Args:
        device: CUDA device index.

    Returns:
30
31
        Socket path
        (e.g., "<tempdir>/gms_GPU-12345678-1234-1234-1234-123456789abc_weights.sock").
32
    """
33
34
35
36
37
38
39
40
41
    import pynvml

    pynvml.nvmlInit()
    try:
        handle = pynvml.nvmlDeviceGetHandleByIndex(device)
        uuid = pynvml.nvmlDeviceGetUUID(handle)
    finally:
        pynvml.nvmlShutdown()
    return os.path.join(tempfile.gettempdir(), f"gms_{uuid}_{tag}.sock")