cuda.py 1.36 KB
Newer Older
1
2
3
4
"""Code inside this file can safely assume cuda platform, e.g. importing
pynvml. However, it should not initialize cuda context.
"""

5
import os
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
from functools import lru_cache, wraps
from typing import Tuple

import pynvml

from .interface import Platform, PlatformEnum


def with_nvml_context(fn):

    @wraps(fn)
    def wrapper(*args, **kwargs):
        pynvml.nvmlInit()
        try:
            return fn(*args, **kwargs)
        finally:
            pynvml.nvmlShutdown()

    return wrapper


27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
@lru_cache(maxsize=8)
@with_nvml_context
def get_physical_device_capability(device_id: int = 0) -> Tuple[int, int]:
    handle = pynvml.nvmlDeviceGetHandleByIndex(device_id)
    return pynvml.nvmlDeviceGetCudaComputeCapability(handle)


def device_id_to_physical_device_id(device_id: int) -> int:
    if "CUDA_VISIBLE_DEVICES" in os.environ:
        device_ids = os.environ["CUDA_VISIBLE_DEVICES"].split(",")
        device_ids = [int(device_id) for device_id in device_ids]
        physical_device_id = device_ids[device_id]
    else:
        physical_device_id = device_id
    return physical_device_id


44
45
46
47
48
class CudaPlatform(Platform):
    _enum = PlatformEnum.CUDA

    @staticmethod
    def get_device_capability(device_id: int = 0) -> Tuple[int, int]:
49
50
        physical_device_id = device_id_to_physical_device_id(device_id)
        return get_physical_device_capability(physical_device_id)