rocm.py 18.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import os
5
from functools import cache, lru_cache, wraps
6
from typing import TYPE_CHECKING
7
8
9

import torch

10
import vllm.envs as envs
11
from vllm.logger import init_logger
12
from vllm.utils.torch_utils import cuda_device_count_stateless
13

14
from .interface import DeviceCapability, Platform, PlatformEnum
15

16
if TYPE_CHECKING:
17
    from vllm.attention.backends.registry import AttentionBackendEnum
18
    from vllm.config import VllmConfig
19
else:
20
    AttentionBackendEnum = None
21

22
23
logger = init_logger(__name__)

24
try:
25
26
27
28
29
30
31
32
    from amdsmi import (
        AmdSmiException,
        amdsmi_get_gpu_asic_info,
        amdsmi_get_processor_handles,
        amdsmi_init,
        amdsmi_shut_down,
        amdsmi_topo_get_link_type,
    )
33
34
35
except ImportError as e:
    logger.warning("Failed to import from amdsmi with %r", e)

36
37
38
39
40
41
42
43
44
45
46
try:
    import vllm._C  # noqa: F401
except ImportError as e:
    logger.warning("Failed to import from vllm._C with %r", e)

# import custom ops, trigger op registration
try:
    import vllm._rocm_C  # noqa: F401
except ImportError as e:
    logger.warning("Failed to import from vllm._rocm_C with %r", e)

47
# Models not supported by ROCm.
48
_ROCM_UNSUPPORTED_MODELS: list[str] = []
49
50
51

# Models partially supported by ROCm.
# Architecture -> Reason.
52
53
_ROCM_SWA_REASON = ()
_ROCM_PARTIALLY_SUPPORTED_MODELS: dict[str, str] = {}
54
_ROCM_DEVICE_ID_NAME_MAP: dict[str, str] = {
55
56
57
    "0x74a0": "AMD_Instinct_MI300A",
    "0x74a1": "AMD_Instinct_MI300X",
    "0x74b5": "AMD_Instinct_MI300X",  # MI300X VF
58
    "0x74a2": "AMD_Instinct_MI308X",
59
60
61
62
    "0x74a5": "AMD_Instinct_MI325X",
    "0x74b9": "AMD_Instinct_MI325X",  # MI325X VF
    "0x74a9": "AMD_Instinct_MI300X_HF",
    "0x74bd": "AMD_Instinct_MI300X_HF",
63
    "0x744c": "AMD_Radeon_RX7900XTX",
64
}
65

66
# Prevent use of clashing `{CUDA/HIP}_VISIBLE_DEVICES`
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
if "HIP_VISIBLE_DEVICES" in os.environ:
    val = os.environ["HIP_VISIBLE_DEVICES"]
    if cuda_val := os.environ.get("CUDA_VISIBLE_DEVICES", None):
        assert val == cuda_val
    else:
        os.environ["CUDA_VISIBLE_DEVICES"] = val

# AMDSMI utils
# Note that NVML is not affected by `{CUDA/HIP}_VISIBLE_DEVICES`,
# all the related functions work on real physical device ids.
# the major benefit of using AMDSMI is that it will not initialize CUDA


def with_amdsmi_context(fn):
    @wraps(fn)
    def wrapper(*args, **kwargs):
        amdsmi_init()
        try:
            return fn(*args, **kwargs)
        finally:
            amdsmi_shut_down()

    return wrapper


92
93
94
95
96
97
@cache
def on_gfx1x() -> bool:
    GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
    return any(arch in GPU_ARCH for arch in ["gfx11", "gfx12"])


98
@cache
99
def on_mi3xx() -> bool:
100
    GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
101
102
103
104
105
106
107
    return any(arch in GPU_ARCH for arch in ["gfx942", "gfx950"])


@cache
def on_gfx9() -> bool:
    GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
    return any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"])
108
109


110
111
112
113
114
115
@cache
def on_gfx950() -> bool:
    GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
    return any(arch in GPU_ARCH for arch in ["gfx950"])


116
@cache
117
def use_rocm_custom_paged_attention(
118
119
120
121
122
123
124
    qtype: torch.dtype,
    head_size: int,
    block_size: int,
    gqa_ratio: int,
    max_seq_len: int,
    sliding_window: int,
    kv_cache_dtype: str,
125
126
    alibi_slopes: torch.Tensor | None = None,
    sinks: torch.Tensor | None = None,
127
) -> bool:
128
129
    from vllm._aiter_ops import rocm_aiter_ops

130
131
    GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
    ON_GFX9 = any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"])
132
    ON_GFX11_GFX12 = any(arch in GPU_ARCH for arch in ["gfx11", "gfx12"])
133

134
135
    # custom paged attn always supported on V0. On V1, requires sliding window
    # disabled due to observed numerical discrepancy.
136
    if ON_GFX9:
137
        return (
138
            (sliding_window == 0 or sliding_window == (-1, -1))
139
140
141
142
143
144
            and (qtype == torch.half or qtype == torch.bfloat16)
            and (head_size == 64 or head_size == 128)
            and (block_size == 16 or block_size == 32)
            and (gqa_ratio >= 1 and gqa_ratio <= 16)
            and max_seq_len <= 128 * 1024
            and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN)
145
            and not (rocm_aiter_ops.is_pa_attn_enabled())
146
147
            and sinks is None
        )
148
149

    else:
150
151
        return (
            ON_GFX11_GFX12
152
            and (sliding_window == 0 or sliding_window == (-1, -1))
153
154
155
156
157
158
159
160
161
162
            and (qtype == torch.half or qtype == torch.bfloat16)
            and head_size == 128
            and block_size == 16
            and (gqa_ratio >= 3 and gqa_ratio <= 16)
            and max_seq_len <= 128 * 1024
            and alibi_slopes is None
            and kv_cache_dtype == "auto"
            and envs.VLLM_ROCM_CUSTOM_PAGED_ATTN
            and sinks is None
        )
163
164


165
166
class RocmPlatform(Platform):
    _enum = PlatformEnum.ROCM
167
    device_name: str = "rocm"
168
    device_type: str = "cuda"
169
    dispatch_key: str = "CUDA"
170
    ray_device_key: str = "GPU"
171
    dist_backend: str = "nccl"
172
173
    # rocm shares the same device control env var as CUDA
    device_control_env_var: str = "CUDA_VISIBLE_DEVICES"
174

175
    supported_quantization: list[str] = [
176
177
178
179
180
181
182
183
184
185
186
        "awq",
        "gptq",
        "fp8",
        "compressed-tensors",
        "fbgemm_fp8",
        "gguf",
        "quark",
        "ptpc_fp8",
        "mxfp4",
        "petit_nvfp4",
        "torchao",
187
    ]
188
189
190
    # bitsandbytes not supported on gfx9 (warp size 64 limitation)
    if not on_gfx9():
        supported_quantization += ["bitsandbytes"]
191

192
    @classmethod
193
194
195
    def get_vit_attn_backend(
        cls, head_size: int, dtype: torch.dtype
    ) -> AttentionBackendEnum:
196
197
        from importlib.util import find_spec

198
        from vllm._aiter_ops import rocm_aiter_ops
199
        from vllm.attention.backends.registry import AttentionBackendEnum
200

201
202
203
        if rocm_aiter_ops.is_mha_enabled():
            # Note: AITER FA is only supported for Qwen-VL models.
            # TODO: Add support for other VL models in their model class.
204
            return AttentionBackendEnum.ROCM_AITER_FA
205
206

        if on_gfx9() and find_spec("flash_attn") is not None:
207
            return AttentionBackendEnum.FLASH_ATTN
208

209
        return AttentionBackendEnum.TORCH_SDPA
210

211
    @classmethod
212
213
214
215
216
217
218
219
220
221
    def get_attn_backend_cls(
        cls,
        selected_backend,
        head_size,
        dtype,
        kv_cache_dtype,
        block_size,
        use_mla,
        has_sink,
        use_sparse,
222
        attn_type: str | None = None,
223
    ) -> str:
224
        from vllm._aiter_ops import rocm_aiter_ops
225
        from vllm.attention.backends.registry import AttentionBackendEnum
226

227
        if use_sparse:
228
229
230
231
232
233
234
235
            if kv_cache_dtype.startswith("fp8"):
                raise ValueError(
                    "ROCMAiterMLASparseBackend doesn't support fp8 kv_cache_dtype."
                )
            assert block_size == 1, (
                "Sparse MLA backend on ROCm only supports block size 1 for now."
            )
            logger.info_once("Using Sparse MLA backend on V1 engine.")
236
            return AttentionBackendEnum.ROCM_AITER_MLA_SPARSE.get_path()
237
238

        if use_mla:
239
            if selected_backend is None:
240
                selected_backend = (
241
                    AttentionBackendEnum.ROCM_AITER_MLA
242
                    if rocm_aiter_ops.is_mla_enabled() or block_size == 1
243
                    else AttentionBackendEnum.TRITON_MLA
244
                )
245
            if selected_backend == AttentionBackendEnum.TRITON_MLA:
246
                if block_size != 1:
247
                    logger.info_once("Using Triton MLA backend.")
248
                    return AttentionBackendEnum.TRITON_MLA.get_path()
249
250
                raise ValueError(
                    f" The selected backend, {selected_backend.name},"
251
252
                    f"does not support block size {block_size}."
                )
253
            if selected_backend == AttentionBackendEnum.ROCM_AITER_MLA:
254
                logger.info("Using AITER MLA backend.")
255
                return AttentionBackendEnum.ROCM_AITER_MLA.get_path()
256
257
258
            if selected_backend == AttentionBackendEnum.ROCM_AITER_TRITON_MLA:
                logger.info("Using AITER TRITON MLA backend.")
                return AttentionBackendEnum.ROCM_AITER_TRITON_MLA.get_path()
259

260
261
            raise ValueError(
                f" The selected backend, {selected_backend.name},"
262
263
                f"is not MLA type while requested for MLA backend."
            )
264

265
266
267
268
        if selected_backend == AttentionBackendEnum.FLEX_ATTENTION:
            logger.info("Using FlexAttention backend.")
            return AttentionBackendEnum.FLEX_ATTENTION.get_path()

269
270
271
272
273
274
        if selected_backend == AttentionBackendEnum.TRITON_ATTN:
            logger.info("Using Triton Attention backend on V1 engine.")
            return AttentionBackendEnum.TRITON_ATTN.get_path()

        if selected_backend == AttentionBackendEnum.ROCM_ATTN:
            logger.info("Using Rocm Attention backend on V1 engine.")
275
            return AttentionBackendEnum.ROCM_ATTN.get_path()
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326

        if selected_backend == AttentionBackendEnum.ROCM_AITER_FA:
            if on_gfx9():
                logger.info("Using Aiter Flash Attention backend on V1 engine.")
                return AttentionBackendEnum.ROCM_AITER_FA.get_path()
            else:
                raise ValueError(
                    f"The selected backend, {selected_backend.name}, "
                    "is only supported on gfx9 architectures."
                )

        if selected_backend == AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN:
            logger.info("Using Aiter Unified Attention backend on V1 engine.")
            return AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN.get_path()

        # Handle automatic backend selection based on environment variables
        if selected_backend is None:
            # Priority 1: Check for AITER Unified Attention (must check before MHA)
            if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION:
                logger.info("Using Aiter Unified Attention backend on V1 engine.")
                return AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN.get_path()

            # Priority 2: Check for AITER MHA (Flash Attention)
            # Only use if explicitly enabled (not just VLLM_ROCM_USE_AITER=1)
            if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9():
                logger.info("Using Aiter Flash Attention backend on V1 engine.")
                return AttentionBackendEnum.ROCM_AITER_FA.get_path()

            # Priority 3: Check for ROCM_ATTN (prefill-decode split)
            if envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION:
                logger.info("Using Rocm Attention backend on V1 engine.")
                return AttentionBackendEnum.ROCM_ATTN.get_path()

            # Priority 4: Check for AITER enabled without specific flags
            # This defaults to AITER FA only if MHA is not explicitly disabled
            if (
                envs.VLLM_ROCM_USE_AITER
                and on_gfx9()
                and envs.VLLM_ROCM_USE_AITER_MHA is not False
            ):
                logger.info("Using Aiter Flash Attention backend on V1 engine.")
                return AttentionBackendEnum.ROCM_AITER_FA.get_path()

            # Default: Triton Unified Attention
            logger.info("Using Triton Attention backend on V1 engine.")
            return AttentionBackendEnum.TRITON_ATTN.get_path()

        raise RuntimeError(
            "V0 attention backends have been removed. Set VLLM_USE_V1=1 "
            "to select a supported backend."
        )
327

328
329
330
331
332
333
334
    @classmethod
    def set_device(cls, device: torch.device) -> None:
        """
        Set the device for the current platform.
        """
        torch.cuda.set_device(device)

335
    @classmethod
336
    @lru_cache(maxsize=8)
337
    def get_device_capability(cls, device_id: int = 0) -> DeviceCapability | None:
338
339
        major, minor = torch.cuda.get_device_capability(device_id)
        return DeviceCapability(major=major, minor=minor)
340

341
    @classmethod
342
    @with_amdsmi_context
343
    def is_fully_connected(cls, physical_device_ids: list[int]) -> bool:
344
345
346
        """
        Query if the set of gpus are fully connected by xgmi (1 hop)
        """
347
        handles = [amdsmi_get_processor_handles()[i] for i in physical_device_ids]
348
349
350
351
        for i, handle in enumerate(handles):
            for j, peer_handle in enumerate(handles):
                if i < j:
                    try:
352
                        link_type = amdsmi_topo_get_link_type(handle, peer_handle)
353
354
355
356
                        # type is 2 for XGMI
                        if link_type["hops"] != 1 or link_type["type"] != 2:
                            return False
                    except AmdSmiException as error:
357
                        logger.error("AMD 1 hop XGMI detection failed.", exc_info=error)
358
359
360
                        return False
        return True

361
    @classmethod
362
    @with_amdsmi_context
363
    @lru_cache(maxsize=8)
364
    def get_device_name(cls, device_id: int = 0) -> str:
365
        physical_device_id = cls.device_id_to_physical_device_id(device_id)
366
        handle = amdsmi_get_processor_handles()[physical_device_id]
367
368
369
370
371
        asic_info = amdsmi_get_gpu_asic_info(handle)
        device_name: str = asic_info["device_id"]
        if device_name in _ROCM_DEVICE_ID_NAME_MAP:
            return _ROCM_DEVICE_ID_NAME_MAP[device_name]
        return asic_info["market_name"]
372
373
374
375
376

    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        device_props = torch.cuda.get_device_properties(device_id)
        return device_props.total_memory
377
378

    @classmethod
379
    def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
380
        from vllm._aiter_ops import rocm_aiter_ops
381
382
        from vllm.config.compilation import CUDAGraphMode

383
        cache_config = vllm_config.cache_config
384
385
386
387
        compilation_config = vllm_config.compilation_config
        parallel_config = vllm_config.parallel_config
        is_eager_execution = compilation_config == CUDAGraphMode.NONE

388
        use_aiter_rms_norm = rocm_aiter_ops.is_rmsnorm_enabled()
389

390
391
392
        if cache_config and cache_config.block_size is None:
            cache_config.block_size = 16

393
        if parallel_config.worker_cls == "auto":
394
            parallel_config.worker_cls = "vllm.v1.worker.gpu_worker.Worker"
395
        #  Aiter rms norm perform best when CUDA Graph capture is enabled.
396
        if (
397
            use_aiter_rms_norm
398
399
400
            and not is_eager_execution
            and "-rms_norm" not in compilation_config.custom_ops
        ):
401
            compilation_config.custom_ops.append("+rms_norm")
402

403
404
405
    @classmethod
    def verify_model_arch(cls, model_arch: str) -> None:
        if model_arch in _ROCM_UNSUPPORTED_MODELS:
406
407
408
            raise ValueError(
                f"Model architecture '{model_arch}' is not supported by ROCm for now."
            )
409
410
411
412

        if model_arch in _ROCM_PARTIALLY_SUPPORTED_MODELS:
            msg = _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch]
            logger.warning(
413
414
415
416
                "Model architecture '%s' is partially supported by ROCm: %s",
                model_arch,
                msg,
            )
417

418
419
420
421
422
423
    @classmethod
    def verify_quantization(cls, quant: str) -> None:
        super().verify_quantization(quant)
        if quant == "awq" and not envs.VLLM_USE_TRITON_AWQ:
            logger.warning(
                "Using AWQ quantization with ROCm, but VLLM_USE_TRITON_AWQ"
424
425
                " is not set, enabling VLLM_USE_TRITON_AWQ."
            )
426
        os.environ["VLLM_USE_TRITON_AWQ"] = "1"
427
428
429
430

    @classmethod
    def get_punica_wrapper(cls) -> str:
        return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU"
431
432

    @classmethod
433
    def get_current_memory_usage(
434
        cls, device: torch.types.Device | None = None
435
    ) -> float:
436
        torch.cuda.reset_peak_memory_stats(device)
437
        return torch.cuda.mem_get_info(device)[1] - torch.cuda.mem_get_info(device)[0]
438
439
440

    @classmethod
    def get_device_communicator_cls(cls) -> str:
441
442
443
        return (
            "vllm.distributed.device_communicators.cuda_communicator.CudaCommunicator"  # noqa
        )
444

445
446
447
448
449
    @classmethod
    def supports_mx(cls) -> bool:
        gcn_arch = torch.cuda.get_device_properties(0).gcnArchName
        return any(gfx in gcn_arch for gfx in ["gfx95"])

450
451
452
    @classmethod
    def supports_fp8(cls) -> bool:
        gcn_arch = torch.cuda.get_device_properties(0).gcnArchName
453
        return any(gfx in gcn_arch for gfx in ["gfx94", "gfx95", "gfx12"])
454
455
456
457

    @classmethod
    def is_fp8_fnuz(cls) -> bool:
        # only device 0 is checked, this assumes MI300 platforms are homogeneous
458
        return "gfx94" in torch.cuda.get_device_properties(0).gcnArchName
459
460
461
462
463
464
465

    @classmethod
    def fp8_dtype(cls) -> torch.dtype:
        if cls.is_fp8_fnuz():
            return torch.float8_e4m3fnuz
        else:
            return torch.float8_e4m3fn
466

467
468
469
470
    @classmethod
    def use_custom_allreduce(cls) -> bool:
        # We only enable custom allreduce for MI300 series
        gcn_arch = torch.cuda.get_device_properties(0).gcnArchName
471
        supported_archs = ["gfx94", "gfx95"]
472
        return any(gfx in gcn_arch for gfx in supported_archs)
473

474
475
476
477
    @classmethod
    def opaque_attention_op(cls) -> bool:
        return True

478
479
    @classmethod
    def is_navi(cls) -> bool:
480
        return "gfx1" in torch.cuda.get_device_properties(0).gcnArchName
481
482

    @classmethod
483
484
    def get_static_graph_wrapper_cls(cls) -> str:
        return "vllm.compilation.cuda_graph.CUDAGraphWrapper"
485

486
487
488
    @classmethod
    def device_count(cls) -> int:
        return cuda_device_count_stateless()
489

490
    @classmethod
491
492
    def check_if_supports_dtype(cls, dtype: torch.dtype):
        if dtype == torch.bfloat16:  # noqa: SIM102
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
            if not cls.has_device_capability(80):
                capability = cls.get_device_capability()
                gpu_name = cls.get_device_name()

                if capability is None:
                    compute_str = "does not have a compute capability"
                else:
                    version_str = capability.as_version_str()
                    compute_str = f"has compute capability {version_str}"

                raise ValueError(
                    "Bfloat16 is only supported on GPUs "
                    "with compute capability of at least 8.0. "
                    f"Your {gpu_name} GPU {compute_str}. "
                    "You can use float16 instead by explicitly setting the "
508
509
                    "`dtype` flag in CLI, for example: --dtype=half."
                )
510
511
512
513

    @classmethod
    def support_hybrid_kv_cache(cls) -> bool:
        return True
514
515
516
517

    @classmethod
    def support_static_graph_mode(cls) -> bool:
        return True