__init__.py 593 Bytes
Newer Older
1
2
3
4
from typing import Optional

import torch

5
6
7
from vllm.utils import is_tpu

from .interface import Platform, PlatformEnum, UnspecifiedPlatform
8
9
10
11
12
13
14
15
16

current_platform: Optional[Platform]

if torch.version.cuda is not None:
    from .cuda import CudaPlatform
    current_platform = CudaPlatform()
elif torch.version.hip is not None:
    from .rocm import RocmPlatform
    current_platform = RocmPlatform()
17
18
19
elif is_tpu():
    from .tpu import TpuPlatform
    current_platform = TpuPlatform()
20
else:
21
    current_platform = UnspecifiedPlatform()
22
23

__all__ = ['Platform', 'PlatformEnum', 'current_platform']