__init__.py 1.34 KB
Newer Older
1
from .interface import Platform, PlatformEnum, UnspecifiedPlatform
2

3
current_platform: Platform
4

5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
# NOTE: we don't use `torch.version.cuda` / `torch.version.hip` because
# they only indicate the build configuration, not the runtime environment.
# For example, people can install a cuda build of pytorch but run on tpu.

is_tpu = False
try:
    import torch_xla.core.xla_model as xm
    xm.xla_device(devkind="TPU")
    is_tpu = True
except Exception:
    pass

is_cuda = False

try:
    import pynvml
    pynvml.nvmlInit()
    try:
        if pynvml.nvmlDeviceGetCount() > 0:
            is_cuda = True
    finally:
        pynvml.nvmlShutdown()
except Exception:
    pass

is_rocm = False

32
try:
33
34
35
36
37
38
39
40
41
    import amdsmi
    amdsmi.amdsmi_init()
    try:
        if len(amdsmi.amdsmi_get_processor_handles()) > 0:
            is_rocm = True
    finally:
        amdsmi.amdsmi_shut_down()
except Exception:
    pass
42

43
if is_tpu:
44
45
46
47
    # people might install pytorch built with cuda but run on tpu
    # so we need to check tpu first
    from .tpu import TpuPlatform
    current_platform = TpuPlatform()
48
elif is_cuda:
49
50
    from .cuda import CudaPlatform
    current_platform = CudaPlatform()
51
elif is_rocm:
52
53
54
    from .rocm import RocmPlatform
    current_platform = RocmPlatform()
else:
55
    current_platform = UnspecifiedPlatform()
56
57

__all__ = ['Platform', 'PlatformEnum', 'current_platform']