rocm.py 19.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import os
5
from functools import cache, lru_cache, wraps
6
from typing import TYPE_CHECKING
7
8
9

import torch

10
import vllm.envs as envs
11
from vllm.attention.backends.registry import AttentionBackendEnum
12
from vllm.logger import init_logger
13
from vllm.utils.torch_utils import cuda_device_count_stateless
14

15
from .interface import DeviceCapability, Platform, PlatformEnum
16

17
if TYPE_CHECKING:
18
    from vllm.config import VllmConfig
19

20
21
logger = init_logger(__name__)

22
try:
23
24
25
26
27
28
29
30
    from amdsmi import (
        AmdSmiException,
        amdsmi_get_gpu_asic_info,
        amdsmi_get_processor_handles,
        amdsmi_init,
        amdsmi_shut_down,
        amdsmi_topo_get_link_type,
    )
31
32
33
except ImportError as e:
    logger.warning("Failed to import from amdsmi with %r", e)

34
35
36
37
38
39
40
41
42
43
44
try:
    import vllm._C  # noqa: F401
except ImportError as e:
    logger.warning("Failed to import from vllm._C with %r", e)

# import custom ops, trigger op registration
try:
    import vllm._rocm_C  # noqa: F401
except ImportError as e:
    logger.warning("Failed to import from vllm._rocm_C with %r", e)

45
# Models not supported by ROCm.
46
_ROCM_UNSUPPORTED_MODELS: list[str] = []
47
48
49

# Models partially supported by ROCm.
# Architecture -> Reason.
50
51
_ROCM_SWA_REASON = ()
_ROCM_PARTIALLY_SUPPORTED_MODELS: dict[str, str] = {}
52
_ROCM_DEVICE_ID_NAME_MAP: dict[str, str] = {
53
54
55
    "0x74a0": "AMD_Instinct_MI300A",
    "0x74a1": "AMD_Instinct_MI300X",
    "0x74b5": "AMD_Instinct_MI300X",  # MI300X VF
56
    "0x74a2": "AMD_Instinct_MI308X",
57
58
59
60
    "0x74a5": "AMD_Instinct_MI325X",
    "0x74b9": "AMD_Instinct_MI325X",  # MI325X VF
    "0x74a9": "AMD_Instinct_MI300X_HF",
    "0x74bd": "AMD_Instinct_MI300X_HF",
61
    "0x744c": "AMD_Radeon_RX7900XTX",
62
}
63

64
# Prevent use of clashing `{CUDA/HIP}_VISIBLE_DEVICES`
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
if "HIP_VISIBLE_DEVICES" in os.environ:
    val = os.environ["HIP_VISIBLE_DEVICES"]
    if cuda_val := os.environ.get("CUDA_VISIBLE_DEVICES", None):
        assert val == cuda_val
    else:
        os.environ["CUDA_VISIBLE_DEVICES"] = val

# AMDSMI utils
# Note that NVML is not affected by `{CUDA/HIP}_VISIBLE_DEVICES`,
# all the related functions work on real physical device ids.
# the major benefit of using AMDSMI is that it will not initialize CUDA


def with_amdsmi_context(fn):
    @wraps(fn)
    def wrapper(*args, **kwargs):
        amdsmi_init()
        try:
            return fn(*args, **kwargs)
        finally:
            amdsmi_shut_down()

    return wrapper


90
91
92
93
94
95
@cache
def on_gfx1x() -> bool:
    GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
    return any(arch in GPU_ARCH for arch in ["gfx11", "gfx12"])


96
@cache
97
def on_mi3xx() -> bool:
98
    GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
99
100
101
102
103
104
105
    return any(arch in GPU_ARCH for arch in ["gfx942", "gfx950"])


@cache
def on_gfx9() -> bool:
    GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
    return any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"])
106
107


108
109
110
111
112
113
@cache
def on_gfx950() -> bool:
    GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
    return any(arch in GPU_ARCH for arch in ["gfx950"])


114
@cache
115
def use_rocm_custom_paged_attention(
116
117
118
119
120
121
122
    qtype: torch.dtype,
    head_size: int,
    block_size: int,
    gqa_ratio: int,
    max_seq_len: int,
    sliding_window: int,
    kv_cache_dtype: str,
123
124
    alibi_slopes: torch.Tensor | None = None,
    sinks: torch.Tensor | None = None,
125
) -> bool:
126
127
    from vllm._aiter_ops import rocm_aiter_ops

128
129
    GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
    ON_GFX9 = any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"])
130
    ON_GFX11_GFX12 = any(arch in GPU_ARCH for arch in ["gfx11", "gfx12"])
131

132
133
    # custom paged attn always supported on V0. On V1, requires sliding window
    # disabled due to observed numerical discrepancy.
134
    if ON_GFX9:
135
        return (
136
            (sliding_window == 0 or sliding_window == (-1, -1))
137
138
139
140
141
142
            and (qtype == torch.half or qtype == torch.bfloat16)
            and (head_size == 64 or head_size == 128)
            and (block_size == 16 or block_size == 32)
            and (gqa_ratio >= 1 and gqa_ratio <= 16)
            and max_seq_len <= 128 * 1024
            and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN)
143
            and not (rocm_aiter_ops.is_pa_attn_enabled())
144
145
            and sinks is None
        )
146
147

    else:
148
149
        return (
            ON_GFX11_GFX12
150
            and (sliding_window == 0 or sliding_window == (-1, -1))
151
152
153
154
155
156
157
158
159
160
            and (qtype == torch.half or qtype == torch.bfloat16)
            and head_size == 128
            and block_size == 16
            and (gqa_ratio >= 3 and gqa_ratio <= 16)
            and max_seq_len <= 128 * 1024
            and alibi_slopes is None
            and kv_cache_dtype == "auto"
            and envs.VLLM_ROCM_CUSTOM_PAGED_ATTN
            and sinks is None
        )
161
162


163
164
class RocmPlatform(Platform):
    _enum = PlatformEnum.ROCM
165
    device_name: str = "rocm"
166
    device_type: str = "cuda"
167
    dispatch_key: str = "CUDA"
168
    ray_device_key: str = "GPU"
169
    dist_backend: str = "nccl"
170
171
    # rocm shares the same device control env var as CUDA
    device_control_env_var: str = "CUDA_VISIBLE_DEVICES"
172

173
    supported_quantization: list[str] = [
174
175
176
177
178
179
180
181
182
183
184
        "awq",
        "gptq",
        "fp8",
        "compressed-tensors",
        "fbgemm_fp8",
        "gguf",
        "quark",
        "ptpc_fp8",
        "mxfp4",
        "petit_nvfp4",
        "torchao",
185
    ]
186
187
188
    # bitsandbytes not supported on gfx9 (warp size 64 limitation)
    if not on_gfx9():
        supported_quantization += ["bitsandbytes"]
189

190
    @classmethod
191
192
193
    def get_vit_attn_backend(
        cls, head_size: int, dtype: torch.dtype
    ) -> AttentionBackendEnum:
194
195
        from importlib.util import find_spec

196
        from vllm._aiter_ops import rocm_aiter_ops
197

198
199
200
        if rocm_aiter_ops.is_mha_enabled():
            # Note: AITER FA is only supported for Qwen-VL models.
            # TODO: Add support for other VL models in their model class.
201
            return AttentionBackendEnum.ROCM_AITER_FA
202
203

        if on_gfx9() and find_spec("flash_attn") is not None:
204
            return AttentionBackendEnum.FLASH_ATTN
205

206
        return AttentionBackendEnum.TORCH_SDPA
207

208
    @classmethod
209
210
211
212
213
214
215
216
217
218
    def get_attn_backend_cls(
        cls,
        selected_backend,
        head_size,
        dtype,
        kv_cache_dtype,
        block_size,
        use_mla,
        has_sink,
        use_sparse,
219
        use_mm_prefix,
220
        attn_type: str | None = None,
221
    ) -> str:
222
        from vllm._aiter_ops import rocm_aiter_ops
223

224
        if use_sparse:
225
226
227
228
229
230
231
232
            if kv_cache_dtype.startswith("fp8"):
                raise ValueError(
                    "ROCMAiterMLASparseBackend doesn't support fp8 kv_cache_dtype."
                )
            assert block_size == 1, (
                "Sparse MLA backend on ROCm only supports block size 1 for now."
            )
            logger.info_once("Using Sparse MLA backend on V1 engine.")
233
            return AttentionBackendEnum.ROCM_AITER_MLA_SPARSE.get_path()
234
235

        if use_mla:
236
            if selected_backend is None:
237
                selected_backend = (
238
                    AttentionBackendEnum.ROCM_AITER_MLA
239
                    if rocm_aiter_ops.is_mla_enabled() or block_size == 1
240
                    else AttentionBackendEnum.TRITON_MLA
241
                )
242
            if selected_backend == AttentionBackendEnum.TRITON_MLA:
243
                if block_size != 1:
244
                    logger.info_once("Using Triton MLA backend.")
245
                    return AttentionBackendEnum.TRITON_MLA.get_path()
246
247
                raise ValueError(
                    f" The selected backend, {selected_backend.name},"
248
249
                    f"does not support block size {block_size}."
                )
250
            if selected_backend == AttentionBackendEnum.ROCM_AITER_MLA:
251
                logger.info("Using AITER MLA backend.")
252
                return AttentionBackendEnum.ROCM_AITER_MLA.get_path()
253
254
255
            if selected_backend == AttentionBackendEnum.ROCM_AITER_TRITON_MLA:
                logger.info("Using AITER TRITON MLA backend.")
                return AttentionBackendEnum.ROCM_AITER_TRITON_MLA.get_path()
256

257
258
            raise ValueError(
                f" The selected backend, {selected_backend.name},"
259
260
                f"is not MLA type while requested for MLA backend."
            )
261

262
263
264
265
        if selected_backend == AttentionBackendEnum.FLEX_ATTENTION:
            logger.info("Using FlexAttention backend.")
            return AttentionBackendEnum.FLEX_ATTENTION.get_path()

266
267
268
269
270
271
        if selected_backend == AttentionBackendEnum.TRITON_ATTN:
            logger.info("Using Triton Attention backend on V1 engine.")
            return AttentionBackendEnum.TRITON_ATTN.get_path()

        if selected_backend == AttentionBackendEnum.ROCM_ATTN:
            logger.info("Using Rocm Attention backend on V1 engine.")
272
            return AttentionBackendEnum.ROCM_ATTN.get_path()
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320

        if selected_backend == AttentionBackendEnum.ROCM_AITER_FA:
            if on_gfx9():
                logger.info("Using Aiter Flash Attention backend on V1 engine.")
                return AttentionBackendEnum.ROCM_AITER_FA.get_path()
            else:
                raise ValueError(
                    f"The selected backend, {selected_backend.name}, "
                    "is only supported on gfx9 architectures."
                )

        if selected_backend == AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN:
            logger.info("Using Aiter Unified Attention backend on V1 engine.")
            return AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN.get_path()

        # Handle automatic backend selection based on environment variables
        if selected_backend is None:
            # Priority 1: Check for AITER Unified Attention (must check before MHA)
            if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION:
                logger.info("Using Aiter Unified Attention backend on V1 engine.")
                return AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN.get_path()

            # Priority 2: Check for AITER MHA (Flash Attention)
            # Only use if explicitly enabled (not just VLLM_ROCM_USE_AITER=1)
            if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9():
                logger.info("Using Aiter Flash Attention backend on V1 engine.")
                return AttentionBackendEnum.ROCM_AITER_FA.get_path()

            # Priority 3: Check for ROCM_ATTN (prefill-decode split)
            if envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION:
                logger.info("Using Rocm Attention backend on V1 engine.")
                return AttentionBackendEnum.ROCM_ATTN.get_path()

            # Priority 4: Check for AITER enabled without specific flags
            # This defaults to AITER FA only if MHA is not explicitly disabled
            if (
                envs.VLLM_ROCM_USE_AITER
                and on_gfx9()
                and envs.VLLM_ROCM_USE_AITER_MHA is not False
            ):
                logger.info("Using Aiter Flash Attention backend on V1 engine.")
                return AttentionBackendEnum.ROCM_AITER_FA.get_path()

            # Default: Triton Unified Attention
            logger.info("Using Triton Attention backend on V1 engine.")
            return AttentionBackendEnum.TRITON_ATTN.get_path()

        raise RuntimeError(
321
322
            f"Attention backend {selected_backend.name} is not supported on "
            "ROCm. Note that V0 attention backends have been removed."
323
        )
324

325
326
327
328
329
330
331
    @classmethod
    def set_device(cls, device: torch.device) -> None:
        """
        Set the device for the current platform.
        """
        torch.cuda.set_device(device)

332
    @classmethod
333
    @lru_cache(maxsize=8)
334
    def get_device_capability(cls, device_id: int = 0) -> DeviceCapability | None:
335
336
        major, minor = torch.cuda.get_device_capability(device_id)
        return DeviceCapability(major=major, minor=minor)
337

338
    @classmethod
339
    @with_amdsmi_context
340
    def is_fully_connected(cls, physical_device_ids: list[int]) -> bool:
341
342
343
        """
        Query if the set of gpus are fully connected by xgmi (1 hop)
        """
344
        handles = [amdsmi_get_processor_handles()[i] for i in physical_device_ids]
345
346
347
348
        for i, handle in enumerate(handles):
            for j, peer_handle in enumerate(handles):
                if i < j:
                    try:
349
                        link_type = amdsmi_topo_get_link_type(handle, peer_handle)
350
351
352
353
                        # type is 2 for XGMI
                        if link_type["hops"] != 1 or link_type["type"] != 2:
                            return False
                    except AmdSmiException as error:
354
                        logger.error("AMD 1 hop XGMI detection failed.", exc_info=error)
355
356
357
                        return False
        return True

358
    @classmethod
359
    @with_amdsmi_context
360
    @lru_cache(maxsize=8)
361
    def get_device_name(cls, device_id: int = 0) -> str:
362
        physical_device_id = cls.device_id_to_physical_device_id(device_id)
363
        handle = amdsmi_get_processor_handles()[physical_device_id]
364
365
366
367
368
        asic_info = amdsmi_get_gpu_asic_info(handle)
        device_name: str = asic_info["device_id"]
        if device_name in _ROCM_DEVICE_ID_NAME_MAP:
            return _ROCM_DEVICE_ID_NAME_MAP[device_name]
        return asic_info["market_name"]
369
370
371
372
373

    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        device_props = torch.cuda.get_device_properties(device_id)
        return device_props.total_memory
374
375

    @classmethod
376
    def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
377
        from vllm._aiter_ops import rocm_aiter_ops
378
379
        from vllm.config.compilation import CUDAGraphMode

380
        cache_config = vllm_config.cache_config
381
382
383
384
        compilation_config = vllm_config.compilation_config
        parallel_config = vllm_config.parallel_config
        is_eager_execution = compilation_config == CUDAGraphMode.NONE

385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
        if compilation_config.cudagraph_mode.has_full_cudagraphs():
            # decode context parallel does not support full cudagraphs
            if parallel_config.decode_context_parallel_size > 1:
                logger.warning_once(
                    "Decode context parallel (DCP) is enabled, which is "
                    "incompatible with full CUDA graphs. "
                    "Overriding cudagraph_mode to PIECEWISE."
                )
                compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
            # prefill context parallel do not support full cudagraphs
            elif parallel_config.prefill_context_parallel_size > 1:
                logger.warning_once(
                    "Prefill context parallel (PCP) is enabled, which is "
                    "incompatible with full CUDA graphs. "
                    "Overriding cudagraph_mode to PIECEWISE."
                )
                compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE

403
        use_aiter_rms_norm = rocm_aiter_ops.is_rmsnorm_enabled()
404

405
406
407
        if cache_config and cache_config.block_size is None:
            cache_config.block_size = 16

408
        if parallel_config.worker_cls == "auto":
409
            parallel_config.worker_cls = "vllm.v1.worker.gpu_worker.Worker"
410
        #  Aiter rms norm perform best when CUDA Graph capture is enabled.
411
        if (
412
            use_aiter_rms_norm
413
414
415
            and not is_eager_execution
            and "-rms_norm" not in compilation_config.custom_ops
        ):
416
            compilation_config.custom_ops.append("+rms_norm")
417

418
419
420
    @classmethod
    def verify_model_arch(cls, model_arch: str) -> None:
        if model_arch in _ROCM_UNSUPPORTED_MODELS:
421
422
423
            raise ValueError(
                f"Model architecture '{model_arch}' is not supported by ROCm for now."
            )
424
425
426
427

        if model_arch in _ROCM_PARTIALLY_SUPPORTED_MODELS:
            msg = _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch]
            logger.warning(
428
429
430
431
                "Model architecture '%s' is partially supported by ROCm: %s",
                model_arch,
                msg,
            )
432

433
434
435
436
437
438
    @classmethod
    def verify_quantization(cls, quant: str) -> None:
        super().verify_quantization(quant)
        if quant == "awq" and not envs.VLLM_USE_TRITON_AWQ:
            logger.warning(
                "Using AWQ quantization with ROCm, but VLLM_USE_TRITON_AWQ"
439
440
                " is not set, enabling VLLM_USE_TRITON_AWQ."
            )
441
        os.environ["VLLM_USE_TRITON_AWQ"] = "1"
442
443
444
445

    @classmethod
    def get_punica_wrapper(cls) -> str:
        return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU"
446
447

    @classmethod
448
    def get_current_memory_usage(
449
        cls, device: torch.types.Device | None = None
450
    ) -> float:
451
        torch.cuda.reset_peak_memory_stats(device)
452
        return torch.cuda.mem_get_info(device)[1] - torch.cuda.mem_get_info(device)[0]
453
454
455

    @classmethod
    def get_device_communicator_cls(cls) -> str:
456
457
458
        return (
            "vllm.distributed.device_communicators.cuda_communicator.CudaCommunicator"  # noqa
        )
459

460
461
462
463
464
    @classmethod
    def supports_mx(cls) -> bool:
        gcn_arch = torch.cuda.get_device_properties(0).gcnArchName
        return any(gfx in gcn_arch for gfx in ["gfx95"])

465
466
467
    @classmethod
    def supports_fp8(cls) -> bool:
        gcn_arch = torch.cuda.get_device_properties(0).gcnArchName
468
        return any(gfx in gcn_arch for gfx in ["gfx94", "gfx95", "gfx12"])
469
470
471
472

    @classmethod
    def is_fp8_fnuz(cls) -> bool:
        # only device 0 is checked, this assumes MI300 platforms are homogeneous
473
        return "gfx94" in torch.cuda.get_device_properties(0).gcnArchName
474
475
476
477
478
479
480

    @classmethod
    def fp8_dtype(cls) -> torch.dtype:
        if cls.is_fp8_fnuz():
            return torch.float8_e4m3fnuz
        else:
            return torch.float8_e4m3fn
481

482
483
484
485
    @classmethod
    def use_custom_allreduce(cls) -> bool:
        # We only enable custom allreduce for MI300 series
        gcn_arch = torch.cuda.get_device_properties(0).gcnArchName
486
        supported_archs = ["gfx94", "gfx95"]
487
        return any(gfx in gcn_arch for gfx in supported_archs)
488

489
490
491
492
    @classmethod
    def opaque_attention_op(cls) -> bool:
        return True

493
494
    @classmethod
    def is_navi(cls) -> bool:
495
        return "gfx1" in torch.cuda.get_device_properties(0).gcnArchName
496
497

    @classmethod
498
499
    def get_static_graph_wrapper_cls(cls) -> str:
        return "vllm.compilation.cuda_graph.CUDAGraphWrapper"
500

501
502
503
    @classmethod
    def device_count(cls) -> int:
        return cuda_device_count_stateless()
504

505
    @classmethod
506
507
    def check_if_supports_dtype(cls, dtype: torch.dtype):
        if dtype == torch.bfloat16:  # noqa: SIM102
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
            if not cls.has_device_capability(80):
                capability = cls.get_device_capability()
                gpu_name = cls.get_device_name()

                if capability is None:
                    compute_str = "does not have a compute capability"
                else:
                    version_str = capability.as_version_str()
                    compute_str = f"has compute capability {version_str}"

                raise ValueError(
                    "Bfloat16 is only supported on GPUs "
                    "with compute capability of at least 8.0. "
                    f"Your {gpu_name} GPU {compute_str}. "
                    "You can use float16 instead by explicitly setting the "
523
524
                    "`dtype` flag in CLI, for example: --dtype=half."
                )
525
526
527
528

    @classmethod
    def support_hybrid_kv_cache(cls) -> bool:
        return True
529
530
531
532

    @classmethod
    def support_static_graph_mode(cls) -> bool:
        return True