rocm.py 885 Bytes
Newer Older
1
import os
2
3
4
5
6
from functools import lru_cache
from typing import Tuple

import torch

7
8
from vllm.logger import init_logger

9
10
from .interface import Platform, PlatformEnum

11
12
13
14
15
16
17
18
logger = init_logger(__name__)

if os.environ.get("VLLM_WORKER_MULTIPROC_METHOD", None) in ["fork", None]:
    logger.warning("`fork` method is not supported by ROCm. "
                   "VLLM_WORKER_MULTIPROC_METHOD is overridden to"
                   " `spawn` instead.")
    os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"

19
20
21
22
23
24
25
26

class RocmPlatform(Platform):
    _enum = PlatformEnum.ROCM

    @staticmethod
    @lru_cache(maxsize=8)
    def get_device_capability(device_id: int = 0) -> Tuple[int, int]:
        return torch.cuda.get_device_capability(device_id)
27
28
29
30
31

    @staticmethod
    @lru_cache(maxsize=8)
    def get_device_name(device_id: int = 0) -> str:
        return torch.cuda.get_device_name(device_id)