rocm.py 1.7 KB
Newer Older
1
import os
2
3
4
5
from functools import lru_cache

import torch

6
7
from vllm.logger import init_logger

8
from .interface import DeviceCapability, Platform, PlatformEnum, _Backend
9

10
11
12
13
14
15
16
17
logger = init_logger(__name__)

if os.environ.get("VLLM_WORKER_MULTIPROC_METHOD", None) in ["fork", None]:
    logger.warning("`fork` method is not supported by ROCm. "
                   "VLLM_WORKER_MULTIPROC_METHOD is overridden to"
                   " `spawn` instead.")
    os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"

18
19
20
21

class RocmPlatform(Platform):
    _enum = PlatformEnum.ROCM

22
23
24
25
26
27
28
29
30
31
32
33
    @classmethod
    def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:
        selected_backend = (_Backend.ROCM_FLASH if selected_backend
                            == _Backend.FLASH_ATTN else selected_backend)
        if selected_backend == _Backend.ROCM_FLASH:
            if not cls.has_device_capability(90):
                # not Instinct series GPUs.
                logger.info("flash_attn is not supported on NAVI GPUs.")
        else:
            logger.info("%s is not supported in AMD GPUs.", selected_backend)
        return _Backend.ROCM_FLASH

34
    @classmethod
35
    @lru_cache(maxsize=8)
36
37
38
    def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:
        major, minor = torch.cuda.get_device_capability(device_id)
        return DeviceCapability(major=major, minor=minor)
39

40
    @classmethod
41
    @lru_cache(maxsize=8)
42
    def get_device_name(cls, device_id: int = 0) -> str:
43
        return torch.cuda.get_device_name(device_id)
44
45
46
47
48

    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        device_props = torch.cuda.get_device_properties(device_id)
        return device_props.total_memory