rocm.py 34.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import os
5
from datetime import timedelta
6
from functools import cache, lru_cache, wraps
7
from typing import TYPE_CHECKING
8

9
import regex as re
10
import torch
11
12
from torch.distributed import PrefixStore, ProcessGroup
from torch.distributed.distributed_c10d import is_nccl_available
13

14
import vllm.envs as envs
15
from vllm.logger import init_logger
16
from vllm.v1.attention.backends.registry import AttentionBackendEnum
17

18
from .interface import DeviceCapability, Platform, PlatformEnum
19

20
if TYPE_CHECKING:
21
    from vllm.config import VllmConfig
22
    from vllm.config.kernel import IrOpPriorityConfig
23
    from vllm.v1.attention.selector import AttentionSelectorConfig
24

25
26
logger = init_logger(__name__)

27
try:
28
29
30
    from amdsmi import (
        AmdSmiException,
        amdsmi_get_gpu_asic_info,
tmm77's avatar
tmm77 committed
31
        amdsmi_get_gpu_device_uuid,
32
33
34
35
        amdsmi_get_processor_handles,
        amdsmi_init,
        amdsmi_shut_down,
        amdsmi_topo_get_link_type,
36
        amdsmi_topo_get_numa_node_number,
37
    )
38
39
40
except ImportError as e:
    logger.warning("Failed to import from amdsmi with %r", e)

41
42
43
44
45
46
try:
    import vllm._C  # noqa: F401
except ImportError as e:
    logger.warning("Failed to import from vllm._C with %r", e)

# import custom ops, trigger op registration
zhuwenwen's avatar
zhuwenwen committed
47
48
49
50
# try:
#     import vllm._rocm_C  # noqa: F401
# except ImportError as e:
#     logger.warning("Failed to import from vllm._rocm_C with %r", e)
51

52
# Models not supported by ROCm.
53
_ROCM_UNSUPPORTED_MODELS: list[str] = []
54
55
56

# Models partially supported by ROCm.
# Architecture -> Reason.
57
_ROCM_PARTIALLY_SUPPORTED_MODELS: dict[str, str] = {}
58
_ROCM_DEVICE_ID_NAME_MAP: dict[str, str] = {
59
60
61
    "0x74a0": "AMD_Instinct_MI300A",
    "0x74a1": "AMD_Instinct_MI300X",
    "0x74b5": "AMD_Instinct_MI300X",  # MI300X VF
62
    "0x74a2": "AMD_Instinct_MI308X",
63
64
65
66
    "0x74a5": "AMD_Instinct_MI325X",
    "0x74b9": "AMD_Instinct_MI325X",  # MI325X VF
    "0x74a9": "AMD_Instinct_MI300X_HF",
    "0x74bd": "AMD_Instinct_MI300X_HF",
67
    "0x744c": "AMD_Radeon_RX7900XTX",
68
69
70
71
72
73
    # RDNA 3.5 APUs (Strix Point / Strix Halo)
    "0x150e": "AMD_Radeon_890M",  # gfx1150, Strix Point
    "0x1586": "AMD_Radeon_8060S",  # gfx1151, Strix Halo
    # RDNA 4 discrete (Navi 48)
    "0x7550": "AMD_Radeon_RX9070XT",  # gfx1201
    "0x7551": "AMD_Radeon_R9700",  # gfx1201
74
}
75

76

77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
@lru_cache(maxsize=8)
def _rocm_device_count_stateless(cuda_visible_devices: str | None = None) -> int:
    """Get number of ROCm devices, caching based on the value of CUDA_VISIBLE_DEVICES
    at the time of call.

    This should be used instead of torch.accelerator.device_count() unless
    CUDA_VISIBLE_DEVICES has already been set to the desired value.

    # This can be removed and simply replaced with torch.cuda.get_device_count
    # after https://github.com/pytorch/pytorch/pull/122815 is released."""
    # Note: cuda_visible_devices is not used, but we keep it as an argument for
    # LRU Cache purposes.

    # Code below is based on
    # https://github.com/pytorch/pytorch/blob/
    # c1cd946818442aca8c7f812b16d187ce1586c3bc/
    # torch/cuda/__init__.py#L831C1-L831C17
    import torch.cuda

    if not torch.cuda._is_compiled():
        return 0
    # ROCm uses amdsmi instead of nvml for stateless device count
    # This requires a sufficiently modern version of Torch 2.4.0
    raw_count = (
        torch.cuda._device_count_amdsmi()
        if (hasattr(torch.cuda, "_device_count_amdsmi"))
        else -1
    )
    r = torch._C._cuda_getDeviceCount() if raw_count < 0 else raw_count
    return r


109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
def _sync_hip_cuda_env_vars():
    """Ensure HIP_VISIBLE_DEVICES and CUDA_VISIBLE_DEVICES are consistent.
    Treats empty string as unset. Raises on genuine conflicts."""
    hip_val = os.environ.get("HIP_VISIBLE_DEVICES") or None
    cuda_val = os.environ.get("CUDA_VISIBLE_DEVICES") or None

    if hip_val is not None and cuda_val is not None:
        if hip_val != cuda_val:
            raise ValueError(
                f"Inconsistent GPU visibility env vars: "
                f"HIP_VISIBLE_DEVICES='{hip_val}' vs "
                f"CUDA_VISIBLE_DEVICES='{cuda_val}'. "
                f"Please set only one, or ensure they match."
            )
    elif hip_val is not None:
        os.environ["CUDA_VISIBLE_DEVICES"] = hip_val
    elif cuda_val is not None:
        os.environ["HIP_VISIBLE_DEVICES"] = cuda_val


# Sync at import time - catches misconfigurations from process start.
_sync_hip_cuda_env_vars()
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149

# AMDSMI utils
# Note that NVML is not affected by `{CUDA/HIP}_VISIBLE_DEVICES`,
# all the related functions work on real physical device ids.
# the major benefit of using AMDSMI is that it will not initialize CUDA


def with_amdsmi_context(fn):
    @wraps(fn)
    def wrapper(*args, **kwargs):
        amdsmi_init()
        try:
            return fn(*args, **kwargs)
        finally:
            amdsmi_shut_down()

    return wrapper


150
151
152
153
154
155
156
157
158
159
160
161
162
163
@with_amdsmi_context
def _query_gcn_arch_from_amdsmi() -> str:
    """Query GCN arch from amdsmi. Raises if not available."""
    handles = amdsmi_get_processor_handles()
    if handles:
        asic_info = amdsmi_get_gpu_asic_info(handles[0])
        # Use target_graphics_version which contains the gfx name
        # e.g., 'gfx942' for MI300X/MI325X
        target_gfx = asic_info.get("target_graphics_version", "")
        if target_gfx:
            return target_gfx
    raise RuntimeError("amdsmi did not return valid GCN arch")


164
def _get_gcn_arch() -> str:
165
    """
166
167
    Get GCN arch via amdsmi (no CUDA init), fallback to torch.cuda.
    Called once at module level; result stored in _GCN_ARCH.
168
169
170
171
172
173
174
175
176
177
178
179
180
181
    """
    try:
        return _query_gcn_arch_from_amdsmi()
    except Exception as e:
        logger.debug("Failed to get GCN arch via amdsmi: %s", e)
        logger.warning_once(
            "Failed to get GCN arch via amdsmi, falling back to torch.cuda. "
            "This will initialize CUDA and may cause "
            "issues if CUDA_VISIBLE_DEVICES is not set yet."
        )
    # Ultimate fallback: use torch.cuda (will initialize CUDA)
    return torch.cuda.get_device_properties("cuda").gcnArchName


182
183
184
185
186
187
# Resolve once at module load. Uses amdsmi (no CUDA init) so Ray workers
# can still set CUDA_VISIBLE_DEVICES after import.
# These are plain Python bools — fully torch.compile/Dynamo safe.
_GCN_ARCH = _get_gcn_arch()

_ON_GFX1X = any(arch in _GCN_ARCH for arch in ["gfx11", "gfx12"])
188
_ON_GFX12X = any(arch in _GCN_ARCH for arch in ["gfx12"])
189
190
_ON_MI3XX = any(arch in _GCN_ARCH for arch in ["gfx942", "gfx950"])
_ON_GFX9 = any(arch in _GCN_ARCH for arch in ["gfx90a", "gfx942", "gfx950"])
191
_ON_GFX90A = "gfx90a" in _GCN_ARCH
192
193
194
195
_ON_GFX942 = "gfx942" in _GCN_ARCH
_ON_GFX950 = "gfx950" in _GCN_ARCH


196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
def _capability_from_gcn_arch(gcn_arch: str) -> tuple[int, int] | None:
    """
    Parse (major, minor) from a GCN arch string, mirroring how
    HIP derives hipDeviceProp_t.major / .minor.

    Format: gfx<MAJOR><MINOR><STEPPING>
      - 1-digit major  (gfx9xx):  "gfx" + M + m + stepping
      - 2-digit major  (gfx1xxx): "gfx" + MM + m + stepping

    Examples:
      gfx90a  -> (9, 0)    gfx942  -> (9, 4)    gfx950 -> (9, 5)
      gfx1100 -> (11, 0)   gfx1101 -> (11, 0)   gfx1200 -> (12, 0)

    Returns None only when the string is not gfx-prefixed at all
    (i.e. not a ROCm arch string). Raises on any string that looks
    like a GCN arch but does not match a known layout.
    """
    m = re.match(r"gfx(\d+)", gcn_arch)
    if not m:
        # Not a gfx string at all — caller should fall back to torch.cuda
        return None

    digits = m.group(1)
    n = len(digits)

    if n < 2:
        raise ValueError(
            f"GCN arch '{gcn_arch}' has too few digits ({n}) after 'gfx' "
            f"to derive a (major, minor) capability. "
            f"Please file a vLLM issue with your GPU model."
        )

    if n in (2, 3):
        # 1-digit major: gfx9 family
        # len 2: major + minor          (e.g. gfx90 from gfx90a)
        # len 3: major + minor + step   (e.g. gfx942)
        major = int(digits[0])
        minor = int(digits[1])
    elif n == 4:
        # 2-digit major: gfx10xx, gfx11xx, gfx12xx
        # major(2) + minor(1) + stepping(1)
        major = int(digits[:2])
        minor = int(digits[2])
    elif n >= 5:
        raise ValueError(
            f"GCN arch '{gcn_arch}' has {n} digits after 'gfx', which "
            f"exceeds the known 4-digit layout (MMms). Cannot determine "
            f"major/minor split unambiguously. "
            f"Please file a vLLM issue with your GPU model."
        )

    if major < 9:
        raise ValueError(
            f"Parsed unknown ROCm architecture from GCN arch '{gcn_arch}': "
            f"major={major}, minor={minor}. "
            f"Major version < 9 is not expected for any supported AMD GPU. "
            f"Please file a vLLM issue with your GPU model."
        )

    if major > 12:
        raise ValueError(
            f"Parsed unknown ROCm architecture from GCN arch '{gcn_arch}': "
            f"major={major}, minor={minor}. "
            f"Major version > 12 is beyond currently known AMD generations. "
            f"Please file a vLLM issue with your GPU model so support "
            f"can be added."
        )

    return (major, minor)


267
def on_gfx1x() -> bool:
268
    return _ON_GFX1X
269
270


271
272
273
274
def on_gfx12x() -> bool:
    return _ON_GFX12X


275
def on_mi3xx() -> bool:
276
    return _ON_MI3XX
277
278
279


def on_gfx9() -> bool:
280
    return _ON_GFX9
281
282


283
284
285
286
def on_gfx90a() -> bool:
    return _ON_GFX90A


287
def on_gfx942() -> bool:
288
    return _ON_GFX942
289
290


291
def on_gfx950() -> bool:
292
    return _ON_GFX950
293
294


295
@cache
296
def use_rocm_custom_paged_attention(
297
298
299
300
301
302
303
    qtype: torch.dtype,
    head_size: int,
    block_size: int,
    gqa_ratio: int,
    max_seq_len: int,
    sliding_window: int,
    kv_cache_dtype: str,
304
305
    alibi_slopes: torch.Tensor | None = None,
    sinks: torch.Tensor | None = None,
306
) -> bool:
307
308
    # custom paged attn always supported on V0. On V1, requires sliding window
    # disabled due to observed numerical discrepancy.
309
    if _ON_GFX9:
310
        return (
311
            (sliding_window == 0 or sliding_window == (-1, -1))
312
313
314
315
316
317
318
            and (qtype == torch.half or qtype == torch.bfloat16)
            and (head_size == 64 or head_size == 128)
            and (block_size == 16 or block_size == 32)
            and (gqa_ratio >= 1 and gqa_ratio <= 16)
            and max_seq_len <= 128 * 1024
            and sinks is None
        )
319
320

    else:
321
        return (
322
            _ON_GFX1X
323
            and (sliding_window == 0 or sliding_window == (-1, -1))
324
325
326
327
328
329
330
331
332
            and (qtype == torch.half or qtype == torch.bfloat16)
            and head_size == 128
            and block_size == 16
            and (gqa_ratio >= 3 and gqa_ratio <= 16)
            and max_seq_len <= 128 * 1024
            and alibi_slopes is None
            and kv_cache_dtype == "auto"
            and sinks is None
        )
333
334


335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
@cache
def flash_attn_triton_available() -> bool:
    if not on_gfx1x():
        return False
    try:
        from importlib.util import find_spec

        if find_spec("flash_attn") is None:
            return False
        if find_spec("flash_attn.flash_attn_triton_amd") is None:
            return False
        if os.environ.get("FLASH_ATTENTION_TRITON_AMD_ENABLE") != "TRUE":
            logger.info_once(
                "Set FLASH_ATTENTION_TRITON_AMD_ENABLE=TRUE to enable "
                "Flash Attention Triton backend on RDNA."
            )
            return False
        return True
    except ImportError:
        return False


357
358
359
360
def _get_backend_priorities(
    use_mla: bool,
    use_sparse: bool,
) -> list[AttentionBackendEnum]:
361
    from vllm._aiter_ops import is_aiter_found_and_supported, rocm_aiter_ops
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377

    if use_sparse:
        return [AttentionBackendEnum.ROCM_AITER_MLA_SPARSE]

    if use_mla:
        if rocm_aiter_ops.is_mla_enabled():
            return [
                AttentionBackendEnum.ROCM_AITER_MLA,
                AttentionBackendEnum.TRITON_MLA,
                AttentionBackendEnum.ROCM_AITER_TRITON_MLA,
            ]
        else:
            return [
                AttentionBackendEnum.TRITON_MLA,
            ]

378
379
380
381
    backends = [
        AttentionBackendEnum.ROCM_ATTN,
    ]
    if rocm_aiter_ops.is_mha_enabled():
382
        backends.append(AttentionBackendEnum.ROCM_AITER_FA)
383
384
    if is_aiter_found_and_supported():
        backends.append(AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN)
385
    backends.append(AttentionBackendEnum.TRITON_ATTN)
386
    backends.append(AttentionBackendEnum.TURBOQUANT)
387

388
389
390
    return backends


391
392
class RocmPlatform(Platform):
    _enum = PlatformEnum.ROCM
393
    device_name: str = "rocm"
394
    device_type: str = "cuda"
395
    dispatch_key: str = "CUDA"
396
    ray_device_key: str = "GPU"
397
    dist_backend: str = "nccl"
398
399
    # rocm shares the same device control env var as CUDA
    device_control_env_var: str = "CUDA_VISIBLE_DEVICES"
400
401
402
403
404
    ray_noset_device_env_vars: list[str] = [
        "RAY_EXPERIMENTAL_NOSET_HIP_VISIBLE_DEVICES",
        "RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES",
        "RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES",
    ]
405

406
    supported_quantization: list[str] = [
407
        "awq",
408
        "awq_marlin",  # will be overwritten with awq
409
        "gptq",
410
        "gptq_marlin",  # will be overwritten with gptq
411
412
413
414
415
416
        "fp8",
        "compressed-tensors",
        "fbgemm_fp8",
        "gguf",
        "quark",
        "mxfp4",
417
        "gpt_oss_mxfp4",
418
        "torchao",
419
        "bitsandbytes",
420
        "modelopt_fp4",
421
    ]
422

423
424
425
426
427
428
429
430
431
432
433
    @classmethod
    def import_kernels(cls) -> None:
        """Import ROCm-specific kernels."""
        super().import_kernels()

        import contextlib

        # Import ROCm-specific extension
        with contextlib.suppress(ImportError):
            import vllm._rocm_C  # noqa: F401

434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
    @classmethod
    def get_valid_backends(
        cls,
        device_capability: DeviceCapability,
        attn_selector_config: "AttentionSelectorConfig",
        num_heads: int | None = None,
    ) -> tuple[
        list[tuple["AttentionBackendEnum", int]],
        dict["AttentionBackendEnum", list[str]],
    ]:
        valid_backends_priorities = []
        invalid_reasons = {}

        backend_priorities = _get_backend_priorities(
            attn_selector_config.use_mla,
            attn_selector_config.use_sparse,
        )
        for priority, backend in enumerate(backend_priorities):
            try:
                backend_class = backend.get_class()
                invalid_reasons_i = backend_class.validate_configuration(
                    device_capability=device_capability,
                    **attn_selector_config._asdict(),
                )
            except ImportError:
                invalid_reasons_i = ["ImportError"]
            if invalid_reasons_i:
                invalid_reasons[backend] = invalid_reasons_i
            else:
                valid_backends_priorities.append((backend, priority))

        return valid_backends_priorities, invalid_reasons

467
    @classmethod
468
469
    def get_attn_backend_cls(
        cls,
470
471
        selected_backend: "AttentionBackendEnum",
        attn_selector_config: "AttentionSelectorConfig",
472
        num_heads: int | None = None,
473
    ) -> str:
474
475
476
477
478
479
480
481
482
483
        device_capability = cls.get_device_capability()
        assert device_capability is not None

        # First try checking just the selected backend, if there is one.
        if selected_backend is not None:
            try:
                backend_class = selected_backend.get_class()
                invalid_reasons = backend_class.validate_configuration(
                    device_capability=device_capability,
                    **attn_selector_config._asdict(),
484
                )
485
486
487
            except ImportError:
                invalid_reasons = ["ImportError"]
            if invalid_reasons:
488
                raise ValueError(
489
490
                    f"Selected backend {selected_backend} is not valid for "
                    f"this configuration. Reason: {invalid_reasons}"
491
                )
492
            else:
493
494
495
496
                logger.info_once(
                    "Using %s backend (selected via --attention-backend).",
                    selected_backend.name,
                )
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
                return selected_backend.get_path()

        # No selected backend or the selected backend is invalid,
        # so we try finding a valid backend.
        valid_backends_priorities, invalid_reasons = cls.get_valid_backends(
            device_capability=device_capability,
            attn_selector_config=attn_selector_config,
            num_heads=num_heads,
        )
        reasons_str = (
            "{"
            + ", ".join(
                f"{backend.name}: [{', '.join(reasons)}]"
                for backend, reasons in invalid_reasons.items()
            )
            + "}"
        )
        config_str = attn_selector_config.__repr__()
        logger.debug_once(
            f"Some attention backends are not valid for {cls.device_name} with "
            f"{config_str}. Reasons: {reasons_str}."
        )
        if len(valid_backends_priorities) == 0:
520
            raise ValueError(
521
522
                f"No valid attention backend found for {cls.device_name} "
                f"with {config_str}. Reasons: {reasons_str}."
523
            )
524

525
526
527
528
529
530
531
532
        # We have found some valid backends. Select the one with the
        # highest priority.
        sorted_indices = sorted(
            range(len(valid_backends_priorities)),
            key=lambda i: valid_backends_priorities[i][1],
        )
        selected_index = sorted_indices[0]
        selected_backend = valid_backends_priorities[selected_index][0]
533
534
        valid_str = (
            "[" + ", ".join(f"'{b[0].name}'" for b in valid_backends_priorities) + "]"
535
        )
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
        if invalid_reasons:
            rejected_str = ", ".join(b.name for b in invalid_reasons)
            logger.info(
                "Found incompatible backend(s) [%s] with %s. "
                "Overriding with %s out of potential backends: %s.",
                rejected_str,
                attn_selector_config.attn_type,
                selected_backend.name,
                valid_str,
            )
        else:
            logger.info_once(
                "Using %s backend out of potential backends: %s.",
                selected_backend.name,
                valid_str,
            )
552

553
554
        return selected_backend.get_path()

555
556
557
558
559
    @classmethod
    def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]:
        return [
            AttentionBackendEnum.FLASH_ATTN,
            AttentionBackendEnum.ROCM_AITER_FA,
560
            AttentionBackendEnum.TRITON_ATTN,
561
562
563
564
565
566
567
568
            AttentionBackendEnum.TORCH_SDPA,
        ]

    @classmethod
    def get_vit_attn_backend(
        cls,
        head_size: int,
        dtype: torch.dtype,
569
        backend: "AttentionBackendEnum | None" = None,
570
571
572
573
574
575
576
577
578
579
580
581
582
    ) -> "AttentionBackendEnum":
        if backend is not None:
            assert backend in cls.get_supported_vit_attn_backends(), (
                f"Backend {backend} is not supported for vit attention. "
                f"Supported backends are: {cls.get_supported_vit_attn_backends()}"
            )
            logger.info_once(f"Using backend {backend} for vit attention")
            return backend

        from importlib.util import find_spec

        from vllm._aiter_ops import rocm_aiter_ops

583
        if rocm_aiter_ops.is_enabled() and on_gfx9():
584
            logger.info_once("Using AITER Flash Attention backend for ViT model.")
585
586
            return AttentionBackendEnum.ROCM_AITER_FA

587
588
589
590
591
592
        if (
            on_gfx9()
            and find_spec("flash_attn") is not None
            and (dtype == torch.float16 or dtype == torch.bfloat16)
        ):
            logger.info_once("Using Flash Attention backend for ViT model.")
593
594
            return AttentionBackendEnum.FLASH_ATTN

595
596
597
598
599
600
601
602
603
604
605
        # RDNA3/RDNA4 (gfx11xx/gfx12xx): Use Flash Attention Triton backend
        if (
            on_gfx1x()
            and flash_attn_triton_available()
            and (dtype == torch.float16 or dtype == torch.bfloat16)
        ):
            logger.info_once(
                "Using Flash Attention (Triton backend) for ViT model on RDNA."
            )
            return AttentionBackendEnum.FLASH_ATTN

606
        logger.info_once("Using Torch SDPA backend for ViT model.")
607
608
        return AttentionBackendEnum.TORCH_SDPA

609
610
611
612
613
614
615
    @classmethod
    def set_device(cls, device: torch.device) -> None:
        """
        Set the device for the current platform.
        """
        torch.cuda.set_device(device)

616
617
618
619
    @classmethod
    def manual_seed_all(cls, seed: int) -> None:
        torch.cuda.manual_seed_all(seed)

620
    @classmethod
621
    @lru_cache(maxsize=8)
622
    def get_device_capability(cls, device_id: int = 0) -> DeviceCapability | None:
623
624
625
626
627
628
629
630
631
        cap = _capability_from_gcn_arch(_GCN_ARCH)
        if cap is not None:
            return DeviceCapability(major=cap[0], minor=cap[1])

        logger.warning_once(
            "Could not derive device capability from GCN arch '%s', "
            "falling back to torch.cuda (this will initialize CUDA).",
            _GCN_ARCH,
        )
632
633
        major, minor = torch.cuda.get_device_capability(device_id)
        return DeviceCapability(major=major, minor=minor)
634

635
    @classmethod
636
    @with_amdsmi_context
637
    def is_fully_connected(cls, physical_device_ids: list[int]) -> bool:
638
639
640
        """
        Query if the set of gpus are fully connected by xgmi (1 hop)
        """
641
        handles = [amdsmi_get_processor_handles()[i] for i in physical_device_ids]
642
643
644
645
        for i, handle in enumerate(handles):
            for j, peer_handle in enumerate(handles):
                if i < j:
                    try:
646
                        link_type = amdsmi_topo_get_link_type(handle, peer_handle)
647
648
649
650
                        # type is 2 for XGMI
                        if link_type["hops"] != 1 or link_type["type"] != 2:
                            return False
                    except AmdSmiException as error:
651
                        logger.error("AMD 1 hop XGMI detection failed.", exc_info=error)
652
653
654
                        return False
        return True

655
    @classmethod
656
    @with_amdsmi_context
657
    @lru_cache(maxsize=8)
658
    def get_device_name(cls, device_id: int = 0) -> str:
659
        physical_device_id = cls.device_id_to_physical_device_id(device_id)
660
        handle = amdsmi_get_processor_handles()[physical_device_id]
661
        asic_info = amdsmi_get_gpu_asic_info(handle)
662
663
664
        asic_info_device_id: str = asic_info["device_id"]
        if asic_info_device_id in _ROCM_DEVICE_ID_NAME_MAP:
            return _ROCM_DEVICE_ID_NAME_MAP[asic_info_device_id]
665
        return asic_info["market_name"]
666

tmm77's avatar
tmm77 committed
667
668
669
670
671
672
673
674
675
676
677
678
679
680
    @classmethod
    @with_amdsmi_context
    def get_device_uuid(cls, device_id: int = 0) -> str:
        try:
            device = amdsmi_get_processor_handles()[device_id]
        except AmdSmiException as error:
            logger.error("amdsmi device query failed ", exc_info=error)
            return ""
        try:
            device_uuid = amdsmi_get_gpu_device_uuid(device)
        except AmdSmiException as error:
            logger.error("amdsmi device uuid query failed ", exc_info=error)
        return device_uuid

681
682
683
684
    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        device_props = torch.cuda.get_device_properties(device_id)
        return device_props.total_memory
685
686

    @classmethod
687
    def apply_config_platform_defaults(cls, vllm_config: "VllmConfig") -> None:
688
        from vllm._aiter_ops import rocm_aiter_ops
689
690
691
        from vllm.config.compilation import CUDAGraphMode

        compilation_config = vllm_config.compilation_config
692
        is_eager_execution = compilation_config.cudagraph_mode == CUDAGraphMode.NONE
693
        use_aiter_fused_moe = rocm_aiter_ops.is_fused_moe_enabled()
vllmellm's avatar
vllmellm committed
694
        use_aiter_rms_norm = rocm_aiter_ops.is_rmsnorm_enabled()
695
        use_aiter_fp8_linear = rocm_aiter_ops.is_linear_fp8_enabled()
696
        use_aiter_fused_se = rocm_aiter_ops.is_fusion_moe_shared_experts_enabled()
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
        #  Aiter rms norm perform best when CUDA Graph capture is enabled.
        if (
            use_aiter_rms_norm
            and not is_eager_execution
            and "-rms_norm" not in compilation_config.custom_ops
        ):
            compilation_config.custom_ops.append("+rms_norm")

        if use_aiter_fp8_linear and "-quant_fp8" not in compilation_config.custom_ops:
            compilation_config.custom_ops.append("+quant_fp8")

        if use_aiter_fused_se and "-grouped_topk" in compilation_config.custom_ops:
            logger.warning_once(
                "VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS is enabled, which "
                "requires the 'grouped_topk' custom op. Overriding the "
                "user-provided '-grouped_topk'."
            )
            compilation_config.custom_ops.remove("-grouped_topk")
        # Ensure grouped_topk is always enabled when using AITER if
        # its not disabled by user
        if (
            use_aiter_fused_moe
            and "+grouped_topk" not in compilation_config.custom_ops
            and "-grouped_topk" not in compilation_config.custom_ops
        ):
            compilation_config.custom_ops.append("+grouped_topk")

        # Default dispatch to rocm's sparse_attn_indexer implementation
        compilation_config.custom_ops.append("+sparse_attn_indexer")

    @classmethod
    def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
        from vllm.config.compilation import CUDAGraphMode

        compilation_config = vllm_config.compilation_config
        parallel_config = vllm_config.parallel_config
733

734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
        if compilation_config.cudagraph_mode.has_full_cudagraphs():
            # decode context parallel does not support full cudagraphs
            if parallel_config.decode_context_parallel_size > 1:
                logger.warning_once(
                    "Decode context parallel (DCP) is enabled, which is "
                    "incompatible with full CUDA graphs. "
                    "Overriding cudagraph_mode to PIECEWISE."
                )
                compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
            # prefill context parallel do not support full cudagraphs
            elif parallel_config.prefill_context_parallel_size > 1:
                logger.warning_once(
                    "Prefill context parallel (PCP) is enabled, which is "
                    "incompatible with full CUDA graphs. "
                    "Overriding cudagraph_mode to PIECEWISE."
                )
                compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE

752
        if parallel_config.worker_cls == "auto":
753
            parallel_config.worker_cls = "vllm.v1.worker.gpu_worker.Worker"
754

755
756
757
    @classmethod
    def verify_model_arch(cls, model_arch: str) -> None:
        if model_arch in _ROCM_UNSUPPORTED_MODELS:
758
759
760
            raise ValueError(
                f"Model architecture '{model_arch}' is not supported by ROCm for now."
            )
761
762
763
764

        if model_arch in _ROCM_PARTIALLY_SUPPORTED_MODELS:
            msg = _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch]
            logger.warning(
765
766
767
768
                "Model architecture '%s' is partially supported by ROCm: %s",
                model_arch,
                msg,
            )
769

770
771
772
773
774
775
    @classmethod
    def verify_quantization(cls, quant: str) -> None:
        super().verify_quantization(quant)
        if quant == "awq" and not envs.VLLM_USE_TRITON_AWQ:
            logger.warning(
                "Using AWQ quantization with ROCm, but VLLM_USE_TRITON_AWQ"
776
777
                " is not set, enabling VLLM_USE_TRITON_AWQ."
            )
778
        os.environ["VLLM_USE_TRITON_AWQ"] = "1"
779
780
781
782

    @classmethod
    def get_punica_wrapper(cls) -> str:
        return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU"
783
784

    @classmethod
785
    def get_current_memory_usage(
786
        cls, device: torch.types.Device | None = None
787
    ) -> float:
788
        torch.cuda.reset_peak_memory_stats(device)
789
790
        free_mem, total_mem = torch.cuda.mem_get_info(device)
        return total_mem - free_mem
791
792
793

    @classmethod
    def get_device_communicator_cls(cls) -> str:
794
795
796
        return (
            "vllm.distributed.device_communicators.cuda_communicator.CudaCommunicator"  # noqa
        )
797

798
799
    @classmethod
    def supports_mx(cls) -> bool:
800
        return any(gfx in _GCN_ARCH for gfx in ["gfx95"])
801

802
803
    @classmethod
    def supports_fp8(cls) -> bool:
804
        return on_gfx9() or on_gfx12x()
805
806
807
808

    @classmethod
    def is_fp8_fnuz(cls) -> bool:
        # only device 0 is checked, this assumes MI300 platforms are homogeneous
809
        return "gfx94" in _GCN_ARCH
810
811
812
813
814
815
816

    @classmethod
    def fp8_dtype(cls) -> torch.dtype:
        if cls.is_fp8_fnuz():
            return torch.float8_e4m3fnuz
        else:
            return torch.float8_e4m3fn
817

818
819
820
    @classmethod
    def use_custom_allreduce(cls) -> bool:
        # We only enable custom allreduce for MI300 series
821
        return any(gfx in _GCN_ARCH for gfx in ["gfx94", "gfx95"])
822

823
824
825
826
    @classmethod
    def opaque_attention_op(cls) -> bool:
        return True

827
828
    @classmethod
    def is_navi(cls) -> bool:
829
        return "gfx1" in _GCN_ARCH
830
831

    @classmethod
832
833
    def get_static_graph_wrapper_cls(cls) -> str:
        return "vllm.compilation.cuda_graph.CUDAGraphWrapper"
834

835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
    @classmethod
    def stateless_init_device_torch_dist_pg(
        cls,
        backend: str,
        prefix_store: PrefixStore,
        group_rank: int,
        group_size: int,
        timeout: timedelta,
    ) -> ProcessGroup:
        assert is_nccl_available()
        pg: ProcessGroup = ProcessGroup(
            prefix_store,
            group_rank,
            group_size,
        )
        from torch.distributed.distributed_c10d import ProcessGroupNCCL

        backend_options = ProcessGroupNCCL.Options()
        backend_options._timeout = timeout

        backend_class = ProcessGroupNCCL(
            prefix_store, group_rank, group_size, backend_options
        )
        backend_type = ProcessGroup.BackendType.NCCL
        device = torch.device("cuda")
        pg._set_default_backend(backend_type)
        backend_class._set_sequence_number_for_group()

        pg._register_backend(device, backend_type, backend_class)
        return pg

866
867
    @classmethod
    def device_count(cls) -> int:
868
        return _rocm_device_count_stateless(getattr(envs, cls.device_control_env_var))
869

870
    @classmethod
871
872
    def check_if_supports_dtype(cls, dtype: torch.dtype):
        if dtype == torch.bfloat16:  # noqa: SIM102
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
            if not cls.has_device_capability(80):
                capability = cls.get_device_capability()
                gpu_name = cls.get_device_name()

                if capability is None:
                    compute_str = "does not have a compute capability"
                else:
                    version_str = capability.as_version_str()
                    compute_str = f"has compute capability {version_str}"

                raise ValueError(
                    "Bfloat16 is only supported on GPUs "
                    "with compute capability of at least 8.0. "
                    f"Your {gpu_name} GPU {compute_str}. "
                    "You can use float16 instead by explicitly setting the "
888
889
                    "`dtype` flag in CLI, for example: --dtype=half."
                )
890

891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
    @classmethod
    def insert_blocks_to_device(
        cls,
        src_cache: torch.Tensor,
        dst_cache: torch.Tensor,
        src_block_indices: torch.Tensor,
        dst_block_indices: torch.Tensor,
    ) -> None:
        """Copy blocks from src_cache to dst_cache on GPU."""
        _src_cache = src_cache[:, src_block_indices]
        dst_cache[:, dst_block_indices] = _src_cache.to(dst_cache.device)

    @classmethod
    def swap_out_blocks_to_host(
        cls,
        src_cache: torch.Tensor,
        dst_cache: torch.Tensor,
        src_block_indices: torch.Tensor,
        dst_block_indices: torch.Tensor,
    ) -> None:
        """Copy blocks from GPU to host (CPU)."""
        _src_cache = src_cache[:, src_block_indices]
        dst_cache[:, dst_block_indices] = _src_cache.cpu()

915
916
917
    @classmethod
    def support_hybrid_kv_cache(cls) -> bool:
        return True
918
919
920
921

    @classmethod
    def support_static_graph_mode(cls) -> bool:
        return True
922
923

    @classmethod
924
    def num_compute_units(cls, device_id: int = 0) -> int:
925
        return torch.cuda.get_device_properties(device_id).multi_processor_count
926
927
928
929

    @classmethod
    def use_custom_op_collectives(cls) -> bool:
        return True
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958

    @classmethod
    def get_default_ir_op_priority(
        cls, vllm_config: "VllmConfig"
    ) -> "IrOpPriorityConfig":
        from vllm.config.compilation import CompilationMode
        from vllm.config.kernel import IrOpPriorityConfig

        # Native used by default when compiling,
        # use vllm_c kernels where available when no codegen
        # TODO(luka/TJ) use aiter, vllm_c, native by default on ROCm
        cc = vllm_config.compilation_config
        using_inductor = cc.backend == "inductor" and cc.mode != CompilationMode.NONE
        default = ["native"] if using_inductor else ["vllm_c", "native"]

        # This (mostly) preserves previous CustomOp behavior
        # Necessary on ROCm because it's common that users
        # enable rms_norm to use the aiter kernel.
        # TODO(luka/TJ) remove env vars completely
        if (
            cc.is_custom_op_enabled("rms_norm")
            and envs.VLLM_ROCM_USE_AITER
            and envs.VLLM_ROCM_USE_AITER_RMSNORM
        ):
            rms_norm = ["aiter"] + default
        else:
            rms_norm = default

        return IrOpPriorityConfig.with_default(default, rms_norm=rms_norm)
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985

    @classmethod
    @with_amdsmi_context
    def get_all_device_numa_nodes(cls) -> list[int] | None:
        """Get NUMA nodes for all visible GPU devices."""
        try:
            handles = amdsmi_get_processor_handles()
            numa_nodes = []
            for device_id in range(cls.device_count()):
                physical_device_id = cls.device_id_to_physical_device_id(device_id)
                try:
                    numa_node = amdsmi_topo_get_numa_node_number(
                        handles[physical_device_id]
                    )
                except AmdSmiException as e:
                    logger.warning(
                        "Could not detect NUMA node for GPU %d, "
                        "disabling automatic NUMA binding: %s",
                        device_id,
                        e,
                    )
                    return None
                numa_nodes.append(numa_node)
            return numa_nodes
        except Exception as e:
            logger.warning("Failed to get NUMA nodes for GPUs: %s", e)
            return None