rocm.py 18.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import os
5
from functools import cache, lru_cache, wraps
6
from typing import TYPE_CHECKING
7
8
9

import torch

10
import vllm.envs as envs
11
from vllm.logger import init_logger
12
from vllm.utils.torch_utils import cuda_device_count_stateless
13

14
from .interface import DeviceCapability, Platform, PlatformEnum
15

16
if TYPE_CHECKING:
17
    from vllm.attention.backends.registry import AttentionBackendEnum
18
    from vllm.config import VllmConfig
19
else:
20
    AttentionBackendEnum = None
21

22
23
logger = init_logger(__name__)

24
try:
25
26
27
28
29
30
31
32
    from amdsmi import (
        AmdSmiException,
        amdsmi_get_gpu_asic_info,
        amdsmi_get_processor_handles,
        amdsmi_init,
        amdsmi_shut_down,
        amdsmi_topo_get_link_type,
    )
33
34
35
except ImportError as e:
    logger.warning("Failed to import from amdsmi with %r", e)

36
37
38
39
40
41
42
43
44
45
46
try:
    import vllm._C  # noqa: F401
except ImportError as e:
    logger.warning("Failed to import from vllm._C with %r", e)

# import custom ops, trigger op registration
try:
    import vllm._rocm_C  # noqa: F401
except ImportError as e:
    logger.warning("Failed to import from vllm._rocm_C with %r", e)

47
# Models not supported by ROCm.
48
_ROCM_UNSUPPORTED_MODELS: list[str] = []
49
50
51

# Models partially supported by ROCm.
# Architecture -> Reason.
52
53
_ROCM_SWA_REASON = ()
_ROCM_PARTIALLY_SUPPORTED_MODELS: dict[str, str] = {}
54
_ROCM_DEVICE_ID_NAME_MAP: dict[str, str] = {
55
56
57
    "0x74a0": "AMD_Instinct_MI300A",
    "0x74a1": "AMD_Instinct_MI300X",
    "0x74b5": "AMD_Instinct_MI300X",  # MI300X VF
58
    "0x74a2": "AMD_Instinct_MI308X",
59
60
61
62
    "0x74a5": "AMD_Instinct_MI325X",
    "0x74b9": "AMD_Instinct_MI325X",  # MI325X VF
    "0x74a9": "AMD_Instinct_MI300X_HF",
    "0x74bd": "AMD_Instinct_MI300X_HF",
63
    "0x744c": "AMD_Radeon_RX7900XTX",
64
}
65

66
# Prevent use of clashing `{CUDA/HIP}_VISIBLE_DEVICES`
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
if "HIP_VISIBLE_DEVICES" in os.environ:
    val = os.environ["HIP_VISIBLE_DEVICES"]
    if cuda_val := os.environ.get("CUDA_VISIBLE_DEVICES", None):
        assert val == cuda_val
    else:
        os.environ["CUDA_VISIBLE_DEVICES"] = val

# AMDSMI utils
# Note that NVML is not affected by `{CUDA/HIP}_VISIBLE_DEVICES`,
# all the related functions work on real physical device ids.
# the major benefit of using AMDSMI is that it will not initialize CUDA


def with_amdsmi_context(fn):
    @wraps(fn)
    def wrapper(*args, **kwargs):
        amdsmi_init()
        try:
            return fn(*args, **kwargs)
        finally:
            amdsmi_shut_down()

    return wrapper


92
93
94
95
96
97
@cache
def on_gfx1x() -> bool:
    GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
    return any(arch in GPU_ARCH for arch in ["gfx11", "gfx12"])


98
@cache
99
def on_mi3xx() -> bool:
100
    GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
101
102
103
104
105
106
107
    return any(arch in GPU_ARCH for arch in ["gfx942", "gfx950"])


@cache
def on_gfx9() -> bool:
    GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
    return any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"])
108
109


110
111
112
113
114
115
@cache
def on_gfx950() -> bool:
    GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
    return any(arch in GPU_ARCH for arch in ["gfx950"])


116
@cache
117
def use_rocm_custom_paged_attention(
118
119
120
121
122
123
124
    qtype: torch.dtype,
    head_size: int,
    block_size: int,
    gqa_ratio: int,
    max_seq_len: int,
    sliding_window: int,
    kv_cache_dtype: str,
125
126
    alibi_slopes: torch.Tensor | None = None,
    sinks: torch.Tensor | None = None,
127
) -> bool:
128
129
    from vllm._aiter_ops import rocm_aiter_ops

130
131
    GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
    ON_GFX9 = any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"])
132
    ON_GFX11_GFX12 = any(arch in GPU_ARCH for arch in ["gfx11", "gfx12"])
133

134
135
    # custom paged attn always supported on V0. On V1, requires sliding window
    # disabled due to observed numerical discrepancy.
136
    if ON_GFX9:
137
        return (
138
            (sliding_window == 0 or sliding_window == (-1, -1))
139
140
141
142
143
144
            and (qtype == torch.half or qtype == torch.bfloat16)
            and (head_size == 64 or head_size == 128)
            and (block_size == 16 or block_size == 32)
            and (gqa_ratio >= 1 and gqa_ratio <= 16)
            and max_seq_len <= 128 * 1024
            and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN)
145
            and not (rocm_aiter_ops.is_pa_attn_enabled())
146
147
            and sinks is None
        )
148
149

    else:
150
151
        return (
            ON_GFX11_GFX12
152
            and (sliding_window == 0 or sliding_window == (-1, -1))
153
154
155
156
157
158
159
160
161
162
            and (qtype == torch.half or qtype == torch.bfloat16)
            and head_size == 128
            and block_size == 16
            and (gqa_ratio >= 3 and gqa_ratio <= 16)
            and max_seq_len <= 128 * 1024
            and alibi_slopes is None
            and kv_cache_dtype == "auto"
            and envs.VLLM_ROCM_CUSTOM_PAGED_ATTN
            and sinks is None
        )
163
164


165
166
class RocmPlatform(Platform):
    _enum = PlatformEnum.ROCM
167
    device_name: str = "rocm"
168
    device_type: str = "cuda"
169
    dispatch_key: str = "CUDA"
170
    ray_device_key: str = "GPU"
171
    dist_backend: str = "nccl"
172
173
    # rocm shares the same device control env var as CUDA
    device_control_env_var: str = "CUDA_VISIBLE_DEVICES"
174

175
    supported_quantization: list[str] = [
176
177
178
179
180
181
182
183
184
185
186
        "awq",
        "gptq",
        "fp8",
        "compressed-tensors",
        "fbgemm_fp8",
        "gguf",
        "quark",
        "ptpc_fp8",
        "mxfp4",
        "petit_nvfp4",
        "torchao",
187
    ]
188
189
190
    # bitsandbytes not supported on gfx9 (warp size 64 limitation)
    if not on_gfx9():
        supported_quantization += ["bitsandbytes"]
191

192
    @classmethod
193
194
195
    def get_vit_attn_backend(
        cls, head_size: int, dtype: torch.dtype
    ) -> AttentionBackendEnum:
196
197
        from importlib.util import find_spec

198
        from vllm._aiter_ops import rocm_aiter_ops
199
        from vllm.attention.backends.registry import AttentionBackendEnum
200

201
202
203
        if rocm_aiter_ops.is_mha_enabled():
            # Note: AITER FA is only supported for Qwen-VL models.
            # TODO: Add support for other VL models in their model class.
204
            return AttentionBackendEnum.ROCM_AITER_FA
205
206

        if on_gfx9() and find_spec("flash_attn") is not None:
207
            return AttentionBackendEnum.FLASH_ATTN
208

209
        return AttentionBackendEnum.TORCH_SDPA
210

211
    @classmethod
212
213
214
215
216
217
218
219
220
221
    def get_attn_backend_cls(
        cls,
        selected_backend,
        head_size,
        dtype,
        kv_cache_dtype,
        block_size,
        use_mla,
        has_sink,
        use_sparse,
222
        attn_type: str | None = None,
223
    ) -> str:
224
        from vllm._aiter_ops import rocm_aiter_ops
225
        from vllm.attention.backends.registry import AttentionBackendEnum
226

227
        if use_sparse:
228
229
230
231
232
233
234
235
            if kv_cache_dtype.startswith("fp8"):
                raise ValueError(
                    "ROCMAiterMLASparseBackend doesn't support fp8 kv_cache_dtype."
                )
            assert block_size == 1, (
                "Sparse MLA backend on ROCm only supports block size 1 for now."
            )
            logger.info_once("Using Sparse MLA backend on V1 engine.")
236
            return AttentionBackendEnum.ROCM_AITER_MLA_SPARSE.get_path()
237
238

        if use_mla:
239
            if selected_backend is None:
240
                selected_backend = (
241
                    AttentionBackendEnum.ROCM_AITER_MLA
242
                    if rocm_aiter_ops.is_mla_enabled() or block_size == 1
243
                    else AttentionBackendEnum.TRITON_MLA
244
                )
245
            if selected_backend == AttentionBackendEnum.TRITON_MLA:
246
                if block_size != 1:
247
                    logger.info_once("Using Triton MLA backend.")
248
                    return AttentionBackendEnum.TRITON_MLA.get_path()
249
250
                raise ValueError(
                    f" The selected backend, {selected_backend.name},"
251
252
                    f"does not support block size {block_size}."
                )
253
            if selected_backend == AttentionBackendEnum.ROCM_AITER_MLA:
254
                logger.info("Using AITER MLA backend.")
255
                return AttentionBackendEnum.ROCM_AITER_MLA.get_path()
256
257
258
            if selected_backend == AttentionBackendEnum.ROCM_AITER_TRITON_MLA:
                logger.info("Using AITER TRITON MLA backend.")
                return AttentionBackendEnum.ROCM_AITER_TRITON_MLA.get_path()
259

260
261
            raise ValueError(
                f" The selected backend, {selected_backend.name},"
262
263
                f"is not MLA type while requested for MLA backend."
            )
264

265
266
267
268
269
270
        if selected_backend == AttentionBackendEnum.TRITON_ATTN:
            logger.info("Using Triton Attention backend on V1 engine.")
            return AttentionBackendEnum.TRITON_ATTN.get_path()

        if selected_backend == AttentionBackendEnum.ROCM_ATTN:
            logger.info("Using Rocm Attention backend on V1 engine.")
271
            return AttentionBackendEnum.ROCM_ATTN.get_path()
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322

        if selected_backend == AttentionBackendEnum.ROCM_AITER_FA:
            if on_gfx9():
                logger.info("Using Aiter Flash Attention backend on V1 engine.")
                return AttentionBackendEnum.ROCM_AITER_FA.get_path()
            else:
                raise ValueError(
                    f"The selected backend, {selected_backend.name}, "
                    "is only supported on gfx9 architectures."
                )

        if selected_backend == AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN:
            logger.info("Using Aiter Unified Attention backend on V1 engine.")
            return AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN.get_path()

        # Handle automatic backend selection based on environment variables
        if selected_backend is None:
            # Priority 1: Check for AITER Unified Attention (must check before MHA)
            if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION:
                logger.info("Using Aiter Unified Attention backend on V1 engine.")
                return AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN.get_path()

            # Priority 2: Check for AITER MHA (Flash Attention)
            # Only use if explicitly enabled (not just VLLM_ROCM_USE_AITER=1)
            if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9():
                logger.info("Using Aiter Flash Attention backend on V1 engine.")
                return AttentionBackendEnum.ROCM_AITER_FA.get_path()

            # Priority 3: Check for ROCM_ATTN (prefill-decode split)
            if envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION:
                logger.info("Using Rocm Attention backend on V1 engine.")
                return AttentionBackendEnum.ROCM_ATTN.get_path()

            # Priority 4: Check for AITER enabled without specific flags
            # This defaults to AITER FA only if MHA is not explicitly disabled
            if (
                envs.VLLM_ROCM_USE_AITER
                and on_gfx9()
                and envs.VLLM_ROCM_USE_AITER_MHA is not False
            ):
                logger.info("Using Aiter Flash Attention backend on V1 engine.")
                return AttentionBackendEnum.ROCM_AITER_FA.get_path()

            # Default: Triton Unified Attention
            logger.info("Using Triton Attention backend on V1 engine.")
            return AttentionBackendEnum.TRITON_ATTN.get_path()

        raise RuntimeError(
            "V0 attention backends have been removed. Set VLLM_USE_V1=1 "
            "to select a supported backend."
        )
323

324
325
326
327
328
329
330
    @classmethod
    def set_device(cls, device: torch.device) -> None:
        """
        Set the device for the current platform.
        """
        torch.cuda.set_device(device)

331
    @classmethod
332
    @lru_cache(maxsize=8)
333
    def get_device_capability(cls, device_id: int = 0) -> DeviceCapability | None:
334
335
        major, minor = torch.cuda.get_device_capability(device_id)
        return DeviceCapability(major=major, minor=minor)
336

337
    @classmethod
338
    @with_amdsmi_context
339
    def is_fully_connected(cls, physical_device_ids: list[int]) -> bool:
340
341
342
        """
        Query if the set of gpus are fully connected by xgmi (1 hop)
        """
343
        handles = [amdsmi_get_processor_handles()[i] for i in physical_device_ids]
344
345
346
347
        for i, handle in enumerate(handles):
            for j, peer_handle in enumerate(handles):
                if i < j:
                    try:
348
                        link_type = amdsmi_topo_get_link_type(handle, peer_handle)
349
350
351
352
                        # type is 2 for XGMI
                        if link_type["hops"] != 1 or link_type["type"] != 2:
                            return False
                    except AmdSmiException as error:
353
                        logger.error("AMD 1 hop XGMI detection failed.", exc_info=error)
354
355
356
                        return False
        return True

357
    @classmethod
358
    @with_amdsmi_context
359
    @lru_cache(maxsize=8)
360
    def get_device_name(cls, device_id: int = 0) -> str:
361
        physical_device_id = cls.device_id_to_physical_device_id(device_id)
362
        handle = amdsmi_get_processor_handles()[physical_device_id]
363
364
365
366
367
        asic_info = amdsmi_get_gpu_asic_info(handle)
        device_name: str = asic_info["device_id"]
        if device_name in _ROCM_DEVICE_ID_NAME_MAP:
            return _ROCM_DEVICE_ID_NAME_MAP[device_name]
        return asic_info["market_name"]
368
369
370
371
372

    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        device_props = torch.cuda.get_device_properties(device_id)
        return device_props.total_memory
373
374

    @classmethod
375
    def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
376
        from vllm._aiter_ops import rocm_aiter_ops
377
378
        from vllm.config.compilation import CUDAGraphMode

379
        cache_config = vllm_config.cache_config
380
381
382
383
        compilation_config = vllm_config.compilation_config
        parallel_config = vllm_config.parallel_config
        is_eager_execution = compilation_config == CUDAGraphMode.NONE

384
        use_aiter_rms_norm = rocm_aiter_ops.is_rmsnorm_enabled()
385

386
387
388
        if cache_config and cache_config.block_size is None:
            cache_config.block_size = 16

389
        if parallel_config.worker_cls == "auto":
390
            parallel_config.worker_cls = "vllm.v1.worker.gpu_worker.Worker"
391
        #  Aiter rms norm perform best when CUDA Graph capture is enabled.
392
        if (
393
            use_aiter_rms_norm
394
395
396
            and not is_eager_execution
            and "-rms_norm" not in compilation_config.custom_ops
        ):
397
            compilation_config.custom_ops.append("+rms_norm")
398

399
400
401
    @classmethod
    def verify_model_arch(cls, model_arch: str) -> None:
        if model_arch in _ROCM_UNSUPPORTED_MODELS:
402
403
404
            raise ValueError(
                f"Model architecture '{model_arch}' is not supported by ROCm for now."
            )
405
406
407
408

        if model_arch in _ROCM_PARTIALLY_SUPPORTED_MODELS:
            msg = _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch]
            logger.warning(
409
410
411
412
                "Model architecture '%s' is partially supported by ROCm: %s",
                model_arch,
                msg,
            )
413

414
415
416
417
418
419
    @classmethod
    def verify_quantization(cls, quant: str) -> None:
        super().verify_quantization(quant)
        if quant == "awq" and not envs.VLLM_USE_TRITON_AWQ:
            logger.warning(
                "Using AWQ quantization with ROCm, but VLLM_USE_TRITON_AWQ"
420
421
                " is not set, enabling VLLM_USE_TRITON_AWQ."
            )
422
        os.environ["VLLM_USE_TRITON_AWQ"] = "1"
423
424
425
426

    @classmethod
    def get_punica_wrapper(cls) -> str:
        return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU"
427
428

    @classmethod
429
    def get_current_memory_usage(
430
        cls, device: torch.types.Device | None = None
431
    ) -> float:
432
        torch.cuda.reset_peak_memory_stats(device)
433
        return torch.cuda.mem_get_info(device)[1] - torch.cuda.mem_get_info(device)[0]
434
435
436

    @classmethod
    def get_device_communicator_cls(cls) -> str:
437
438
439
        return (
            "vllm.distributed.device_communicators.cuda_communicator.CudaCommunicator"  # noqa
        )
440

441
442
443
444
445
    @classmethod
    def supports_mx(cls) -> bool:
        gcn_arch = torch.cuda.get_device_properties(0).gcnArchName
        return any(gfx in gcn_arch for gfx in ["gfx95"])

446
447
448
    @classmethod
    def supports_fp8(cls) -> bool:
        gcn_arch = torch.cuda.get_device_properties(0).gcnArchName
449
        return any(gfx in gcn_arch for gfx in ["gfx94", "gfx95", "gfx12"])
450
451
452
453

    @classmethod
    def is_fp8_fnuz(cls) -> bool:
        # only device 0 is checked, this assumes MI300 platforms are homogeneous
454
        return "gfx94" in torch.cuda.get_device_properties(0).gcnArchName
455
456
457
458
459
460
461

    @classmethod
    def fp8_dtype(cls) -> torch.dtype:
        if cls.is_fp8_fnuz():
            return torch.float8_e4m3fnuz
        else:
            return torch.float8_e4m3fn
462

463
464
465
466
    @classmethod
    def use_custom_allreduce(cls) -> bool:
        # We only enable custom allreduce for MI300 series
        gcn_arch = torch.cuda.get_device_properties(0).gcnArchName
467
        supported_archs = ["gfx94", "gfx95"]
468
        return any(gfx in gcn_arch for gfx in supported_archs)
469

470
471
472
473
    @classmethod
    def opaque_attention_op(cls) -> bool:
        return True

474
475
    @classmethod
    def is_navi(cls) -> bool:
476
        return "gfx1" in torch.cuda.get_device_properties(0).gcnArchName
477
478

    @classmethod
479
480
    def get_static_graph_wrapper_cls(cls) -> str:
        return "vllm.compilation.cuda_graph.CUDAGraphWrapper"
481

482
483
484
    @classmethod
    def device_count(cls) -> int:
        return cuda_device_count_stateless()
485

486
    @classmethod
487
488
    def check_if_supports_dtype(cls, dtype: torch.dtype):
        if dtype == torch.bfloat16:  # noqa: SIM102
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
            if not cls.has_device_capability(80):
                capability = cls.get_device_capability()
                gpu_name = cls.get_device_name()

                if capability is None:
                    compute_str = "does not have a compute capability"
                else:
                    version_str = capability.as_version_str()
                    compute_str = f"has compute capability {version_str}"

                raise ValueError(
                    "Bfloat16 is only supported on GPUs "
                    "with compute capability of at least 8.0. "
                    f"Your {gpu_name} GPU {compute_str}. "
                    "You can use float16 instead by explicitly setting the "
504
505
                    "`dtype` flag in CLI, for example: --dtype=half."
                )
506
507
508
509

    @classmethod
    def support_hybrid_kv_cache(cls) -> bool:
        return True
510
511
512
513

    @classmethod
    def support_static_graph_mode(cls) -> bool:
        return True