__init__.py 697 Bytes
Newer Older
1
2
import torch

3
from .interface import Platform, PlatformEnum, UnspecifiedPlatform
4

5
current_platform: Platform
6

7
8
9
10
11
12
13
14
15
16
17
try:
    import libtpu
except ImportError:
    libtpu = None

if libtpu is not None:
    # people might install pytorch built with cuda but run on tpu
    # so we need to check tpu first
    from .tpu import TpuPlatform
    current_platform = TpuPlatform()
elif torch.version.cuda is not None:
18
19
20
21
22
23
    from .cuda import CudaPlatform
    current_platform = CudaPlatform()
elif torch.version.hip is not None:
    from .rocm import RocmPlatform
    current_platform = RocmPlatform()
else:
24
    current_platform = UnspecifiedPlatform()
25
26

__all__ = ['Platform', 'PlatformEnum', 'current_platform']