Unverified Commit a3c9435d authored by youkaichao's avatar youkaichao Committed by GitHub
Browse files

[hardware][cuda] use device id under CUDA_VISIBLE_DEVICES for get_device_capability (#6216)

parent 4f0e0ea1
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
pynvml. However, it should not initialize cuda context. pynvml. However, it should not initialize cuda context.
""" """
import os
from functools import lru_cache, wraps from functools import lru_cache, wraps
from typing import Tuple from typing import Tuple
...@@ -23,12 +24,27 @@ def with_nvml_context(fn): ...@@ -23,12 +24,27 @@ def with_nvml_context(fn):
return wrapper return wrapper
@lru_cache(maxsize=8)
@with_nvml_context
def get_physical_device_capability(device_id: int = 0) -> Tuple[int, int]:
handle = pynvml.nvmlDeviceGetHandleByIndex(device_id)
return pynvml.nvmlDeviceGetCudaComputeCapability(handle)
def device_id_to_physical_device_id(device_id: int) -> int:
if "CUDA_VISIBLE_DEVICES" in os.environ:
device_ids = os.environ["CUDA_VISIBLE_DEVICES"].split(",")
device_ids = [int(device_id) for device_id in device_ids]
physical_device_id = device_ids[device_id]
else:
physical_device_id = device_id
return physical_device_id
class CudaPlatform(Platform): class CudaPlatform(Platform):
_enum = PlatformEnum.CUDA _enum = PlatformEnum.CUDA
@staticmethod @staticmethod
@lru_cache(maxsize=8)
@with_nvml_context
def get_device_capability(device_id: int = 0) -> Tuple[int, int]: def get_device_capability(device_id: int = 0) -> Tuple[int, int]:
handle = pynvml.nvmlDeviceGetHandleByIndex(device_id) physical_device_id = device_id_to_physical_device_id(device_id)
return pynvml.nvmlDeviceGetCudaComputeCapability(handle) return get_physical_device_capability(physical_device_id)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment