rocm.py 25.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import os
5
from functools import cache, lru_cache, wraps
6
from typing import TYPE_CHECKING
7
8
9

import torch

10
import vllm.envs as envs
11
from vllm.logger import init_logger
12
from vllm.utils.torch_utils import cuda_device_count_stateless
13
from vllm.v1.attention.backends.registry import AttentionBackendEnum
14

15
from .interface import DeviceCapability, Platform, PlatformEnum
16

17
if TYPE_CHECKING:
18
    from vllm.config import VllmConfig
19
    from vllm.v1.attention.selector import AttentionSelectorConfig
20

21
22
logger = init_logger(__name__)

23
try:
24
25
26
27
28
29
30
31
    from amdsmi import (
        AmdSmiException,
        amdsmi_get_gpu_asic_info,
        amdsmi_get_processor_handles,
        amdsmi_init,
        amdsmi_shut_down,
        amdsmi_topo_get_link_type,
    )
32
33
34
except ImportError as e:
    logger.warning("Failed to import from amdsmi with %r", e)

35
36
37
38
39
40
41
42
43
44
45
try:
    import vllm._C  # noqa: F401
except ImportError as e:
    logger.warning("Failed to import from vllm._C with %r", e)

# import custom ops, trigger op registration
try:
    import vllm._rocm_C  # noqa: F401
except ImportError as e:
    logger.warning("Failed to import from vllm._rocm_C with %r", e)

46
# Models not supported by ROCm.
47
_ROCM_UNSUPPORTED_MODELS: list[str] = []
48
49
50

# Models partially supported by ROCm.
# Architecture -> Reason.
51
_ROCM_PARTIALLY_SUPPORTED_MODELS: dict[str, str] = {}
52
_ROCM_DEVICE_ID_NAME_MAP: dict[str, str] = {
53
54
55
    "0x74a0": "AMD_Instinct_MI300A",
    "0x74a1": "AMD_Instinct_MI300X",
    "0x74b5": "AMD_Instinct_MI300X",  # MI300X VF
56
    "0x74a2": "AMD_Instinct_MI308X",
57
58
59
60
    "0x74a5": "AMD_Instinct_MI325X",
    "0x74b9": "AMD_Instinct_MI325X",  # MI325X VF
    "0x74a9": "AMD_Instinct_MI300X_HF",
    "0x74bd": "AMD_Instinct_MI300X_HF",
61
    "0x744c": "AMD_Radeon_RX7900XTX",
62
}
63

64
# Prevent use of clashing `{CUDA/HIP}_VISIBLE_DEVICES`
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
if "HIP_VISIBLE_DEVICES" in os.environ:
    val = os.environ["HIP_VISIBLE_DEVICES"]
    if cuda_val := os.environ.get("CUDA_VISIBLE_DEVICES", None):
        assert val == cuda_val
    else:
        os.environ["CUDA_VISIBLE_DEVICES"] = val

# AMDSMI utils
# Note that NVML is not affected by `{CUDA/HIP}_VISIBLE_DEVICES`,
# all the related functions work on real physical device ids.
# the major benefit of using AMDSMI is that it will not initialize CUDA


def with_amdsmi_context(fn):
    @wraps(fn)
    def wrapper(*args, **kwargs):
        amdsmi_init()
        try:
            return fn(*args, **kwargs)
        finally:
            amdsmi_shut_down()

    return wrapper


90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
@with_amdsmi_context
def _query_gcn_arch_from_amdsmi() -> str:
    """Query GCN arch from amdsmi. Raises if not available."""
    handles = amdsmi_get_processor_handles()
    if handles:
        asic_info = amdsmi_get_gpu_asic_info(handles[0])
        # Use target_graphics_version which contains the gfx name
        # e.g., 'gfx942' for MI300X/MI325X
        target_gfx = asic_info.get("target_graphics_version", "")
        if target_gfx:
            return target_gfx
    raise RuntimeError("amdsmi did not return valid GCN arch")


@cache
def _get_gcn_arch_via_amdsmi() -> str:
    """
    Get the GCN architecture name using amdsmi instead of torch.cuda.
    This avoids initializing CUDA, which is important for Ray workers
    that need to set CUDA_VISIBLE_DEVICES after importing vLLM.
    """
    try:
        return _query_gcn_arch_from_amdsmi()
    except Exception as e:
        logger.debug("Failed to get GCN arch via amdsmi: %s", e)
        logger.warning_once(
            "Failed to get GCN arch via amdsmi, falling back to torch.cuda. "
            "This will initialize CUDA and may cause "
            "issues if CUDA_VISIBLE_DEVICES is not set yet."
        )
    # Ultimate fallback: use torch.cuda (will initialize CUDA)
    return torch.cuda.get_device_properties("cuda").gcnArchName


124
125
@cache
def on_gfx1x() -> bool:
126
    GPU_ARCH = _get_gcn_arch_via_amdsmi()
127
128
129
    return any(arch in GPU_ARCH for arch in ["gfx11", "gfx12"])


130
@cache
131
def on_mi3xx() -> bool:
132
    GPU_ARCH = _get_gcn_arch_via_amdsmi()
133
134
135
136
137
    return any(arch in GPU_ARCH for arch in ["gfx942", "gfx950"])


@cache
def on_gfx9() -> bool:
138
    GPU_ARCH = _get_gcn_arch_via_amdsmi()
139
    return any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"])
140
141


142
143
@cache
def on_gfx942() -> bool:
144
    GPU_ARCH = _get_gcn_arch_via_amdsmi()
145
146
147
    return any(arch in GPU_ARCH for arch in ["gfx942"])


148
149
@cache
def on_gfx950() -> bool:
150
    GPU_ARCH = _get_gcn_arch_via_amdsmi()
151
152
153
    return any(arch in GPU_ARCH for arch in ["gfx950"])


154
@cache
155
def use_rocm_custom_paged_attention(
156
157
158
159
160
161
162
    qtype: torch.dtype,
    head_size: int,
    block_size: int,
    gqa_ratio: int,
    max_seq_len: int,
    sliding_window: int,
    kv_cache_dtype: str,
163
164
    alibi_slopes: torch.Tensor | None = None,
    sinks: torch.Tensor | None = None,
165
) -> bool:
166
    GPU_ARCH = _get_gcn_arch_via_amdsmi()
167
    ON_GFX9 = any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"])
168
    ON_GFX11_GFX12 = any(arch in GPU_ARCH for arch in ["gfx11", "gfx12"])
169

170
171
    # custom paged attn always supported on V0. On V1, requires sliding window
    # disabled due to observed numerical discrepancy.
172
    if ON_GFX9:
173
        return (
174
            (sliding_window == 0 or sliding_window == (-1, -1))
175
176
177
178
179
180
181
182
            and (qtype == torch.half or qtype == torch.bfloat16)
            and (head_size == 64 or head_size == 128)
            and (block_size == 16 or block_size == 32)
            and (gqa_ratio >= 1 and gqa_ratio <= 16)
            and max_seq_len <= 128 * 1024
            and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN)
            and sinks is None
        )
183
184

    else:
185
186
        return (
            ON_GFX11_GFX12
187
            and (sliding_window == 0 or sliding_window == (-1, -1))
188
189
190
191
192
193
194
195
196
197
            and (qtype == torch.half or qtype == torch.bfloat16)
            and head_size == 128
            and block_size == 16
            and (gqa_ratio >= 3 and gqa_ratio <= 16)
            and max_seq_len <= 128 * 1024
            and alibi_slopes is None
            and kv_cache_dtype == "auto"
            and envs.VLLM_ROCM_CUSTOM_PAGED_ATTN
            and sinks is None
        )
198
199


200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
@cache
def flash_attn_triton_available() -> bool:
    if not on_gfx1x():
        return False
    try:
        from importlib.util import find_spec

        if find_spec("flash_attn") is None:
            return False
        if find_spec("flash_attn.flash_attn_triton_amd") is None:
            return False
        if os.environ.get("FLASH_ATTENTION_TRITON_AMD_ENABLE") != "TRUE":
            logger.info_once(
                "Set FLASH_ATTENTION_TRITON_AMD_ENABLE=TRUE to enable "
                "Flash Attention Triton backend on RDNA."
            )
            return False
        return True
    except ImportError:
        return False


222
223
class RocmPlatform(Platform):
    _enum = PlatformEnum.ROCM
224
    device_name: str = "rocm"
225
    device_type: str = "cuda"
226
    dispatch_key: str = "CUDA"
227
    ray_device_key: str = "GPU"
228
    dist_backend: str = "nccl"
229
230
    # rocm shares the same device control env var as CUDA
    device_control_env_var: str = "CUDA_VISIBLE_DEVICES"
231
232
233
234
235
    ray_noset_device_env_vars: list[str] = [
        "RAY_EXPERIMENTAL_NOSET_HIP_VISIBLE_DEVICES",
        "RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES",
        "RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES",
    ]
236

237
    supported_quantization: list[str] = [
238
        "awq",
239
        "awq_marlin",  # will be overwritten with awq
240
        "gptq",
241
        "gptq_marlin",  # will be overwritten with gptq
242
243
244
245
246
247
248
249
250
        "fp8",
        "compressed-tensors",
        "fbgemm_fp8",
        "gguf",
        "quark",
        "ptpc_fp8",
        "mxfp4",
        "petit_nvfp4",
        "torchao",
251
    ]
252
253
254
    # bitsandbytes not supported on gfx9 (warp size 64 limitation)
    if not on_gfx9():
        supported_quantization += ["bitsandbytes"]
255

256
257
258
259
260
261
262
263
264
265
266
    @classmethod
    def import_kernels(cls) -> None:
        """Import ROCm-specific kernels."""
        super().import_kernels()

        import contextlib

        # Import ROCm-specific extension
        with contextlib.suppress(ImportError):
            import vllm._rocm_C  # noqa: F401

267
    @classmethod
268
269
    def get_attn_backend_cls(
        cls,
270
271
        selected_backend: "AttentionBackendEnum",
        attn_selector_config: "AttentionSelectorConfig",
272
    ) -> str:
273
        from vllm._aiter_ops import rocm_aiter_ops
274

275
276
277
278
279
        block_size = attn_selector_config.block_size
        kv_cache_dtype = attn_selector_config.kv_cache_dtype

        if attn_selector_config.use_sparse:
            if kv_cache_dtype and kv_cache_dtype.startswith("fp8"):
280
281
282
283
284
285
                raise ValueError(
                    "ROCMAiterMLASparseBackend doesn't support fp8 kv_cache_dtype."
                )
            assert block_size == 1, (
                "Sparse MLA backend on ROCm only supports block size 1 for now."
            )
286
            logger.info_once("Using Sparse MLA backend.")
287
            return AttentionBackendEnum.ROCM_AITER_MLA_SPARSE.get_path()
288

289
        if attn_selector_config.use_mla:
290
            if selected_backend is None:
291
                selected_backend = (
292
                    AttentionBackendEnum.ROCM_AITER_MLA
293
                    if rocm_aiter_ops.is_mla_enabled() or block_size == 1
294
                    else AttentionBackendEnum.TRITON_MLA
295
                )
296
            if selected_backend == AttentionBackendEnum.TRITON_MLA:
297
                if block_size != 1:
298
                    logger.info_once("Using Triton MLA backend.")
299
                    return AttentionBackendEnum.TRITON_MLA.get_path()
300
301
                raise ValueError(
                    f" The selected backend, {selected_backend.name},"
302
303
                    f"does not support block size {block_size}."
                )
304
            if selected_backend == AttentionBackendEnum.ROCM_AITER_MLA:
305
                logger.info("Using AITER MLA backend.")
306
                return AttentionBackendEnum.ROCM_AITER_MLA.get_path()
307
308
309
            if selected_backend == AttentionBackendEnum.ROCM_AITER_TRITON_MLA:
                logger.info("Using AITER TRITON MLA backend.")
                return AttentionBackendEnum.ROCM_AITER_TRITON_MLA.get_path()
310

311
312
            raise ValueError(
                f" The selected backend, {selected_backend.name},"
313
314
                f"is not MLA type while requested for MLA backend."
            )
315

316
317
318
319
        if selected_backend == AttentionBackendEnum.FLEX_ATTENTION:
            logger.info("Using FlexAttention backend.")
            return AttentionBackendEnum.FLEX_ATTENTION.get_path()

320
        if selected_backend == AttentionBackendEnum.TRITON_ATTN:
321
            logger.info("Using Triton Attention backend.")
322
323
324
            return AttentionBackendEnum.TRITON_ATTN.get_path()

        if selected_backend == AttentionBackendEnum.ROCM_ATTN:
325
            logger.info("Using Rocm Attention backend.")
326
            return AttentionBackendEnum.ROCM_ATTN.get_path()
327
328
329

        if selected_backend == AttentionBackendEnum.ROCM_AITER_FA:
            if on_gfx9():
330
                logger.info("Using Aiter Flash Attention backend.")
331
332
333
334
335
336
337
338
                return AttentionBackendEnum.ROCM_AITER_FA.get_path()
            else:
                raise ValueError(
                    f"The selected backend, {selected_backend.name}, "
                    "is only supported on gfx9 architectures."
                )

        if selected_backend == AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN:
339
            logger.info("Using Aiter Unified Attention backend.")
340
341
342
343
344
345
            return AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN.get_path()

        # Handle automatic backend selection based on environment variables
        if selected_backend is None:
            # Priority 1: Check for AITER Unified Attention (must check before MHA)
            if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION:
346
                logger.info("Using Aiter Unified Attention backend.")
347
348
349
350
351
                return AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN.get_path()

            # Priority 2: Check for AITER MHA (Flash Attention)
            # Only use if explicitly enabled (not just VLLM_ROCM_USE_AITER=1)
            if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9():
352
                logger.info("Using Aiter Flash Attention backend.")
353
354
355
                return AttentionBackendEnum.ROCM_AITER_FA.get_path()

            # Priority 3: Check for ROCM_ATTN (prefill-decode split)
356
            from vllm.config import get_current_vllm_config_or_none
357

358
359
360
361
362
            vllm_config = get_current_vllm_config_or_none()
            if (
                vllm_config is not None
                and vllm_config.attention_config.use_prefill_decode_attention
            ):
363
                logger.info("Using Rocm Attention backend.")
364
365
366
367
368
369
370
371
372
                return AttentionBackendEnum.ROCM_ATTN.get_path()

            # Priority 4: Check for AITER enabled without specific flags
            # This defaults to AITER FA only if MHA is not explicitly disabled
            if (
                envs.VLLM_ROCM_USE_AITER
                and on_gfx9()
                and envs.VLLM_ROCM_USE_AITER_MHA is not False
            ):
373
                logger.info("Using Aiter Flash Attention backend.")
374
375
376
                return AttentionBackendEnum.ROCM_AITER_FA.get_path()

            # Default: Triton Unified Attention
377
            logger.info("Using Triton Attention backend.")
378
379
380
            return AttentionBackendEnum.TRITON_ATTN.get_path()

        raise RuntimeError(
381
382
            f"Attention backend {selected_backend.name} is not supported on "
            "ROCm. Note that V0 attention backends have been removed."
383
        )
384

385
386
387
388
389
390
391
392
393
394
395
396
397
    @classmethod
    def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]:
        return [
            AttentionBackendEnum.FLASH_ATTN,
            AttentionBackendEnum.ROCM_AITER_FA,
            AttentionBackendEnum.TORCH_SDPA,
        ]

    @classmethod
    def get_vit_attn_backend(
        cls,
        head_size: int,
        dtype: torch.dtype,
398
        backend: "AttentionBackendEnum | None" = None,
399
400
401
402
403
404
405
406
407
408
409
410
411
    ) -> "AttentionBackendEnum":
        if backend is not None:
            assert backend in cls.get_supported_vit_attn_backends(), (
                f"Backend {backend} is not supported for vit attention. "
                f"Supported backends are: {cls.get_supported_vit_attn_backends()}"
            )
            logger.info_once(f"Using backend {backend} for vit attention")
            return backend

        from importlib.util import find_spec

        from vllm._aiter_ops import rocm_aiter_ops

412
        if rocm_aiter_ops.is_enabled() and on_gfx9():
413
            logger.info_once("Using AITER Flash Attention backend for ViT model.")
414
415
            return AttentionBackendEnum.ROCM_AITER_FA

416
417
418
419
420
421
        if (
            on_gfx9()
            and find_spec("flash_attn") is not None
            and (dtype == torch.float16 or dtype == torch.bfloat16)
        ):
            logger.info_once("Using Flash Attention backend for ViT model.")
422
423
            return AttentionBackendEnum.FLASH_ATTN

424
425
426
427
428
429
430
431
432
433
434
        # RDNA3/RDNA4 (gfx11xx/gfx12xx): Use Flash Attention Triton backend
        if (
            on_gfx1x()
            and flash_attn_triton_available()
            and (dtype == torch.float16 or dtype == torch.bfloat16)
        ):
            logger.info_once(
                "Using Flash Attention (Triton backend) for ViT model on RDNA."
            )
            return AttentionBackendEnum.FLASH_ATTN

435
        logger.info_once("Using Torch SDPA backend for ViT model.")
436
437
        return AttentionBackendEnum.TORCH_SDPA

438
439
440
441
442
443
444
    @classmethod
    def set_device(cls, device: torch.device) -> None:
        """
        Set the device for the current platform.
        """
        torch.cuda.set_device(device)

445
    @classmethod
446
    @lru_cache(maxsize=8)
447
    def get_device_capability(cls, device_id: int = 0) -> DeviceCapability | None:
448
449
        major, minor = torch.cuda.get_device_capability(device_id)
        return DeviceCapability(major=major, minor=minor)
450

451
    @classmethod
452
    @with_amdsmi_context
453
    def is_fully_connected(cls, physical_device_ids: list[int]) -> bool:
454
455
456
        """
        Query if the set of gpus are fully connected by xgmi (1 hop)
        """
457
        handles = [amdsmi_get_processor_handles()[i] for i in physical_device_ids]
458
459
460
461
        for i, handle in enumerate(handles):
            for j, peer_handle in enumerate(handles):
                if i < j:
                    try:
462
                        link_type = amdsmi_topo_get_link_type(handle, peer_handle)
463
464
465
466
                        # type is 2 for XGMI
                        if link_type["hops"] != 1 or link_type["type"] != 2:
                            return False
                    except AmdSmiException as error:
467
                        logger.error("AMD 1 hop XGMI detection failed.", exc_info=error)
468
469
470
                        return False
        return True

471
    @classmethod
472
    @with_amdsmi_context
473
    @lru_cache(maxsize=8)
474
    def get_device_name(cls, device_id: int = 0) -> str:
475
        physical_device_id = cls.device_id_to_physical_device_id(device_id)
476
        handle = amdsmi_get_processor_handles()[physical_device_id]
477
478
479
480
481
        asic_info = amdsmi_get_gpu_asic_info(handle)
        device_name: str = asic_info["device_id"]
        if device_name in _ROCM_DEVICE_ID_NAME_MAP:
            return _ROCM_DEVICE_ID_NAME_MAP[device_name]
        return asic_info["market_name"]
482
483
484
485
486

    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        device_props = torch.cuda.get_device_properties(device_id)
        return device_props.total_memory
487
488

    @classmethod
489
    def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
490
        from vllm._aiter_ops import rocm_aiter_ops
491
492
        from vllm.config.compilation import CUDAGraphMode

493
        cache_config = vllm_config.cache_config
494
495
496
        compilation_config = vllm_config.compilation_config
        parallel_config = vllm_config.parallel_config
        is_eager_execution = compilation_config == CUDAGraphMode.NONE
497
        use_aiter_fused_moe = rocm_aiter_ops.is_fused_moe_enabled()
vllmellm's avatar
vllmellm committed
498
        use_aiter_rms_norm = rocm_aiter_ops.is_rmsnorm_enabled()
499
        use_aiter_fp8_linear = rocm_aiter_ops.is_linear_fp8_enabled()
500
        use_aiter_fused_se = rocm_aiter_ops.is_fusion_moe_shared_experts_enabled()
501

502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
        if compilation_config.cudagraph_mode.has_full_cudagraphs():
            # decode context parallel does not support full cudagraphs
            if parallel_config.decode_context_parallel_size > 1:
                logger.warning_once(
                    "Decode context parallel (DCP) is enabled, which is "
                    "incompatible with full CUDA graphs. "
                    "Overriding cudagraph_mode to PIECEWISE."
                )
                compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
            # prefill context parallel do not support full cudagraphs
            elif parallel_config.prefill_context_parallel_size > 1:
                logger.warning_once(
                    "Prefill context parallel (PCP) is enabled, which is "
                    "incompatible with full CUDA graphs. "
                    "Overriding cudagraph_mode to PIECEWISE."
                )
                compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE

520
        if cache_config and cache_config.block_size is None:
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
            if (
                envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION and envs.VLLM_ROCM_USE_AITER
                # NOTE: This block has been deprecated
                # or get_env_variable_attn_backend()
                # == AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN
                # TODO: monitor https://github.com/vllm-project/vllm/pull/30396
                # to see how we can transition to the new way of selecting
                # attention backends
            ):
                cache_config.block_size = 64
                logger.warning(
                    "[ROCM_AITER_UNIFIED_ATTN]: Setting kv cache block size to 64."
                )
            else:
                cache_config.block_size = 16
536

537
        if parallel_config.worker_cls == "auto":
538
            parallel_config.worker_cls = "vllm.v1.worker.gpu_worker.Worker"
539
        #  Aiter rms norm perform best when CUDA Graph capture is enabled.
540
        if (
541
            use_aiter_rms_norm
542
543
544
            and not is_eager_execution
            and "-rms_norm" not in compilation_config.custom_ops
        ):
545
            compilation_config.custom_ops.append("+rms_norm")
546

vllmellm's avatar
vllmellm committed
547
548
549
        if use_aiter_fp8_linear and "-quant_fp8" not in compilation_config.custom_ops:
            compilation_config.custom_ops.append("+quant_fp8")

550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
        if use_aiter_fused_se and "-grouped_topk" in compilation_config.custom_ops:
            logger.warning_once(
                "VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS is enabled, which "
                "requires the 'grouped_topk' custom op. Overriding the "
                "user-provided '-grouped_topk'."
            )
            compilation_config.custom_ops.remove("-grouped_topk")
        # Ensure grouped_topk is always enabled when using AITER if
        # its not disabled by user
        if (
            use_aiter_fused_moe
            and "+grouped_topk" not in compilation_config.custom_ops
            and "-grouped_topk" not in compilation_config.custom_ops
        ):
            compilation_config.custom_ops.append("+grouped_topk")

566
567
568
        # Default dispatch to rocm's sparse_attn_indexer implementation
        compilation_config.custom_ops.append("+sparse_attn_indexer")

569
570
571
    @classmethod
    def verify_model_arch(cls, model_arch: str) -> None:
        if model_arch in _ROCM_UNSUPPORTED_MODELS:
572
573
574
            raise ValueError(
                f"Model architecture '{model_arch}' is not supported by ROCm for now."
            )
575
576
577
578

        if model_arch in _ROCM_PARTIALLY_SUPPORTED_MODELS:
            msg = _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch]
            logger.warning(
579
580
581
582
                "Model architecture '%s' is partially supported by ROCm: %s",
                model_arch,
                msg,
            )
583

584
585
586
587
588
589
    @classmethod
    def verify_quantization(cls, quant: str) -> None:
        super().verify_quantization(quant)
        if quant == "awq" and not envs.VLLM_USE_TRITON_AWQ:
            logger.warning(
                "Using AWQ quantization with ROCm, but VLLM_USE_TRITON_AWQ"
590
591
                " is not set, enabling VLLM_USE_TRITON_AWQ."
            )
592
        os.environ["VLLM_USE_TRITON_AWQ"] = "1"
593
594
595
596

    @classmethod
    def get_punica_wrapper(cls) -> str:
        return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU"
597
598

    @classmethod
599
    def get_current_memory_usage(
600
        cls, device: torch.types.Device | None = None
601
    ) -> float:
602
        torch.cuda.reset_peak_memory_stats(device)
603
604
        free_mem, total_mem = torch.cuda.mem_get_info(device)
        return total_mem - free_mem
605
606
607

    @classmethod
    def get_device_communicator_cls(cls) -> str:
608
609
610
        return (
            "vllm.distributed.device_communicators.cuda_communicator.CudaCommunicator"  # noqa
        )
611

612
613
614
615
616
    @classmethod
    def supports_mx(cls) -> bool:
        gcn_arch = torch.cuda.get_device_properties(0).gcnArchName
        return any(gfx in gcn_arch for gfx in ["gfx95"])

617
618
619
    @classmethod
    def supports_fp8(cls) -> bool:
        gcn_arch = torch.cuda.get_device_properties(0).gcnArchName
620
        return any(gfx in gcn_arch for gfx in ["gfx94", "gfx95", "gfx12"])
621
622
623
624

    @classmethod
    def is_fp8_fnuz(cls) -> bool:
        # only device 0 is checked, this assumes MI300 platforms are homogeneous
625
        return "gfx94" in torch.cuda.get_device_properties(0).gcnArchName
626
627
628
629
630
631
632

    @classmethod
    def fp8_dtype(cls) -> torch.dtype:
        if cls.is_fp8_fnuz():
            return torch.float8_e4m3fnuz
        else:
            return torch.float8_e4m3fn
633

634
635
636
637
    @classmethod
    def use_custom_allreduce(cls) -> bool:
        # We only enable custom allreduce for MI300 series
        gcn_arch = torch.cuda.get_device_properties(0).gcnArchName
638
        supported_archs = ["gfx94", "gfx95"]
639
        return any(gfx in gcn_arch for gfx in supported_archs)
640

641
642
643
644
    @classmethod
    def opaque_attention_op(cls) -> bool:
        return True

645
646
    @classmethod
    def is_navi(cls) -> bool:
647
        return "gfx1" in torch.cuda.get_device_properties(0).gcnArchName
648
649

    @classmethod
650
651
    def get_static_graph_wrapper_cls(cls) -> str:
        return "vllm.compilation.cuda_graph.CUDAGraphWrapper"
652

653
654
655
    @classmethod
    def device_count(cls) -> int:
        return cuda_device_count_stateless()
656

657
    @classmethod
658
659
    def check_if_supports_dtype(cls, dtype: torch.dtype):
        if dtype == torch.bfloat16:  # noqa: SIM102
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
            if not cls.has_device_capability(80):
                capability = cls.get_device_capability()
                gpu_name = cls.get_device_name()

                if capability is None:
                    compute_str = "does not have a compute capability"
                else:
                    version_str = capability.as_version_str()
                    compute_str = f"has compute capability {version_str}"

                raise ValueError(
                    "Bfloat16 is only supported on GPUs "
                    "with compute capability of at least 8.0. "
                    f"Your {gpu_name} GPU {compute_str}. "
                    "You can use float16 instead by explicitly setting the "
675
676
                    "`dtype` flag in CLI, for example: --dtype=half."
                )
677
678
679
680

    @classmethod
    def support_hybrid_kv_cache(cls) -> bool:
        return True
681
682
683
684

    @classmethod
    def support_static_graph_mode(cls) -> bool:
        return True