"tests/vscode:/vscode.git/clone" did not exist on "f55a9aea4573ed5282ee3221182993495ddc3709"
__init__.py 439 Bytes
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
from typing import Optional

import torch

from .interface import Platform, PlatformEnum

current_platform: Optional[Platform]

if torch.version.cuda is not None:
    from .cuda import CudaPlatform
    current_platform = CudaPlatform()
elif torch.version.hip is not None:
    from .rocm import RocmPlatform
    current_platform = RocmPlatform()
else:
    current_platform = None

__all__ = ['Platform', 'PlatformEnum', 'current_platform']