rocm.py 25.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import os
5
from functools import cache, lru_cache, wraps
6
from typing import TYPE_CHECKING
7
8
9

import torch

10
import vllm.envs as envs
11
from vllm.logger import init_logger
12
from vllm.utils.torch_utils import cuda_device_count_stateless
13
from vllm.v1.attention.backends.registry import AttentionBackendEnum
14

15
from .interface import DeviceCapability, Platform, PlatformEnum
16

17
if TYPE_CHECKING:
18
    from vllm.config import VllmConfig
19
    from vllm.v1.attention.selector import AttentionSelectorConfig
20

21
22
logger = init_logger(__name__)

23
try:
24
25
26
27
28
29
30
31
    from amdsmi import (
        AmdSmiException,
        amdsmi_get_gpu_asic_info,
        amdsmi_get_processor_handles,
        amdsmi_init,
        amdsmi_shut_down,
        amdsmi_topo_get_link_type,
    )
32
33
34
except ImportError as e:
    logger.warning("Failed to import from amdsmi with %r", e)

35
36
37
38
39
40
41
42
43
44
45
try:
    import vllm._C  # noqa: F401
except ImportError as e:
    logger.warning("Failed to import from vllm._C with %r", e)

# import custom ops, trigger op registration
try:
    import vllm._rocm_C  # noqa: F401
except ImportError as e:
    logger.warning("Failed to import from vllm._rocm_C with %r", e)

46
# Models not supported by ROCm.
47
_ROCM_UNSUPPORTED_MODELS: list[str] = []
48
49
50

# Models partially supported by ROCm.
# Architecture -> Reason.
51
_ROCM_PARTIALLY_SUPPORTED_MODELS: dict[str, str] = {}
52
_ROCM_DEVICE_ID_NAME_MAP: dict[str, str] = {
53
54
55
    "0x74a0": "AMD_Instinct_MI300A",
    "0x74a1": "AMD_Instinct_MI300X",
    "0x74b5": "AMD_Instinct_MI300X",  # MI300X VF
56
    "0x74a2": "AMD_Instinct_MI308X",
57
58
59
60
    "0x74a5": "AMD_Instinct_MI325X",
    "0x74b9": "AMD_Instinct_MI325X",  # MI325X VF
    "0x74a9": "AMD_Instinct_MI300X_HF",
    "0x74bd": "AMD_Instinct_MI300X_HF",
61
    "0x744c": "AMD_Radeon_RX7900XTX",
62
}
63

64
# Prevent use of clashing `{CUDA/HIP}_VISIBLE_DEVICES`
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
if "HIP_VISIBLE_DEVICES" in os.environ:
    val = os.environ["HIP_VISIBLE_DEVICES"]
    if cuda_val := os.environ.get("CUDA_VISIBLE_DEVICES", None):
        assert val == cuda_val
    else:
        os.environ["CUDA_VISIBLE_DEVICES"] = val

# AMDSMI utils
# Note that NVML is not affected by `{CUDA/HIP}_VISIBLE_DEVICES`,
# all the related functions work on real physical device ids.
# the major benefit of using AMDSMI is that it will not initialize CUDA


def with_amdsmi_context(fn):
    @wraps(fn)
    def wrapper(*args, **kwargs):
        amdsmi_init()
        try:
            return fn(*args, **kwargs)
        finally:
            amdsmi_shut_down()

    return wrapper


90
91
92
93
94
95
96
97
98
99
100
101
102
103
@with_amdsmi_context
def _query_gcn_arch_from_amdsmi() -> str:
    """Query GCN arch from amdsmi. Raises if not available."""
    handles = amdsmi_get_processor_handles()
    if handles:
        asic_info = amdsmi_get_gpu_asic_info(handles[0])
        # Use target_graphics_version which contains the gfx name
        # e.g., 'gfx942' for MI300X/MI325X
        target_gfx = asic_info.get("target_graphics_version", "")
        if target_gfx:
            return target_gfx
    raise RuntimeError("amdsmi did not return valid GCN arch")


104
def _get_gcn_arch() -> str:
105
    """
106
107
    Get GCN arch via amdsmi (no CUDA init), fallback to torch.cuda.
    Called once at module level; result stored in _GCN_ARCH.
108
109
110
111
112
113
114
115
116
117
118
119
120
121
    """
    try:
        return _query_gcn_arch_from_amdsmi()
    except Exception as e:
        logger.debug("Failed to get GCN arch via amdsmi: %s", e)
        logger.warning_once(
            "Failed to get GCN arch via amdsmi, falling back to torch.cuda. "
            "This will initialize CUDA and may cause "
            "issues if CUDA_VISIBLE_DEVICES is not set yet."
        )
    # Ultimate fallback: use torch.cuda (will initialize CUDA)
    return torch.cuda.get_device_properties("cuda").gcnArchName


122
123
124
125
126
127
128
129
130
131
132
133
# Resolve once at module load. Uses amdsmi (no CUDA init) so Ray workers
# can still set CUDA_VISIBLE_DEVICES after import.
# These are plain Python bools — fully torch.compile/Dynamo safe.
_GCN_ARCH = _get_gcn_arch()

_ON_GFX1X = any(arch in _GCN_ARCH for arch in ["gfx11", "gfx12"])
_ON_MI3XX = any(arch in _GCN_ARCH for arch in ["gfx942", "gfx950"])
_ON_GFX9 = any(arch in _GCN_ARCH for arch in ["gfx90a", "gfx942", "gfx950"])
_ON_GFX942 = "gfx942" in _GCN_ARCH
_ON_GFX950 = "gfx950" in _GCN_ARCH


134
def on_gfx1x() -> bool:
135
    return _ON_GFX1X
136
137


138
def on_mi3xx() -> bool:
139
    return _ON_MI3XX
140
141
142


def on_gfx9() -> bool:
143
    return _ON_GFX9
144
145


146
def on_gfx942() -> bool:
147
    return _ON_GFX942
148
149


150
def on_gfx950() -> bool:
151
    return _ON_GFX950
152
153


154
@cache
155
def use_rocm_custom_paged_attention(
156
157
158
159
160
161
162
    qtype: torch.dtype,
    head_size: int,
    block_size: int,
    gqa_ratio: int,
    max_seq_len: int,
    sliding_window: int,
    kv_cache_dtype: str,
163
164
    alibi_slopes: torch.Tensor | None = None,
    sinks: torch.Tensor | None = None,
165
) -> bool:
166
167
    # custom paged attn always supported on V0. On V1, requires sliding window
    # disabled due to observed numerical discrepancy.
168
    if _ON_GFX9:
169
        return (
170
            (sliding_window == 0 or sliding_window == (-1, -1))
171
172
173
174
175
176
177
178
            and (qtype == torch.half or qtype == torch.bfloat16)
            and (head_size == 64 or head_size == 128)
            and (block_size == 16 or block_size == 32)
            and (gqa_ratio >= 1 and gqa_ratio <= 16)
            and max_seq_len <= 128 * 1024
            and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN)
            and sinks is None
        )
179
180

    else:
181
        return (
182
            _ON_GFX1X
183
            and (sliding_window == 0 or sliding_window == (-1, -1))
184
185
186
187
188
189
190
191
192
193
            and (qtype == torch.half or qtype == torch.bfloat16)
            and head_size == 128
            and block_size == 16
            and (gqa_ratio >= 3 and gqa_ratio <= 16)
            and max_seq_len <= 128 * 1024
            and alibi_slopes is None
            and kv_cache_dtype == "auto"
            and envs.VLLM_ROCM_CUSTOM_PAGED_ATTN
            and sinks is None
        )
194
195


196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
@cache
def flash_attn_triton_available() -> bool:
    if not on_gfx1x():
        return False
    try:
        from importlib.util import find_spec

        if find_spec("flash_attn") is None:
            return False
        if find_spec("flash_attn.flash_attn_triton_amd") is None:
            return False
        if os.environ.get("FLASH_ATTENTION_TRITON_AMD_ENABLE") != "TRUE":
            logger.info_once(
                "Set FLASH_ATTENTION_TRITON_AMD_ENABLE=TRUE to enable "
                "Flash Attention Triton backend on RDNA."
            )
            return False
        return True
    except ImportError:
        return False


218
219
class RocmPlatform(Platform):
    _enum = PlatformEnum.ROCM
220
    device_name: str = "rocm"
221
    device_type: str = "cuda"
222
    dispatch_key: str = "CUDA"
223
    ray_device_key: str = "GPU"
224
    dist_backend: str = "nccl"
225
226
    # rocm shares the same device control env var as CUDA
    device_control_env_var: str = "CUDA_VISIBLE_DEVICES"
227
228
229
230
231
    ray_noset_device_env_vars: list[str] = [
        "RAY_EXPERIMENTAL_NOSET_HIP_VISIBLE_DEVICES",
        "RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES",
        "RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES",
    ]
232

233
    supported_quantization: list[str] = [
234
        "awq",
235
        "awq_marlin",  # will be overwritten with awq
236
        "gptq",
237
        "gptq_marlin",  # will be overwritten with gptq
238
239
240
241
242
243
244
245
246
        "fp8",
        "compressed-tensors",
        "fbgemm_fp8",
        "gguf",
        "quark",
        "ptpc_fp8",
        "mxfp4",
        "petit_nvfp4",
        "torchao",
247
        "bitsandbytes",
248
    ]
249

250
251
252
253
254
255
256
257
258
259
260
    @classmethod
    def import_kernels(cls) -> None:
        """Import ROCm-specific kernels."""
        super().import_kernels()

        import contextlib

        # Import ROCm-specific extension
        with contextlib.suppress(ImportError):
            import vllm._rocm_C  # noqa: F401

261
    @classmethod
262
263
    def get_attn_backend_cls(
        cls,
264
265
        selected_backend: "AttentionBackendEnum",
        attn_selector_config: "AttentionSelectorConfig",
266
        num_heads: int | None = None,
267
    ) -> str:
268
        from vllm._aiter_ops import rocm_aiter_ops
269

270
271
272
273
274
        block_size = attn_selector_config.block_size
        kv_cache_dtype = attn_selector_config.kv_cache_dtype

        if attn_selector_config.use_sparse:
            if kv_cache_dtype and kv_cache_dtype.startswith("fp8"):
275
276
277
278
279
280
                raise ValueError(
                    "ROCMAiterMLASparseBackend doesn't support fp8 kv_cache_dtype."
                )
            assert block_size == 1, (
                "Sparse MLA backend on ROCm only supports block size 1 for now."
            )
281
            logger.info_once("Using Sparse MLA backend.")
282
            return AttentionBackendEnum.ROCM_AITER_MLA_SPARSE.get_path()
283

284
        if attn_selector_config.use_mla:
285
            if selected_backend is None:
286
                selected_backend = (
287
                    AttentionBackendEnum.ROCM_AITER_MLA
288
                    if rocm_aiter_ops.is_mla_enabled() or block_size == 1
289
                    else AttentionBackendEnum.TRITON_MLA
290
                )
291
            if selected_backend == AttentionBackendEnum.TRITON_MLA:
292
                if block_size != 1:
293
                    logger.info_once("Using Triton MLA backend.")
294
                    return AttentionBackendEnum.TRITON_MLA.get_path()
295
296
                raise ValueError(
                    f" The selected backend, {selected_backend.name},"
297
298
                    f"does not support block size {block_size}."
                )
299
            if selected_backend == AttentionBackendEnum.ROCM_AITER_MLA:
300
                logger.info("Using AITER MLA backend.")
301
                return AttentionBackendEnum.ROCM_AITER_MLA.get_path()
302
303
304
            if selected_backend == AttentionBackendEnum.ROCM_AITER_TRITON_MLA:
                logger.info("Using AITER TRITON MLA backend.")
                return AttentionBackendEnum.ROCM_AITER_TRITON_MLA.get_path()
305

306
307
            raise ValueError(
                f" The selected backend, {selected_backend.name},"
308
309
                f"is not MLA type while requested for MLA backend."
            )
310

311
312
313
314
        if selected_backend == AttentionBackendEnum.FLEX_ATTENTION:
            logger.info("Using FlexAttention backend.")
            return AttentionBackendEnum.FLEX_ATTENTION.get_path()

315
        if selected_backend == AttentionBackendEnum.TRITON_ATTN:
316
            logger.info("Using Triton Attention backend.")
317
318
319
            return AttentionBackendEnum.TRITON_ATTN.get_path()

        if selected_backend == AttentionBackendEnum.ROCM_ATTN:
320
            logger.info("Using Rocm Attention backend.")
321
            return AttentionBackendEnum.ROCM_ATTN.get_path()
322
323
324

        if selected_backend == AttentionBackendEnum.ROCM_AITER_FA:
            if on_gfx9():
325
                logger.info("Using Aiter Flash Attention backend.")
326
327
328
329
330
331
332
333
                return AttentionBackendEnum.ROCM_AITER_FA.get_path()
            else:
                raise ValueError(
                    f"The selected backend, {selected_backend.name}, "
                    "is only supported on gfx9 architectures."
                )

        if selected_backend == AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN:
334
            logger.info("Using Aiter Unified Attention backend.")
335
336
337
338
339
340
            return AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN.get_path()

        # Handle automatic backend selection based on environment variables
        if selected_backend is None:
            # Priority 1: Check for AITER Unified Attention (must check before MHA)
            if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION:
341
                logger.info("Using Aiter Unified Attention backend.")
342
343
344
345
346
                return AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN.get_path()

            # Priority 2: Check for AITER MHA (Flash Attention)
            # Only use if explicitly enabled (not just VLLM_ROCM_USE_AITER=1)
            if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9():
347
                logger.info("Using Aiter Flash Attention backend.")
348
349
350
                return AttentionBackendEnum.ROCM_AITER_FA.get_path()

            # Priority 3: Check for ROCM_ATTN (prefill-decode split)
351
            from vllm.config import get_current_vllm_config_or_none
352

353
354
355
356
357
            vllm_config = get_current_vllm_config_or_none()
            if (
                vllm_config is not None
                and vllm_config.attention_config.use_prefill_decode_attention
            ):
358
                logger.info("Using Rocm Attention backend.")
359
360
361
362
363
364
365
366
367
                return AttentionBackendEnum.ROCM_ATTN.get_path()

            # Priority 4: Check for AITER enabled without specific flags
            # This defaults to AITER FA only if MHA is not explicitly disabled
            if (
                envs.VLLM_ROCM_USE_AITER
                and on_gfx9()
                and envs.VLLM_ROCM_USE_AITER_MHA is not False
            ):
368
                logger.info("Using Aiter Flash Attention backend.")
369
370
371
                return AttentionBackendEnum.ROCM_AITER_FA.get_path()

            # Default: Triton Unified Attention
372
            logger.info("Using Triton Attention backend.")
373
374
375
            return AttentionBackendEnum.TRITON_ATTN.get_path()

        raise RuntimeError(
376
377
            f"Attention backend {selected_backend.name} is not supported on "
            "ROCm. Note that V0 attention backends have been removed."
378
        )
379

380
381
382
383
384
    @classmethod
    def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]:
        return [
            AttentionBackendEnum.FLASH_ATTN,
            AttentionBackendEnum.ROCM_AITER_FA,
385
            AttentionBackendEnum.TRITON_ATTN,
386
387
388
389
390
391
392
393
            AttentionBackendEnum.TORCH_SDPA,
        ]

    @classmethod
    def get_vit_attn_backend(
        cls,
        head_size: int,
        dtype: torch.dtype,
394
        backend: "AttentionBackendEnum | None" = None,
395
396
397
398
399
400
401
402
403
404
405
406
407
    ) -> "AttentionBackendEnum":
        if backend is not None:
            assert backend in cls.get_supported_vit_attn_backends(), (
                f"Backend {backend} is not supported for vit attention. "
                f"Supported backends are: {cls.get_supported_vit_attn_backends()}"
            )
            logger.info_once(f"Using backend {backend} for vit attention")
            return backend

        from importlib.util import find_spec

        from vllm._aiter_ops import rocm_aiter_ops

408
        if rocm_aiter_ops.is_enabled() and on_gfx9():
409
            logger.info_once("Using AITER Flash Attention backend for ViT model.")
410
411
            return AttentionBackendEnum.ROCM_AITER_FA

412
413
414
415
416
417
        if (
            on_gfx9()
            and find_spec("flash_attn") is not None
            and (dtype == torch.float16 or dtype == torch.bfloat16)
        ):
            logger.info_once("Using Flash Attention backend for ViT model.")
418
419
            return AttentionBackendEnum.FLASH_ATTN

420
421
422
423
424
425
426
427
428
429
430
        # RDNA3/RDNA4 (gfx11xx/gfx12xx): Use Flash Attention Triton backend
        if (
            on_gfx1x()
            and flash_attn_triton_available()
            and (dtype == torch.float16 or dtype == torch.bfloat16)
        ):
            logger.info_once(
                "Using Flash Attention (Triton backend) for ViT model on RDNA."
            )
            return AttentionBackendEnum.FLASH_ATTN

431
        logger.info_once("Using Torch SDPA backend for ViT model.")
432
433
        return AttentionBackendEnum.TORCH_SDPA

434
435
436
437
438
439
440
    @classmethod
    def set_device(cls, device: torch.device) -> None:
        """
        Set the device for the current platform.
        """
        torch.cuda.set_device(device)

441
    @classmethod
442
    @lru_cache(maxsize=8)
443
    def get_device_capability(cls, device_id: int = 0) -> DeviceCapability | None:
444
445
        major, minor = torch.cuda.get_device_capability(device_id)
        return DeviceCapability(major=major, minor=minor)
446

447
    @classmethod
448
    @with_amdsmi_context
449
    def is_fully_connected(cls, physical_device_ids: list[int]) -> bool:
450
451
452
        """
        Query if the set of gpus are fully connected by xgmi (1 hop)
        """
453
        handles = [amdsmi_get_processor_handles()[i] for i in physical_device_ids]
454
455
456
457
        for i, handle in enumerate(handles):
            for j, peer_handle in enumerate(handles):
                if i < j:
                    try:
458
                        link_type = amdsmi_topo_get_link_type(handle, peer_handle)
459
460
461
462
                        # type is 2 for XGMI
                        if link_type["hops"] != 1 or link_type["type"] != 2:
                            return False
                    except AmdSmiException as error:
463
                        logger.error("AMD 1 hop XGMI detection failed.", exc_info=error)
464
465
466
                        return False
        return True

467
    @classmethod
468
    @with_amdsmi_context
469
    @lru_cache(maxsize=8)
470
    def get_device_name(cls, device_id: int = 0) -> str:
471
        physical_device_id = cls.device_id_to_physical_device_id(device_id)
472
        handle = amdsmi_get_processor_handles()[physical_device_id]
473
474
475
476
477
        asic_info = amdsmi_get_gpu_asic_info(handle)
        device_name: str = asic_info["device_id"]
        if device_name in _ROCM_DEVICE_ID_NAME_MAP:
            return _ROCM_DEVICE_ID_NAME_MAP[device_name]
        return asic_info["market_name"]
478
479
480
481
482

    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        device_props = torch.cuda.get_device_properties(device_id)
        return device_props.total_memory
483
484

    @classmethod
485
    def apply_config_platform_defaults(cls, vllm_config: "VllmConfig") -> None:
486
        from vllm._aiter_ops import rocm_aiter_ops
487
488
489
        from vllm.config.compilation import CUDAGraphMode

        compilation_config = vllm_config.compilation_config
490
        is_eager_execution = compilation_config.cudagraph_mode == CUDAGraphMode.NONE
491
        use_aiter_fused_moe = rocm_aiter_ops.is_fused_moe_enabled()
vllmellm's avatar
vllmellm committed
492
        use_aiter_rms_norm = rocm_aiter_ops.is_rmsnorm_enabled()
493
        use_aiter_fp8_linear = rocm_aiter_ops.is_linear_fp8_enabled()
494
        use_aiter_fused_se = rocm_aiter_ops.is_fusion_moe_shared_experts_enabled()
495
        use_aiter_triton_rope = rocm_aiter_ops.is_triton_rotary_embed_enabled()
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
        #  Aiter rms norm perform best when CUDA Graph capture is enabled.
        if (
            use_aiter_rms_norm
            and not is_eager_execution
            and "-rms_norm" not in compilation_config.custom_ops
        ):
            compilation_config.custom_ops.append("+rms_norm")

        if use_aiter_fp8_linear and "-quant_fp8" not in compilation_config.custom_ops:
            compilation_config.custom_ops.append("+quant_fp8")

        if use_aiter_fused_se and "-grouped_topk" in compilation_config.custom_ops:
            logger.warning_once(
                "VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS is enabled, which "
                "requires the 'grouped_topk' custom op. Overriding the "
                "user-provided '-grouped_topk'."
            )
            compilation_config.custom_ops.remove("-grouped_topk")
        # Ensure grouped_topk is always enabled when using AITER if
        # its not disabled by user
        if (
            use_aiter_fused_moe
            and "+grouped_topk" not in compilation_config.custom_ops
            and "-grouped_topk" not in compilation_config.custom_ops
        ):
            compilation_config.custom_ops.append("+grouped_topk")
        # Enable rotary embedding when using AITER if its not disabled by user
        if (
            use_aiter_triton_rope
            and "+rotary_embedding" not in compilation_config.custom_ops
            and "-rotary_embedding" not in compilation_config.custom_ops
        ):
            compilation_config.custom_ops.append("+rotary_embedding")

        # Default dispatch to rocm's sparse_attn_indexer implementation
        compilation_config.custom_ops.append("+sparse_attn_indexer")

    @classmethod
    def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
        from vllm.config.compilation import CUDAGraphMode

        cache_config = vllm_config.cache_config
        compilation_config = vllm_config.compilation_config
        parallel_config = vllm_config.parallel_config
540

541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
        if compilation_config.cudagraph_mode.has_full_cudagraphs():
            # decode context parallel does not support full cudagraphs
            if parallel_config.decode_context_parallel_size > 1:
                logger.warning_once(
                    "Decode context parallel (DCP) is enabled, which is "
                    "incompatible with full CUDA graphs. "
                    "Overriding cudagraph_mode to PIECEWISE."
                )
                compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
            # prefill context parallel do not support full cudagraphs
            elif parallel_config.prefill_context_parallel_size > 1:
                logger.warning_once(
                    "Prefill context parallel (PCP) is enabled, which is "
                    "incompatible with full CUDA graphs. "
                    "Overriding cudagraph_mode to PIECEWISE."
                )
                compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE

559
        if cache_config and cache_config.block_size is None:
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
            if (
                envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION and envs.VLLM_ROCM_USE_AITER
                # NOTE: This block has been deprecated
                # or get_env_variable_attn_backend()
                # == AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN
                # TODO: monitor https://github.com/vllm-project/vllm/pull/30396
                # to see how we can transition to the new way of selecting
                # attention backends
            ):
                cache_config.block_size = 64
                logger.warning(
                    "[ROCM_AITER_UNIFIED_ATTN]: Setting kv cache block size to 64."
                )
            else:
                cache_config.block_size = 16
575

576
        if parallel_config.worker_cls == "auto":
577
            parallel_config.worker_cls = "vllm.v1.worker.gpu_worker.Worker"
578

579
580
581
    @classmethod
    def verify_model_arch(cls, model_arch: str) -> None:
        if model_arch in _ROCM_UNSUPPORTED_MODELS:
582
583
584
            raise ValueError(
                f"Model architecture '{model_arch}' is not supported by ROCm for now."
            )
585
586
587
588

        if model_arch in _ROCM_PARTIALLY_SUPPORTED_MODELS:
            msg = _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch]
            logger.warning(
589
590
591
592
                "Model architecture '%s' is partially supported by ROCm: %s",
                model_arch,
                msg,
            )
593

594
595
596
597
598
599
    @classmethod
    def verify_quantization(cls, quant: str) -> None:
        super().verify_quantization(quant)
        if quant == "awq" and not envs.VLLM_USE_TRITON_AWQ:
            logger.warning(
                "Using AWQ quantization with ROCm, but VLLM_USE_TRITON_AWQ"
600
601
                " is not set, enabling VLLM_USE_TRITON_AWQ."
            )
602
        os.environ["VLLM_USE_TRITON_AWQ"] = "1"
603
604
605
606

    @classmethod
    def get_punica_wrapper(cls) -> str:
        return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU"
607
608

    @classmethod
609
    def get_current_memory_usage(
610
        cls, device: torch.types.Device | None = None
611
    ) -> float:
612
        torch.cuda.reset_peak_memory_stats(device)
613
614
        free_mem, total_mem = torch.cuda.mem_get_info(device)
        return total_mem - free_mem
615
616
617

    @classmethod
    def get_device_communicator_cls(cls) -> str:
618
619
620
        return (
            "vllm.distributed.device_communicators.cuda_communicator.CudaCommunicator"  # noqa
        )
621

622
623
    @classmethod
    def supports_mx(cls) -> bool:
624
        return any(gfx in _GCN_ARCH for gfx in ["gfx95"])
625

626
627
    @classmethod
    def supports_fp8(cls) -> bool:
628
        return any(gfx in _GCN_ARCH for gfx in ["gfx94", "gfx95", "gfx12"])
629
630
631
632

    @classmethod
    def is_fp8_fnuz(cls) -> bool:
        # only device 0 is checked, this assumes MI300 platforms are homogeneous
633
        return "gfx94" in _GCN_ARCH
634
635
636
637
638
639
640

    @classmethod
    def fp8_dtype(cls) -> torch.dtype:
        if cls.is_fp8_fnuz():
            return torch.float8_e4m3fnuz
        else:
            return torch.float8_e4m3fn
641

642
643
644
    @classmethod
    def use_custom_allreduce(cls) -> bool:
        # We only enable custom allreduce for MI300 series
645
        return any(gfx in _GCN_ARCH for gfx in ["gfx94", "gfx95"])
646

647
648
649
650
    @classmethod
    def opaque_attention_op(cls) -> bool:
        return True

651
652
    @classmethod
    def is_navi(cls) -> bool:
653
        return "gfx1" in _GCN_ARCH
654
655

    @classmethod
656
657
    def get_static_graph_wrapper_cls(cls) -> str:
        return "vllm.compilation.cuda_graph.CUDAGraphWrapper"
658

659
660
661
    @classmethod
    def device_count(cls) -> int:
        return cuda_device_count_stateless()
662

663
    @classmethod
664
665
    def check_if_supports_dtype(cls, dtype: torch.dtype):
        if dtype == torch.bfloat16:  # noqa: SIM102
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
            if not cls.has_device_capability(80):
                capability = cls.get_device_capability()
                gpu_name = cls.get_device_name()

                if capability is None:
                    compute_str = "does not have a compute capability"
                else:
                    version_str = capability.as_version_str()
                    compute_str = f"has compute capability {version_str}"

                raise ValueError(
                    "Bfloat16 is only supported on GPUs "
                    "with compute capability of at least 8.0. "
                    f"Your {gpu_name} GPU {compute_str}. "
                    "You can use float16 instead by explicitly setting the "
681
682
                    "`dtype` flag in CLI, for example: --dtype=half."
                )
683
684
685
686

    @classmethod
    def support_hybrid_kv_cache(cls) -> bool:
        return True
687
688
689
690

    @classmethod
    def support_static_graph_mode(cls) -> bool:
        return True
691
692
693
694

    @classmethod
    def num_compute_units(cls, device_id=0):
        return torch.cuda.get_device_properties(device_id).multi_processor_count