rocm.py 959 Bytes
Newer Older
1
import os
2
3
4
5
from functools import lru_cache

import torch

6
7
from vllm.logger import init_logger

8
from .interface import DeviceCapability, Platform, PlatformEnum
9

10
11
12
logger = init_logger(__name__)

if os.environ.get("VLLM_WORKER_MULTIPROC_METHOD", None) in ["fork", None]:
zhuwenwen's avatar
zhuwenwen committed
13
14
15
    # logger.warning("`fork` method is not supported by ROCm. "
    #                "VLLM_WORKER_MULTIPROC_METHOD is overridden to"
    #                " `spawn` instead.")
16
17
    os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"

18
19
20
21

class RocmPlatform(Platform):
    _enum = PlatformEnum.ROCM

22
    @classmethod
23
    @lru_cache(maxsize=8)
24
25
26
    def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:
        major, minor = torch.cuda.get_device_capability(device_id)
        return DeviceCapability(major=major, minor=minor)
27

28
    @classmethod
29
    @lru_cache(maxsize=8)
30
    def get_device_name(cls, device_id: int = 0) -> str:
31
        return torch.cuda.get_device_name(device_id)