rocm.py 7.68 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import os
4
from functools import lru_cache
5
from typing import TYPE_CHECKING, Dict, List, Optional
6
7
8

import torch

9
import vllm.envs as envs
10
11
from vllm.logger import init_logger

12
from .interface import DeviceCapability, Platform, PlatformEnum, _Backend
13

14
15
16
17
if TYPE_CHECKING:
    from vllm.config import VllmConfig
else:
    VllmConfig = None
18

19
20
logger = init_logger(__name__)

xiabo's avatar
xiabo committed
21
22
23
24
from amdsmi import (AmdSmiException, amdsmi_get_gpu_asic_info,
                        amdsmi_get_processor_handles, amdsmi_init,
                        amdsmi_shut_down, amdsmi_topo_get_link_type)

25
26
27
28
29
30
try:
    import vllm._C  # noqa: F401
except ImportError as e:
    logger.warning("Failed to import from vllm._C with %r", e)

# import custom ops, trigger op registration
zhuwenwen's avatar
zhuwenwen committed
31
32
33
34
# try:
#     import vllm._rocm_C  # noqa: F401
# except ImportError as e:
#     logger.warning("Failed to import from vllm._rocm_C with %r", e)
35

36
if os.environ.get("VLLM_WORKER_MULTIPROC_METHOD", None) in ["fork", None]:
zhuwenwen's avatar
zhuwenwen committed
37
38
39
    # logger.warning("`fork` method is not supported by ROCm. "
    #                "VLLM_WORKER_MULTIPROC_METHOD is overridden to"
    #                " `spawn` instead.")
40
41
    os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"

42
43
44
45
46
# Models not supported by ROCm.
_ROCM_UNSUPPORTED_MODELS: List[str] = []

# Models partially supported by ROCm.
# Architecture -> Reason.
zhuwenwen's avatar
zhuwenwen committed
47
48
49
50
# _ROCM_SWA_REASON = ("Sliding window attention (SWA) is not yet supported in "
#                     "Triton flash attention. For half-precision SWA support, "
#                     "please use CK flash attention by setting "
#                     "`VLLM_USE_TRITON_FLASH_ATTN=0`")
51
_ROCM_PARTIALLY_SUPPORTED_MODELS: Dict[str, str] = {
zhuwenwen's avatar
zhuwenwen committed
52
53
54
55
56
57
    # "Qwen2ForCausalLM":
    # _ROCM_SWA_REASON,
    # "MistralForCausalLM":
    # _ROCM_SWA_REASON,
    # "MixtralForCausalLM":
    # _ROCM_SWA_REASON,
58
59
60
    "PaliGemmaForConditionalGeneration":
    ("ROCm flash attention does not yet "
     "fully support 32-bit precision on PaliGemma"),
zhuwenwen's avatar
zhuwenwen committed
61
62
63
64
    # "Phi3VForCausalLM":
    # ("ROCm Triton flash attention may run into compilation errors due to "
    #  "excessive use of shared memory. If this happens, disable Triton FA "
    #  "by setting `VLLM_USE_TRITON_FLASH_ATTN=0`")
65
66
}

67
68
69

class RocmPlatform(Platform):
    _enum = PlatformEnum.ROCM
70
    device_name: str = "rocm"
71
    device_type: str = "cuda"
72
    dispatch_key: str = "CUDA"
73
    ray_device_key: str = "GPU"
74
75
    # rocm shares the same device control env var as CUDA
    device_control_env_var: str = "CUDA_VISIBLE_DEVICES"
76

77
78
    supported_quantization: list[str] = [
        "awq", "gptq", "fp8", "compressed_tensors", "compressed-tensors",
79
        "fbgemm_fp8", "gguf", "quark", "moe_wna16","blockwise_int8"
80
    ]
81

82
    @classmethod
83
    def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
84
85
                             kv_cache_dtype, block_size, use_v1,
                             use_mla) -> str:
86
87
88
        if use_mla:
            logger.info("Using Triton MLA backend.")
            return "vllm.attention.backends.triton_mla.TritonMLABackend"
89
90
91
92
93
94
95
96
        selected_backend = (_Backend.ROCM_FLASH if selected_backend
                            == _Backend.FLASH_ATTN else selected_backend)
        if selected_backend == _Backend.ROCM_FLASH:
            if not cls.has_device_capability(90):
                # not Instinct series GPUs.
                logger.info("flash_attn is not supported on NAVI GPUs.")
        else:
            logger.info("%s is not supported in AMD GPUs.", selected_backend)
97
98
        logger.info("Using ROCmFlashAttention backend.")
        return "vllm.attention.backends.rocm_flash_attn.ROCmFlashAttentionBackend"  # noqa: E501
99

100
    @classmethod
101
    @lru_cache(maxsize=8)
102
103
104
    def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:
        major, minor = torch.cuda.get_device_capability(device_id)
        return DeviceCapability(major=major, minor=minor)
105

106
    @classmethod
107
    @lru_cache(maxsize=8)
108
    def get_device_name(cls, device_id: int = 0) -> str:
109
        return torch.cuda.get_device_name(device_id)
110

xiabo's avatar
xiabo committed
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
    @staticmethod
    def is_fully_connected_nvlink_or_xgmi(
            physical_device_ids: List[int]) -> bool:
        """
        Query if the set of gpus are fully connected by xgmi (1 hop)
        """
        handles = [
            amdsmi_get_processor_handles()[i] for i in physical_device_ids
        ]
        for i, handle in enumerate(handles):
            for j, peer_handle in enumerate(handles):
                if i < j:
                    try:
                        link_type = amdsmi_topo_get_link_type(
                            handle, peer_handle)
                        # type is 2 for XGMI
                        if link_type["hops"] != 1 or link_type["type"] != 2:
                            return False
                    except AmdSmiException as error:
                        logger.error("AMD 1 hop XGMI detection failed.",
                                     exc_info=error)
                        return False
        return True

zhuwenwen's avatar
zhuwenwen committed
135
136
137
138
    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        device_props = torch.cuda.get_device_properties(device_id)
        return device_props.total_memory
139

140
141
142
143
144
145
146
147
148
149
    @classmethod
    def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
        if enforce_eager:
            logger.warning(
                "To see benefits of async output processing, enable CUDA "
                "graph. Since, enforce-eager is enabled, async output "
                "processor cannot be used")
            return False
        return True

150
151
    @classmethod
    def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
152
153
154
155
        cache_config = vllm_config.cache_config
        if cache_config and cache_config.block_size is None:
            cache_config.block_size = 16

156
157
158
159
160
161
162
163
164
        parallel_config = vllm_config.parallel_config
        scheduler_config = vllm_config.scheduler_config
        if parallel_config.worker_cls == "auto":
            if scheduler_config.is_multi_step:
                parallel_config.worker_cls = \
                    "vllm.worker.multi_step_worker.MultiStepWorker"
            elif vllm_config.speculative_config:
                parallel_config.worker_cls = \
                    "vllm.spec_decode.spec_decode_worker.create_spec_worker"
165
166
                parallel_config.sd_worker_cls = \
                    "vllm.worker.worker.Worker"
167
168
            else:
                parallel_config.worker_cls = "vllm.worker.worker.Worker"
169

170
171
172
173
174
175
176
177
178
179
180
181
    @classmethod
    def verify_model_arch(cls, model_arch: str) -> None:
        if model_arch in _ROCM_UNSUPPORTED_MODELS:
            raise ValueError(f"Model architecture '{model_arch}' is not "
                             "supported by ROCm for now.")

        if model_arch in _ROCM_PARTIALLY_SUPPORTED_MODELS:
            msg = _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch]
            logger.warning(
                "Model architecture '%s' is partially "
                "supported by ROCm: %s", model_arch, msg)

182
183
184
185
186
187
    @classmethod
    def verify_quantization(cls, quant: str) -> None:
        super().verify_quantization(quant)
        if quant == "awq" and not envs.VLLM_USE_TRITON_AWQ:
            logger.warning(
                "Using AWQ quantization with ROCm, but VLLM_USE_TRITON_AWQ"
188
189
                " is not set, disabling VLLM_USE_TRITON_AWQ.")
            envs.VLLM_USE_TRITON_AWQ = False
190
191
192
193

    @classmethod
    def get_punica_wrapper(cls) -> str:
        return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU"
194
195
196
197
198
199
200

    @classmethod
    def get_current_memory_usage(cls,
                                 device: Optional[torch.types.Device] = None
                                 ) -> float:
        torch.cuda.reset_peak_memory_stats(device)
        return torch.cuda.max_memory_allocated(device)