rocm.py 19.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import os
5
from functools import cache, lru_cache, wraps
6
from typing import TYPE_CHECKING
7
8
9

import torch

10
import vllm.envs as envs
11
from vllm.attention.backends.registry import AttentionBackendEnum
12
from vllm.logger import init_logger
13
from vllm.utils.torch_utils import cuda_device_count_stateless
14

15
from .interface import DeviceCapability, Platform, PlatformEnum
16

17
if TYPE_CHECKING:
18
    from vllm.config import VllmConfig
19

20
21
logger = init_logger(__name__)

22
try:
23
24
25
26
27
28
29
30
    from amdsmi import (
        AmdSmiException,
        amdsmi_get_gpu_asic_info,
        amdsmi_get_processor_handles,
        amdsmi_init,
        amdsmi_shut_down,
        amdsmi_topo_get_link_type,
    )
31
32
33
except ImportError as e:
    logger.warning("Failed to import from amdsmi with %r", e)

34
35
36
37
38
39
try:
    import vllm._C  # noqa: F401
except ImportError as e:
    logger.warning("Failed to import from vllm._C with %r", e)

# import custom ops, trigger op registration
zhuwenwen's avatar
zhuwenwen committed
40
41
42
43
# try:
#     import vllm._rocm_C  # noqa: F401
# except ImportError as e:
#     logger.warning("Failed to import from vllm._rocm_C with %r", e)
44

45
# Models not supported by ROCm.
46
_ROCM_UNSUPPORTED_MODELS: list[str] = []
47
48
49

# Models partially supported by ROCm.
# Architecture -> Reason.
50
51
_ROCM_SWA_REASON = ()
_ROCM_PARTIALLY_SUPPORTED_MODELS: dict[str, str] = {}
52
_ROCM_DEVICE_ID_NAME_MAP: dict[str, str] = {
53
54
55
    "0x74a0": "AMD_Instinct_MI300A",
    "0x74a1": "AMD_Instinct_MI300X",
    "0x74b5": "AMD_Instinct_MI300X",  # MI300X VF
56
    "0x74a2": "AMD_Instinct_MI308X",
57
58
59
60
    "0x74a5": "AMD_Instinct_MI325X",
    "0x74b9": "AMD_Instinct_MI325X",  # MI325X VF
    "0x74a9": "AMD_Instinct_MI300X_HF",
    "0x74bd": "AMD_Instinct_MI300X_HF",
61
    "0x744c": "AMD_Radeon_RX7900XTX",
62
}
63

64
# Prevent use of clashing `{CUDA/HIP}_VISIBLE_DEVICES`
65
66
67
68
69
70
# if "HIP_VISIBLE_DEVICES" in os.environ:
#     val = os.environ["HIP_VISIBLE_DEVICES"]
#     if cuda_val := os.environ.get("CUDA_VISIBLE_DEVICES", None):
#         assert val == cuda_val
#     else:
#         os.environ["CUDA_VISIBLE_DEVICES"] = val
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89

# AMDSMI utils
# Note that NVML is not affected by `{CUDA/HIP}_VISIBLE_DEVICES`,
# all the related functions work on real physical device ids.
# the major benefit of using AMDSMI is that it will not initialize CUDA


def with_amdsmi_context(fn):
    @wraps(fn)
    def wrapper(*args, **kwargs):
        amdsmi_init()
        try:
            return fn(*args, **kwargs)
        finally:
            amdsmi_shut_down()

    return wrapper


90
@cache
91
def on_gfx1x() -> bool:
92
    GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
93
94
95
    return any(arch in GPU_ARCH for arch in ["gfx11", "gfx12"])


96
@cache
97
def on_mi3xx() -> bool:
98
    GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
99
100
101
102
103
104
105
    return any(arch in GPU_ARCH for arch in ["gfx942", "gfx950"])


@cache
def on_gfx9() -> bool:
    GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
    return any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"])
106
107


108
109
110
111
112
113
@cache
def on_gfx950() -> bool:
    GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
    return any(arch in GPU_ARCH for arch in ["gfx950"])


114
@cache
115
def use_rocm_custom_paged_attention(
116
117
118
119
120
121
122
    qtype: torch.dtype,
    head_size: int,
    block_size: int,
    gqa_ratio: int,
    max_seq_len: int,
    sliding_window: int,
    kv_cache_dtype: str,
123
124
    alibi_slopes: torch.Tensor | None = None,
    sinks: torch.Tensor | None = None,
125
) -> bool:
126
    from vllm._aiter_ops import rocm_aiter_ops
127

zhuwenwen's avatar
zhuwenwen committed
128
129
130
    # GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
    # ON_GFX9 = any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"])
    # ON_GFX11_GFX12 = any(arch in GPU_ARCH for arch in ["gfx11", "gfx12"])
131

132
133
    # custom paged attn always supported on V0. On V1, requires sliding window
    # disabled due to observed numerical discrepancy.
zhuwenwen's avatar
zhuwenwen committed
134
    # if ON_GFX9:
135
136
137
138
139
140
141
142
143
144
145
    #     return (
    #         (sliding_window == 0 or sliding_window == (-1, -1))
    #         and (qtype == torch.half or qtype == torch.bfloat16)
    #         and (head_size == 64 or head_size == 128)
    #         and (block_size == 16 or block_size == 32)
    #         and (gqa_ratio >= 1 and gqa_ratio <= 16)
    #         and max_seq_len <= 128 * 1024
    #         and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN)
    #         and not (rocm_aiter_ops.is_pa_attn_enabled())
    #         and sinks is None
    #     )
zhuwenwen's avatar
zhuwenwen committed
146
147

    # else:
148
149
150
151
152
153
154
155
156
157
158
159
160
    #     return (
    #         ON_GFX11_GFX12
    #         and (sliding_window == 0 or sliding_window == (-1, -1))
    #         and (qtype == torch.half or qtype == torch.bfloat16)
    #         and head_size == 128
    #         and block_size == 16
    #         and (gqa_ratio >= 3 and gqa_ratio <= 16)
    #         and max_seq_len <= 128 * 1024
    #         and alibi_slopes is None
    #         and kv_cache_dtype == "auto"
    #         and envs.VLLM_ROCM_CUSTOM_PAGED_ATTN
    #         and sinks is None
    #     )
zhuwenwen's avatar
zhuwenwen committed
161
    return False
162
163


164
165
class RocmPlatform(Platform):
    _enum = PlatformEnum.ROCM
166
    device_name: str = "rocm"
167
    device_type: str = "cuda"
168
    dispatch_key: str = "CUDA"
169
    ray_device_key: str = "GPU"
170
    dist_backend: str = "nccl"
171
172
    # rocm shares the same device control env var as CUDA
    device_control_env_var: str = "CUDA_VISIBLE_DEVICES"
173

174
    supported_quantization: list[str] = [
175
176
177
178
179
180
181
182
183
184
185
        "awq",
        "gptq",
        "fp8",
        "compressed-tensors",
        "fbgemm_fp8",
        "gguf",
        "quark",
        "ptpc_fp8",
        "mxfp4",
        "petit_nvfp4",
        "torchao",
186
    ]
187
188
189
    # bitsandbytes not supported on gfx9 (warp size 64 limitation)
    if not on_gfx9():
        supported_quantization += ["bitsandbytes"]
190

191
    @classmethod
192
193
194
    def get_vit_attn_backend(
        cls, head_size: int, dtype: torch.dtype
    ) -> AttentionBackendEnum:
195
196
        from importlib.util import find_spec

197
        from vllm._aiter_ops import rocm_aiter_ops
198

199
        if rocm_aiter_ops.is_mha_enabled():
200
201
            # Note: AITER FA is only supported for Qwen-VL models.
            # TODO: Add support for other VL models in their model class.
202
            return AttentionBackendEnum.ROCM_AITER_FA
203
204

        if on_gfx9() and find_spec("flash_attn") is not None:
205
            return AttentionBackendEnum.FLASH_ATTN
206

207
        return AttentionBackendEnum.TORCH_SDPA
208

209
    @classmethod
210
211
212
213
214
215
216
217
218
219
    def get_attn_backend_cls(
        cls,
        selected_backend,
        head_size,
        dtype,
        kv_cache_dtype,
        block_size,
        use_mla,
        has_sink,
        use_sparse,
220
        use_mm_prefix,
221
        attn_type: str | None = None,
222
    ) -> str:
223
        from vllm._aiter_ops import rocm_aiter_ops
224

225
        if use_sparse:
226
227
228
229
230
231
232
233
            if kv_cache_dtype.startswith("fp8"):
                raise ValueError(
                    "ROCMAiterMLASparseBackend doesn't support fp8 kv_cache_dtype."
                )
            assert block_size == 1, (
                "Sparse MLA backend on ROCm only supports block size 1 for now."
            )
            logger.info_once("Using Sparse MLA backend on V1 engine.")
234
            return AttentionBackendEnum.ROCM_AITER_MLA_SPARSE.get_path()
235

236
        if use_mla:
237
            if selected_backend is None:
238
                selected_backend = (
239
240
241
242
                    # AttentionBackendEnum.ROCM_AITER_MLA
                    # if rocm_aiter_ops.is_mla_enabled() or block_size == 1
                    # else AttentionBackendEnum.TRITON_MLA
                    AttentionBackendEnum.TRITON_MLA
243
                )
244
            if selected_backend == AttentionBackendEnum.TRITON_MLA:
245
                if block_size != 1:
246
                    logger.info_once("Using Triton MLA backend.")
247
                    return AttentionBackendEnum.TRITON_MLA.get_path()
248
249
                raise ValueError(
                    f" The selected backend, {selected_backend.name},"
250
251
                    f"does not support block size {block_size}."
                )
252
253
254
            # if selected_backend == AttentionBackendEnum.ROCM_AITER_MLA:
            #     logger.info("Using AITER MLA backend.")
            #     return AttentionBackendEnum.ROCM_AITER_MLA.get_path()
255
256
257
            # if selected_backend == AttentionBackendEnum.ROCM_AITER_TRITON_MLA:
            #     logger.info("Using AITER TRITON MLA backend.")
            #     return AttentionBackendEnum.ROCM_AITER_TRITON_MLA.get_path()
258

259
260
            raise ValueError(
                f" The selected backend, {selected_backend.name},"
261
262
                f"is not MLA type while requested for MLA backend."
            )
263

264
        if selected_backend == AttentionBackendEnum.FLEX_ATTENTION:
265
            logger.info("Using FlexAttention backend.")
266
267
            return AttentionBackendEnum.FLEX_ATTENTION.get_path()

268
269
270
271
272
273
        if selected_backend == AttentionBackendEnum.TRITON_ATTN:
            logger.info("Using Triton Attention backend on V1 engine.")
            return AttentionBackendEnum.TRITON_ATTN.get_path()

        if selected_backend == AttentionBackendEnum.ROCM_ATTN:
            logger.info("Using Rocm Attention backend on V1 engine.")
274
            return AttentionBackendEnum.ROCM_ATTN.get_path()
275

276
277
278
279
280
281
282
283
284
        # if selected_backend == AttentionBackendEnum.ROCM_AITER_FA:
        #     if on_gfx9():
        #         logger.info("Using Aiter Flash Attention backend on V1 engine.")
        #         return AttentionBackendEnum.ROCM_AITER_FA.get_path()
        #     else:
        #         raise ValueError(
        #             f"The selected backend, {selected_backend.name}, "
        #             "is only supported on gfx9 architectures."
        #         )
285

286
287
288
        # if selected_backend == AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN:
        #     logger.info("Using Aiter Unified Attention backend on V1 engine.")
        #     return AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN.get_path()
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322

        # Handle automatic backend selection based on environment variables
        if selected_backend is None:
            # Priority 1: Check for AITER Unified Attention (must check before MHA)
            if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION:
                logger.info("Using Aiter Unified Attention backend on V1 engine.")
                return AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN.get_path()

            # Priority 2: Check for AITER MHA (Flash Attention)
            # Only use if explicitly enabled (not just VLLM_ROCM_USE_AITER=1)
            if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9():
                logger.info("Using Aiter Flash Attention backend on V1 engine.")
                return AttentionBackendEnum.ROCM_AITER_FA.get_path()

            # Priority 3: Check for ROCM_ATTN (prefill-decode split)
            if envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION:
                logger.info("Using Rocm Attention backend on V1 engine.")
                return AttentionBackendEnum.ROCM_ATTN.get_path()

            # Priority 4: Check for AITER enabled without specific flags
            # This defaults to AITER FA only if MHA is not explicitly disabled
            if (
                envs.VLLM_ROCM_USE_AITER
                and on_gfx9()
                and envs.VLLM_ROCM_USE_AITER_MHA is not False
            ):
                logger.info("Using Aiter Flash Attention backend on V1 engine.")
                return AttentionBackendEnum.ROCM_AITER_FA.get_path()

            # Default: Triton Unified Attention
            logger.info("Using Triton Attention backend on V1 engine.")
            return AttentionBackendEnum.TRITON_ATTN.get_path()

        raise RuntimeError(
323
324
            f"Attention backend {selected_backend.name} is not supported on "
            "ROCm. Note that V0 attention backends have been removed."
325
        )
326

327
328
329
330
331
332
333
    @classmethod
    def set_device(cls, device: torch.device) -> None:
        """
        Set the device for the current platform.
        """
        torch.cuda.set_device(device)

334
    @classmethod
335
    @lru_cache(maxsize=8)
336
    def get_device_capability(cls, device_id: int = 0) -> DeviceCapability | None:
337
338
        major, minor = torch.cuda.get_device_capability(device_id)
        return DeviceCapability(major=major, minor=minor)
339

340
    @classmethod
341
    @with_amdsmi_context
342
    def is_fully_connected(cls, physical_device_ids: list[int]) -> bool:
343
344
345
        """
        Query if the set of gpus are fully connected by xgmi (1 hop)
        """
346
        handles = [amdsmi_get_processor_handles()[i] for i in physical_device_ids]
347
348
349
350
        for i, handle in enumerate(handles):
            for j, peer_handle in enumerate(handles):
                if i < j:
                    try:
351
                        link_type = amdsmi_topo_get_link_type(handle, peer_handle)
352
353
354
355
                        # type is 2 for XGMI
                        if link_type["hops"] != 1 or link_type["type"] != 2:
                            return False
                    except AmdSmiException as error:
356
                        logger.error("AMD 1 hop XGMI detection failed.", exc_info=error)
357
358
359
                        return False
        return True

360
    @classmethod
361
    @with_amdsmi_context
362
    @lru_cache(maxsize=8)
363
    def get_device_name(cls, device_id: int = 0) -> str:
364
        physical_device_id = cls.device_id_to_physical_device_id(device_id)
365
        handle = amdsmi_get_processor_handles()[physical_device_id]
366
367
368
369
370
        asic_info = amdsmi_get_gpu_asic_info(handle)
        device_name: str = asic_info["device_id"]
        if device_name in _ROCM_DEVICE_ID_NAME_MAP:
            return _ROCM_DEVICE_ID_NAME_MAP[device_name]
        return asic_info["market_name"]
371
372
373
374
375

    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        device_props = torch.cuda.get_device_properties(device_id)
        return device_props.total_memory
376
377

    @classmethod
378
    def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
379
        from vllm._aiter_ops import rocm_aiter_ops
380
381
        from vllm.config.compilation import CUDAGraphMode

382
        cache_config = vllm_config.cache_config
383
384
385
        compilation_config = vllm_config.compilation_config
        parallel_config = vllm_config.parallel_config
        is_eager_execution = compilation_config == CUDAGraphMode.NONE
386
        use_aiter_rms_norm = rocm_aiter_ops.is_rmsnorm_enabled()
vllmellm's avatar
vllmellm committed
387
        use_aiter_fp8_linear = rocm_aiter_ops.is_linear_fp8_enaled()
388

389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
        if compilation_config.cudagraph_mode.has_full_cudagraphs():
            # decode context parallel does not support full cudagraphs
            if parallel_config.decode_context_parallel_size > 1:
                logger.warning_once(
                    "Decode context parallel (DCP) is enabled, which is "
                    "incompatible with full CUDA graphs. "
                    "Overriding cudagraph_mode to PIECEWISE."
                )
                compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
            # prefill context parallel do not support full cudagraphs
            elif parallel_config.prefill_context_parallel_size > 1:
                logger.warning_once(
                    "Prefill context parallel (PCP) is enabled, which is "
                    "incompatible with full CUDA graphs. "
                    "Overriding cudagraph_mode to PIECEWISE."
                )
                compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
406

407
408
409
        if cache_config and cache_config.block_size is None:
            cache_config.block_size = 16

410
        if parallel_config.worker_cls == "auto":
411
            parallel_config.worker_cls = "vllm.v1.worker.gpu_worker.Worker"
412
        #  Aiter rms norm perform best when CUDA Graph capture is enabled.
413
        if (
414
            use_aiter_rms_norm
415
416
417
            and not is_eager_execution
            and "-rms_norm" not in compilation_config.custom_ops
        ):
418
            compilation_config.custom_ops.append("+rms_norm")
419

vllmellm's avatar
vllmellm committed
420
421
422
        if use_aiter_fp8_linear and "-quant_fp8" not in compilation_config.custom_ops:
            compilation_config.custom_ops.append("+quant_fp8")

423
424
425
    @classmethod
    def verify_model_arch(cls, model_arch: str) -> None:
        if model_arch in _ROCM_UNSUPPORTED_MODELS:
426
427
428
            raise ValueError(
                f"Model architecture '{model_arch}' is not supported by ROCm for now."
            )
429
430
431
432

        if model_arch in _ROCM_PARTIALLY_SUPPORTED_MODELS:
            msg = _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch]
            logger.warning(
433
434
435
436
                "Model architecture '%s' is partially supported by ROCm: %s",
                model_arch,
                msg,
            )
437

438
439
440
441
442
443
    @classmethod
    def verify_quantization(cls, quant: str) -> None:
        super().verify_quantization(quant)
        if quant == "awq" and not envs.VLLM_USE_TRITON_AWQ:
            logger.warning(
                "Using AWQ quantization with ROCm, but VLLM_USE_TRITON_AWQ"
444
445
                " is not set, enabling VLLM_USE_TRITON_AWQ."
            )
446
        os.environ["VLLM_USE_TRITON_AWQ"] = "1"
447
448
449
450

    @classmethod
    def get_punica_wrapper(cls) -> str:
        return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU"
451
452

    @classmethod
453
    def get_current_memory_usage(
454
        cls, device: torch.types.Device | None = None
455
    ) -> float:
456
        torch.cuda.reset_peak_memory_stats(device)
457
        # return torch.cuda.mem_get_info(device)[1] - torch.cuda.mem_get_info(device)[0]
zhuwenwen's avatar
zhuwenwen committed
458
        return torch.cuda.max_memory_allocated(device)
459
460
461

    @classmethod
    def get_device_communicator_cls(cls) -> str:
462
463
464
        return (
            "vllm.distributed.device_communicators.cuda_communicator.CudaCommunicator"  # noqa
        )
465

466
467
468
469
470
    @classmethod
    def supports_mx(cls) -> bool:
        gcn_arch = torch.cuda.get_device_properties(0).gcnArchName
        return any(gfx in gcn_arch for gfx in ["gfx95"])

471
472
473
    @classmethod
    def supports_fp8(cls) -> bool:
        gcn_arch = torch.cuda.get_device_properties(0).gcnArchName
474
        return any(gfx in gcn_arch for gfx in ["gfx94", "gfx95", "gfx12"])
475
476
477
478

    @classmethod
    def is_fp8_fnuz(cls) -> bool:
        # only device 0 is checked, this assumes MI300 platforms are homogeneous
479
        return "gfx94" in torch.cuda.get_device_properties(0).gcnArchName
480
481
482
483
484
485
486

    @classmethod
    def fp8_dtype(cls) -> torch.dtype:
        if cls.is_fp8_fnuz():
            return torch.float8_e4m3fnuz
        else:
            return torch.float8_e4m3fn
487

488
489
490
491
    @classmethod
    def use_custom_allreduce(cls) -> bool:
        # We only enable custom allreduce for MI300 series
        gcn_arch = torch.cuda.get_device_properties(0).gcnArchName
492
        supported_archs = ["gfx94", "gfx95"]
493
        return any(gfx in gcn_arch for gfx in supported_archs)
494

495
496
497
498
    @classmethod
    def opaque_attention_op(cls) -> bool:
        return True

499
500
    @classmethod
    def is_navi(cls) -> bool:
501
        return "gfx1" in torch.cuda.get_device_properties(0).gcnArchName
502
503

    @classmethod
504
505
    def get_static_graph_wrapper_cls(cls) -> str:
        return "vllm.compilation.cuda_graph.CUDAGraphWrapper"
506

507
508
509
    @classmethod
    def device_count(cls) -> int:
        return cuda_device_count_stateless()
510
511

    @classmethod
512
513
    def check_if_supports_dtype(cls, dtype: torch.dtype):
        if dtype == torch.bfloat16:  # noqa: SIM102
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
            if not cls.has_device_capability(80):
                capability = cls.get_device_capability()
                gpu_name = cls.get_device_name()

                if capability is None:
                    compute_str = "does not have a compute capability"
                else:
                    version_str = capability.as_version_str()
                    compute_str = f"has compute capability {version_str}"

                raise ValueError(
                    "Bfloat16 is only supported on GPUs "
                    "with compute capability of at least 8.0. "
                    f"Your {gpu_name} GPU {compute_str}. "
                    "You can use float16 instead by explicitly setting the "
529
530
                    "`dtype` flag in CLI, for example: --dtype=half."
                )
531
532
533
534

    @classmethod
    def support_hybrid_kv_cache(cls) -> bool:
        return True
535
536
537
538

    @classmethod
    def support_static_graph_mode(cls) -> bool:
        return True