rocm.py 26.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import os
5
from datetime import timedelta
6
from functools import cache, lru_cache, wraps
7
from typing import TYPE_CHECKING
8
9

import torch
10
11
from torch.distributed import PrefixStore, ProcessGroup
from torch.distributed.distributed_c10d import is_nccl_available
12

13
import vllm.envs as envs
14
from vllm.logger import init_logger
15
from vllm.utils.torch_utils import cuda_device_count_stateless
16
from vllm.v1.attention.backends.registry import AttentionBackendEnum
17

18
from .interface import DeviceCapability, Platform, PlatformEnum
19

20
if TYPE_CHECKING:
21
    from vllm.config import VllmConfig
22
    from vllm.v1.attention.selector import AttentionSelectorConfig
23

24
25
logger = init_logger(__name__)

26
try:
27
28
29
30
31
32
33
34
    from amdsmi import (
        AmdSmiException,
        amdsmi_get_gpu_asic_info,
        amdsmi_get_processor_handles,
        amdsmi_init,
        amdsmi_shut_down,
        amdsmi_topo_get_link_type,
    )
35
36
37
except ImportError as e:
    logger.warning("Failed to import from amdsmi with %r", e)

38
39
40
41
42
43
44
45
46
47
48
try:
    import vllm._C  # noqa: F401
except ImportError as e:
    logger.warning("Failed to import from vllm._C with %r", e)

# import custom ops, trigger op registration
try:
    import vllm._rocm_C  # noqa: F401
except ImportError as e:
    logger.warning("Failed to import from vllm._rocm_C with %r", e)

49
# Models not supported by ROCm.
50
_ROCM_UNSUPPORTED_MODELS: list[str] = []
51
52
53

# Models partially supported by ROCm.
# Architecture -> Reason.
54
_ROCM_PARTIALLY_SUPPORTED_MODELS: dict[str, str] = {}
55
_ROCM_DEVICE_ID_NAME_MAP: dict[str, str] = {
56
57
58
    "0x74a0": "AMD_Instinct_MI300A",
    "0x74a1": "AMD_Instinct_MI300X",
    "0x74b5": "AMD_Instinct_MI300X",  # MI300X VF
59
    "0x74a2": "AMD_Instinct_MI308X",
60
61
62
63
    "0x74a5": "AMD_Instinct_MI325X",
    "0x74b9": "AMD_Instinct_MI325X",  # MI325X VF
    "0x74a9": "AMD_Instinct_MI300X_HF",
    "0x74bd": "AMD_Instinct_MI300X_HF",
64
    "0x744c": "AMD_Radeon_RX7900XTX",
65
}
66

67
# Prevent use of clashing `{CUDA/HIP}_VISIBLE_DEVICES`
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
if "HIP_VISIBLE_DEVICES" in os.environ:
    val = os.environ["HIP_VISIBLE_DEVICES"]
    if cuda_val := os.environ.get("CUDA_VISIBLE_DEVICES", None):
        assert val == cuda_val
    else:
        os.environ["CUDA_VISIBLE_DEVICES"] = val

# AMDSMI utils
# Note that NVML is not affected by `{CUDA/HIP}_VISIBLE_DEVICES`,
# all the related functions work on real physical device ids.
# the major benefit of using AMDSMI is that it will not initialize CUDA


def with_amdsmi_context(fn):
    @wraps(fn)
    def wrapper(*args, **kwargs):
        amdsmi_init()
        try:
            return fn(*args, **kwargs)
        finally:
            amdsmi_shut_down()

    return wrapper


93
94
95
96
97
98
99
100
101
102
103
104
105
106
@with_amdsmi_context
def _query_gcn_arch_from_amdsmi() -> str:
    """Query GCN arch from amdsmi. Raises if not available."""
    handles = amdsmi_get_processor_handles()
    if handles:
        asic_info = amdsmi_get_gpu_asic_info(handles[0])
        # Use target_graphics_version which contains the gfx name
        # e.g., 'gfx942' for MI300X/MI325X
        target_gfx = asic_info.get("target_graphics_version", "")
        if target_gfx:
            return target_gfx
    raise RuntimeError("amdsmi did not return valid GCN arch")


107
def _get_gcn_arch() -> str:
108
    """
109
110
    Get GCN arch via amdsmi (no CUDA init), fallback to torch.cuda.
    Called once at module level; result stored in _GCN_ARCH.
111
112
113
114
115
116
117
118
119
120
121
122
123
124
    """
    try:
        return _query_gcn_arch_from_amdsmi()
    except Exception as e:
        logger.debug("Failed to get GCN arch via amdsmi: %s", e)
        logger.warning_once(
            "Failed to get GCN arch via amdsmi, falling back to torch.cuda. "
            "This will initialize CUDA and may cause "
            "issues if CUDA_VISIBLE_DEVICES is not set yet."
        )
    # Ultimate fallback: use torch.cuda (will initialize CUDA)
    return torch.cuda.get_device_properties("cuda").gcnArchName


125
126
127
128
129
130
131
132
133
134
135
136
# Resolve once at module load. Uses amdsmi (no CUDA init) so Ray workers
# can still set CUDA_VISIBLE_DEVICES after import.
# These are plain Python bools — fully torch.compile/Dynamo safe.
_GCN_ARCH = _get_gcn_arch()

_ON_GFX1X = any(arch in _GCN_ARCH for arch in ["gfx11", "gfx12"])
_ON_MI3XX = any(arch in _GCN_ARCH for arch in ["gfx942", "gfx950"])
_ON_GFX9 = any(arch in _GCN_ARCH for arch in ["gfx90a", "gfx942", "gfx950"])
_ON_GFX942 = "gfx942" in _GCN_ARCH
_ON_GFX950 = "gfx950" in _GCN_ARCH


137
def on_gfx1x() -> bool:
138
    return _ON_GFX1X
139
140


141
def on_mi3xx() -> bool:
142
    return _ON_MI3XX
143
144
145


def on_gfx9() -> bool:
146
    return _ON_GFX9
147
148


149
def on_gfx942() -> bool:
150
    return _ON_GFX942
151
152


153
def on_gfx950() -> bool:
154
    return _ON_GFX950
155
156


157
@cache
158
def use_rocm_custom_paged_attention(
159
160
161
162
163
164
165
    qtype: torch.dtype,
    head_size: int,
    block_size: int,
    gqa_ratio: int,
    max_seq_len: int,
    sliding_window: int,
    kv_cache_dtype: str,
166
167
    alibi_slopes: torch.Tensor | None = None,
    sinks: torch.Tensor | None = None,
168
) -> bool:
169
170
    # custom paged attn always supported on V0. On V1, requires sliding window
    # disabled due to observed numerical discrepancy.
171
    if _ON_GFX9:
172
        return (
173
            (sliding_window == 0 or sliding_window == (-1, -1))
174
175
176
177
178
179
180
181
            and (qtype == torch.half or qtype == torch.bfloat16)
            and (head_size == 64 or head_size == 128)
            and (block_size == 16 or block_size == 32)
            and (gqa_ratio >= 1 and gqa_ratio <= 16)
            and max_seq_len <= 128 * 1024
            and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN)
            and sinks is None
        )
182
183

    else:
184
        return (
185
            _ON_GFX1X
186
            and (sliding_window == 0 or sliding_window == (-1, -1))
187
188
189
190
191
192
193
194
195
196
            and (qtype == torch.half or qtype == torch.bfloat16)
            and head_size == 128
            and block_size == 16
            and (gqa_ratio >= 3 and gqa_ratio <= 16)
            and max_seq_len <= 128 * 1024
            and alibi_slopes is None
            and kv_cache_dtype == "auto"
            and envs.VLLM_ROCM_CUSTOM_PAGED_ATTN
            and sinks is None
        )
197
198


199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
@cache
def flash_attn_triton_available() -> bool:
    if not on_gfx1x():
        return False
    try:
        from importlib.util import find_spec

        if find_spec("flash_attn") is None:
            return False
        if find_spec("flash_attn.flash_attn_triton_amd") is None:
            return False
        if os.environ.get("FLASH_ATTENTION_TRITON_AMD_ENABLE") != "TRUE":
            logger.info_once(
                "Set FLASH_ATTENTION_TRITON_AMD_ENABLE=TRUE to enable "
                "Flash Attention Triton backend on RDNA."
            )
            return False
        return True
    except ImportError:
        return False


221
222
class RocmPlatform(Platform):
    _enum = PlatformEnum.ROCM
223
    device_name: str = "rocm"
224
    device_type: str = "cuda"
225
    dispatch_key: str = "CUDA"
226
    ray_device_key: str = "GPU"
227
    dist_backend: str = "nccl"
228
229
    # rocm shares the same device control env var as CUDA
    device_control_env_var: str = "CUDA_VISIBLE_DEVICES"
230
231
232
233
234
    ray_noset_device_env_vars: list[str] = [
        "RAY_EXPERIMENTAL_NOSET_HIP_VISIBLE_DEVICES",
        "RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES",
        "RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES",
    ]
235

236
    supported_quantization: list[str] = [
237
        "awq",
238
        "awq_marlin",  # will be overwritten with awq
239
        "gptq",
240
        "gptq_marlin",  # will be overwritten with gptq
241
242
243
244
245
246
247
248
249
        "fp8",
        "compressed-tensors",
        "fbgemm_fp8",
        "gguf",
        "quark",
        "ptpc_fp8",
        "mxfp4",
        "petit_nvfp4",
        "torchao",
250
        "bitsandbytes",
251
    ]
252

253
254
255
256
257
258
259
260
261
262
263
    @classmethod
    def import_kernels(cls) -> None:
        """Import ROCm-specific kernels."""
        super().import_kernels()

        import contextlib

        # Import ROCm-specific extension
        with contextlib.suppress(ImportError):
            import vllm._rocm_C  # noqa: F401

264
    @classmethod
265
266
    def get_attn_backend_cls(
        cls,
267
268
        selected_backend: "AttentionBackendEnum",
        attn_selector_config: "AttentionSelectorConfig",
269
        num_heads: int | None = None,
270
    ) -> str:
271
        from vllm._aiter_ops import rocm_aiter_ops
272

273
274
275
276
277
        block_size = attn_selector_config.block_size
        kv_cache_dtype = attn_selector_config.kv_cache_dtype

        if attn_selector_config.use_sparse:
            if kv_cache_dtype and kv_cache_dtype.startswith("fp8"):
278
279
280
281
282
283
                raise ValueError(
                    "ROCMAiterMLASparseBackend doesn't support fp8 kv_cache_dtype."
                )
            assert block_size == 1, (
                "Sparse MLA backend on ROCm only supports block size 1 for now."
            )
284
            logger.info_once("Using Sparse MLA backend.")
285
            return AttentionBackendEnum.ROCM_AITER_MLA_SPARSE.get_path()
286

287
        if attn_selector_config.use_mla:
288
            if selected_backend is None:
289
                selected_backend = (
290
                    AttentionBackendEnum.ROCM_AITER_MLA
291
                    if rocm_aiter_ops.is_mla_enabled() or block_size == 1
292
                    else AttentionBackendEnum.TRITON_MLA
293
                )
294
            if selected_backend == AttentionBackendEnum.TRITON_MLA:
295
                if block_size != 1:
296
                    logger.info_once("Using Triton MLA backend.")
297
                    return AttentionBackendEnum.TRITON_MLA.get_path()
298
299
                raise ValueError(
                    f" The selected backend, {selected_backend.name},"
300
301
                    f"does not support block size {block_size}."
                )
302
            if selected_backend == AttentionBackendEnum.ROCM_AITER_MLA:
303
                logger.info("Using AITER MLA backend.")
304
                return AttentionBackendEnum.ROCM_AITER_MLA.get_path()
305
306
307
            if selected_backend == AttentionBackendEnum.ROCM_AITER_TRITON_MLA:
                logger.info("Using AITER TRITON MLA backend.")
                return AttentionBackendEnum.ROCM_AITER_TRITON_MLA.get_path()
308

309
310
            raise ValueError(
                f" The selected backend, {selected_backend.name},"
311
312
                f"is not MLA type while requested for MLA backend."
            )
313

314
315
316
317
        if selected_backend == AttentionBackendEnum.FLEX_ATTENTION:
            logger.info("Using FlexAttention backend.")
            return AttentionBackendEnum.FLEX_ATTENTION.get_path()

318
        if selected_backend == AttentionBackendEnum.TRITON_ATTN:
319
            logger.info("Using Triton Attention backend.")
320
321
322
            return AttentionBackendEnum.TRITON_ATTN.get_path()

        if selected_backend == AttentionBackendEnum.ROCM_ATTN:
323
            logger.info("Using Rocm Attention backend.")
324
            return AttentionBackendEnum.ROCM_ATTN.get_path()
325
326
327

        if selected_backend == AttentionBackendEnum.ROCM_AITER_FA:
            if on_gfx9():
328
                logger.info("Using Aiter Flash Attention backend.")
329
330
331
332
333
334
335
336
                return AttentionBackendEnum.ROCM_AITER_FA.get_path()
            else:
                raise ValueError(
                    f"The selected backend, {selected_backend.name}, "
                    "is only supported on gfx9 architectures."
                )

        if selected_backend == AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN:
337
            logger.info("Using Aiter Unified Attention backend.")
338
339
340
341
342
343
            return AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN.get_path()

        # Handle automatic backend selection based on environment variables
        if selected_backend is None:
            # Priority 1: Check for AITER Unified Attention (must check before MHA)
            if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION:
344
                logger.info("Using Aiter Unified Attention backend.")
345
346
347
348
349
                return AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN.get_path()

            # Priority 2: Check for AITER MHA (Flash Attention)
            # Only use if explicitly enabled (not just VLLM_ROCM_USE_AITER=1)
            if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9():
350
                logger.info("Using Aiter Flash Attention backend.")
351
352
353
                return AttentionBackendEnum.ROCM_AITER_FA.get_path()

            # Priority 3: Check for ROCM_ATTN (prefill-decode split)
354
            from vllm.config import get_current_vllm_config_or_none
355

356
357
358
359
360
            vllm_config = get_current_vllm_config_or_none()
            if (
                vllm_config is not None
                and vllm_config.attention_config.use_prefill_decode_attention
            ):
361
                logger.info("Using Rocm Attention backend.")
362
363
364
365
366
367
368
369
370
                return AttentionBackendEnum.ROCM_ATTN.get_path()

            # Priority 4: Check for AITER enabled without specific flags
            # This defaults to AITER FA only if MHA is not explicitly disabled
            if (
                envs.VLLM_ROCM_USE_AITER
                and on_gfx9()
                and envs.VLLM_ROCM_USE_AITER_MHA is not False
            ):
371
                logger.info("Using Aiter Flash Attention backend.")
372
373
374
                return AttentionBackendEnum.ROCM_AITER_FA.get_path()

            # Default: Triton Unified Attention
375
            logger.info("Using Triton Attention backend.")
376
377
378
            return AttentionBackendEnum.TRITON_ATTN.get_path()

        raise RuntimeError(
379
380
            f"Attention backend {selected_backend.name} is not supported on "
            "ROCm. Note that V0 attention backends have been removed."
381
        )
382

383
384
385
386
387
    @classmethod
    def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]:
        return [
            AttentionBackendEnum.FLASH_ATTN,
            AttentionBackendEnum.ROCM_AITER_FA,
388
            AttentionBackendEnum.TRITON_ATTN,
389
390
391
392
393
394
395
396
            AttentionBackendEnum.TORCH_SDPA,
        ]

    @classmethod
    def get_vit_attn_backend(
        cls,
        head_size: int,
        dtype: torch.dtype,
397
        backend: "AttentionBackendEnum | None" = None,
398
399
400
401
402
403
404
405
406
407
408
409
410
    ) -> "AttentionBackendEnum":
        if backend is not None:
            assert backend in cls.get_supported_vit_attn_backends(), (
                f"Backend {backend} is not supported for vit attention. "
                f"Supported backends are: {cls.get_supported_vit_attn_backends()}"
            )
            logger.info_once(f"Using backend {backend} for vit attention")
            return backend

        from importlib.util import find_spec

        from vllm._aiter_ops import rocm_aiter_ops

411
        if rocm_aiter_ops.is_enabled() and on_gfx9():
412
            logger.info_once("Using AITER Flash Attention backend for ViT model.")
413
414
            return AttentionBackendEnum.ROCM_AITER_FA

415
416
417
418
419
420
        if (
            on_gfx9()
            and find_spec("flash_attn") is not None
            and (dtype == torch.float16 or dtype == torch.bfloat16)
        ):
            logger.info_once("Using Flash Attention backend for ViT model.")
421
422
            return AttentionBackendEnum.FLASH_ATTN

423
424
425
426
427
428
429
430
431
432
433
        # RDNA3/RDNA4 (gfx11xx/gfx12xx): Use Flash Attention Triton backend
        if (
            on_gfx1x()
            and flash_attn_triton_available()
            and (dtype == torch.float16 or dtype == torch.bfloat16)
        ):
            logger.info_once(
                "Using Flash Attention (Triton backend) for ViT model on RDNA."
            )
            return AttentionBackendEnum.FLASH_ATTN

434
        logger.info_once("Using Torch SDPA backend for ViT model.")
435
436
        return AttentionBackendEnum.TORCH_SDPA

437
438
439
440
441
442
443
    @classmethod
    def set_device(cls, device: torch.device) -> None:
        """
        Set the device for the current platform.
        """
        torch.cuda.set_device(device)

444
    @classmethod
445
    @lru_cache(maxsize=8)
446
    def get_device_capability(cls, device_id: int = 0) -> DeviceCapability | None:
447
448
        major, minor = torch.cuda.get_device_capability(device_id)
        return DeviceCapability(major=major, minor=minor)
449

450
    @classmethod
451
    @with_amdsmi_context
452
    def is_fully_connected(cls, physical_device_ids: list[int]) -> bool:
453
454
455
        """
        Query if the set of gpus are fully connected by xgmi (1 hop)
        """
456
        handles = [amdsmi_get_processor_handles()[i] for i in physical_device_ids]
457
458
459
460
        for i, handle in enumerate(handles):
            for j, peer_handle in enumerate(handles):
                if i < j:
                    try:
461
                        link_type = amdsmi_topo_get_link_type(handle, peer_handle)
462
463
464
465
                        # type is 2 for XGMI
                        if link_type["hops"] != 1 or link_type["type"] != 2:
                            return False
                    except AmdSmiException as error:
466
                        logger.error("AMD 1 hop XGMI detection failed.", exc_info=error)
467
468
469
                        return False
        return True

470
    @classmethod
471
    @with_amdsmi_context
472
    @lru_cache(maxsize=8)
473
    def get_device_name(cls, device_id: int = 0) -> str:
474
        physical_device_id = cls.device_id_to_physical_device_id(device_id)
475
        handle = amdsmi_get_processor_handles()[physical_device_id]
476
477
478
479
480
        asic_info = amdsmi_get_gpu_asic_info(handle)
        device_name: str = asic_info["device_id"]
        if device_name in _ROCM_DEVICE_ID_NAME_MAP:
            return _ROCM_DEVICE_ID_NAME_MAP[device_name]
        return asic_info["market_name"]
481
482
483
484
485

    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        device_props = torch.cuda.get_device_properties(device_id)
        return device_props.total_memory
486
487

    @classmethod
488
    def apply_config_platform_defaults(cls, vllm_config: "VllmConfig") -> None:
489
        from vllm._aiter_ops import rocm_aiter_ops
490
491
492
        from vllm.config.compilation import CUDAGraphMode

        compilation_config = vllm_config.compilation_config
493
        is_eager_execution = compilation_config.cudagraph_mode == CUDAGraphMode.NONE
494
        use_aiter_fused_moe = rocm_aiter_ops.is_fused_moe_enabled()
vllmellm's avatar
vllmellm committed
495
        use_aiter_rms_norm = rocm_aiter_ops.is_rmsnorm_enabled()
496
        use_aiter_fp8_linear = rocm_aiter_ops.is_linear_fp8_enabled()
497
        use_aiter_fused_se = rocm_aiter_ops.is_fusion_moe_shared_experts_enabled()
498
        use_aiter_triton_rope = rocm_aiter_ops.is_triton_rotary_embed_enabled()
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
        #  Aiter rms norm perform best when CUDA Graph capture is enabled.
        if (
            use_aiter_rms_norm
            and not is_eager_execution
            and "-rms_norm" not in compilation_config.custom_ops
        ):
            compilation_config.custom_ops.append("+rms_norm")

        if use_aiter_fp8_linear and "-quant_fp8" not in compilation_config.custom_ops:
            compilation_config.custom_ops.append("+quant_fp8")

        if use_aiter_fused_se and "-grouped_topk" in compilation_config.custom_ops:
            logger.warning_once(
                "VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS is enabled, which "
                "requires the 'grouped_topk' custom op. Overriding the "
                "user-provided '-grouped_topk'."
            )
            compilation_config.custom_ops.remove("-grouped_topk")
        # Ensure grouped_topk is always enabled when using AITER if
        # its not disabled by user
        if (
            use_aiter_fused_moe
            and "+grouped_topk" not in compilation_config.custom_ops
            and "-grouped_topk" not in compilation_config.custom_ops
        ):
            compilation_config.custom_ops.append("+grouped_topk")
        # Enable rotary embedding when using AITER if its not disabled by user
        if (
            use_aiter_triton_rope
            and "+rotary_embedding" not in compilation_config.custom_ops
            and "-rotary_embedding" not in compilation_config.custom_ops
        ):
            compilation_config.custom_ops.append("+rotary_embedding")

        # Default dispatch to rocm's sparse_attn_indexer implementation
        compilation_config.custom_ops.append("+sparse_attn_indexer")

    @classmethod
    def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
        from vllm.config.compilation import CUDAGraphMode

        cache_config = vllm_config.cache_config
        compilation_config = vllm_config.compilation_config
        parallel_config = vllm_config.parallel_config
543

544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
        if compilation_config.cudagraph_mode.has_full_cudagraphs():
            # decode context parallel does not support full cudagraphs
            if parallel_config.decode_context_parallel_size > 1:
                logger.warning_once(
                    "Decode context parallel (DCP) is enabled, which is "
                    "incompatible with full CUDA graphs. "
                    "Overriding cudagraph_mode to PIECEWISE."
                )
                compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
            # prefill context parallel do not support full cudagraphs
            elif parallel_config.prefill_context_parallel_size > 1:
                logger.warning_once(
                    "Prefill context parallel (PCP) is enabled, which is "
                    "incompatible with full CUDA graphs. "
                    "Overriding cudagraph_mode to PIECEWISE."
                )
                compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE

562
        if cache_config and cache_config.block_size is None:
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
            if (
                envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION and envs.VLLM_ROCM_USE_AITER
                # NOTE: This block has been deprecated
                # or get_env_variable_attn_backend()
                # == AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN
                # TODO: monitor https://github.com/vllm-project/vllm/pull/30396
                # to see how we can transition to the new way of selecting
                # attention backends
            ):
                cache_config.block_size = 64
                logger.warning(
                    "[ROCM_AITER_UNIFIED_ATTN]: Setting kv cache block size to 64."
                )
            else:
                cache_config.block_size = 16
578

579
        if parallel_config.worker_cls == "auto":
580
            parallel_config.worker_cls = "vllm.v1.worker.gpu_worker.Worker"
581

582
583
584
    @classmethod
    def verify_model_arch(cls, model_arch: str) -> None:
        if model_arch in _ROCM_UNSUPPORTED_MODELS:
585
586
587
            raise ValueError(
                f"Model architecture '{model_arch}' is not supported by ROCm for now."
            )
588
589
590
591

        if model_arch in _ROCM_PARTIALLY_SUPPORTED_MODELS:
            msg = _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch]
            logger.warning(
592
593
594
595
                "Model architecture '%s' is partially supported by ROCm: %s",
                model_arch,
                msg,
            )
596

597
598
599
600
601
602
    @classmethod
    def verify_quantization(cls, quant: str) -> None:
        super().verify_quantization(quant)
        if quant == "awq" and not envs.VLLM_USE_TRITON_AWQ:
            logger.warning(
                "Using AWQ quantization with ROCm, but VLLM_USE_TRITON_AWQ"
603
604
                " is not set, enabling VLLM_USE_TRITON_AWQ."
            )
605
        os.environ["VLLM_USE_TRITON_AWQ"] = "1"
606
607
608
609

    @classmethod
    def get_punica_wrapper(cls) -> str:
        return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU"
610
611

    @classmethod
612
    def get_current_memory_usage(
613
        cls, device: torch.types.Device | None = None
614
    ) -> float:
615
        torch.cuda.reset_peak_memory_stats(device)
616
617
        free_mem, total_mem = torch.cuda.mem_get_info(device)
        return total_mem - free_mem
618
619
620

    @classmethod
    def get_device_communicator_cls(cls) -> str:
621
622
623
        return (
            "vllm.distributed.device_communicators.cuda_communicator.CudaCommunicator"  # noqa
        )
624

625
626
    @classmethod
    def supports_mx(cls) -> bool:
627
        return any(gfx in _GCN_ARCH for gfx in ["gfx95"])
628

629
630
    @classmethod
    def supports_fp8(cls) -> bool:
631
        return any(gfx in _GCN_ARCH for gfx in ["gfx94", "gfx95", "gfx12"])
632
633
634
635

    @classmethod
    def is_fp8_fnuz(cls) -> bool:
        # only device 0 is checked, this assumes MI300 platforms are homogeneous
636
        return "gfx94" in _GCN_ARCH
637
638
639
640
641
642
643

    @classmethod
    def fp8_dtype(cls) -> torch.dtype:
        if cls.is_fp8_fnuz():
            return torch.float8_e4m3fnuz
        else:
            return torch.float8_e4m3fn
644

645
646
647
    @classmethod
    def use_custom_allreduce(cls) -> bool:
        # We only enable custom allreduce for MI300 series
648
        return any(gfx in _GCN_ARCH for gfx in ["gfx94", "gfx95"])
649

650
651
652
653
    @classmethod
    def opaque_attention_op(cls) -> bool:
        return True

654
655
    @classmethod
    def is_navi(cls) -> bool:
656
        return "gfx1" in _GCN_ARCH
657
658

    @classmethod
659
660
    def get_static_graph_wrapper_cls(cls) -> str:
        return "vllm.compilation.cuda_graph.CUDAGraphWrapper"
661

662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
    @classmethod
    def stateless_init_device_torch_dist_pg(
        cls,
        backend: str,
        prefix_store: PrefixStore,
        group_rank: int,
        group_size: int,
        timeout: timedelta,
    ) -> ProcessGroup:
        assert is_nccl_available()
        pg: ProcessGroup = ProcessGroup(
            prefix_store,
            group_rank,
            group_size,
        )
        from torch.distributed.distributed_c10d import ProcessGroupNCCL

        backend_options = ProcessGroupNCCL.Options()
        backend_options._timeout = timeout

        backend_class = ProcessGroupNCCL(
            prefix_store, group_rank, group_size, backend_options
        )
        backend_type = ProcessGroup.BackendType.NCCL
        device = torch.device("cuda")
        pg._set_default_backend(backend_type)
        backend_class._set_sequence_number_for_group()

        pg._register_backend(device, backend_type, backend_class)
        return pg

693
694
695
    @classmethod
    def device_count(cls) -> int:
        return cuda_device_count_stateless()
696

697
    @classmethod
698
699
    def check_if_supports_dtype(cls, dtype: torch.dtype):
        if dtype == torch.bfloat16:  # noqa: SIM102
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
            if not cls.has_device_capability(80):
                capability = cls.get_device_capability()
                gpu_name = cls.get_device_name()

                if capability is None:
                    compute_str = "does not have a compute capability"
                else:
                    version_str = capability.as_version_str()
                    compute_str = f"has compute capability {version_str}"

                raise ValueError(
                    "Bfloat16 is only supported on GPUs "
                    "with compute capability of at least 8.0. "
                    f"Your {gpu_name} GPU {compute_str}. "
                    "You can use float16 instead by explicitly setting the "
715
716
                    "`dtype` flag in CLI, for example: --dtype=half."
                )
717
718
719
720

    @classmethod
    def support_hybrid_kv_cache(cls) -> bool:
        return True
721
722
723
724

    @classmethod
    def support_static_graph_mode(cls) -> bool:
        return True
725
726
727
728

    @classmethod
    def num_compute_units(cls, device_id=0):
        return torch.cuda.get_device_properties(device_id).multi_processor_count