rocm.py 18.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import os
5
from datetime import timedelta
6
from functools import cache, lru_cache, wraps
7
from typing import TYPE_CHECKING, Optional
8
9

import torch
10
11
from torch.distributed import PrefixStore, ProcessGroup
from torch.distributed.distributed_c10d import is_nccl_available
12

13
import vllm.envs as envs
14
from vllm.logger import init_logger
15
from vllm.utils import cuda_device_count_stateless
16

17
from .interface import DeviceCapability, Platform, PlatformEnum
18

19
if TYPE_CHECKING:
20
    from vllm.attention.backends.registry import _Backend
21
    from vllm.config import ModelConfig, VllmConfig
22
23
else:
    _Backend = None
24

25
26
logger = init_logger(__name__)

27
try:
28
29
30
31
32
33
34
35
    from amdsmi import (
        AmdSmiException,
        amdsmi_get_gpu_asic_info,
        amdsmi_get_processor_handles,
        amdsmi_init,
        amdsmi_shut_down,
        amdsmi_topo_get_link_type,
    )
36
37
38
except ImportError as e:
    logger.warning("Failed to import from amdsmi with %r", e)

39
40
41
42
43
44
45
46
47
48
49
try:
    import vllm._C  # noqa: F401
except ImportError as e:
    logger.warning("Failed to import from vllm._C with %r", e)

# import custom ops, trigger op registration
try:
    import vllm._rocm_C  # noqa: F401
except ImportError as e:
    logger.warning("Failed to import from vllm._rocm_C with %r", e)

50
# Models not supported by ROCm.
51
_ROCM_UNSUPPORTED_MODELS: list[str] = []
52
53
54

# Models partially supported by ROCm.
# Architecture -> Reason.
55
56
57
58
59
60
_ROCM_SWA_REASON = (
    "Sliding window attention (SWA) is not yet supported in "
    "Triton flash attention. For half-precision SWA support, "
    "please use CK flash attention by setting "
    "`VLLM_USE_TRITON_FLASH_ATTN=0`"
)
61
_ROCM_PARTIALLY_SUPPORTED_MODELS: dict[str, str] = {
62
63
64
65
66
67
68
69
70
71
72
    "Qwen2ForCausalLM": _ROCM_SWA_REASON,
    "MistralForCausalLM": _ROCM_SWA_REASON,
    "MixtralForCausalLM": _ROCM_SWA_REASON,
    "PaliGemmaForConditionalGeneration": (
        "ROCm flash attention does not yet fully support 32-bit precision on PaliGemma"
    ),
    "Phi3VForCausalLM": (
        "ROCm Triton flash attention may run into compilation errors due to "
        "excessive use of shared memory. If this happens, disable Triton FA "
        "by setting `VLLM_USE_TRITON_FLASH_ATTN=0`"
    ),
73
}
74
_ROCM_DEVICE_ID_NAME_MAP: dict[str, str] = {
75
76
77
78
79
80
81
82
    "0x74a0": "AMD_Instinct_MI300A",
    "0x74a1": "AMD_Instinct_MI300X",
    "0x74b5": "AMD_Instinct_MI300X",  # MI300X VF
    "0x74a5": "AMD_Instinct_MI325X",
    "0x74b9": "AMD_Instinct_MI325X",  # MI325X VF
    "0x74a9": "AMD_Instinct_MI300X_HF",
    "0x74bd": "AMD_Instinct_MI300X_HF",
}
83

84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
# Prevent use of clashing `{CUDA/HIP}_VISIBLE_DEVICES``
if "HIP_VISIBLE_DEVICES" in os.environ:
    val = os.environ["HIP_VISIBLE_DEVICES"]
    if cuda_val := os.environ.get("CUDA_VISIBLE_DEVICES", None):
        assert val == cuda_val
    else:
        os.environ["CUDA_VISIBLE_DEVICES"] = val

# AMDSMI utils
# Note that NVML is not affected by `{CUDA/HIP}_VISIBLE_DEVICES`,
# all the related functions work on real physical device ids.
# the major benefit of using AMDSMI is that it will not initialize CUDA


def with_amdsmi_context(fn):
    @wraps(fn)
    def wrapper(*args, **kwargs):
        amdsmi_init()
        try:
            return fn(*args, **kwargs)
        finally:
            amdsmi_shut_down()

    return wrapper


110
111
112
113
114
115
@cache
def on_gfx1x() -> bool:
    GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
    return any(arch in GPU_ARCH for arch in ["gfx11", "gfx12"])


116
@cache
117
def on_mi3xx() -> bool:
118
    GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
119
120
121
122
123
124
125
    return any(arch in GPU_ARCH for arch in ["gfx942", "gfx950"])


@cache
def on_gfx9() -> bool:
    GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
    return any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"])
126
127


128
129
130
131
132
133
@cache
def on_gfx950() -> bool:
    GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
    return any(arch in GPU_ARCH for arch in ["gfx950"])


134
@cache
135
def use_rocm_custom_paged_attention(
136
137
138
139
140
141
142
143
144
145
    qtype: torch.dtype,
    head_size: int,
    block_size: int,
    gqa_ratio: int,
    max_seq_len: int,
    sliding_window: int,
    kv_cache_dtype: str,
    alibi_slopes: Optional[torch.Tensor] = None,
    sinks: Optional[torch.Tensor] = None,
) -> bool:
146
147
    GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
    ON_GFX9 = any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"])
148
    ON_GFX11_GFX12 = any(arch in GPU_ARCH for arch in ["gfx11", "gfx12"])
149

150
151
    # custom paged attn always supported on V0. On V1, requires sliding window
    # disabled due to observed numerical discrepancy.
152
    if ON_GFX9:
153
154
155
156
157
158
159
160
161
162
163
        return (
            (not envs.VLLM_USE_V1 or sliding_window == 0 or sliding_window == (-1, -1))
            and (qtype == torch.half or qtype == torch.bfloat16)
            and (head_size == 64 or head_size == 128)
            and (block_size == 16 or block_size == 32)
            and (gqa_ratio >= 1 and gqa_ratio <= 16)
            and max_seq_len <= 128 * 1024
            and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN)
            and not (envs.VLLM_ROCM_USE_AITER_PAGED_ATTN and envs.VLLM_ROCM_USE_AITER)
            and sinks is None
        )
164
165

    else:
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
        return (
            ON_GFX11_GFX12
            and (
                not envs.VLLM_USE_V1
                or sliding_window == 0
                or sliding_window == (-1, -1)
            )
            and (qtype == torch.half or qtype == torch.bfloat16)
            and head_size == 128
            and block_size == 16
            and (gqa_ratio >= 3 and gqa_ratio <= 16)
            and max_seq_len <= 128 * 1024
            and alibi_slopes is None
            and kv_cache_dtype == "auto"
            and envs.VLLM_ROCM_CUSTOM_PAGED_ATTN
            and sinks is None
        )
183
184


185
186
class RocmPlatform(Platform):
    _enum = PlatformEnum.ROCM
187
    device_name: str = "rocm"
188
    device_type: str = "cuda"
189
    dispatch_key: str = "CUDA"
190
    ray_device_key: str = "GPU"
191
    dist_backend: str = "nccl"
192
193
    # rocm shares the same device control env var as CUDA
    device_control_env_var: str = "CUDA_VISIBLE_DEVICES"
194

195
    supported_quantization: list[str] = [
196
197
198
199
200
201
202
203
204
205
206
        "awq",
        "gptq",
        "fp8",
        "compressed-tensors",
        "fbgemm_fp8",
        "gguf",
        "quark",
        "ptpc_fp8",
        "mxfp4",
        "petit_nvfp4",
        "torchao",
207
    ]
208

209
    @classmethod
210
    def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> "_Backend":
211
        from vllm.attention.backends.registry import _Backend
212
213

        if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9():
214
215
216
            return _Backend.ROCM_AITER_FA
        if on_gfx9():
            return _Backend.FLASH_ATTN
217
218
        return _Backend.TORCH_SDPA

219
    @classmethod
220
221
222
223
224
225
226
227
228
229
230
231
    def get_attn_backend_cls(
        cls,
        selected_backend,
        head_size,
        dtype,
        kv_cache_dtype,
        block_size,
        use_v1,
        use_mla,
        has_sink,
        use_sparse,
    ) -> str:
232
        from vllm.attention.backends.registry import _Backend
233

234
        if use_sparse:
235
            raise NotImplementedError("Sparse Attention is not supported on ROCm.")
236
        if use_mla:
237
238
239
            if not use_v1:
                raise RuntimeError(
                    "MLA attention backends require the V1 engine. "
240
241
                    "Set VLLM_USE_V1=1 to enable them."
                )
242

243
            from vllm.v1.attention.backends.mla.rocm_aiter_mla import (
244
245
                is_aiter_mla_enabled,
            )
246
247

            if selected_backend is None:
248
249
250
251
252
                selected_backend = (
                    _Backend.ROCM_AITER_MLA
                    if is_aiter_mla_enabled() or block_size == 1
                    else _Backend.TRITON_MLA
                )
253
254
255

            if selected_backend == _Backend.TRITON_MLA:
                if block_size != 1:
256
                    logger.info_once("Using Triton MLA backend on V1 engine.")
257
                    return "vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend"
258
259
                raise ValueError(
                    f" The selected backend, {selected_backend.name},"
260
261
                    f"does not support block size {block_size}."
                )
262
            if selected_backend == _Backend.ROCM_AITER_MLA:
263
                if block_size == 1:
264
                    logger.info("Using AITER MLA backend on V1 engine.")
265
266
267
                    return (
                        "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend"  # noqa: E501
                    )
268
269
                raise ValueError(
                    f" The selected backend, {selected_backend.name},"
270
                    f"does not support block size {block_size}."
271
272
                    "(currently only supports block size 1)"
                )
273
274
            raise ValueError(
                f" The selected backend, {selected_backend.name},"
275
276
                f"is not MLA type while requested for MLA backend."
            )
277

278
        if envs.VLLM_USE_V1:
279
            if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9():
280
                logger.info("Using Flash Attention backend on V1 engine.")
281
282
283
284
285
286
287
288
289
                return (
                    "vllm.v1.attention.backends."
                    "rocm_aiter_fa.AiterFlashAttentionBackend"
                )
            elif (
                (envs.VLLM_ROCM_USE_AITER and envs.VLLM_USE_AITER_UNIFIED_ATTENTION)
                or envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION
                or selected_backend == _Backend.ROCM_ATTN
            ):
290
291
292
                # rocm specific backend, with aiter and/or
                #   triton prefix-prefill
                logger.info("Using Rocm/Aiter Attention backend on V1 engine.")
293
                return "vllm.v1.attention.backends.rocm_attn.RocmAttentionBackend"
294
            else:
295
                # default case, using triton unified attention
296
                logger.info("Using Triton Attention backend on V1 engine.")
297
                return "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend"
298
299
        raise RuntimeError(
            "V0 attention backends have been removed. Set VLLM_USE_V1=1 "
300
301
            "to select a supported backend."
        )
302

303
304
305
306
307
308
309
    @classmethod
    def set_device(cls, device: torch.device) -> None:
        """
        Set the device for the current platform.
        """
        torch.cuda.set_device(device)

310
    @classmethod
311
    @lru_cache(maxsize=8)
312
    def get_device_capability(cls, device_id: int = 0) -> Optional[DeviceCapability]:
313
314
        major, minor = torch.cuda.get_device_capability(device_id)
        return DeviceCapability(major=major, minor=minor)
315

316
    @classmethod
317
    @with_amdsmi_context
318
    def is_fully_connected(cls, physical_device_ids: list[int]) -> bool:
319
320
321
        """
        Query if the set of gpus are fully connected by xgmi (1 hop)
        """
322
        handles = [amdsmi_get_processor_handles()[i] for i in physical_device_ids]
323
324
325
326
        for i, handle in enumerate(handles):
            for j, peer_handle in enumerate(handles):
                if i < j:
                    try:
327
                        link_type = amdsmi_topo_get_link_type(handle, peer_handle)
328
329
330
331
                        # type is 2 for XGMI
                        if link_type["hops"] != 1 or link_type["type"] != 2:
                            return False
                    except AmdSmiException as error:
332
                        logger.error("AMD 1 hop XGMI detection failed.", exc_info=error)
333
334
335
                        return False
        return True

336
    @classmethod
337
    @with_amdsmi_context
338
    @lru_cache(maxsize=8)
339
    def get_device_name(cls, device_id: int = 0) -> str:
340
        physical_device_id = cls.device_id_to_physical_device_id(device_id)
341
        handle = amdsmi_get_processor_handles()[physical_device_id]
342
343
344
345
346
        asic_info = amdsmi_get_gpu_asic_info(handle)
        device_name: str = asic_info["device_id"]
        if device_name in _ROCM_DEVICE_ID_NAME_MAP:
            return _ROCM_DEVICE_ID_NAME_MAP[device_name]
        return asic_info["market_name"]
347
348
349
350
351

    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        device_props = torch.cuda.get_device_properties(device_id)
        return device_props.total_memory
352
353

    @classmethod
354
    def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
355
356
        from vllm.config.compilation import CUDAGraphMode

357
        cache_config = vllm_config.cache_config
358
359
360
361
362
        compilation_config = vllm_config.compilation_config
        parallel_config = vllm_config.parallel_config
        is_eager_execution = compilation_config == CUDAGraphMode.NONE

        use_v1 = envs.VLLM_USE_V1
363
364
365
        use_aiter_rms_norm = (
            envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_RMSNORM
        )
366

367
368
369
        if cache_config and cache_config.block_size is None:
            cache_config.block_size = 16

370
        if parallel_config.worker_cls == "auto":
371
            parallel_config.worker_cls = "vllm.v1.worker.gpu_worker.Worker"
372
        #  Aiter rms norm perform best when CUDA Graph capture is enabled.
373
374
375
376
377
378
        if (
            use_v1
            and use_aiter_rms_norm
            and not is_eager_execution
            and "-rms_norm" not in compilation_config.custom_ops
        ):
379
            compilation_config.custom_ops.append("+rms_norm")
380

381
382
383
    @classmethod
    def verify_model_arch(cls, model_arch: str) -> None:
        if model_arch in _ROCM_UNSUPPORTED_MODELS:
384
385
386
            raise ValueError(
                f"Model architecture '{model_arch}' is not supported by ROCm for now."
            )
387
388
389
390

        if model_arch in _ROCM_PARTIALLY_SUPPORTED_MODELS:
            msg = _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch]
            logger.warning(
391
392
393
394
                "Model architecture '%s' is partially supported by ROCm: %s",
                model_arch,
                msg,
            )
395

396
397
398
399
400
401
    @classmethod
    def verify_quantization(cls, quant: str) -> None:
        super().verify_quantization(quant)
        if quant == "awq" and not envs.VLLM_USE_TRITON_AWQ:
            logger.warning(
                "Using AWQ quantization with ROCm, but VLLM_USE_TRITON_AWQ"
402
403
                " is not set, enabling VLLM_USE_TRITON_AWQ."
            )
404
        envs.VLLM_USE_TRITON_AWQ = True
405
406
407
408

    @classmethod
    def get_punica_wrapper(cls) -> str:
        return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU"
409
410

    @classmethod
411
412
413
    def get_current_memory_usage(
        cls, device: Optional[torch.types.Device] = None
    ) -> float:
414
        torch.cuda.reset_peak_memory_stats(device)
415
        return torch.cuda.mem_get_info(device)[1] - torch.cuda.mem_get_info(device)[0]
416
417
418

    @classmethod
    def get_device_communicator_cls(cls) -> str:
419
420
421
        return (
            "vllm.distributed.device_communicators.cuda_communicator.CudaCommunicator"  # noqa
        )
422

423
424
425
426
427
    @classmethod
    def supports_mx(cls) -> bool:
        gcn_arch = torch.cuda.get_device_properties(0).gcnArchName
        return any(gfx in gcn_arch for gfx in ["gfx95"])

428
429
430
    @classmethod
    def supports_fp8(cls) -> bool:
        gcn_arch = torch.cuda.get_device_properties(0).gcnArchName
431
        return any(gfx in gcn_arch for gfx in ["gfx94", "gfx95", "gfx12"])
432
433
434
435

    @classmethod
    def is_fp8_fnuz(cls) -> bool:
        # only device 0 is checked, this assumes MI300 platforms are homogeneous
436
        return "gfx94" in torch.cuda.get_device_properties(0).gcnArchName
437
438
439
440
441
442
443

    @classmethod
    def fp8_dtype(cls) -> torch.dtype:
        if cls.is_fp8_fnuz():
            return torch.float8_e4m3fnuz
        else:
            return torch.float8_e4m3fn
444

445
446
447
448
    @classmethod
    def use_custom_allreduce(cls) -> bool:
        # We only enable custom allreduce for MI300 series
        gcn_arch = torch.cuda.get_device_properties(0).gcnArchName
449
        supported_archs = ["gfx94", "gfx95"]
450
        return any(gfx in gcn_arch for gfx in supported_archs)
451

452
453
454
455
    @classmethod
    def opaque_attention_op(cls) -> bool:
        return True

456
457
    @classmethod
    def get_cu_count(cls, device_id: int = 0) -> int:
458
        return torch.cuda.get_device_properties(device_id).multi_processor_count
459
460
461

    @classmethod
    def is_navi(cls) -> bool:
462
        return "gfx1" in torch.cuda.get_device_properties(0).gcnArchName
463
464

    @classmethod
465
466
    def get_static_graph_wrapper_cls(cls) -> str:
        return "vllm.compilation.cuda_graph.CUDAGraphWrapper"
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487

    @classmethod
    def stateless_init_device_torch_dist_pg(
        cls,
        backend: str,
        prefix_store: PrefixStore,
        group_rank: int,
        group_size: int,
        timeout: timedelta,
    ) -> ProcessGroup:
        assert is_nccl_available()
        pg: ProcessGroup = ProcessGroup(
            prefix_store,
            group_rank,
            group_size,
        )
        from torch.distributed.distributed_c10d import ProcessGroupNCCL

        backend_options = ProcessGroupNCCL.Options()
        backend_options._timeout = timeout

488
489
490
        backend_class = ProcessGroupNCCL(
            prefix_store, group_rank, group_size, backend_options
        )
491
492
493
494
495
496
497
        backend_type = ProcessGroup.BackendType.NCCL
        device = torch.device("cuda")
        pg._set_default_backend(backend_type)
        backend_class._set_sequence_number_for_group()

        pg._register_backend(device, backend_type, backend_class)
        return pg
498
499
500
501

    @classmethod
    def device_count(cls) -> int:
        return cuda_device_count_stateless()
502
503

    @classmethod
504
505
506
    def is_kv_cache_dtype_supported(
        cls, kv_cache_dtype: str, model_config: "ModelConfig"
    ) -> bool:
507
        return True
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526

    @classmethod
    def check_if_supports_dtype(cls, torch_dtype: torch.dtype):
        if torch_dtype == torch.bfloat16:  # noqa: SIM102
            if not cls.has_device_capability(80):
                capability = cls.get_device_capability()
                gpu_name = cls.get_device_name()

                if capability is None:
                    compute_str = "does not have a compute capability"
                else:
                    version_str = capability.as_version_str()
                    compute_str = f"has compute capability {version_str}"

                raise ValueError(
                    "Bfloat16 is only supported on GPUs "
                    "with compute capability of at least 8.0. "
                    f"Your {gpu_name} GPU {compute_str}. "
                    "You can use float16 instead by explicitly setting the "
527
528
                    "`dtype` flag in CLI, for example: --dtype=half."
                )
529
530
531
532

    @classmethod
    def support_hybrid_kv_cache(cls) -> bool:
        return True
533
534
535
536

    @classmethod
    def support_static_graph_mode(cls) -> bool:
        return True