rocm.py 14.3 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import os
4
from functools import cache, lru_cache, wraps
5
from typing import TYPE_CHECKING, Dict, List, Optional
6
7
8

import torch

9
import vllm.envs as envs
10
11
from vllm.logger import init_logger

12
from .interface import DeviceCapability, Platform, PlatformEnum, _Backend
13

14
if TYPE_CHECKING:
15
    from vllm.config import ModelConfig, VllmConfig
16

17
18
logger = init_logger(__name__)

19
try:
20
21
22
    from amdsmi import (AmdSmiException, amdsmi_get_gpu_asic_info,
                        amdsmi_get_processor_handles, amdsmi_init,
                        amdsmi_shut_down, amdsmi_topo_get_link_type)
23
24
25
except ImportError as e:
    logger.warning("Failed to import from amdsmi with %r", e)

26
27
28
29
30
31
32
33
34
35
36
try:
    import vllm._C  # noqa: F401
except ImportError as e:
    logger.warning("Failed to import from vllm._C with %r", e)

# import custom ops, trigger op registration
try:
    import vllm._rocm_C  # noqa: F401
except ImportError as e:
    logger.warning("Failed to import from vllm._rocm_C with %r", e)

37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
# Models not supported by ROCm.
_ROCM_UNSUPPORTED_MODELS: List[str] = []

# Models partially supported by ROCm.
# Architecture -> Reason.
_ROCM_SWA_REASON = ("Sliding window attention (SWA) is not yet supported in "
                    "Triton flash attention. For half-precision SWA support, "
                    "please use CK flash attention by setting "
                    "`VLLM_USE_TRITON_FLASH_ATTN=0`")
_ROCM_PARTIALLY_SUPPORTED_MODELS: Dict[str, str] = {
    "Qwen2ForCausalLM":
    _ROCM_SWA_REASON,
    "MistralForCausalLM":
    _ROCM_SWA_REASON,
    "MixtralForCausalLM":
    _ROCM_SWA_REASON,
    "PaliGemmaForConditionalGeneration":
    ("ROCm flash attention does not yet "
     "fully support 32-bit precision on PaliGemma"),
    "Phi3VForCausalLM":
    ("ROCm Triton flash attention may run into compilation errors due to "
     "excessive use of shared memory. If this happens, disable Triton FA "
     "by setting `VLLM_USE_TRITON_FLASH_ATTN=0`")
}
61
62
63
64
65
66
67
68
69
_ROCM_DEVICE_ID_NAME_MAP: Dict[str, str] = {
    "0x74a0": "AMD_Instinct_MI300A",
    "0x74a1": "AMD_Instinct_MI300X",
    "0x74b5": "AMD_Instinct_MI300X",  # MI300X VF
    "0x74a5": "AMD_Instinct_MI325X",
    "0x74b9": "AMD_Instinct_MI325X",  # MI325X VF
    "0x74a9": "AMD_Instinct_MI300X_HF",
    "0x74bd": "AMD_Instinct_MI300X_HF",
}
70

71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
# Prevent use of clashing `{CUDA/HIP}_VISIBLE_DEVICES``
if "HIP_VISIBLE_DEVICES" in os.environ:
    val = os.environ["HIP_VISIBLE_DEVICES"]
    if cuda_val := os.environ.get("CUDA_VISIBLE_DEVICES", None):
        assert val == cuda_val
    else:
        os.environ["CUDA_VISIBLE_DEVICES"] = val

# AMDSMI utils
# Note that NVML is not affected by `{CUDA/HIP}_VISIBLE_DEVICES`,
# all the related functions work on real physical device ids.
# the major benefit of using AMDSMI is that it will not initialize CUDA


def with_amdsmi_context(fn):

    @wraps(fn)
    def wrapper(*args, **kwargs):
        amdsmi_init()
        try:
            return fn(*args, **kwargs)
        finally:
            amdsmi_shut_down()

    return wrapper


98
@cache
99
100
101
102
103
def on_mi250_mi300() -> bool:
    GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
    return any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942"])


104
105
106
107
108
109
@cache
def use_rocm_custom_paged_attention(qtype: torch.dtype, head_size: int,
                                    block_size: int, gqa_ratio: int,
                                    max_seq_len: int,
                                    sliding_window: int) -> bool:

110
111
112
    GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
    ON_GFX9 = any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"])

113
    # rocm custom page attention not support on gfx1*
114
115
    # custom paged attn always supported on V0. On V1, requires sliding window
    # disabled due to observed numerical discrepancy.
116
117
    return (ON_GFX9 and (not envs.VLLM_USE_V1 or sliding_window == 0
                         or sliding_window == (-1, -1))
118
119
120
121
            and (qtype == torch.half or qtype == torch.bfloat16)
            and (head_size == 64 or head_size == 128)
            and (block_size == 16 or block_size == 32)
            and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 32768
122
            and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN)
123
124
            and not (envs.VLLM_ROCM_USE_AITER_PAGED_ATTN
                     and envs.VLLM_ROCM_USE_AITER))
125
126


127
128
class RocmPlatform(Platform):
    _enum = PlatformEnum.ROCM
129
    device_name: str = "rocm"
130
    device_type: str = "cuda"
131
    dispatch_key: str = "CUDA"
132
    ray_device_key: str = "GPU"
133
134
    # rocm shares the same device control env var as CUDA
    device_control_env_var: str = "CUDA_VISIBLE_DEVICES"
135

136
    supported_quantization: list[str] = [
137
138
        "awq", "gptq", "fp8", "compressed-tensors", "fbgemm_fp8", "gguf",
        "quark", "ptpc_fp8"
139
    ]
140

141
    @classmethod
142
    def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
143
144
                             kv_cache_dtype, block_size, use_v1,
                             use_mla) -> str:
145
        if use_mla:
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
            from vllm.attention.backends.rocm_aiter_mla import (
                is_aiter_mla_enabled)

            if selected_backend is None:
                selected_backend = (_Backend.ROCM_AITER_MLA if
                                    is_aiter_mla_enabled() or block_size == 1
                                    else _Backend.TRITON_MLA)

            if selected_backend == _Backend.TRITON_MLA:
                if block_size != 1:
                    logger.info("Using Triton MLA backend.")
                    return "vllm.attention.backends.triton_mla.TritonMLABackend"  # noqa: E501
                else:
                    raise ValueError(
                        f" The selected backend, {selected_backend.name},"
                        f"does not support block size {block_size}.")
162
163
            elif selected_backend == _Backend.ROCM_AITER_MLA \
                or selected_backend == _Backend.ROCM_AITER_MLA_VLLM_V1:
164
                if block_size == 1:
165
166
167
168
169
170
                    if use_v1:
                        logger.info("Using AITER MLA backend on V1 engine.")
                        return "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend"  # noqa: E501
                    else:
                        logger.info("Using AITER MLA backend")
                        return "vllm.attention.backends.rocm_aiter_mla.AiterMLABackend"  # noqa: E501
171
172
173
174
175
176
177
178
179
180
                else:
                    raise ValueError(
                        f" The selected backend, {selected_backend.name},"
                        f"does not support block size {block_size}."
                        "(currently only supports block size 1)")
            else:
                raise ValueError(
                    f" The selected backend, {selected_backend.name},"
                    f"is not MLA type while requested for MLA backend.")

181
182
        selected_backend = (_Backend.ROCM_FLASH if selected_backend
                            == _Backend.FLASH_ATTN else selected_backend)
183
        if envs.VLLM_USE_V1:
184
185
186
            logger.info("Using Triton Attention backend on V1 engine.")
            return ("vllm.v1.attention.backends."
                    "triton_attn.TritonAttentionBackend")
187
188
189
190
191
192
        if selected_backend == _Backend.ROCM_FLASH:
            if not cls.has_device_capability(90):
                # not Instinct series GPUs.
                logger.info("flash_attn is not supported on NAVI GPUs.")
        else:
            logger.info("%s is not supported in AMD GPUs.", selected_backend)
193
194
        logger.info("Using ROCmFlashAttention backend.")
        return "vllm.attention.backends.rocm_flash_attn.ROCmFlashAttentionBackend"  # noqa: E501
195

196
    @classmethod
197
    @lru_cache(maxsize=8)
198
199
200
    def get_device_capability(cls,
                              device_id: int = 0
                              ) -> Optional[DeviceCapability]:
201
202
        major, minor = torch.cuda.get_device_capability(device_id)
        return DeviceCapability(major=major, minor=minor)
203

204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
    @staticmethod
    @with_amdsmi_context
    def is_fully_connected(physical_device_ids: List[int]) -> bool:
        """
        Query if the set of gpus are fully connected by xgmi (1 hop)
        """
        handles = [
            amdsmi_get_processor_handles()[i] for i in physical_device_ids
        ]
        for i, handle in enumerate(handles):
            for j, peer_handle in enumerate(handles):
                if i < j:
                    try:
                        link_type = amdsmi_topo_get_link_type(
                            handle, peer_handle)
                        # type is 2 for XGMI
                        if link_type["hops"] != 1 or link_type["type"] != 2:
                            return False
                    except AmdSmiException as error:
                        logger.error("AMD 1 hop XGMI detection failed.",
                                     exc_info=error)
                        return False
        return True

228
    @classmethod
229
    @with_amdsmi_context
230
    @lru_cache(maxsize=8)
231
    def get_device_name(cls, device_id: int = 0) -> str:
232
        physical_device_id = cls.device_id_to_physical_device_id(device_id)
233
        handle = amdsmi_get_processor_handles()[physical_device_id]
234
235
236
237
238
        asic_info = amdsmi_get_gpu_asic_info(handle)
        device_name: str = asic_info["device_id"]
        if device_name in _ROCM_DEVICE_ID_NAME_MAP:
            return _ROCM_DEVICE_ID_NAME_MAP[device_name]
        return asic_info["market_name"]
239
240
241
242
243

    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        device_props = torch.cuda.get_device_properties(device_id)
        return device_props.total_memory
244

245
246
247
248
249
250
251
252
253
254
    @classmethod
    def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
        if enforce_eager:
            logger.warning(
                "To see benefits of async output processing, enable CUDA "
                "graph. Since, enforce-eager is enabled, async output "
                "processor cannot be used")
            return False
        return True

255
    @classmethod
256
    def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
257
258
259
260
        cache_config = vllm_config.cache_config
        if cache_config and cache_config.block_size is None:
            cache_config.block_size = 16

261
262
263
264
        parallel_config = vllm_config.parallel_config
        scheduler_config = vllm_config.scheduler_config
        if parallel_config.worker_cls == "auto":
            if scheduler_config.is_multi_step:
265
266
267
                if envs.VLLM_USE_V1:
                    raise NotImplementedError(
                        "Multi-step scheduling is not supported (and not "
268
                        "needed) on vLLM V1. Please launch without "
269
270
271
272
                        "--num-scheduler-steps.")
                else:
                    parallel_config.worker_cls = \
                        "vllm.worker.multi_step_worker.MultiStepWorker"
273
            elif vllm_config.speculative_config:
274
275
                if envs.VLLM_USE_V1:
                    raise NotImplementedError(
276
                        "Speculative decoding is not yet supported on vLLM V1."
277
278
279
280
281
282
                    )
                else:
                    parallel_config.worker_cls = \
                        "vllm.spec_decode.spec_decode_worker.create_spec_worker"
                    parallel_config.sd_worker_cls = \
                        "vllm.worker.worker.Worker"
283
            else:
284
285
286
287
288
                if envs.VLLM_USE_V1:
                    parallel_config.worker_cls = \
                            "vllm.v1.worker.gpu_worker.Worker"
                else:
                    parallel_config.worker_cls = "vllm.worker.worker.Worker"
289

290
291
292
293
294
295
296
297
298
299
300
301
    @classmethod
    def verify_model_arch(cls, model_arch: str) -> None:
        if model_arch in _ROCM_UNSUPPORTED_MODELS:
            raise ValueError(f"Model architecture '{model_arch}' is not "
                             "supported by ROCm for now.")

        if model_arch in _ROCM_PARTIALLY_SUPPORTED_MODELS:
            msg = _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch]
            logger.warning(
                "Model architecture '%s' is partially "
                "supported by ROCm: %s", model_arch, msg)

302
303
304
305
306
307
308
309
    @classmethod
    def verify_quantization(cls, quant: str) -> None:
        super().verify_quantization(quant)
        if quant == "awq" and not envs.VLLM_USE_TRITON_AWQ:
            logger.warning(
                "Using AWQ quantization with ROCm, but VLLM_USE_TRITON_AWQ"
                " is not set, enabling VLLM_USE_TRITON_AWQ.")
        envs.VLLM_USE_TRITON_AWQ = True
310
311
312
313

    @classmethod
    def get_punica_wrapper(cls) -> str:
        return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU"
314
315
316
317
318
319

    @classmethod
    def get_current_memory_usage(cls,
                                 device: Optional[torch.types.Device] = None
                                 ) -> float:
        torch.cuda.reset_peak_memory_stats(device)
320
321
        return torch.cuda.mem_get_info(device)[1] - torch.cuda.mem_get_info(
            device)[0]
322
323
324
325

    @classmethod
    def get_device_communicator_cls(cls) -> str:
        return "vllm.distributed.device_communicators.cuda_communicator.CudaCommunicator"  # noqa
326

327
328
329
330
331
    @classmethod
    def supports_mx(cls) -> bool:
        gcn_arch = torch.cuda.get_device_properties(0).gcnArchName
        return any(gfx in gcn_arch for gfx in ["gfx95"])

332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
    @classmethod
    def supports_fp8(cls) -> bool:
        gcn_arch = torch.cuda.get_device_properties(0).gcnArchName
        return any(gfx in gcn_arch for gfx in ['gfx94', 'gfx95', 'gfx12'])

    @classmethod
    def is_fp8_fnuz(cls) -> bool:
        # only device 0 is checked, this assumes MI300 platforms are homogeneous
        return 'gfx94' in torch.cuda.get_device_properties(0).gcnArchName

    @classmethod
    def fp8_dtype(cls) -> torch.dtype:
        if cls.is_fp8_fnuz():
            return torch.float8_e4m3fnuz
        else:
            return torch.float8_e4m3fn
348
349

    @classmethod
350
    def supports_v1(cls, model_config: "ModelConfig") -> bool:
351
352
        # V1 support on AMD gpus is experimental
        return True
353
354
355
356
357

    @classmethod
    def use_custom_allreduce(cls) -> bool:
        # We only enable custom allreduce for MI300 series
        gcn_arch = torch.cuda.get_device_properties(0).gcnArchName
358
        supported_archs = ['gfx94', 'gfx95']
359
        return any(gfx in gcn_arch for gfx in supported_archs)
360
361
362
363

    @classmethod
    def get_cu_count(cls, device_id: int = 0) -> int:
        return torch.cuda.get_device_properties(
364
            device_id).multi_processor_count