rocm.py 17.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import os
5
from functools import cache, lru_cache, wraps
6
from typing import TYPE_CHECKING
7
8
9

import torch

10
import vllm.envs as envs
11
from vllm.logger import init_logger
12
from vllm.utils import cuda_device_count_stateless
13

14
from .interface import DeviceCapability, Platform, PlatformEnum
15

16
if TYPE_CHECKING:
17
    from vllm.attention.backends.registry import _Backend
18
    from vllm.config import ModelConfig, VllmConfig
19
20
else:
    _Backend = None
21

22
23
logger = init_logger(__name__)

24
try:
25
26
27
28
29
30
31
32
    from amdsmi import (
        AmdSmiException,
        amdsmi_get_gpu_asic_info,
        amdsmi_get_processor_handles,
        amdsmi_init,
        amdsmi_shut_down,
        amdsmi_topo_get_link_type,
    )
33
34
35
except ImportError as e:
    logger.warning("Failed to import from amdsmi with %r", e)

36
37
38
39
40
41
42
43
44
45
46
try:
    import vllm._C  # noqa: F401
except ImportError as e:
    logger.warning("Failed to import from vllm._C with %r", e)

# import custom ops, trigger op registration
try:
    import vllm._rocm_C  # noqa: F401
except ImportError as e:
    logger.warning("Failed to import from vllm._rocm_C with %r", e)

47
# Models not supported by ROCm.
48
_ROCM_UNSUPPORTED_MODELS: list[str] = []
49
50
51

# Models partially supported by ROCm.
# Architecture -> Reason.
52
53
54
55
56
57
_ROCM_SWA_REASON = (
    "Sliding window attention (SWA) is not yet supported in "
    "Triton flash attention. For half-precision SWA support, "
    "please use CK flash attention by setting "
    "`VLLM_USE_TRITON_FLASH_ATTN=0`"
)
58
_ROCM_PARTIALLY_SUPPORTED_MODELS: dict[str, str] = {
59
60
61
62
63
64
65
66
67
68
69
    "Qwen2ForCausalLM": _ROCM_SWA_REASON,
    "MistralForCausalLM": _ROCM_SWA_REASON,
    "MixtralForCausalLM": _ROCM_SWA_REASON,
    "PaliGemmaForConditionalGeneration": (
        "ROCm flash attention does not yet fully support 32-bit precision on PaliGemma"
    ),
    "Phi3VForCausalLM": (
        "ROCm Triton flash attention may run into compilation errors due to "
        "excessive use of shared memory. If this happens, disable Triton FA "
        "by setting `VLLM_USE_TRITON_FLASH_ATTN=0`"
    ),
70
}
71
_ROCM_DEVICE_ID_NAME_MAP: dict[str, str] = {
72
73
74
75
76
77
78
79
    "0x74a0": "AMD_Instinct_MI300A",
    "0x74a1": "AMD_Instinct_MI300X",
    "0x74b5": "AMD_Instinct_MI300X",  # MI300X VF
    "0x74a5": "AMD_Instinct_MI325X",
    "0x74b9": "AMD_Instinct_MI325X",  # MI325X VF
    "0x74a9": "AMD_Instinct_MI300X_HF",
    "0x74bd": "AMD_Instinct_MI300X_HF",
}
80

81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
# Prevent use of clashing `{CUDA/HIP}_VISIBLE_DEVICES``
if "HIP_VISIBLE_DEVICES" in os.environ:
    val = os.environ["HIP_VISIBLE_DEVICES"]
    if cuda_val := os.environ.get("CUDA_VISIBLE_DEVICES", None):
        assert val == cuda_val
    else:
        os.environ["CUDA_VISIBLE_DEVICES"] = val

# AMDSMI utils
# Note that NVML is not affected by `{CUDA/HIP}_VISIBLE_DEVICES`,
# all the related functions work on real physical device ids.
# the major benefit of using AMDSMI is that it will not initialize CUDA


def with_amdsmi_context(fn):
    @wraps(fn)
    def wrapper(*args, **kwargs):
        amdsmi_init()
        try:
            return fn(*args, **kwargs)
        finally:
            amdsmi_shut_down()

    return wrapper


107
108
109
110
111
112
@cache
def on_gfx1x() -> bool:
    GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
    return any(arch in GPU_ARCH for arch in ["gfx11", "gfx12"])


113
@cache
114
def on_mi3xx() -> bool:
115
    GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
116
117
118
119
120
121
122
    return any(arch in GPU_ARCH for arch in ["gfx942", "gfx950"])


@cache
def on_gfx9() -> bool:
    GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
    return any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"])
123
124


125
126
127
128
129
130
@cache
def on_gfx950() -> bool:
    GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
    return any(arch in GPU_ARCH for arch in ["gfx950"])


131
@cache
132
def use_rocm_custom_paged_attention(
133
134
135
136
137
138
139
    qtype: torch.dtype,
    head_size: int,
    block_size: int,
    gqa_ratio: int,
    max_seq_len: int,
    sliding_window: int,
    kv_cache_dtype: str,
140
141
    alibi_slopes: torch.Tensor | None = None,
    sinks: torch.Tensor | None = None,
142
) -> bool:
143
144
    GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
    ON_GFX9 = any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"])
145
    ON_GFX11_GFX12 = any(arch in GPU_ARCH for arch in ["gfx11", "gfx12"])
146

147
148
    # custom paged attn always supported on V0. On V1, requires sliding window
    # disabled due to observed numerical discrepancy.
149
    if ON_GFX9:
150
151
152
153
154
155
156
157
158
159
160
        return (
            (not envs.VLLM_USE_V1 or sliding_window == 0 or sliding_window == (-1, -1))
            and (qtype == torch.half or qtype == torch.bfloat16)
            and (head_size == 64 or head_size == 128)
            and (block_size == 16 or block_size == 32)
            and (gqa_ratio >= 1 and gqa_ratio <= 16)
            and max_seq_len <= 128 * 1024
            and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN)
            and not (envs.VLLM_ROCM_USE_AITER_PAGED_ATTN and envs.VLLM_ROCM_USE_AITER)
            and sinks is None
        )
161
162

    else:
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
        return (
            ON_GFX11_GFX12
            and (
                not envs.VLLM_USE_V1
                or sliding_window == 0
                or sliding_window == (-1, -1)
            )
            and (qtype == torch.half or qtype == torch.bfloat16)
            and head_size == 128
            and block_size == 16
            and (gqa_ratio >= 3 and gqa_ratio <= 16)
            and max_seq_len <= 128 * 1024
            and alibi_slopes is None
            and kv_cache_dtype == "auto"
            and envs.VLLM_ROCM_CUSTOM_PAGED_ATTN
            and sinks is None
        )
180
181


182
183
class RocmPlatform(Platform):
    _enum = PlatformEnum.ROCM
184
    device_name: str = "rocm"
185
    device_type: str = "cuda"
186
    dispatch_key: str = "CUDA"
187
    ray_device_key: str = "GPU"
188
    dist_backend: str = "nccl"
189
190
    # rocm shares the same device control env var as CUDA
    device_control_env_var: str = "CUDA_VISIBLE_DEVICES"
191

192
    supported_quantization: list[str] = [
193
194
195
196
197
198
199
200
201
202
203
        "awq",
        "gptq",
        "fp8",
        "compressed-tensors",
        "fbgemm_fp8",
        "gguf",
        "quark",
        "ptpc_fp8",
        "mxfp4",
        "petit_nvfp4",
        "torchao",
204
    ]
205

206
    @classmethod
207
    def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> "_Backend":
208
        from vllm.attention.backends.registry import _Backend
209
210

        if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9():
211
212
213
            return _Backend.ROCM_AITER_FA
        if on_gfx9():
            return _Backend.FLASH_ATTN
214
215
        return _Backend.TORCH_SDPA

216
    @classmethod
217
218
219
220
221
222
223
224
225
226
227
228
    def get_attn_backend_cls(
        cls,
        selected_backend,
        head_size,
        dtype,
        kv_cache_dtype,
        block_size,
        use_v1,
        use_mla,
        has_sink,
        use_sparse,
    ) -> str:
229
        from vllm.attention.backends.registry import _Backend
230

231
        if use_sparse:
232
            raise NotImplementedError("Sparse Attention is not supported on ROCm.")
233
        if use_mla:
234
235
236
            if not use_v1:
                raise RuntimeError(
                    "MLA attention backends require the V1 engine. "
237
238
                    "Set VLLM_USE_V1=1 to enable them."
                )
239

240
            from vllm.v1.attention.backends.mla.rocm_aiter_mla import (
241
242
                is_aiter_mla_enabled,
            )
243
244

            if selected_backend is None:
245
246
247
248
249
                selected_backend = (
                    _Backend.ROCM_AITER_MLA
                    if is_aiter_mla_enabled() or block_size == 1
                    else _Backend.TRITON_MLA
                )
250
251
252

            if selected_backend == _Backend.TRITON_MLA:
                if block_size != 1:
253
                    logger.info_once("Using Triton MLA backend on V1 engine.")
254
                    return "vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend"
255
256
                raise ValueError(
                    f" The selected backend, {selected_backend.name},"
257
258
                    f"does not support block size {block_size}."
                )
259
            if selected_backend == _Backend.ROCM_AITER_MLA:
260
                if block_size == 1:
261
                    logger.info("Using AITER MLA backend on V1 engine.")
262
263
264
                    return (
                        "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend"  # noqa: E501
                    )
265
266
                raise ValueError(
                    f" The selected backend, {selected_backend.name},"
267
                    f"does not support block size {block_size}."
268
269
                    "(currently only supports block size 1)"
                )
270
271
            raise ValueError(
                f" The selected backend, {selected_backend.name},"
272
273
                f"is not MLA type while requested for MLA backend."
            )
274

275
        if envs.VLLM_USE_V1:
276
277
278
            if selected_backend == _Backend.FLEX_ATTENTION:
                logger.info("Using FlexAttention backend on V1 engine.")
                return "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend"
279
280
281
282
            if (
                envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9()
            ) or selected_backend == _Backend.ROCM_AITER_FA:
                logger.info("Using Aiter Flash Attention backend on V1 engine.")
283
284
285
286
                return (
                    "vllm.v1.attention.backends."
                    "rocm_aiter_fa.AiterFlashAttentionBackend"
                )
287
288
289
290
291
292
293
294
295
296
            if (
                envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION
            ) or selected_backend == _Backend.ROCM_AITER_UNIFIED_ATTN:
                logger.info("Using Aiter Unified Attention backend on V1 engine.")
                return (
                    "vllm.v1.attention.backends."
                    "rocm_aiter_unified_attn.RocmAiterUnifiedAttentionBackend"
                )
            if (
                envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION
297
298
                or selected_backend == _Backend.ROCM_ATTN
            ):
299
300
                # rocm specific backend, with aiter and/or
                #   triton prefix-prefill
301
                logger.info("Using Rocm Attention backend on V1 engine.")
302
                return "vllm.v1.attention.backends.rocm_attn.RocmAttentionBackend"
303
304
305
            # default case, using triton unified attention
            logger.info("Using Triton Attention backend on V1 engine.")
            return "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend"
306
307
        raise RuntimeError(
            "V0 attention backends have been removed. Set VLLM_USE_V1=1 "
308
309
            "to select a supported backend."
        )
310

311
312
313
314
315
316
317
    @classmethod
    def set_device(cls, device: torch.device) -> None:
        """
        Set the device for the current platform.
        """
        torch.cuda.set_device(device)

318
    @classmethod
319
    @lru_cache(maxsize=8)
320
    def get_device_capability(cls, device_id: int = 0) -> DeviceCapability | None:
321
322
        major, minor = torch.cuda.get_device_capability(device_id)
        return DeviceCapability(major=major, minor=minor)
323

324
    @classmethod
325
    @with_amdsmi_context
326
    def is_fully_connected(cls, physical_device_ids: list[int]) -> bool:
327
328
329
        """
        Query if the set of gpus are fully connected by xgmi (1 hop)
        """
330
        handles = [amdsmi_get_processor_handles()[i] for i in physical_device_ids]
331
332
333
334
        for i, handle in enumerate(handles):
            for j, peer_handle in enumerate(handles):
                if i < j:
                    try:
335
                        link_type = amdsmi_topo_get_link_type(handle, peer_handle)
336
337
338
339
                        # type is 2 for XGMI
                        if link_type["hops"] != 1 or link_type["type"] != 2:
                            return False
                    except AmdSmiException as error:
340
                        logger.error("AMD 1 hop XGMI detection failed.", exc_info=error)
341
342
343
                        return False
        return True

344
    @classmethod
345
    @with_amdsmi_context
346
    @lru_cache(maxsize=8)
347
    def get_device_name(cls, device_id: int = 0) -> str:
348
        physical_device_id = cls.device_id_to_physical_device_id(device_id)
349
        handle = amdsmi_get_processor_handles()[physical_device_id]
350
351
352
353
354
        asic_info = amdsmi_get_gpu_asic_info(handle)
        device_name: str = asic_info["device_id"]
        if device_name in _ROCM_DEVICE_ID_NAME_MAP:
            return _ROCM_DEVICE_ID_NAME_MAP[device_name]
        return asic_info["market_name"]
355
356
357
358
359

    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        device_props = torch.cuda.get_device_properties(device_id)
        return device_props.total_memory
360
361

    @classmethod
362
    def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
363
364
        from vllm.config.compilation import CUDAGraphMode

365
        cache_config = vllm_config.cache_config
366
367
368
369
370
        compilation_config = vllm_config.compilation_config
        parallel_config = vllm_config.parallel_config
        is_eager_execution = compilation_config == CUDAGraphMode.NONE

        use_v1 = envs.VLLM_USE_V1
371
372
373
        use_aiter_rms_norm = (
            envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_RMSNORM
        )
374

375
376
377
        if cache_config and cache_config.block_size is None:
            cache_config.block_size = 16

378
        if parallel_config.worker_cls == "auto":
379
            parallel_config.worker_cls = "vllm.v1.worker.gpu_worker.Worker"
380
        #  Aiter rms norm perform best when CUDA Graph capture is enabled.
381
382
383
384
385
386
        if (
            use_v1
            and use_aiter_rms_norm
            and not is_eager_execution
            and "-rms_norm" not in compilation_config.custom_ops
        ):
387
            compilation_config.custom_ops.append("+rms_norm")
388

389
390
391
    @classmethod
    def verify_model_arch(cls, model_arch: str) -> None:
        if model_arch in _ROCM_UNSUPPORTED_MODELS:
392
393
394
            raise ValueError(
                f"Model architecture '{model_arch}' is not supported by ROCm for now."
            )
395
396
397
398

        if model_arch in _ROCM_PARTIALLY_SUPPORTED_MODELS:
            msg = _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch]
            logger.warning(
399
400
401
402
                "Model architecture '%s' is partially supported by ROCm: %s",
                model_arch,
                msg,
            )
403

404
405
406
407
408
409
    @classmethod
    def verify_quantization(cls, quant: str) -> None:
        super().verify_quantization(quant)
        if quant == "awq" and not envs.VLLM_USE_TRITON_AWQ:
            logger.warning(
                "Using AWQ quantization with ROCm, but VLLM_USE_TRITON_AWQ"
410
411
                " is not set, enabling VLLM_USE_TRITON_AWQ."
            )
412
        envs.VLLM_USE_TRITON_AWQ = True
413
414
415
416

    @classmethod
    def get_punica_wrapper(cls) -> str:
        return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU"
417
418

    @classmethod
419
    def get_current_memory_usage(
420
        cls, device: torch.types.Device | None = None
421
    ) -> float:
422
        torch.cuda.reset_peak_memory_stats(device)
423
        return torch.cuda.mem_get_info(device)[1] - torch.cuda.mem_get_info(device)[0]
424
425
426

    @classmethod
    def get_device_communicator_cls(cls) -> str:
427
428
429
        return (
            "vllm.distributed.device_communicators.cuda_communicator.CudaCommunicator"  # noqa
        )
430

431
432
433
434
435
    @classmethod
    def supports_mx(cls) -> bool:
        gcn_arch = torch.cuda.get_device_properties(0).gcnArchName
        return any(gfx in gcn_arch for gfx in ["gfx95"])

436
437
438
    @classmethod
    def supports_fp8(cls) -> bool:
        gcn_arch = torch.cuda.get_device_properties(0).gcnArchName
439
        return any(gfx in gcn_arch for gfx in ["gfx94", "gfx95", "gfx12"])
440
441
442
443

    @classmethod
    def is_fp8_fnuz(cls) -> bool:
        # only device 0 is checked, this assumes MI300 platforms are homogeneous
444
        return "gfx94" in torch.cuda.get_device_properties(0).gcnArchName
445
446
447
448
449
450
451

    @classmethod
    def fp8_dtype(cls) -> torch.dtype:
        if cls.is_fp8_fnuz():
            return torch.float8_e4m3fnuz
        else:
            return torch.float8_e4m3fn
452

453
454
455
456
    @classmethod
    def use_custom_allreduce(cls) -> bool:
        # We only enable custom allreduce for MI300 series
        gcn_arch = torch.cuda.get_device_properties(0).gcnArchName
457
        supported_archs = ["gfx94", "gfx95"]
458
        return any(gfx in gcn_arch for gfx in supported_archs)
459

460
461
462
463
    @classmethod
    def opaque_attention_op(cls) -> bool:
        return True

464
465
    @classmethod
    def get_cu_count(cls, device_id: int = 0) -> int:
466
        return torch.cuda.get_device_properties(device_id).multi_processor_count
467
468
469

    @classmethod
    def is_navi(cls) -> bool:
470
        return "gfx1" in torch.cuda.get_device_properties(0).gcnArchName
471
472

    @classmethod
473
474
    def get_static_graph_wrapper_cls(cls) -> str:
        return "vllm.compilation.cuda_graph.CUDAGraphWrapper"
475

476
477
478
    @classmethod
    def device_count(cls) -> int:
        return cuda_device_count_stateless()
479
480

    @classmethod
481
482
483
    def is_kv_cache_dtype_supported(
        cls, kv_cache_dtype: str, model_config: "ModelConfig"
    ) -> bool:
484
        return True
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503

    @classmethod
    def check_if_supports_dtype(cls, torch_dtype: torch.dtype):
        if torch_dtype == torch.bfloat16:  # noqa: SIM102
            if not cls.has_device_capability(80):
                capability = cls.get_device_capability()
                gpu_name = cls.get_device_name()

                if capability is None:
                    compute_str = "does not have a compute capability"
                else:
                    version_str = capability.as_version_str()
                    compute_str = f"has compute capability {version_str}"

                raise ValueError(
                    "Bfloat16 is only supported on GPUs "
                    "with compute capability of at least 8.0. "
                    f"Your {gpu_name} GPU {compute_str}. "
                    "You can use float16 instead by explicitly setting the "
504
505
                    "`dtype` flag in CLI, for example: --dtype=half."
                )
506
507
508
509

    @classmethod
    def support_hybrid_kv_cache(cls) -> bool:
        return True
510
511
512
513

    @classmethod
    def support_static_graph_mode(cls) -> bool:
        return True