rocm.py 17.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import os
5
from datetime import timedelta
6
from functools import cache, lru_cache, wraps
7
from typing import TYPE_CHECKING, Optional
8
9

import torch
10
11
from torch.distributed import PrefixStore, ProcessGroup
from torch.distributed.distributed_c10d import is_nccl_available
12

13
import vllm.envs as envs
14
from vllm.logger import init_logger
15
from vllm.utils import cuda_device_count_stateless
16

17
from .interface import DeviceCapability, Platform, PlatformEnum, _Backend
18

19
if TYPE_CHECKING:
20
    from vllm.config import ModelConfig, VllmConfig
21

22
23
logger = init_logger(__name__)

24
try:
25
26
27
    from amdsmi import (AmdSmiException, amdsmi_get_gpu_asic_info,
                        amdsmi_get_processor_handles, amdsmi_init,
                        amdsmi_shut_down, amdsmi_topo_get_link_type)
28
29
30
except ImportError as e:
    logger.warning("Failed to import from amdsmi with %r", e)

31
32
33
34
35
36
37
38
39
40
41
try:
    import vllm._C  # noqa: F401
except ImportError as e:
    logger.warning("Failed to import from vllm._C with %r", e)

# import custom ops, trigger op registration
try:
    import vllm._rocm_C  # noqa: F401
except ImportError as e:
    logger.warning("Failed to import from vllm._rocm_C with %r", e)

42
# Models not supported by ROCm.
43
_ROCM_UNSUPPORTED_MODELS: list[str] = []
44
45
46
47
48
49
50

# Models partially supported by ROCm.
# Architecture -> Reason.
_ROCM_SWA_REASON = ("Sliding window attention (SWA) is not yet supported in "
                    "Triton flash attention. For half-precision SWA support, "
                    "please use CK flash attention by setting "
                    "`VLLM_USE_TRITON_FLASH_ATTN=0`")
51
_ROCM_PARTIALLY_SUPPORTED_MODELS: dict[str, str] = {
52
53
54
55
56
57
58
59
60
61
62
63
64
65
    "Qwen2ForCausalLM":
    _ROCM_SWA_REASON,
    "MistralForCausalLM":
    _ROCM_SWA_REASON,
    "MixtralForCausalLM":
    _ROCM_SWA_REASON,
    "PaliGemmaForConditionalGeneration":
    ("ROCm flash attention does not yet "
     "fully support 32-bit precision on PaliGemma"),
    "Phi3VForCausalLM":
    ("ROCm Triton flash attention may run into compilation errors due to "
     "excessive use of shared memory. If this happens, disable Triton FA "
     "by setting `VLLM_USE_TRITON_FLASH_ATTN=0`")
}
66
_ROCM_DEVICE_ID_NAME_MAP: dict[str, str] = {
67
68
69
70
71
72
73
74
    "0x74a0": "AMD_Instinct_MI300A",
    "0x74a1": "AMD_Instinct_MI300X",
    "0x74b5": "AMD_Instinct_MI300X",  # MI300X VF
    "0x74a5": "AMD_Instinct_MI325X",
    "0x74b9": "AMD_Instinct_MI325X",  # MI325X VF
    "0x74a9": "AMD_Instinct_MI300X_HF",
    "0x74bd": "AMD_Instinct_MI300X_HF",
}
75

76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
# Prevent use of clashing `{CUDA/HIP}_VISIBLE_DEVICES``
if "HIP_VISIBLE_DEVICES" in os.environ:
    val = os.environ["HIP_VISIBLE_DEVICES"]
    if cuda_val := os.environ.get("CUDA_VISIBLE_DEVICES", None):
        assert val == cuda_val
    else:
        os.environ["CUDA_VISIBLE_DEVICES"] = val

# AMDSMI utils
# Note that NVML is not affected by `{CUDA/HIP}_VISIBLE_DEVICES`,
# all the related functions work on real physical device ids.
# the major benefit of using AMDSMI is that it will not initialize CUDA


def with_amdsmi_context(fn):

    @wraps(fn)
    def wrapper(*args, **kwargs):
        amdsmi_init()
        try:
            return fn(*args, **kwargs)
        finally:
            amdsmi_shut_down()

    return wrapper


103
104
105
106
107
108
@cache
def on_gfx1x() -> bool:
    GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
    return any(arch in GPU_ARCH for arch in ["gfx11", "gfx12"])


109
@cache
110
def on_mi3xx() -> bool:
111
    GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
112
113
114
115
116
117
118
    return any(arch in GPU_ARCH for arch in ["gfx942", "gfx950"])


@cache
def on_gfx9() -> bool:
    GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
    return any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"])
119
120


121
@cache
122
123
124
125
126
127
128
129
130
def use_rocm_custom_paged_attention(
        qtype: torch.dtype,
        head_size: int,
        block_size: int,
        gqa_ratio: int,
        max_seq_len: int,
        sliding_window: int,
        kv_cache_dtype: str,
        alibi_slopes: Optional[torch.Tensor] = None) -> bool:
131

132
133
    GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
    ON_GFX9 = any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"])
134
    ON_GFX11_GFX12 = any(arch in GPU_ARCH for arch in ["gfx11", "gfx12"])
135

136
137
    # custom paged attn always supported on V0. On V1, requires sliding window
    # disabled due to observed numerical discrepancy.
138
139
140
141
142
143
144
    if ON_GFX9:
        return ((not envs.VLLM_USE_V1 or sliding_window == 0
                 or sliding_window == (-1, -1))
                and (qtype == torch.half or qtype == torch.bfloat16)
                and (head_size == 64 or head_size == 128)
                and (block_size == 16 or block_size == 32)
                and (gqa_ratio >= 1 and gqa_ratio <= 16)
145
146
                and max_seq_len <= 128 * 1024
                and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN)
147
148
149
150
151
152
153
154
155
                and not (envs.VLLM_ROCM_USE_AITER_PAGED_ATTN
                         and envs.VLLM_ROCM_USE_AITER))

    else:
        return (ON_GFX11_GFX12 and (not envs.VLLM_USE_V1 or sliding_window == 0
                                    or sliding_window == (-1, -1))
                and (qtype == torch.half or qtype == torch.bfloat16)
                and head_size == 128 and block_size == 16
                and (gqa_ratio >= 3 and gqa_ratio <= 16)
156
                and max_seq_len <= 128 * 1024 and alibi_slopes is None
157
158
                and kv_cache_dtype == "auto"
                and envs.VLLM_ROCM_CUSTOM_PAGED_ATTN)
159
160


161
162
class RocmPlatform(Platform):
    _enum = PlatformEnum.ROCM
163
    device_name: str = "rocm"
164
    device_type: str = "cuda"
165
    dispatch_key: str = "CUDA"
166
    ray_device_key: str = "GPU"
167
    dist_backend: str = "nccl"
168
169
    # rocm shares the same device control env var as CUDA
    device_control_env_var: str = "CUDA_VISIBLE_DEVICES"
170

171
    supported_quantization: list[str] = [
172
173
        "awq", "gptq", "fp8", "compressed-tensors", "fbgemm_fp8", "gguf",
        "quark", "ptpc_fp8"
174
    ]
175

176
    @classmethod
177
    def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
178
179
                             kv_cache_dtype, block_size, use_v1,
                             use_mla) -> str:
180
        if use_mla:
181
182
183
184
185
186
187
188
189
190
            from vllm.attention.backends.rocm_aiter_mla import (
                is_aiter_mla_enabled)

            if selected_backend is None:
                selected_backend = (_Backend.ROCM_AITER_MLA if
                                    is_aiter_mla_enabled() or block_size == 1
                                    else _Backend.TRITON_MLA)

            if selected_backend == _Backend.TRITON_MLA:
                if block_size != 1:
191
192
193
194
195
196
197
198
                    if use_v1:
                        logger.info_once(
                            "Using Triton MLA backend on V1 engine.")
                        return ("vllm.v1.attention.backends.mla."
                                "triton_mla.TritonMLABackend")
                    else:
                        logger.info("Using Triton MLA backend.")
                        return "vllm.attention.backends.triton_mla.TritonMLABackend"  # noqa: E501
199
200
201
202
                else:
                    raise ValueError(
                        f" The selected backend, {selected_backend.name},"
                        f"does not support block size {block_size}.")
203
204
            elif selected_backend == _Backend.ROCM_AITER_MLA \
                or selected_backend == _Backend.ROCM_AITER_MLA_VLLM_V1:
205
                if block_size == 1:
206
207
208
209
210
211
                    if use_v1:
                        logger.info("Using AITER MLA backend on V1 engine.")
                        return "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend"  # noqa: E501
                    else:
                        logger.info("Using AITER MLA backend")
                        return "vllm.attention.backends.rocm_aiter_mla.AiterMLABackend"  # noqa: E501
212
213
214
215
216
217
218
219
220
221
                else:
                    raise ValueError(
                        f" The selected backend, {selected_backend.name},"
                        f"does not support block size {block_size}."
                        "(currently only supports block size 1)")
            else:
                raise ValueError(
                    f" The selected backend, {selected_backend.name},"
                    f"is not MLA type while requested for MLA backend.")

222
223
224
        if selected_backend is None or selected_backend == _Backend.FLASH_ATTN:
            selected_backend = _Backend.ROCM_FLASH

225
        if envs.VLLM_USE_V1:
226
227
228
229
230
231
232
233
234
            if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA \
                and on_gfx9():
                logger.info("Using Flash Attention backend on V1 engine.")
                return ("vllm.v1.attention.backends."
                        "rocm_aiter_fa.AiterFlashAttentionBackend")
            else:
                logger.info("Using Triton Attention backend on V1 engine.")
                return ("vllm.v1.attention.backends."
                        "triton_attn.TritonAttentionBackend")
235
236
237
238
239
240
        if selected_backend == _Backend.ROCM_FLASH:
            if not cls.has_device_capability(90):
                # not Instinct series GPUs.
                logger.info("flash_attn is not supported on NAVI GPUs.")
        else:
            logger.info("%s is not supported in AMD GPUs.", selected_backend)
241
242
        logger.info("Using ROCmFlashAttention backend.")
        return "vllm.attention.backends.rocm_flash_attn.ROCmFlashAttentionBackend"  # noqa: E501
243

244
    @classmethod
245
    @lru_cache(maxsize=8)
246
247
248
    def get_device_capability(cls,
                              device_id: int = 0
                              ) -> Optional[DeviceCapability]:
249
250
        major, minor = torch.cuda.get_device_capability(device_id)
        return DeviceCapability(major=major, minor=minor)
251

252
    @classmethod
253
    @with_amdsmi_context
254
    def is_fully_connected(cls, physical_device_ids: list[int]) -> bool:
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
        """
        Query if the set of gpus are fully connected by xgmi (1 hop)
        """
        handles = [
            amdsmi_get_processor_handles()[i] for i in physical_device_ids
        ]
        for i, handle in enumerate(handles):
            for j, peer_handle in enumerate(handles):
                if i < j:
                    try:
                        link_type = amdsmi_topo_get_link_type(
                            handle, peer_handle)
                        # type is 2 for XGMI
                        if link_type["hops"] != 1 or link_type["type"] != 2:
                            return False
                    except AmdSmiException as error:
                        logger.error("AMD 1 hop XGMI detection failed.",
                                     exc_info=error)
                        return False
        return True

276
    @classmethod
277
    @with_amdsmi_context
278
    @lru_cache(maxsize=8)
279
    def get_device_name(cls, device_id: int = 0) -> str:
280
        physical_device_id = cls.device_id_to_physical_device_id(device_id)
281
        handle = amdsmi_get_processor_handles()[physical_device_id]
282
283
284
285
286
        asic_info = amdsmi_get_gpu_asic_info(handle)
        device_name: str = asic_info["device_id"]
        if device_name in _ROCM_DEVICE_ID_NAME_MAP:
            return _ROCM_DEVICE_ID_NAME_MAP[device_name]
        return asic_info["market_name"]
287
288
289
290
291

    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        device_props = torch.cuda.get_device_properties(device_id)
        return device_props.total_memory
292

293
294
295
296
297
298
299
300
301
302
    @classmethod
    def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
        if enforce_eager:
            logger.warning(
                "To see benefits of async output processing, enable CUDA "
                "graph. Since, enforce-eager is enabled, async output "
                "processor cannot be used")
            return False
        return True

303
    @classmethod
304
    def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
305
306
307
308
        cache_config = vllm_config.cache_config
        if cache_config and cache_config.block_size is None:
            cache_config.block_size = 16

309
310
311
312
        parallel_config = vllm_config.parallel_config
        scheduler_config = vllm_config.scheduler_config
        if parallel_config.worker_cls == "auto":
            if scheduler_config.is_multi_step:
313
314
315
                if envs.VLLM_USE_V1:
                    raise NotImplementedError(
                        "Multi-step scheduling is not supported (and not "
316
                        "needed) on vLLM V1. Please launch without "
317
318
319
320
                        "--num-scheduler-steps.")
                else:
                    parallel_config.worker_cls = \
                        "vllm.worker.multi_step_worker.MultiStepWorker"
321
            elif vllm_config.speculative_config:
322
323
                if envs.VLLM_USE_V1:
                    raise NotImplementedError(
324
                        "Speculative decoding is not yet supported on vLLM V1."
325
326
327
328
329
330
                    )
                else:
                    parallel_config.worker_cls = \
                        "vllm.spec_decode.spec_decode_worker.create_spec_worker"
                    parallel_config.sd_worker_cls = \
                        "vllm.worker.worker.Worker"
331
            else:
332
333
334
335
336
                if envs.VLLM_USE_V1:
                    parallel_config.worker_cls = \
                            "vllm.v1.worker.gpu_worker.Worker"
                else:
                    parallel_config.worker_cls = "vllm.worker.worker.Worker"
337

338
339
340
341
342
343
344
345
346
347
348
349
    @classmethod
    def verify_model_arch(cls, model_arch: str) -> None:
        if model_arch in _ROCM_UNSUPPORTED_MODELS:
            raise ValueError(f"Model architecture '{model_arch}' is not "
                             "supported by ROCm for now.")

        if model_arch in _ROCM_PARTIALLY_SUPPORTED_MODELS:
            msg = _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch]
            logger.warning(
                "Model architecture '%s' is partially "
                "supported by ROCm: %s", model_arch, msg)

350
351
352
353
354
355
356
357
    @classmethod
    def verify_quantization(cls, quant: str) -> None:
        super().verify_quantization(quant)
        if quant == "awq" and not envs.VLLM_USE_TRITON_AWQ:
            logger.warning(
                "Using AWQ quantization with ROCm, but VLLM_USE_TRITON_AWQ"
                " is not set, enabling VLLM_USE_TRITON_AWQ.")
        envs.VLLM_USE_TRITON_AWQ = True
358
359
360
361

    @classmethod
    def get_punica_wrapper(cls) -> str:
        return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU"
362
363
364
365
366
367

    @classmethod
    def get_current_memory_usage(cls,
                                 device: Optional[torch.types.Device] = None
                                 ) -> float:
        torch.cuda.reset_peak_memory_stats(device)
368
369
        return torch.cuda.mem_get_info(device)[1] - torch.cuda.mem_get_info(
            device)[0]
370
371
372
373

    @classmethod
    def get_device_communicator_cls(cls) -> str:
        return "vllm.distributed.device_communicators.cuda_communicator.CudaCommunicator"  # noqa
374

375
376
377
378
379
    @classmethod
    def supports_mx(cls) -> bool:
        gcn_arch = torch.cuda.get_device_properties(0).gcnArchName
        return any(gfx in gcn_arch for gfx in ["gfx95"])

380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
    @classmethod
    def supports_fp8(cls) -> bool:
        gcn_arch = torch.cuda.get_device_properties(0).gcnArchName
        return any(gfx in gcn_arch for gfx in ['gfx94', 'gfx95', 'gfx12'])

    @classmethod
    def is_fp8_fnuz(cls) -> bool:
        # only device 0 is checked, this assumes MI300 platforms are homogeneous
        return 'gfx94' in torch.cuda.get_device_properties(0).gcnArchName

    @classmethod
    def fp8_dtype(cls) -> torch.dtype:
        if cls.is_fp8_fnuz():
            return torch.float8_e4m3fnuz
        else:
            return torch.float8_e4m3fn
396
397

    @classmethod
398
    def supports_v1(cls, model_config: "ModelConfig") -> bool:
399
400
        # V1 support on AMD gpus is experimental
        return True
401
402
403
404
405

    @classmethod
    def use_custom_allreduce(cls) -> bool:
        # We only enable custom allreduce for MI300 series
        gcn_arch = torch.cuda.get_device_properties(0).gcnArchName
406
        supported_archs = ['gfx94', 'gfx95']
407
        return any(gfx in gcn_arch for gfx in supported_archs)
408
409
410
411

    @classmethod
    def get_cu_count(cls, device_id: int = 0) -> int:
        return torch.cuda.get_device_properties(
412
            device_id).multi_processor_count
413
414
415
416

    @classmethod
    def is_navi(cls) -> bool:
        return 'gfx1' in torch.cuda.get_device_properties(0).gcnArchName
417
418
419
420

    @classmethod
    def get_piecewise_backend_cls(cls) -> str:
        return "vllm.compilation.cuda_piecewise_backend.CUDAPiecewiseBackend"  # noqa
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450

    @classmethod
    def stateless_init_device_torch_dist_pg(
        cls,
        backend: str,
        prefix_store: PrefixStore,
        group_rank: int,
        group_size: int,
        timeout: timedelta,
    ) -> ProcessGroup:
        assert is_nccl_available()
        pg: ProcessGroup = ProcessGroup(
            prefix_store,
            group_rank,
            group_size,
        )
        from torch.distributed.distributed_c10d import ProcessGroupNCCL

        backend_options = ProcessGroupNCCL.Options()
        backend_options._timeout = timeout

        backend_class = ProcessGroupNCCL(prefix_store, group_rank, group_size,
                                         backend_options)
        backend_type = ProcessGroup.BackendType.NCCL
        device = torch.device("cuda")
        pg._set_default_backend(backend_type)
        backend_class._set_sequence_number_for_group()

        pg._register_backend(device, backend_type, backend_class)
        return pg
451
452
453
454

    @classmethod
    def device_count(cls) -> int:
        return cuda_device_count_stateless()