Unverified Commit 96237ba1 authored by Schwinn Saereesitthipitak's avatar Schwinn Saereesitthipitak Committed by GitHub
Browse files

fix(gms): resolve socket UUIDs via CUDA driver API (#6891)

parent bb2ca1a9
......@@ -5,15 +5,20 @@
import os
import tempfile
import uuid
import pynvml
from cuda.bindings import driver as cuda
from gpu_memory_service.common.cuda_vmm_utils import (
check_cuda_result,
ensure_cuda_initialized,
)
def get_socket_path(device: int) -> str:
"""Get GMS socket path for the given CUDA device.
The socket path is based on GPU UUID, making it stable across different
CUDA_VISIBLE_DEVICES configurations.
The socket path is based on GPU UUID resolved by CUDA.
CUDA_VISIBLE_DEVICES remapping is handled by CUDA device enumeration.
Args:
device: CUDA device index.
......@@ -21,10 +26,13 @@ def get_socket_path(device: int) -> str:
Returns:
Socket path (e.g., "<tempdir>/gms_GPU-12345678-1234-1234-1234-123456789abc.sock").
"""
pynvml.nvmlInit()
try:
handle = pynvml.nvmlDeviceGetHandleByIndex(device)
uuid = pynvml.nvmlDeviceGetUUID(handle)
finally:
pynvml.nvmlShutdown()
return os.path.join(tempfile.gettempdir(), f"gms_{uuid}.sock")
ensure_cuda_initialized()
result, cu_device = cuda.cuDeviceGet(device)
check_cuda_result(result, "cuDeviceGet")
result, cu_uuid = cuda.cuDeviceGetUuid(cu_device)
check_cuda_result(result, "cuDeviceGetUuid")
gpu_uuid = f"GPU-{uuid.UUID(bytes=bytes(cu_uuid.bytes))}"
return os.path.join(tempfile.gettempdir(), f"gms_{gpu_uuid}.sock")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment