rocm.py 21.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import os
5
from functools import cache, lru_cache, wraps
6
from typing import TYPE_CHECKING, Optional
7
8
9

import torch

10
import vllm.envs as envs
11
from vllm.attention.backends.registry import AttentionBackendEnum
12
from vllm.logger import init_logger
13
from vllm.utils.torch_utils import cuda_device_count_stateless
14

15
from .interface import DeviceCapability, Platform, PlatformEnum
16

17
if TYPE_CHECKING:
18
    from vllm.attention.selector import AttentionSelectorConfig
19
    from vllm.config import VllmConfig
20

21
22
logger = init_logger(__name__)

23
try:
24
25
26
27
28
29
30
31
    from amdsmi import (
        AmdSmiException,
        amdsmi_get_gpu_asic_info,
        amdsmi_get_processor_handles,
        amdsmi_init,
        amdsmi_shut_down,
        amdsmi_topo_get_link_type,
    )
32
33
34
except ImportError as e:
    logger.warning("Failed to import from amdsmi with %r", e)

35
36
37
38
39
40
41
42
43
44
45
try:
    import vllm._C  # noqa: F401
except ImportError as e:
    logger.warning("Failed to import from vllm._C with %r", e)

# import custom ops, trigger op registration
try:
    import vllm._rocm_C  # noqa: F401
except ImportError as e:
    logger.warning("Failed to import from vllm._rocm_C with %r", e)

46
# Models not supported by ROCm.
47
_ROCM_UNSUPPORTED_MODELS: list[str] = []
48
49
50

# Models partially supported by ROCm.
# Architecture -> Reason.
51
52
_ROCM_SWA_REASON = ()
_ROCM_PARTIALLY_SUPPORTED_MODELS: dict[str, str] = {}
53
_ROCM_DEVICE_ID_NAME_MAP: dict[str, str] = {
54
55
56
    "0x74a0": "AMD_Instinct_MI300A",
    "0x74a1": "AMD_Instinct_MI300X",
    "0x74b5": "AMD_Instinct_MI300X",  # MI300X VF
57
    "0x74a2": "AMD_Instinct_MI308X",
58
59
60
61
    "0x74a5": "AMD_Instinct_MI325X",
    "0x74b9": "AMD_Instinct_MI325X",  # MI325X VF
    "0x74a9": "AMD_Instinct_MI300X_HF",
    "0x74bd": "AMD_Instinct_MI300X_HF",
62
    "0x744c": "AMD_Radeon_RX7900XTX",
63
}
64

65
# Prevent use of clashing `{CUDA/HIP}_VISIBLE_DEVICES`
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
if "HIP_VISIBLE_DEVICES" in os.environ:
    val = os.environ["HIP_VISIBLE_DEVICES"]
    if cuda_val := os.environ.get("CUDA_VISIBLE_DEVICES", None):
        assert val == cuda_val
    else:
        os.environ["CUDA_VISIBLE_DEVICES"] = val

# AMDSMI utils
# Note that NVML is not affected by `{CUDA/HIP}_VISIBLE_DEVICES`,
# all the related functions work on real physical device ids.
# the major benefit of using AMDSMI is that it will not initialize CUDA


def with_amdsmi_context(fn):
    @wraps(fn)
    def wrapper(*args, **kwargs):
        amdsmi_init()
        try:
            return fn(*args, **kwargs)
        finally:
            amdsmi_shut_down()

    return wrapper


91
92
93
94
95
96
@cache
def on_gfx1x() -> bool:
    GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
    return any(arch in GPU_ARCH for arch in ["gfx11", "gfx12"])


97
@cache
98
def on_mi3xx() -> bool:
99
    GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
100
101
102
103
104
105
106
    return any(arch in GPU_ARCH for arch in ["gfx942", "gfx950"])


@cache
def on_gfx9() -> bool:
    GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
    return any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"])
107
108


109
110
111
112
113
114
@cache
def on_gfx950() -> bool:
    GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
    return any(arch in GPU_ARCH for arch in ["gfx950"])


115
@cache
116
def use_rocm_custom_paged_attention(
117
118
119
120
121
122
123
    qtype: torch.dtype,
    head_size: int,
    block_size: int,
    gqa_ratio: int,
    max_seq_len: int,
    sliding_window: int,
    kv_cache_dtype: str,
124
125
    alibi_slopes: torch.Tensor | None = None,
    sinks: torch.Tensor | None = None,
126
) -> bool:
127
128
    from vllm._aiter_ops import rocm_aiter_ops

129
130
    GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
    ON_GFX9 = any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"])
131
    ON_GFX11_GFX12 = any(arch in GPU_ARCH for arch in ["gfx11", "gfx12"])
132

133
134
    # custom paged attn always supported on V0. On V1, requires sliding window
    # disabled due to observed numerical discrepancy.
135
    if ON_GFX9:
136
        return (
137
            (sliding_window == 0 or sliding_window == (-1, -1))
138
139
140
141
142
143
            and (qtype == torch.half or qtype == torch.bfloat16)
            and (head_size == 64 or head_size == 128)
            and (block_size == 16 or block_size == 32)
            and (gqa_ratio >= 1 and gqa_ratio <= 16)
            and max_seq_len <= 128 * 1024
            and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN)
144
            and not (rocm_aiter_ops.is_pa_attn_enabled())
145
146
            and sinks is None
        )
147
148

    else:
149
150
        return (
            ON_GFX11_GFX12
151
            and (sliding_window == 0 or sliding_window == (-1, -1))
152
153
154
155
156
157
158
159
160
161
            and (qtype == torch.half or qtype == torch.bfloat16)
            and head_size == 128
            and block_size == 16
            and (gqa_ratio >= 3 and gqa_ratio <= 16)
            and max_seq_len <= 128 * 1024
            and alibi_slopes is None
            and kv_cache_dtype == "auto"
            and envs.VLLM_ROCM_CUSTOM_PAGED_ATTN
            and sinks is None
        )
162
163


164
165
class RocmPlatform(Platform):
    _enum = PlatformEnum.ROCM
166
    device_name: str = "rocm"
167
    device_type: str = "cuda"
168
    dispatch_key: str = "CUDA"
169
    ray_device_key: str = "GPU"
170
    dist_backend: str = "nccl"
171
172
    # rocm shares the same device control env var as CUDA
    device_control_env_var: str = "CUDA_VISIBLE_DEVICES"
173

174
    supported_quantization: list[str] = [
175
176
177
178
179
180
181
182
183
184
185
        "awq",
        "gptq",
        "fp8",
        "compressed-tensors",
        "fbgemm_fp8",
        "gguf",
        "quark",
        "ptpc_fp8",
        "mxfp4",
        "petit_nvfp4",
        "torchao",
186
    ]
187
188
189
    # bitsandbytes not supported on gfx9 (warp size 64 limitation)
    if not on_gfx9():
        supported_quantization += ["bitsandbytes"]
190

191
    @classmethod
192
193
    def get_attn_backend_cls(
        cls,
194
195
        selected_backend: "AttentionBackendEnum",
        attn_selector_config: "AttentionSelectorConfig",
196
    ) -> str:
197
        from vllm._aiter_ops import rocm_aiter_ops
198

199
200
201
202
203
        block_size = attn_selector_config.block_size
        kv_cache_dtype = attn_selector_config.kv_cache_dtype

        if attn_selector_config.use_sparse:
            if kv_cache_dtype and kv_cache_dtype.startswith("fp8"):
204
205
206
207
208
209
210
                raise ValueError(
                    "ROCMAiterMLASparseBackend doesn't support fp8 kv_cache_dtype."
                )
            assert block_size == 1, (
                "Sparse MLA backend on ROCm only supports block size 1 for now."
            )
            logger.info_once("Using Sparse MLA backend on V1 engine.")
211
            return AttentionBackendEnum.ROCM_AITER_MLA_SPARSE.get_path()
212

213
        if attn_selector_config.use_mla:
214
            if selected_backend is None:
215
                selected_backend = (
216
                    AttentionBackendEnum.ROCM_AITER_MLA
217
                    if rocm_aiter_ops.is_mla_enabled() or block_size == 1
218
                    else AttentionBackendEnum.TRITON_MLA
219
                )
220
            if selected_backend == AttentionBackendEnum.TRITON_MLA:
221
                if block_size != 1:
222
                    logger.info_once("Using Triton MLA backend.")
223
                    return AttentionBackendEnum.TRITON_MLA.get_path()
224
225
                raise ValueError(
                    f" The selected backend, {selected_backend.name},"
226
227
                    f"does not support block size {block_size}."
                )
228
            if selected_backend == AttentionBackendEnum.ROCM_AITER_MLA:
229
                logger.info("Using AITER MLA backend.")
230
                return AttentionBackendEnum.ROCM_AITER_MLA.get_path()
231
232
233
            if selected_backend == AttentionBackendEnum.ROCM_AITER_TRITON_MLA:
                logger.info("Using AITER TRITON MLA backend.")
                return AttentionBackendEnum.ROCM_AITER_TRITON_MLA.get_path()
234

235
236
            raise ValueError(
                f" The selected backend, {selected_backend.name},"
237
238
                f"is not MLA type while requested for MLA backend."
            )
239

240
241
242
243
        if selected_backend == AttentionBackendEnum.FLEX_ATTENTION:
            logger.info("Using FlexAttention backend.")
            return AttentionBackendEnum.FLEX_ATTENTION.get_path()

244
245
246
247
248
249
        if selected_backend == AttentionBackendEnum.TRITON_ATTN:
            logger.info("Using Triton Attention backend on V1 engine.")
            return AttentionBackendEnum.TRITON_ATTN.get_path()

        if selected_backend == AttentionBackendEnum.ROCM_ATTN:
            logger.info("Using Rocm Attention backend on V1 engine.")
250
            return AttentionBackendEnum.ROCM_ATTN.get_path()
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298

        if selected_backend == AttentionBackendEnum.ROCM_AITER_FA:
            if on_gfx9():
                logger.info("Using Aiter Flash Attention backend on V1 engine.")
                return AttentionBackendEnum.ROCM_AITER_FA.get_path()
            else:
                raise ValueError(
                    f"The selected backend, {selected_backend.name}, "
                    "is only supported on gfx9 architectures."
                )

        if selected_backend == AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN:
            logger.info("Using Aiter Unified Attention backend on V1 engine.")
            return AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN.get_path()

        # Handle automatic backend selection based on environment variables
        if selected_backend is None:
            # Priority 1: Check for AITER Unified Attention (must check before MHA)
            if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION:
                logger.info("Using Aiter Unified Attention backend on V1 engine.")
                return AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN.get_path()

            # Priority 2: Check for AITER MHA (Flash Attention)
            # Only use if explicitly enabled (not just VLLM_ROCM_USE_AITER=1)
            if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9():
                logger.info("Using Aiter Flash Attention backend on V1 engine.")
                return AttentionBackendEnum.ROCM_AITER_FA.get_path()

            # Priority 3: Check for ROCM_ATTN (prefill-decode split)
            if envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION:
                logger.info("Using Rocm Attention backend on V1 engine.")
                return AttentionBackendEnum.ROCM_ATTN.get_path()

            # Priority 4: Check for AITER enabled without specific flags
            # This defaults to AITER FA only if MHA is not explicitly disabled
            if (
                envs.VLLM_ROCM_USE_AITER
                and on_gfx9()
                and envs.VLLM_ROCM_USE_AITER_MHA is not False
            ):
                logger.info("Using Aiter Flash Attention backend on V1 engine.")
                return AttentionBackendEnum.ROCM_AITER_FA.get_path()

            # Default: Triton Unified Attention
            logger.info("Using Triton Attention backend on V1 engine.")
            return AttentionBackendEnum.TRITON_ATTN.get_path()

        raise RuntimeError(
299
300
            f"Attention backend {selected_backend.name} is not supported on "
            "ROCm. Note that V0 attention backends have been removed."
301
        )
302

303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
    @classmethod
    def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]:
        return [
            AttentionBackendEnum.FLASH_ATTN,
            AttentionBackendEnum.ROCM_AITER_FA,
            AttentionBackendEnum.TORCH_SDPA,
        ]

    @classmethod
    def get_vit_attn_backend(
        cls,
        head_size: int,
        dtype: torch.dtype,
        backend: Optional["AttentionBackendEnum"] = None,
    ) -> "AttentionBackendEnum":
        if backend is not None:
            assert backend in cls.get_supported_vit_attn_backends(), (
                f"Backend {backend} is not supported for vit attention. "
                f"Supported backends are: {cls.get_supported_vit_attn_backends()}"
            )
            logger.info_once(f"Using backend {backend} for vit attention")
            return backend

        from importlib.util import find_spec

        from vllm._aiter_ops import rocm_aiter_ops

        if rocm_aiter_ops.is_mha_enabled():
            # Note: AITER FA is only supported for Qwen-VL models.
            # TODO: Add support for other VL models in their model class.
            return AttentionBackendEnum.ROCM_AITER_FA

        if on_gfx9() and find_spec("flash_attn") is not None:
            return AttentionBackendEnum.FLASH_ATTN

        return AttentionBackendEnum.TORCH_SDPA

340
341
342
343
344
345
346
    @classmethod
    def set_device(cls, device: torch.device) -> None:
        """
        Set the device for the current platform.
        """
        torch.cuda.set_device(device)

347
    @classmethod
348
    @lru_cache(maxsize=8)
349
    def get_device_capability(cls, device_id: int = 0) -> DeviceCapability | None:
350
351
        major, minor = torch.cuda.get_device_capability(device_id)
        return DeviceCapability(major=major, minor=minor)
352

353
    @classmethod
354
    @with_amdsmi_context
355
    def is_fully_connected(cls, physical_device_ids: list[int]) -> bool:
356
357
358
        """
        Query if the set of gpus are fully connected by xgmi (1 hop)
        """
359
        handles = [amdsmi_get_processor_handles()[i] for i in physical_device_ids]
360
361
362
363
        for i, handle in enumerate(handles):
            for j, peer_handle in enumerate(handles):
                if i < j:
                    try:
364
                        link_type = amdsmi_topo_get_link_type(handle, peer_handle)
365
366
367
368
                        # type is 2 for XGMI
                        if link_type["hops"] != 1 or link_type["type"] != 2:
                            return False
                    except AmdSmiException as error:
369
                        logger.error("AMD 1 hop XGMI detection failed.", exc_info=error)
370
371
372
                        return False
        return True

373
    @classmethod
374
    @with_amdsmi_context
375
    @lru_cache(maxsize=8)
376
    def get_device_name(cls, device_id: int = 0) -> str:
377
        physical_device_id = cls.device_id_to_physical_device_id(device_id)
378
        handle = amdsmi_get_processor_handles()[physical_device_id]
379
380
381
382
383
        asic_info = amdsmi_get_gpu_asic_info(handle)
        device_name: str = asic_info["device_id"]
        if device_name in _ROCM_DEVICE_ID_NAME_MAP:
            return _ROCM_DEVICE_ID_NAME_MAP[device_name]
        return asic_info["market_name"]
384
385
386
387
388

    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        device_props = torch.cuda.get_device_properties(device_id)
        return device_props.total_memory
389
390

    @classmethod
391
    def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
392
        from vllm._aiter_ops import rocm_aiter_ops
393
394
        from vllm.config.compilation import CUDAGraphMode

395
        cache_config = vllm_config.cache_config
396
397
398
        compilation_config = vllm_config.compilation_config
        parallel_config = vllm_config.parallel_config
        is_eager_execution = compilation_config == CUDAGraphMode.NONE
vllmellm's avatar
vllmellm committed
399
400
        use_aiter_rms_norm = rocm_aiter_ops.is_rmsnorm_enabled()
        use_aiter_fp8_linear = rocm_aiter_ops.is_linear_fp8_enaled()
401

402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
        if compilation_config.cudagraph_mode.has_full_cudagraphs():
            # decode context parallel does not support full cudagraphs
            if parallel_config.decode_context_parallel_size > 1:
                logger.warning_once(
                    "Decode context parallel (DCP) is enabled, which is "
                    "incompatible with full CUDA graphs. "
                    "Overriding cudagraph_mode to PIECEWISE."
                )
                compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
            # prefill context parallel do not support full cudagraphs
            elif parallel_config.prefill_context_parallel_size > 1:
                logger.warning_once(
                    "Prefill context parallel (PCP) is enabled, which is "
                    "incompatible with full CUDA graphs. "
                    "Overriding cudagraph_mode to PIECEWISE."
                )
                compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE

420
        if cache_config and cache_config.block_size is None:
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
            if (
                envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION and envs.VLLM_ROCM_USE_AITER
                # NOTE: This block has been deprecated
                # or get_env_variable_attn_backend()
                # == AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN
                # TODO: monitor https://github.com/vllm-project/vllm/pull/30396
                # to see how we can transition to the new way of selecting
                # attention backends
            ):
                cache_config.block_size = 64
                logger.warning(
                    "[ROCM_AITER_UNIFIED_ATTN]: Setting kv cache block size to 64."
                )
            else:
                cache_config.block_size = 16
436

437
        if parallel_config.worker_cls == "auto":
438
            parallel_config.worker_cls = "vllm.v1.worker.gpu_worker.Worker"
439
        #  Aiter rms norm perform best when CUDA Graph capture is enabled.
440
        if (
441
            use_aiter_rms_norm
442
443
444
            and not is_eager_execution
            and "-rms_norm" not in compilation_config.custom_ops
        ):
445
            compilation_config.custom_ops.append("+rms_norm")
446

vllmellm's avatar
vllmellm committed
447
448
449
        if use_aiter_fp8_linear and "-quant_fp8" not in compilation_config.custom_ops:
            compilation_config.custom_ops.append("+quant_fp8")

450
451
452
    @classmethod
    def verify_model_arch(cls, model_arch: str) -> None:
        if model_arch in _ROCM_UNSUPPORTED_MODELS:
453
454
455
            raise ValueError(
                f"Model architecture '{model_arch}' is not supported by ROCm for now."
            )
456
457
458
459

        if model_arch in _ROCM_PARTIALLY_SUPPORTED_MODELS:
            msg = _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch]
            logger.warning(
460
461
462
463
                "Model architecture '%s' is partially supported by ROCm: %s",
                model_arch,
                msg,
            )
464

465
466
467
468
469
470
    @classmethod
    def verify_quantization(cls, quant: str) -> None:
        super().verify_quantization(quant)
        if quant == "awq" and not envs.VLLM_USE_TRITON_AWQ:
            logger.warning(
                "Using AWQ quantization with ROCm, but VLLM_USE_TRITON_AWQ"
471
472
                " is not set, enabling VLLM_USE_TRITON_AWQ."
            )
473
        os.environ["VLLM_USE_TRITON_AWQ"] = "1"
474
475
476
477

    @classmethod
    def get_punica_wrapper(cls) -> str:
        return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU"
478
479

    @classmethod
480
    def get_current_memory_usage(
481
        cls, device: torch.types.Device | None = None
482
    ) -> float:
483
        torch.cuda.reset_peak_memory_stats(device)
484
        return torch.cuda.mem_get_info(device)[1] - torch.cuda.mem_get_info(device)[0]
485
486
487

    @classmethod
    def get_device_communicator_cls(cls) -> str:
488
489
490
        return (
            "vllm.distributed.device_communicators.cuda_communicator.CudaCommunicator"  # noqa
        )
491

492
493
494
495
496
    @classmethod
    def supports_mx(cls) -> bool:
        gcn_arch = torch.cuda.get_device_properties(0).gcnArchName
        return any(gfx in gcn_arch for gfx in ["gfx95"])

497
498
499
    @classmethod
    def supports_fp8(cls) -> bool:
        gcn_arch = torch.cuda.get_device_properties(0).gcnArchName
500
        return any(gfx in gcn_arch for gfx in ["gfx94", "gfx95", "gfx12"])
501
502
503
504

    @classmethod
    def is_fp8_fnuz(cls) -> bool:
        # only device 0 is checked, this assumes MI300 platforms are homogeneous
505
        return "gfx94" in torch.cuda.get_device_properties(0).gcnArchName
506
507
508
509
510
511
512

    @classmethod
    def fp8_dtype(cls) -> torch.dtype:
        if cls.is_fp8_fnuz():
            return torch.float8_e4m3fnuz
        else:
            return torch.float8_e4m3fn
513

514
515
516
517
    @classmethod
    def use_custom_allreduce(cls) -> bool:
        # We only enable custom allreduce for MI300 series
        gcn_arch = torch.cuda.get_device_properties(0).gcnArchName
518
        supported_archs = ["gfx94", "gfx95"]
519
        return any(gfx in gcn_arch for gfx in supported_archs)
520

521
522
523
524
    @classmethod
    def opaque_attention_op(cls) -> bool:
        return True

525
526
    @classmethod
    def is_navi(cls) -> bool:
527
        return "gfx1" in torch.cuda.get_device_properties(0).gcnArchName
528
529

    @classmethod
530
531
    def get_static_graph_wrapper_cls(cls) -> str:
        return "vllm.compilation.cuda_graph.CUDAGraphWrapper"
532

533
534
535
    @classmethod
    def device_count(cls) -> int:
        return cuda_device_count_stateless()
536

537
    @classmethod
538
539
    def check_if_supports_dtype(cls, dtype: torch.dtype):
        if dtype == torch.bfloat16:  # noqa: SIM102
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
            if not cls.has_device_capability(80):
                capability = cls.get_device_capability()
                gpu_name = cls.get_device_name()

                if capability is None:
                    compute_str = "does not have a compute capability"
                else:
                    version_str = capability.as_version_str()
                    compute_str = f"has compute capability {version_str}"

                raise ValueError(
                    "Bfloat16 is only supported on GPUs "
                    "with compute capability of at least 8.0. "
                    f"Your {gpu_name} GPU {compute_str}. "
                    "You can use float16 instead by explicitly setting the "
555
556
                    "`dtype` flag in CLI, for example: --dtype=half."
                )
557
558
559
560

    @classmethod
    def support_hybrid_kv_cache(cls) -> bool:
        return True
561
562
563
564

    @classmethod
    def support_static_graph_mode(cls) -> bool:
        return True