"requirements/common.txt" did not exist on "4a6769053ab2616f7f490e6ec5b8241e76ef0c2a"
rocm.py 30 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import os
5
from datetime import timedelta
6
from functools import cache, lru_cache, wraps
7
from typing import TYPE_CHECKING
8

9
import regex as re
10
import torch
11
12
from torch.distributed import PrefixStore, ProcessGroup
from torch.distributed.distributed_c10d import is_nccl_available
13

14
import vllm.envs as envs
15
from vllm.logger import init_logger
16
from vllm.utils.torch_utils import cuda_device_count_stateless
17
from vllm.v1.attention.backends.registry import AttentionBackendEnum
18

19
from .interface import DeviceCapability, Platform, PlatformEnum
20

21
if TYPE_CHECKING:
22
    from vllm.config import VllmConfig
23
    from vllm.v1.attention.selector import AttentionSelectorConfig
24

25
26
logger = init_logger(__name__)

27
try:
28
29
30
31
32
33
34
35
    from amdsmi import (
        AmdSmiException,
        amdsmi_get_gpu_asic_info,
        amdsmi_get_processor_handles,
        amdsmi_init,
        amdsmi_shut_down,
        amdsmi_topo_get_link_type,
    )
36
37
38
except ImportError as e:
    logger.warning("Failed to import from amdsmi with %r", e)

39
40
41
42
43
44
45
46
47
48
49
try:
    import vllm._C  # noqa: F401
except ImportError as e:
    logger.warning("Failed to import from vllm._C with %r", e)

# import custom ops, trigger op registration
try:
    import vllm._rocm_C  # noqa: F401
except ImportError as e:
    logger.warning("Failed to import from vllm._rocm_C with %r", e)

50
# Models not supported by ROCm.
51
_ROCM_UNSUPPORTED_MODELS: list[str] = []
52
53
54

# Models partially supported by ROCm.
# Architecture -> Reason.
55
_ROCM_PARTIALLY_SUPPORTED_MODELS: dict[str, str] = {}
56
_ROCM_DEVICE_ID_NAME_MAP: dict[str, str] = {
57
58
59
    "0x74a0": "AMD_Instinct_MI300A",
    "0x74a1": "AMD_Instinct_MI300X",
    "0x74b5": "AMD_Instinct_MI300X",  # MI300X VF
60
    "0x74a2": "AMD_Instinct_MI308X",
61
62
63
64
    "0x74a5": "AMD_Instinct_MI325X",
    "0x74b9": "AMD_Instinct_MI325X",  # MI325X VF
    "0x74a9": "AMD_Instinct_MI300X_HF",
    "0x74bd": "AMD_Instinct_MI300X_HF",
65
    "0x744c": "AMD_Radeon_RX7900XTX",
66
}
67

68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90

def _sync_hip_cuda_env_vars():
    """Ensure HIP_VISIBLE_DEVICES and CUDA_VISIBLE_DEVICES are consistent.
    Treats empty string as unset. Raises on genuine conflicts."""
    hip_val = os.environ.get("HIP_VISIBLE_DEVICES") or None
    cuda_val = os.environ.get("CUDA_VISIBLE_DEVICES") or None

    if hip_val is not None and cuda_val is not None:
        if hip_val != cuda_val:
            raise ValueError(
                f"Inconsistent GPU visibility env vars: "
                f"HIP_VISIBLE_DEVICES='{hip_val}' vs "
                f"CUDA_VISIBLE_DEVICES='{cuda_val}'. "
                f"Please set only one, or ensure they match."
            )
    elif hip_val is not None:
        os.environ["CUDA_VISIBLE_DEVICES"] = hip_val
    elif cuda_val is not None:
        os.environ["HIP_VISIBLE_DEVICES"] = cuda_val


# Sync at import time - catches misconfigurations from process start.
_sync_hip_cuda_env_vars()
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109

# AMDSMI utils
# Note that NVML is not affected by `{CUDA/HIP}_VISIBLE_DEVICES`,
# all the related functions work on real physical device ids.
# the major benefit of using AMDSMI is that it will not initialize CUDA


def with_amdsmi_context(fn):
    @wraps(fn)
    def wrapper(*args, **kwargs):
        amdsmi_init()
        try:
            return fn(*args, **kwargs)
        finally:
            amdsmi_shut_down()

    return wrapper


110
111
112
113
114
115
116
117
118
119
120
121
122
123
@with_amdsmi_context
def _query_gcn_arch_from_amdsmi() -> str:
    """Query GCN arch from amdsmi. Raises if not available."""
    handles = amdsmi_get_processor_handles()
    if handles:
        asic_info = amdsmi_get_gpu_asic_info(handles[0])
        # Use target_graphics_version which contains the gfx name
        # e.g., 'gfx942' for MI300X/MI325X
        target_gfx = asic_info.get("target_graphics_version", "")
        if target_gfx:
            return target_gfx
    raise RuntimeError("amdsmi did not return valid GCN arch")


124
def _get_gcn_arch() -> str:
125
    """
126
127
    Get GCN arch via amdsmi (no CUDA init), fallback to torch.cuda.
    Called once at module level; result stored in _GCN_ARCH.
128
129
130
131
132
133
134
135
136
137
138
139
140
141
    """
    try:
        return _query_gcn_arch_from_amdsmi()
    except Exception as e:
        logger.debug("Failed to get GCN arch via amdsmi: %s", e)
        logger.warning_once(
            "Failed to get GCN arch via amdsmi, falling back to torch.cuda. "
            "This will initialize CUDA and may cause "
            "issues if CUDA_VISIBLE_DEVICES is not set yet."
        )
    # Ultimate fallback: use torch.cuda (will initialize CUDA)
    return torch.cuda.get_device_properties("cuda").gcnArchName


142
143
144
145
146
147
148
149
150
151
152
153
# Resolve once at module load. Uses amdsmi (no CUDA init) so Ray workers
# can still set CUDA_VISIBLE_DEVICES after import.
# These are plain Python bools — fully torch.compile/Dynamo safe.
_GCN_ARCH = _get_gcn_arch()

_ON_GFX1X = any(arch in _GCN_ARCH for arch in ["gfx11", "gfx12"])
_ON_MI3XX = any(arch in _GCN_ARCH for arch in ["gfx942", "gfx950"])
_ON_GFX9 = any(arch in _GCN_ARCH for arch in ["gfx90a", "gfx942", "gfx950"])
_ON_GFX942 = "gfx942" in _GCN_ARCH
_ON_GFX950 = "gfx950" in _GCN_ARCH


154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
def _capability_from_gcn_arch(gcn_arch: str) -> tuple[int, int] | None:
    """
    Parse (major, minor) from a GCN arch string, mirroring how
    HIP derives hipDeviceProp_t.major / .minor.

    Format: gfx<MAJOR><MINOR><STEPPING>
      - 1-digit major  (gfx9xx):  "gfx" + M + m + stepping
      - 2-digit major  (gfx1xxx): "gfx" + MM + m + stepping

    Examples:
      gfx90a  -> (9, 0)    gfx942  -> (9, 4)    gfx950 -> (9, 5)
      gfx1100 -> (11, 0)   gfx1101 -> (11, 0)   gfx1200 -> (12, 0)

    Returns None only when the string is not gfx-prefixed at all
    (i.e. not a ROCm arch string). Raises on any string that looks
    like a GCN arch but does not match a known layout.
    """
    m = re.match(r"gfx(\d+)", gcn_arch)
    if not m:
        # Not a gfx string at all — caller should fall back to torch.cuda
        return None

    digits = m.group(1)
    n = len(digits)

    if n < 2:
        raise ValueError(
            f"GCN arch '{gcn_arch}' has too few digits ({n}) after 'gfx' "
            f"to derive a (major, minor) capability. "
            f"Please file a vLLM issue with your GPU model."
        )

    if n in (2, 3):
        # 1-digit major: gfx9 family
        # len 2: major + minor          (e.g. gfx90 from gfx90a)
        # len 3: major + minor + step   (e.g. gfx942)
        major = int(digits[0])
        minor = int(digits[1])
    elif n == 4:
        # 2-digit major: gfx10xx, gfx11xx, gfx12xx
        # major(2) + minor(1) + stepping(1)
        major = int(digits[:2])
        minor = int(digits[2])
    elif n >= 5:
        raise ValueError(
            f"GCN arch '{gcn_arch}' has {n} digits after 'gfx', which "
            f"exceeds the known 4-digit layout (MMms). Cannot determine "
            f"major/minor split unambiguously. "
            f"Please file a vLLM issue with your GPU model."
        )

    if major < 9:
        raise ValueError(
            f"Parsed unknown ROCm architecture from GCN arch '{gcn_arch}': "
            f"major={major}, minor={minor}. "
            f"Major version < 9 is not expected for any supported AMD GPU. "
            f"Please file a vLLM issue with your GPU model."
        )

    if major > 12:
        raise ValueError(
            f"Parsed unknown ROCm architecture from GCN arch '{gcn_arch}': "
            f"major={major}, minor={minor}. "
            f"Major version > 12 is beyond currently known AMD generations. "
            f"Please file a vLLM issue with your GPU model so support "
            f"can be added."
        )

    return (major, minor)


225
def on_gfx1x() -> bool:
226
    return _ON_GFX1X
227
228


229
def on_mi3xx() -> bool:
230
    return _ON_MI3XX
231
232
233


def on_gfx9() -> bool:
234
    return _ON_GFX9
235
236


237
def on_gfx942() -> bool:
238
    return _ON_GFX942
239
240


241
def on_gfx950() -> bool:
242
    return _ON_GFX950
243
244


245
@cache
246
def use_rocm_custom_paged_attention(
247
248
249
250
251
252
253
    qtype: torch.dtype,
    head_size: int,
    block_size: int,
    gqa_ratio: int,
    max_seq_len: int,
    sliding_window: int,
    kv_cache_dtype: str,
254
255
    alibi_slopes: torch.Tensor | None = None,
    sinks: torch.Tensor | None = None,
256
) -> bool:
257
258
    # custom paged attn always supported on V0. On V1, requires sliding window
    # disabled due to observed numerical discrepancy.
259
    if _ON_GFX9:
260
        return (
261
            (sliding_window == 0 or sliding_window == (-1, -1))
262
263
264
265
266
267
268
269
            and (qtype == torch.half or qtype == torch.bfloat16)
            and (head_size == 64 or head_size == 128)
            and (block_size == 16 or block_size == 32)
            and (gqa_ratio >= 1 and gqa_ratio <= 16)
            and max_seq_len <= 128 * 1024
            and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN)
            and sinks is None
        )
270
271

    else:
272
        return (
273
            _ON_GFX1X
274
            and (sliding_window == 0 or sliding_window == (-1, -1))
275
276
277
278
279
280
281
282
283
284
            and (qtype == torch.half or qtype == torch.bfloat16)
            and head_size == 128
            and block_size == 16
            and (gqa_ratio >= 3 and gqa_ratio <= 16)
            and max_seq_len <= 128 * 1024
            and alibi_slopes is None
            and kv_cache_dtype == "auto"
            and envs.VLLM_ROCM_CUSTOM_PAGED_ATTN
            and sinks is None
        )
285
286


287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
@cache
def flash_attn_triton_available() -> bool:
    if not on_gfx1x():
        return False
    try:
        from importlib.util import find_spec

        if find_spec("flash_attn") is None:
            return False
        if find_spec("flash_attn.flash_attn_triton_amd") is None:
            return False
        if os.environ.get("FLASH_ATTENTION_TRITON_AMD_ENABLE") != "TRUE":
            logger.info_once(
                "Set FLASH_ATTENTION_TRITON_AMD_ENABLE=TRUE to enable "
                "Flash Attention Triton backend on RDNA."
            )
            return False
        return True
    except ImportError:
        return False


309
310
class RocmPlatform(Platform):
    _enum = PlatformEnum.ROCM
311
    device_name: str = "rocm"
312
    device_type: str = "cuda"
313
    dispatch_key: str = "CUDA"
314
    ray_device_key: str = "GPU"
315
    dist_backend: str = "nccl"
316
317
    # rocm shares the same device control env var as CUDA
    device_control_env_var: str = "CUDA_VISIBLE_DEVICES"
318
319
320
321
322
    ray_noset_device_env_vars: list[str] = [
        "RAY_EXPERIMENTAL_NOSET_HIP_VISIBLE_DEVICES",
        "RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES",
        "RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES",
    ]
323

324
    supported_quantization: list[str] = [
325
        "awq",
326
        "awq_marlin",  # will be overwritten with awq
327
        "gptq",
328
        "gptq_marlin",  # will be overwritten with gptq
329
330
331
332
333
334
335
336
337
        "fp8",
        "compressed-tensors",
        "fbgemm_fp8",
        "gguf",
        "quark",
        "ptpc_fp8",
        "mxfp4",
        "petit_nvfp4",
        "torchao",
338
        "bitsandbytes",
339
    ]
340

341
342
343
344
345
346
347
348
349
350
351
    @classmethod
    def import_kernels(cls) -> None:
        """Import ROCm-specific kernels."""
        super().import_kernels()

        import contextlib

        # Import ROCm-specific extension
        with contextlib.suppress(ImportError):
            import vllm._rocm_C  # noqa: F401

352
    @classmethod
353
354
    def get_attn_backend_cls(
        cls,
355
356
        selected_backend: "AttentionBackendEnum",
        attn_selector_config: "AttentionSelectorConfig",
357
        num_heads: int | None = None,
358
    ) -> str:
359
        from vllm._aiter_ops import rocm_aiter_ops
360

361
362
363
364
365
        block_size = attn_selector_config.block_size
        kv_cache_dtype = attn_selector_config.kv_cache_dtype

        if attn_selector_config.use_sparse:
            if kv_cache_dtype and kv_cache_dtype.startswith("fp8"):
366
367
368
369
370
371
                raise ValueError(
                    "ROCMAiterMLASparseBackend doesn't support fp8 kv_cache_dtype."
                )
            assert block_size == 1, (
                "Sparse MLA backend on ROCm only supports block size 1 for now."
            )
372
            logger.info_once("Using Sparse MLA backend.")
373
            return AttentionBackendEnum.ROCM_AITER_MLA_SPARSE.get_path()
374

375
        if attn_selector_config.use_mla:
376
            if selected_backend is None:
377
                selected_backend = (
378
                    AttentionBackendEnum.ROCM_AITER_MLA
379
                    if rocm_aiter_ops.is_mla_enabled() or block_size == 1
380
                    else AttentionBackendEnum.TRITON_MLA
381
                )
382
            if selected_backend == AttentionBackendEnum.TRITON_MLA:
383
                if block_size != 1:
384
                    logger.info_once("Using Triton MLA backend.")
385
                    return AttentionBackendEnum.TRITON_MLA.get_path()
386
387
                raise ValueError(
                    f" The selected backend, {selected_backend.name},"
388
389
                    f"does not support block size {block_size}."
                )
390
            if selected_backend == AttentionBackendEnum.ROCM_AITER_MLA:
391
                logger.info("Using AITER MLA backend.")
392
                return AttentionBackendEnum.ROCM_AITER_MLA.get_path()
393
394
395
            if selected_backend == AttentionBackendEnum.ROCM_AITER_TRITON_MLA:
                logger.info("Using AITER TRITON MLA backend.")
                return AttentionBackendEnum.ROCM_AITER_TRITON_MLA.get_path()
396

397
398
            raise ValueError(
                f" The selected backend, {selected_backend.name},"
399
400
                f"is not MLA type while requested for MLA backend."
            )
401

402
403
404
405
        if selected_backend == AttentionBackendEnum.FLEX_ATTENTION:
            logger.info("Using FlexAttention backend.")
            return AttentionBackendEnum.FLEX_ATTENTION.get_path()

406
        if selected_backend == AttentionBackendEnum.TRITON_ATTN:
407
            logger.info("Using Triton Attention backend.")
408
409
410
            return AttentionBackendEnum.TRITON_ATTN.get_path()

        if selected_backend == AttentionBackendEnum.ROCM_ATTN:
411
            logger.info("Using Rocm Attention backend.")
412
            return AttentionBackendEnum.ROCM_ATTN.get_path()
413
414
415

        if selected_backend == AttentionBackendEnum.ROCM_AITER_FA:
            if on_gfx9():
416
                logger.info("Using Aiter Flash Attention backend.")
417
418
419
420
421
422
423
424
                return AttentionBackendEnum.ROCM_AITER_FA.get_path()
            else:
                raise ValueError(
                    f"The selected backend, {selected_backend.name}, "
                    "is only supported on gfx9 architectures."
                )

        if selected_backend == AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN:
425
            logger.info("Using Aiter Unified Attention backend.")
426
427
428
429
430
431
            return AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN.get_path()

        # Handle automatic backend selection based on environment variables
        if selected_backend is None:
            # Priority 1: Check for AITER Unified Attention (must check before MHA)
            if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION:
432
                logger.info("Using Aiter Unified Attention backend.")
433
434
435
436
437
                return AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN.get_path()

            # Priority 2: Check for AITER MHA (Flash Attention)
            # Only use if explicitly enabled (not just VLLM_ROCM_USE_AITER=1)
            if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9():
438
                logger.info("Using Aiter Flash Attention backend.")
439
440
441
                return AttentionBackendEnum.ROCM_AITER_FA.get_path()

            # Priority 3: Check for ROCM_ATTN (prefill-decode split)
442
            from vllm.config import get_current_vllm_config_or_none
443

444
445
446
447
448
            vllm_config = get_current_vllm_config_or_none()
            if (
                vllm_config is not None
                and vllm_config.attention_config.use_prefill_decode_attention
            ):
449
                logger.info("Using Rocm Attention backend.")
450
451
452
453
454
455
456
457
458
                return AttentionBackendEnum.ROCM_ATTN.get_path()

            # Priority 4: Check for AITER enabled without specific flags
            # This defaults to AITER FA only if MHA is not explicitly disabled
            if (
                envs.VLLM_ROCM_USE_AITER
                and on_gfx9()
                and envs.VLLM_ROCM_USE_AITER_MHA is not False
            ):
459
                logger.info("Using Aiter Flash Attention backend.")
460
461
462
                return AttentionBackendEnum.ROCM_AITER_FA.get_path()

            # Default: Triton Unified Attention
463
            logger.info("Using Triton Attention backend.")
464
465
466
            return AttentionBackendEnum.TRITON_ATTN.get_path()

        raise RuntimeError(
467
468
            f"Attention backend {selected_backend.name} is not supported on "
            "ROCm. Note that V0 attention backends have been removed."
469
        )
470

471
472
473
474
475
    @classmethod
    def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]:
        return [
            AttentionBackendEnum.FLASH_ATTN,
            AttentionBackendEnum.ROCM_AITER_FA,
476
            AttentionBackendEnum.TRITON_ATTN,
477
478
479
480
481
482
483
484
            AttentionBackendEnum.TORCH_SDPA,
        ]

    @classmethod
    def get_vit_attn_backend(
        cls,
        head_size: int,
        dtype: torch.dtype,
485
        backend: "AttentionBackendEnum | None" = None,
486
487
488
489
490
491
492
493
494
495
496
497
498
    ) -> "AttentionBackendEnum":
        if backend is not None:
            assert backend in cls.get_supported_vit_attn_backends(), (
                f"Backend {backend} is not supported for vit attention. "
                f"Supported backends are: {cls.get_supported_vit_attn_backends()}"
            )
            logger.info_once(f"Using backend {backend} for vit attention")
            return backend

        from importlib.util import find_spec

        from vllm._aiter_ops import rocm_aiter_ops

499
        if rocm_aiter_ops.is_enabled() and on_gfx9():
500
            logger.info_once("Using AITER Flash Attention backend for ViT model.")
501
502
            return AttentionBackendEnum.ROCM_AITER_FA

503
504
505
506
507
508
        if (
            on_gfx9()
            and find_spec("flash_attn") is not None
            and (dtype == torch.float16 or dtype == torch.bfloat16)
        ):
            logger.info_once("Using Flash Attention backend for ViT model.")
509
510
            return AttentionBackendEnum.FLASH_ATTN

511
512
513
514
515
516
517
518
519
520
521
        # RDNA3/RDNA4 (gfx11xx/gfx12xx): Use Flash Attention Triton backend
        if (
            on_gfx1x()
            and flash_attn_triton_available()
            and (dtype == torch.float16 or dtype == torch.bfloat16)
        ):
            logger.info_once(
                "Using Flash Attention (Triton backend) for ViT model on RDNA."
            )
            return AttentionBackendEnum.FLASH_ATTN

522
        logger.info_once("Using Torch SDPA backend for ViT model.")
523
524
        return AttentionBackendEnum.TORCH_SDPA

525
526
527
528
529
530
531
    @classmethod
    def set_device(cls, device: torch.device) -> None:
        """
        Set the device for the current platform.
        """
        torch.cuda.set_device(device)

532
    @classmethod
533
    @lru_cache(maxsize=8)
534
    def get_device_capability(cls, device_id: int = 0) -> DeviceCapability | None:
535
536
537
538
539
540
541
542
543
        cap = _capability_from_gcn_arch(_GCN_ARCH)
        if cap is not None:
            return DeviceCapability(major=cap[0], minor=cap[1])

        logger.warning_once(
            "Could not derive device capability from GCN arch '%s', "
            "falling back to torch.cuda (this will initialize CUDA).",
            _GCN_ARCH,
        )
544
545
        major, minor = torch.cuda.get_device_capability(device_id)
        return DeviceCapability(major=major, minor=minor)
546

547
    @classmethod
548
    @with_amdsmi_context
549
    def is_fully_connected(cls, physical_device_ids: list[int]) -> bool:
550
551
552
        """
        Query if the set of gpus are fully connected by xgmi (1 hop)
        """
553
        handles = [amdsmi_get_processor_handles()[i] for i in physical_device_ids]
554
555
556
557
        for i, handle in enumerate(handles):
            for j, peer_handle in enumerate(handles):
                if i < j:
                    try:
558
                        link_type = amdsmi_topo_get_link_type(handle, peer_handle)
559
560
561
562
                        # type is 2 for XGMI
                        if link_type["hops"] != 1 or link_type["type"] != 2:
                            return False
                    except AmdSmiException as error:
563
                        logger.error("AMD 1 hop XGMI detection failed.", exc_info=error)
564
565
566
                        return False
        return True

567
    @classmethod
568
    @with_amdsmi_context
569
    @lru_cache(maxsize=8)
570
    def get_device_name(cls, device_id: int = 0) -> str:
571
        physical_device_id = cls.device_id_to_physical_device_id(device_id)
572
        handle = amdsmi_get_processor_handles()[physical_device_id]
573
574
575
576
577
        asic_info = amdsmi_get_gpu_asic_info(handle)
        device_name: str = asic_info["device_id"]
        if device_name in _ROCM_DEVICE_ID_NAME_MAP:
            return _ROCM_DEVICE_ID_NAME_MAP[device_name]
        return asic_info["market_name"]
578
579
580
581
582

    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        device_props = torch.cuda.get_device_properties(device_id)
        return device_props.total_memory
583
584

    @classmethod
585
    def apply_config_platform_defaults(cls, vllm_config: "VllmConfig") -> None:
586
        from vllm._aiter_ops import rocm_aiter_ops
587
588
589
        from vllm.config.compilation import CUDAGraphMode

        compilation_config = vllm_config.compilation_config
590
        is_eager_execution = compilation_config.cudagraph_mode == CUDAGraphMode.NONE
591
        use_aiter_fused_moe = rocm_aiter_ops.is_fused_moe_enabled()
vllmellm's avatar
vllmellm committed
592
        use_aiter_rms_norm = rocm_aiter_ops.is_rmsnorm_enabled()
593
        use_aiter_fp8_linear = rocm_aiter_ops.is_linear_fp8_enabled()
594
        use_aiter_fused_se = rocm_aiter_ops.is_fusion_moe_shared_experts_enabled()
595
        use_aiter_triton_rope = rocm_aiter_ops.is_triton_rotary_embed_enabled()
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
        #  Aiter rms norm perform best when CUDA Graph capture is enabled.
        if (
            use_aiter_rms_norm
            and not is_eager_execution
            and "-rms_norm" not in compilation_config.custom_ops
        ):
            compilation_config.custom_ops.append("+rms_norm")

        if use_aiter_fp8_linear and "-quant_fp8" not in compilation_config.custom_ops:
            compilation_config.custom_ops.append("+quant_fp8")

        if use_aiter_fused_se and "-grouped_topk" in compilation_config.custom_ops:
            logger.warning_once(
                "VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS is enabled, which "
                "requires the 'grouped_topk' custom op. Overriding the "
                "user-provided '-grouped_topk'."
            )
            compilation_config.custom_ops.remove("-grouped_topk")
        # Ensure grouped_topk is always enabled when using AITER if
        # its not disabled by user
        if (
            use_aiter_fused_moe
            and "+grouped_topk" not in compilation_config.custom_ops
            and "-grouped_topk" not in compilation_config.custom_ops
        ):
            compilation_config.custom_ops.append("+grouped_topk")
        # Enable rotary embedding when using AITER if its not disabled by user
        if (
            use_aiter_triton_rope
            and "+rotary_embedding" not in compilation_config.custom_ops
            and "-rotary_embedding" not in compilation_config.custom_ops
        ):
            compilation_config.custom_ops.append("+rotary_embedding")

        # Default dispatch to rocm's sparse_attn_indexer implementation
        compilation_config.custom_ops.append("+sparse_attn_indexer")

    @classmethod
    def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
        from vllm.config.compilation import CUDAGraphMode

        cache_config = vllm_config.cache_config
        compilation_config = vllm_config.compilation_config
        parallel_config = vllm_config.parallel_config
640

641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
        if compilation_config.cudagraph_mode.has_full_cudagraphs():
            # decode context parallel does not support full cudagraphs
            if parallel_config.decode_context_parallel_size > 1:
                logger.warning_once(
                    "Decode context parallel (DCP) is enabled, which is "
                    "incompatible with full CUDA graphs. "
                    "Overriding cudagraph_mode to PIECEWISE."
                )
                compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
            # prefill context parallel do not support full cudagraphs
            elif parallel_config.prefill_context_parallel_size > 1:
                logger.warning_once(
                    "Prefill context parallel (PCP) is enabled, which is "
                    "incompatible with full CUDA graphs. "
                    "Overriding cudagraph_mode to PIECEWISE."
                )
                compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE

659
        if cache_config and cache_config.block_size is None:
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
            if (
                envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION and envs.VLLM_ROCM_USE_AITER
                # NOTE: This block has been deprecated
                # or get_env_variable_attn_backend()
                # == AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN
                # TODO: monitor https://github.com/vllm-project/vllm/pull/30396
                # to see how we can transition to the new way of selecting
                # attention backends
            ):
                cache_config.block_size = 64
                logger.warning(
                    "[ROCM_AITER_UNIFIED_ATTN]: Setting kv cache block size to 64."
                )
            else:
                cache_config.block_size = 16
675

676
        if parallel_config.worker_cls == "auto":
677
            parallel_config.worker_cls = "vllm.v1.worker.gpu_worker.Worker"
678

679
680
681
    @classmethod
    def verify_model_arch(cls, model_arch: str) -> None:
        if model_arch in _ROCM_UNSUPPORTED_MODELS:
682
683
684
            raise ValueError(
                f"Model architecture '{model_arch}' is not supported by ROCm for now."
            )
685
686
687
688

        if model_arch in _ROCM_PARTIALLY_SUPPORTED_MODELS:
            msg = _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch]
            logger.warning(
689
690
691
692
                "Model architecture '%s' is partially supported by ROCm: %s",
                model_arch,
                msg,
            )
693

694
695
696
697
698
699
    @classmethod
    def verify_quantization(cls, quant: str) -> None:
        super().verify_quantization(quant)
        if quant == "awq" and not envs.VLLM_USE_TRITON_AWQ:
            logger.warning(
                "Using AWQ quantization with ROCm, but VLLM_USE_TRITON_AWQ"
700
701
                " is not set, enabling VLLM_USE_TRITON_AWQ."
            )
702
        os.environ["VLLM_USE_TRITON_AWQ"] = "1"
703
704
705
706

    @classmethod
    def get_punica_wrapper(cls) -> str:
        return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU"
707
708

    @classmethod
709
    def get_current_memory_usage(
710
        cls, device: torch.types.Device | None = None
711
    ) -> float:
712
        torch.cuda.reset_peak_memory_stats(device)
713
714
        free_mem, total_mem = torch.cuda.mem_get_info(device)
        return total_mem - free_mem
715
716
717

    @classmethod
    def get_device_communicator_cls(cls) -> str:
718
719
720
        return (
            "vllm.distributed.device_communicators.cuda_communicator.CudaCommunicator"  # noqa
        )
721

722
723
    @classmethod
    def supports_mx(cls) -> bool:
724
        return any(gfx in _GCN_ARCH for gfx in ["gfx95"])
725

726
727
    @classmethod
    def supports_fp8(cls) -> bool:
728
        return any(gfx in _GCN_ARCH for gfx in ["gfx94", "gfx95", "gfx12"])
729
730
731
732

    @classmethod
    def is_fp8_fnuz(cls) -> bool:
        # only device 0 is checked, this assumes MI300 platforms are homogeneous
733
        return "gfx94" in _GCN_ARCH
734
735
736
737
738
739
740

    @classmethod
    def fp8_dtype(cls) -> torch.dtype:
        if cls.is_fp8_fnuz():
            return torch.float8_e4m3fnuz
        else:
            return torch.float8_e4m3fn
741

742
743
744
    @classmethod
    def use_custom_allreduce(cls) -> bool:
        # We only enable custom allreduce for MI300 series
745
        return any(gfx in _GCN_ARCH for gfx in ["gfx94", "gfx95"])
746

747
748
749
750
    @classmethod
    def opaque_attention_op(cls) -> bool:
        return True

751
752
    @classmethod
    def is_navi(cls) -> bool:
753
        return "gfx1" in _GCN_ARCH
754
755

    @classmethod
756
757
    def get_static_graph_wrapper_cls(cls) -> str:
        return "vllm.compilation.cuda_graph.CUDAGraphWrapper"
758

759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
    @classmethod
    def stateless_init_device_torch_dist_pg(
        cls,
        backend: str,
        prefix_store: PrefixStore,
        group_rank: int,
        group_size: int,
        timeout: timedelta,
    ) -> ProcessGroup:
        assert is_nccl_available()
        pg: ProcessGroup = ProcessGroup(
            prefix_store,
            group_rank,
            group_size,
        )
        from torch.distributed.distributed_c10d import ProcessGroupNCCL

        backend_options = ProcessGroupNCCL.Options()
        backend_options._timeout = timeout

        backend_class = ProcessGroupNCCL(
            prefix_store, group_rank, group_size, backend_options
        )
        backend_type = ProcessGroup.BackendType.NCCL
        device = torch.device("cuda")
        pg._set_default_backend(backend_type)
        backend_class._set_sequence_number_for_group()

        pg._register_backend(device, backend_type, backend_class)
        return pg

790
791
792
    @classmethod
    def device_count(cls) -> int:
        return cuda_device_count_stateless()
793

794
    @classmethod
795
796
    def check_if_supports_dtype(cls, dtype: torch.dtype):
        if dtype == torch.bfloat16:  # noqa: SIM102
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
            if not cls.has_device_capability(80):
                capability = cls.get_device_capability()
                gpu_name = cls.get_device_name()

                if capability is None:
                    compute_str = "does not have a compute capability"
                else:
                    version_str = capability.as_version_str()
                    compute_str = f"has compute capability {version_str}"

                raise ValueError(
                    "Bfloat16 is only supported on GPUs "
                    "with compute capability of at least 8.0. "
                    f"Your {gpu_name} GPU {compute_str}. "
                    "You can use float16 instead by explicitly setting the "
812
813
                    "`dtype` flag in CLI, for example: --dtype=half."
                )
814
815
816
817

    @classmethod
    def support_hybrid_kv_cache(cls) -> bool:
        return True
818
819
820
821

    @classmethod
    def support_static_graph_mode(cls) -> bool:
        return True
822
823
824
825

    @classmethod
    def num_compute_units(cls, device_id=0):
        return torch.cuda.get_device_properties(device_id).multi_processor_count