rocm.py 18.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import os
5
from datetime import timedelta
6
from functools import cache, lru_cache, wraps
7
from typing import TYPE_CHECKING, Optional
8
9

import torch
10
11
from torch.distributed import PrefixStore, ProcessGroup
from torch.distributed.distributed_c10d import is_nccl_available
12

13
import vllm.envs as envs
14
from vllm.logger import init_logger
15
from vllm.utils import cuda_device_count_stateless
16

17
from .interface import DeviceCapability, Platform, PlatformEnum, _Backend
18

19
if TYPE_CHECKING:
20
    from vllm.config import ModelConfig, VllmConfig
21

22
23
logger = init_logger(__name__)

24
try:
25
26
27
    from amdsmi import (AmdSmiException, amdsmi_get_gpu_asic_info,
                        amdsmi_get_processor_handles, amdsmi_init,
                        amdsmi_shut_down, amdsmi_topo_get_link_type)
28
29
30
except ImportError as e:
    logger.warning("Failed to import from amdsmi with %r", e)

31
32
33
34
35
36
37
38
39
40
41
try:
    import vllm._C  # noqa: F401
except ImportError as e:
    logger.warning("Failed to import from vllm._C with %r", e)

# import custom ops, trigger op registration
try:
    import vllm._rocm_C  # noqa: F401
except ImportError as e:
    logger.warning("Failed to import from vllm._rocm_C with %r", e)

42
# Models not supported by ROCm.
43
_ROCM_UNSUPPORTED_MODELS: list[str] = []
44
45
46
47
48
49
50

# Models partially supported by ROCm.
# Architecture -> Reason.
_ROCM_SWA_REASON = ("Sliding window attention (SWA) is not yet supported in "
                    "Triton flash attention. For half-precision SWA support, "
                    "please use CK flash attention by setting "
                    "`VLLM_USE_TRITON_FLASH_ATTN=0`")
51
_ROCM_PARTIALLY_SUPPORTED_MODELS: dict[str, str] = {
52
53
54
55
56
57
58
59
60
61
62
63
64
65
    "Qwen2ForCausalLM":
    _ROCM_SWA_REASON,
    "MistralForCausalLM":
    _ROCM_SWA_REASON,
    "MixtralForCausalLM":
    _ROCM_SWA_REASON,
    "PaliGemmaForConditionalGeneration":
    ("ROCm flash attention does not yet "
     "fully support 32-bit precision on PaliGemma"),
    "Phi3VForCausalLM":
    ("ROCm Triton flash attention may run into compilation errors due to "
     "excessive use of shared memory. If this happens, disable Triton FA "
     "by setting `VLLM_USE_TRITON_FLASH_ATTN=0`")
}
66
_ROCM_DEVICE_ID_NAME_MAP: dict[str, str] = {
67
68
69
70
71
72
73
74
    "0x74a0": "AMD_Instinct_MI300A",
    "0x74a1": "AMD_Instinct_MI300X",
    "0x74b5": "AMD_Instinct_MI300X",  # MI300X VF
    "0x74a5": "AMD_Instinct_MI325X",
    "0x74b9": "AMD_Instinct_MI325X",  # MI325X VF
    "0x74a9": "AMD_Instinct_MI300X_HF",
    "0x74bd": "AMD_Instinct_MI300X_HF",
}
75

76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
# Prevent use of clashing `{CUDA/HIP}_VISIBLE_DEVICES``
if "HIP_VISIBLE_DEVICES" in os.environ:
    val = os.environ["HIP_VISIBLE_DEVICES"]
    if cuda_val := os.environ.get("CUDA_VISIBLE_DEVICES", None):
        assert val == cuda_val
    else:
        os.environ["CUDA_VISIBLE_DEVICES"] = val

# AMDSMI utils
# Note that NVML is not affected by `{CUDA/HIP}_VISIBLE_DEVICES`,
# all the related functions work on real physical device ids.
# the major benefit of using AMDSMI is that it will not initialize CUDA


def with_amdsmi_context(fn):

    @wraps(fn)
    def wrapper(*args, **kwargs):
        amdsmi_init()
        try:
            return fn(*args, **kwargs)
        finally:
            amdsmi_shut_down()

    return wrapper


103
104
105
106
107
108
@cache
def on_gfx1x() -> bool:
    GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
    return any(arch in GPU_ARCH for arch in ["gfx11", "gfx12"])


109
@cache
110
def on_mi3xx() -> bool:
111
    GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
112
113
114
115
116
117
118
    return any(arch in GPU_ARCH for arch in ["gfx942", "gfx950"])


@cache
def on_gfx9() -> bool:
    GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
    return any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"])
119
120


121
@cache
122
123
124
125
126
127
128
129
def use_rocm_custom_paged_attention(
        qtype: torch.dtype,
        head_size: int,
        block_size: int,
        gqa_ratio: int,
        max_seq_len: int,
        sliding_window: int,
        kv_cache_dtype: str,
130
131
        alibi_slopes: Optional[torch.Tensor] = None,
        sinks: Optional[torch.Tensor] = None) -> bool:
132

133
134
    GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
    ON_GFX9 = any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"])
135
    ON_GFX11_GFX12 = any(arch in GPU_ARCH for arch in ["gfx11", "gfx12"])
136

137
138
    # custom paged attn always supported on V0. On V1, requires sliding window
    # disabled due to observed numerical discrepancy.
139
140
141
142
143
144
145
    if ON_GFX9:
        return ((not envs.VLLM_USE_V1 or sliding_window == 0
                 or sliding_window == (-1, -1))
                and (qtype == torch.half or qtype == torch.bfloat16)
                and (head_size == 64 or head_size == 128)
                and (block_size == 16 or block_size == 32)
                and (gqa_ratio >= 1 and gqa_ratio <= 16)
146
147
                and max_seq_len <= 128 * 1024
                and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN)
148
                and not (envs.VLLM_ROCM_USE_AITER_PAGED_ATTN
149
                         and envs.VLLM_ROCM_USE_AITER) and sinks is None)
150
151
152
153
154
155
156

    else:
        return (ON_GFX11_GFX12 and (not envs.VLLM_USE_V1 or sliding_window == 0
                                    or sliding_window == (-1, -1))
                and (qtype == torch.half or qtype == torch.bfloat16)
                and head_size == 128 and block_size == 16
                and (gqa_ratio >= 3 and gqa_ratio <= 16)
157
                and max_seq_len <= 128 * 1024 and alibi_slopes is None
158
                and kv_cache_dtype == "auto"
159
                and envs.VLLM_ROCM_CUSTOM_PAGED_ATTN and sinks is None)
160
161


162
163
class RocmPlatform(Platform):
    _enum = PlatformEnum.ROCM
164
    device_name: str = "rocm"
165
    device_type: str = "cuda"
166
    dispatch_key: str = "CUDA"
167
    ray_device_key: str = "GPU"
168
    dist_backend: str = "nccl"
169
170
    # rocm shares the same device control env var as CUDA
    device_control_env_var: str = "CUDA_VISIBLE_DEVICES"
171

172
    supported_quantization: list[str] = [
173
        "awq", "gptq", "fp8", "compressed-tensors", "fbgemm_fp8", "gguf",
174
        "quark", "ptpc_fp8", "mxfp4", "petit_nvfp4", "torchao"
175
    ]
176

177
    @classmethod
178
179
180
181
182
183
184
185
186
    def get_vit_attn_backend(cls, head_size: int,
                             dtype: torch.dtype) -> _Backend:
        if (envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA
                and on_gfx9()):
            # Note: AITER FA is only supported for Qwen-VL models.
            # TODO: Add support for other VL models in their model class.
            return _Backend.ROCM_AITER_FA
        if on_gfx9():
            return _Backend.FLASH_ATTN
187
188
        return _Backend.TORCH_SDPA

189
    @classmethod
190
    def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
191
192
                             kv_cache_dtype, block_size, use_v1, use_mla,
                             has_sink) -> str:
193
        if use_mla:
194
195
196
197
198
            if not use_v1:
                raise RuntimeError(
                    "MLA attention backends require the V1 engine. "
                    "Set VLLM_USE_V1=1 to enable them.")

199
            from vllm.v1.attention.backends.mla.rocm_aiter_mla import (
200
201
202
203
204
205
206
207
208
                is_aiter_mla_enabled)

            if selected_backend is None:
                selected_backend = (_Backend.ROCM_AITER_MLA if
                                    is_aiter_mla_enabled() or block_size == 1
                                    else _Backend.TRITON_MLA)

            if selected_backend == _Backend.TRITON_MLA:
                if block_size != 1:
209
210
211
212
213
214
215
216
                    logger.info_once("Using Triton MLA backend on V1 engine.")
                    return ("vllm.v1.attention.backends.mla."
                            "triton_mla.TritonMLABackend")
                raise ValueError(
                    f" The selected backend, {selected_backend.name},"
                    f"does not support block size {block_size}.")
            if selected_backend in (_Backend.ROCM_AITER_MLA,
                                    _Backend.ROCM_AITER_MLA_VLLM_V1):
217
                if block_size == 1:
218
219
                    logger.info("Using AITER MLA backend on V1 engine.")
                    return "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend"  # noqa: E501
220
221
                raise ValueError(
                    f" The selected backend, {selected_backend.name},"
222
223
224
225
226
                    f"does not support block size {block_size}."
                    "(currently only supports block size 1)")
            raise ValueError(
                f" The selected backend, {selected_backend.name},"
                f"is not MLA type while requested for MLA backend.")
227

228
        if envs.VLLM_USE_V1:
229
230
231
232
233
            if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA \
                and on_gfx9():
                logger.info("Using Flash Attention backend on V1 engine.")
                return ("vllm.v1.attention.backends."
                        "rocm_aiter_fa.AiterFlashAttentionBackend")
234
235
236
237
238
239
240
241
242
            elif (envs.VLLM_ROCM_USE_AITER and
                envs.VLLM_USE_AITER_UNIFIED_ATTENTION) or \
                    envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION or \
                        selected_backend == _Backend.ROCM_ATTN_VLLM_V1:
                # rocm specific backend, with aiter and/or
                #   triton prefix-prefill
                logger.info("Using Rocm/Aiter Attention backend on V1 engine.")
                return ("vllm.v1.attention.backends."
                        "rocm_attn.RocmAttentionBackend")
243
            else:
244
                # default case, using triton unified attention
245
246
247
                logger.info("Using Triton Attention backend on V1 engine.")
                return ("vllm.v1.attention.backends."
                        "triton_attn.TritonAttentionBackend")
248
249
250
        raise RuntimeError(
            "V0 attention backends have been removed. Set VLLM_USE_V1=1 "
            "to select a supported backend.")
251

252
253
254
255
256
257
258
    @classmethod
    def set_device(cls, device: torch.device) -> None:
        """
        Set the device for the current platform.
        """
        torch.cuda.set_device(device)

259
    @classmethod
260
    @lru_cache(maxsize=8)
261
262
263
    def get_device_capability(cls,
                              device_id: int = 0
                              ) -> Optional[DeviceCapability]:
264
265
        major, minor = torch.cuda.get_device_capability(device_id)
        return DeviceCapability(major=major, minor=minor)
266

267
    @classmethod
268
    @with_amdsmi_context
269
    def is_fully_connected(cls, physical_device_ids: list[int]) -> bool:
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
        """
        Query if the set of gpus are fully connected by xgmi (1 hop)
        """
        handles = [
            amdsmi_get_processor_handles()[i] for i in physical_device_ids
        ]
        for i, handle in enumerate(handles):
            for j, peer_handle in enumerate(handles):
                if i < j:
                    try:
                        link_type = amdsmi_topo_get_link_type(
                            handle, peer_handle)
                        # type is 2 for XGMI
                        if link_type["hops"] != 1 or link_type["type"] != 2:
                            return False
                    except AmdSmiException as error:
                        logger.error("AMD 1 hop XGMI detection failed.",
                                     exc_info=error)
                        return False
        return True

291
    @classmethod
292
    @with_amdsmi_context
293
    @lru_cache(maxsize=8)
294
    def get_device_name(cls, device_id: int = 0) -> str:
295
        physical_device_id = cls.device_id_to_physical_device_id(device_id)
296
        handle = amdsmi_get_processor_handles()[physical_device_id]
297
298
299
300
301
        asic_info = amdsmi_get_gpu_asic_info(handle)
        device_name: str = asic_info["device_id"]
        if device_name in _ROCM_DEVICE_ID_NAME_MAP:
            return _ROCM_DEVICE_ID_NAME_MAP[device_name]
        return asic_info["market_name"]
302
303
304
305
306

    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        device_props = torch.cuda.get_device_properties(device_id)
        return device_props.total_memory
307
308

    @classmethod
309
    def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
310
311
        from vllm.config.compilation import CUDAGraphMode

312
        cache_config = vllm_config.cache_config
313
314
315
316
317
318
319
320
        compilation_config = vllm_config.compilation_config
        parallel_config = vllm_config.parallel_config
        is_eager_execution = compilation_config == CUDAGraphMode.NONE

        use_v1 = envs.VLLM_USE_V1
        use_aiter_rms_norm = envs.VLLM_ROCM_USE_AITER and \
             envs.VLLM_ROCM_USE_AITER_RMSNORM

321
322
323
        if cache_config and cache_config.block_size is None:
            cache_config.block_size = 16

324
        if parallel_config.worker_cls == "auto":
325
            if vllm_config.speculative_config:
326
                if not use_v1:
327
                    raise NotImplementedError(
328
329
                        "Speculative decoding is not supported on vLLM V0.")
                parallel_config.worker_cls = "vllm.v1.worker.gpu_worker.Worker"
330
            else:
331
                if use_v1:
332
                    parallel_config.worker_cls = \
333
                        "vllm.v1.worker.gpu_worker.Worker"
334
335
                else:
                    parallel_config.worker_cls = "vllm.worker.worker.Worker"
336
337
338
        #  Aiter rms norm perform best when CUDA Graph capture is enabled.
        if use_v1 and use_aiter_rms_norm and not is_eager_execution:
            compilation_config.custom_ops.append("+rms_norm")
339

340
341
342
343
344
345
346
347
348
349
350
351
    @classmethod
    def verify_model_arch(cls, model_arch: str) -> None:
        if model_arch in _ROCM_UNSUPPORTED_MODELS:
            raise ValueError(f"Model architecture '{model_arch}' is not "
                             "supported by ROCm for now.")

        if model_arch in _ROCM_PARTIALLY_SUPPORTED_MODELS:
            msg = _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch]
            logger.warning(
                "Model architecture '%s' is partially "
                "supported by ROCm: %s", model_arch, msg)

352
353
354
355
356
357
358
359
    @classmethod
    def verify_quantization(cls, quant: str) -> None:
        super().verify_quantization(quant)
        if quant == "awq" and not envs.VLLM_USE_TRITON_AWQ:
            logger.warning(
                "Using AWQ quantization with ROCm, but VLLM_USE_TRITON_AWQ"
                " is not set, enabling VLLM_USE_TRITON_AWQ.")
        envs.VLLM_USE_TRITON_AWQ = True
360
361
362
363

    @classmethod
    def get_punica_wrapper(cls) -> str:
        return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU"
364
365
366
367
368
369

    @classmethod
    def get_current_memory_usage(cls,
                                 device: Optional[torch.types.Device] = None
                                 ) -> float:
        torch.cuda.reset_peak_memory_stats(device)
370
371
        return torch.cuda.mem_get_info(device)[1] - torch.cuda.mem_get_info(
            device)[0]
372
373
374
375

    @classmethod
    def get_device_communicator_cls(cls) -> str:
        return "vllm.distributed.device_communicators.cuda_communicator.CudaCommunicator"  # noqa
376

377
378
379
380
381
    @classmethod
    def supports_mx(cls) -> bool:
        gcn_arch = torch.cuda.get_device_properties(0).gcnArchName
        return any(gfx in gcn_arch for gfx in ["gfx95"])

382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
    @classmethod
    def supports_fp8(cls) -> bool:
        gcn_arch = torch.cuda.get_device_properties(0).gcnArchName
        return any(gfx in gcn_arch for gfx in ['gfx94', 'gfx95', 'gfx12'])

    @classmethod
    def is_fp8_fnuz(cls) -> bool:
        # only device 0 is checked, this assumes MI300 platforms are homogeneous
        return 'gfx94' in torch.cuda.get_device_properties(0).gcnArchName

    @classmethod
    def fp8_dtype(cls) -> torch.dtype:
        if cls.is_fp8_fnuz():
            return torch.float8_e4m3fnuz
        else:
            return torch.float8_e4m3fn
398

399
400
401
402
    @classmethod
    def use_custom_allreduce(cls) -> bool:
        # We only enable custom allreduce for MI300 series
        gcn_arch = torch.cuda.get_device_properties(0).gcnArchName
403
        supported_archs = ['gfx94', 'gfx95']
404
        return any(gfx in gcn_arch for gfx in supported_archs)
405

406
407
408
409
    @classmethod
    def opaque_attention_op(cls) -> bool:
        return True

410
411
412
    @classmethod
    def get_cu_count(cls, device_id: int = 0) -> int:
        return torch.cuda.get_device_properties(
413
            device_id).multi_processor_count
414
415
416
417

    @classmethod
    def is_navi(cls) -> bool:
        return 'gfx1' in torch.cuda.get_device_properties(0).gcnArchName
418
419

    @classmethod
420
421
    def get_static_graph_wrapper_cls(cls) -> str:
        return "vllm.compilation.cuda_graph.CUDAGraphWrapper"
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451

    @classmethod
    def stateless_init_device_torch_dist_pg(
        cls,
        backend: str,
        prefix_store: PrefixStore,
        group_rank: int,
        group_size: int,
        timeout: timedelta,
    ) -> ProcessGroup:
        assert is_nccl_available()
        pg: ProcessGroup = ProcessGroup(
            prefix_store,
            group_rank,
            group_size,
        )
        from torch.distributed.distributed_c10d import ProcessGroupNCCL

        backend_options = ProcessGroupNCCL.Options()
        backend_options._timeout = timeout

        backend_class = ProcessGroupNCCL(prefix_store, group_rank, group_size,
                                         backend_options)
        backend_type = ProcessGroup.BackendType.NCCL
        device = torch.device("cuda")
        pg._set_default_backend(backend_type)
        backend_class._set_sequence_number_for_group()

        pg._register_backend(device, backend_type, backend_class)
        return pg
452
453
454
455

    @classmethod
    def device_count(cls) -> int:
        return cuda_device_count_stateless()
456
457

    @classmethod
458
459
    def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str,
                                    model_config: "ModelConfig") -> bool:
460
        return True
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480

    @classmethod
    def check_if_supports_dtype(cls, torch_dtype: torch.dtype):
        if torch_dtype == torch.bfloat16:  # noqa: SIM102
            if not cls.has_device_capability(80):
                capability = cls.get_device_capability()
                gpu_name = cls.get_device_name()

                if capability is None:
                    compute_str = "does not have a compute capability"
                else:
                    version_str = capability.as_version_str()
                    compute_str = f"has compute capability {version_str}"

                raise ValueError(
                    "Bfloat16 is only supported on GPUs "
                    "with compute capability of at least 8.0. "
                    f"Your {gpu_name} GPU {compute_str}. "
                    "You can use float16 instead by explicitly setting the "
                    "`dtype` flag in CLI, for example: --dtype=half.")
481
482
483
484

    @classmethod
    def support_hybrid_kv_cache(cls) -> bool:
        return True
485
486
487
488

    @classmethod
    def support_static_graph_mode(cls) -> bool:
        return True