rocm.py 19 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import os
5
from datetime import timedelta
6
from functools import cache, lru_cache, wraps
7
from typing import TYPE_CHECKING, Optional
8
9

import torch
10
11
from torch.distributed import PrefixStore, ProcessGroup
from torch.distributed.distributed_c10d import is_nccl_available
12

13
import vllm.envs as envs
14
from vllm.logger import init_logger
15
from vllm.utils import cuda_device_count_stateless
16

17
from .interface import DeviceCapability, Platform, PlatformEnum, _Backend
18

zhuwenwen's avatar
zhuwenwen committed
19
20
21
22
23
24
25
from vllm.utils import SUPPORT_TC

if not SUPPORT_TC:
    os.environ['VLLM_USE_V1'] = '0'
    os.environ['VLLM_USE_FLASH_ATTN_PA'] = '0'
    os.environ['VLLM_USE_FLASH_MLA'] = '0'

26
if TYPE_CHECKING:
27
    from vllm.config import ModelConfig, VllmConfig
28

29
30
logger = init_logger(__name__)

31
try:
32
33
34
    from amdsmi import (AmdSmiException, amdsmi_get_gpu_asic_info,
                        amdsmi_get_processor_handles, amdsmi_init,
                        amdsmi_shut_down, amdsmi_topo_get_link_type)
35
36
37
except ImportError as e:
    logger.warning("Failed to import from amdsmi with %r", e)

38
39
40
41
42
43
try:
    import vllm._C  # noqa: F401
except ImportError as e:
    logger.warning("Failed to import from vllm._C with %r", e)

# import custom ops, trigger op registration
zhuwenwen's avatar
zhuwenwen committed
44
45
46
47
# try:
#     import vllm._rocm_C  # noqa: F401
# except ImportError as e:
#     logger.warning("Failed to import from vllm._rocm_C with %r", e)
48

49

50
# Models not supported by ROCm.
51
_ROCM_UNSUPPORTED_MODELS: list[str] = []
52
53
54

# Models partially supported by ROCm.
# Architecture -> Reason.
zhuwenwen's avatar
zhuwenwen committed
55
56
57
58
# _ROCM_SWA_REASON = ("Sliding window attention (SWA) is not yet supported in "
#                     "Triton flash attention. For half-precision SWA support, "
#                     "please use CK flash attention by setting "
#                     "`VLLM_USE_TRITON_FLASH_ATTN=0`")
59
_ROCM_PARTIALLY_SUPPORTED_MODELS: dict[str, str] = {
zhuwenwen's avatar
zhuwenwen committed
60
61
62
63
64
65
    # "Qwen2ForCausalLM":
    # _ROCM_SWA_REASON,
    # "MistralForCausalLM":
    # _ROCM_SWA_REASON,
    # "MixtralForCausalLM":
    # _ROCM_SWA_REASON,
zhuwenwen's avatar
zhuwenwen committed
66
67
68
    # "PaliGemmaForConditionalGeneration":
    # ("ROCm flash attention does not yet "
    #  "fully support 32-bit precision on PaliGemma"),
zhuwenwen's avatar
zhuwenwen committed
69
70
71
72
    # "Phi3VForCausalLM":
    # ("ROCm Triton flash attention may run into compilation errors due to "
    #  "excessive use of shared memory. If this happens, disable Triton FA "
    #  "by setting `VLLM_USE_TRITON_FLASH_ATTN=0`")
73
}
74
_ROCM_DEVICE_ID_NAME_MAP: dict[str, str] = {
75
76
77
78
79
80
81
82
    "0x74a0": "AMD_Instinct_MI300A",
    "0x74a1": "AMD_Instinct_MI300X",
    "0x74b5": "AMD_Instinct_MI300X",  # MI300X VF
    "0x74a5": "AMD_Instinct_MI325X",
    "0x74b9": "AMD_Instinct_MI325X",  # MI325X VF
    "0x74a9": "AMD_Instinct_MI300X_HF",
    "0x74bd": "AMD_Instinct_MI300X_HF",
}
83

84
# Prevent use of clashing `{CUDA/HIP}_VISIBLE_DEVICES``
zhuwenwen's avatar
zhuwenwen committed
85
86
87
88
89
90
# if "HIP_VISIBLE_DEVICES" in os.environ:
#     val = os.environ["HIP_VISIBLE_DEVICES"]
#     if cuda_val := os.environ.get("CUDA_VISIBLE_DEVICES", None):
#         assert val == cuda_val
#     else:
#         os.environ["CUDA_VISIBLE_DEVICES"] = val
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110

# AMDSMI utils
# Note that NVML is not affected by `{CUDA/HIP}_VISIBLE_DEVICES`,
# all the related functions work on real physical device ids.
# the major benefit of using AMDSMI is that it will not initialize CUDA


def with_amdsmi_context(fn):

    @wraps(fn)
    def wrapper(*args, **kwargs):
        amdsmi_init()
        try:
            return fn(*args, **kwargs)
        finally:
            amdsmi_shut_down()

    return wrapper


111
112
113
114
115
116
117
118
119
def device_id_to_physical_device_id(device_id: int) -> int:
    if "CUDA_VISIBLE_DEVICES" in os.environ:
        device_ids = os.environ["CUDA_VISIBLE_DEVICES"].split(",")
        physical_device_id = device_ids[device_id]
        return int(physical_device_id)
    else:
        return device_id


120
121
122
123
@cache
def on_gfx1x() -> bool:
    GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
    return any(arch in GPU_ARCH for arch in ["gfx11", "gfx12"])
124

125

126
@cache
127
def on_mi3xx() -> bool:
128
    GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
129
    return any(arch in GPU_ARCH for arch in ["gfx942", "gfx950"])
130
131


132
@cache
133
134
def on_gfx9() -> bool:
    GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
135
    return any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950", "gfx928", "gfx936"])
136
137


138
139
140
141
@cache
def on_gfx950() -> bool:
    GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
    return any(arch in GPU_ARCH for arch in ["gfx950"])
142
143


144
@cache
145
146
147
148
149
150
151
152
def use_rocm_custom_paged_attention(
        qtype: torch.dtype,
        head_size: int,
        block_size: int,
        gqa_ratio: int,
        max_seq_len: int,
        sliding_window: int,
        kv_cache_dtype: str,
153
154
        alibi_slopes: Optional[torch.Tensor] = None,
        sinks: Optional[torch.Tensor] = None) -> bool:
155

156
    GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
157
    ON_GFX9 = any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950", "gfx928", "gfx936"])
158
    ON_GFX11_GFX12 = any(arch in GPU_ARCH for arch in ["gfx11", "gfx12"])
159

160
161
    # custom paged attn always supported on V0. On V1, requires sliding window
    # disabled due to observed numerical discrepancy.
zhuwenwen's avatar
zhuwenwen committed
162
163
164
165
166
167
168
    # if ON_GFX9:
    #     return ((not envs.VLLM_USE_V1 or sliding_window == 0
    #              or sliding_window == (-1, -1))
    #             and (qtype == torch.half or qtype == torch.bfloat16)
    #             and (head_size == 64 or head_size == 128)
    #             and (block_size == 16 or block_size == 32)
    #             and (gqa_ratio >= 1 and gqa_ratio <= 16)
zhuwenwen's avatar
zhuwenwen committed
169
170
    #             and max_seq_len <= 128 * 1024
    #             and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN)
zhuwenwen's avatar
zhuwenwen committed
171
    #             and not (envs.VLLM_ROCM_USE_AITER_PAGED_ATTN
172
    #                      and envs.VLLM_ROCM_USE_AITER) and sinks is None)
zhuwenwen's avatar
zhuwenwen committed
173
174
175
176
177
178
179

    # else:
    #     return (ON_GFX11_GFX12 and (not envs.VLLM_USE_V1 or sliding_window == 0
    #                                 or sliding_window == (-1, -1))
    #             and (qtype == torch.half or qtype == torch.bfloat16)
    #             and head_size == 128 and block_size == 16
    #             and (gqa_ratio >= 3 and gqa_ratio <= 16)
zhuwenwen's avatar
zhuwenwen committed
180
    #             and max_seq_len <= 128 * 1024 and alibi_slopes is None
zhuwenwen's avatar
zhuwenwen committed
181
    #             and kv_cache_dtype == "auto"
182
    #             and envs.VLLM_ROCM_CUSTOM_PAGED_ATTN and sinks is None)
183
    return False
184
185


186
187
class RocmPlatform(Platform):
    _enum = PlatformEnum.ROCM
188
    device_name: str = "rocm"
189
    device_type: str = "cuda"
190
    dispatch_key: str = "CUDA"
191
    ray_device_key: str = "GPU"
192
    dist_backend: str = "nccl"
193
194
    # rocm shares the same device control env var as CUDA
    device_control_env_var: str = "CUDA_VISIBLE_DEVICES"
195

196
    supported_quantization: list[str] = [
197
        "awq", "gptq", "fp8", "compressed-tensors", "fbgemm_fp8", "gguf",
198
        "quark", "ptpc_fp8", "mxfp4", "petit_nvfp4", "torchao",
199
        "moe_wna16", "slimquant_w4a8", "w8a8_int8", "awq_marlin", "slimquant_w4a8_marlin", "slimquant_compressed_tensors_marlin"
200
    ]
201

202
    @classmethod
203
204
    def get_vit_attn_backend(cls, head_size: int,
                             dtype: torch.dtype) -> _Backend:
205
206
207
208
209
        # if (envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA
        #         and on_gfx9()):
        #     # Note: AITER FA is only supported for Qwen-VL models.
        #     # TODO: Add support for other VL models in their model class.
        #     return _Backend.ROCM_AITER_FA
210
211
        if on_gfx9():
            return _Backend.FLASH_ATTN
212
213
        return _Backend.TORCH_SDPA

214
    @classmethod
215
    def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
216
                             kv_cache_dtype, block_size, use_v1, use_mla,
217
                             has_sink, use_sparse) -> str:
zhuwenwen's avatar
zhuwenwen committed
218
219
220
        # if use_sparse:
        #     raise NotImplementedError(
        #         "Sparse Attention is not supported on ROCm.")
221
        if use_mla:
222
223
224
225
            if not use_v1:
                raise RuntimeError(
                    "MLA attention backends require the V1 engine. "
                    "Set VLLM_USE_V1=1 to enable them.")
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
                
            
            from vllm.attention.ops.flashmla import is_flashmla_supported
            from vllm.attention.utils.fa_utils import flash_attn_supports_mla

            if use_sparse:
                logger.info_once("Using Sparse MLA backend on V1 engine.")
                return ("vllm.v1.attention.backends.mla.flashmla_sparse."
                        "FlashMLASparseBackend")
                
            use_flashmla = selected_backend == _Backend.FLASHMLA or envs.VLLM_USE_FLASH_MLA or (
                selected_backend is None and is_flashmla_supported()[0])
            use_triton = selected_backend == _Backend.TRITON_MLA or (
                selected_backend is None)
            
            if use_flashmla: 
                if block_size != 64:
                    logger.warning(
                        "FlashMLA backend is not supported for block size %d"
                        " (currently only supports block size 64).",
                        block_size)
247
                else:
248
                    logger.info_once("Using FlashMLA backend on V1 engine.")
249
                    return ("vllm.v1.attention.backends.mla."
250
251
252
253
254
255
256
                            "flashmla.FlashMLABackend")
                    
            if use_triton:
                logger.info_once("Using Triton MLA backend on V1 engine.")
                return ("vllm.v1.attention.backends.mla."
                        "triton_mla.TritonMLABackend")
        
257
        if envs.VLLM_USE_V1:
258
            TRITON_ATTN = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend"  # noqa: E501
zhuwenwen's avatar
zhuwenwen committed
259
260
            FLASH_ATTN_V1 = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"  # noqa: E501
            
261
            if envs.VLLM_USE_FLASH_ATTN_PA and block_size == 64:
zhuwenwen's avatar
zhuwenwen committed
262
263
                logger.info_once("Using Flash Attention backend on V1 engine. (only supports block size 64)")
                return FLASH_ATTN_V1
264
            else:
265
                os.environ['VLLM_USE_FLASH_ATTN_PA'] = '0'
zhuwenwen's avatar
zhuwenwen committed
266
                logger.info_once("Using Triton backend on V1 engine.")
267
268
                return TRITON_ATTN
            
269
270
271
        raise RuntimeError(
            "V0 attention backends have been removed. Set VLLM_USE_V1=1 "
            "to select a supported backend.")
272

273
274
275
276
277
278
279
    @classmethod
    def set_device(cls, device: torch.device) -> None:
        """
        Set the device for the current platform.
        """
        torch.cuda.set_device(device)

280
    @classmethod
281
    @lru_cache(maxsize=8)
282
283
284
    def get_device_capability(cls,
                              device_id: int = 0
                              ) -> Optional[DeviceCapability]:
285
286
        major, minor = torch.cuda.get_device_capability(device_id)
        return DeviceCapability(major=major, minor=minor)
287

288
    @classmethod
289
    @with_amdsmi_context
290
    def is_fully_connected(cls, physical_device_ids: list[int]) -> bool:
zhuwenwen's avatar
zhuwenwen committed
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
        """
        Query if the set of gpus are fully connected by xgmi (1 hop)
        """
        handles = [
            amdsmi_get_processor_handles()[i] for i in physical_device_ids
        ]
        for i, handle in enumerate(handles):
            for j, peer_handle in enumerate(handles):
                if i < j:
                    try:
                        link_type = amdsmi_topo_get_link_type(
                            handle, peer_handle)
                        # type is 2 for XGMI
                        if link_type["hops"] != 1 or link_type["type"] != 2:
                            return False
                    except AmdSmiException as error:
                        logger.error("AMD 1 hop XGMI detection failed.",
                                     exc_info=error)
                        return False
        return True
311

312
    @classmethod
313
    @with_amdsmi_context
314
    @lru_cache(maxsize=8)
315
    def get_device_name(cls, device_id: int = 0) -> str:
316
        physical_device_id = device_id_to_physical_device_id(device_id)
317
        handle = amdsmi_get_processor_handles()[physical_device_id]
318
319
        # return amdsmi_get_gpu_asic_info(handle)["market_name"]
        return torch.cuda.get_device_name(device_id)
320

zhuwenwen's avatar
zhuwenwen committed
321
322
323
324
    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        device_props = torch.cuda.get_device_properties(device_id)
        return device_props.total_memory
325
326

    @classmethod
327
    def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
328
329
        from vllm.config.compilation import CUDAGraphMode

330
        cache_config = vllm_config.cache_config
331
332
333
334
335
336
337
338
        compilation_config = vllm_config.compilation_config
        parallel_config = vllm_config.parallel_config
        is_eager_execution = compilation_config == CUDAGraphMode.NONE

        use_v1 = envs.VLLM_USE_V1
        use_aiter_rms_norm = envs.VLLM_ROCM_USE_AITER and \
             envs.VLLM_ROCM_USE_AITER_RMSNORM

339
340
341
        if cache_config and cache_config.block_size is None:
            cache_config.block_size = 16

342
        if parallel_config.worker_cls == "auto":
343
            if vllm_config.speculative_config:
344
                if not use_v1:
345
                    raise NotImplementedError(
346
347
                        "Speculative decoding is not supported on vLLM V0.")
                parallel_config.worker_cls = "vllm.v1.worker.gpu_worker.Worker"
348
            else:
349
                if use_v1:
350
                    parallel_config.worker_cls = \
351
                        "vllm.v1.worker.gpu_worker.Worker"
352
353
                else:
                    parallel_config.worker_cls = "vllm.worker.worker.Worker"
354
        #  Aiter rms norm perform best when CUDA Graph capture is enabled.
355
356
        if (use_v1 and use_aiter_rms_norm and not is_eager_execution
                and "-rms_norm" not in compilation_config.custom_ops):
357
            compilation_config.custom_ops.append("+rms_norm")
358

359
360
361
362
363
364
365
366
367
368
369
370
    @classmethod
    def verify_model_arch(cls, model_arch: str) -> None:
        if model_arch in _ROCM_UNSUPPORTED_MODELS:
            raise ValueError(f"Model architecture '{model_arch}' is not "
                             "supported by ROCm for now.")

        if model_arch in _ROCM_PARTIALLY_SUPPORTED_MODELS:
            msg = _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch]
            logger.warning(
                "Model architecture '%s' is partially "
                "supported by ROCm: %s", model_arch, msg)

371
372
373
374
375
376
    @classmethod
    def verify_quantization(cls, quant: str) -> None:
        super().verify_quantization(quant)
        if quant == "awq" and not envs.VLLM_USE_TRITON_AWQ:
            logger.warning(
                "Using AWQ quantization with ROCm, but VLLM_USE_TRITON_AWQ"
377
378
                " is not set, disabling VLLM_USE_TRITON_AWQ.")
            envs.VLLM_USE_TRITON_AWQ = False
379
380
381
382

    @classmethod
    def get_punica_wrapper(cls) -> str:
        return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU"
383
384
385
386
387
388

    @classmethod
    def get_current_memory_usage(cls,
                                 device: Optional[torch.types.Device] = None
                                 ) -> float:
        torch.cuda.reset_peak_memory_stats(device)
zhuwenwen's avatar
zhuwenwen committed
389
390
391
        # return torch.cuda.mem_get_info(device)[1] - torch.cuda.mem_get_info(
        #     device)[0]
        return torch.cuda.max_memory_allocated(device)
392
393
394
395

    @classmethod
    def get_device_communicator_cls(cls) -> str:
        return "vllm.distributed.device_communicators.cuda_communicator.CudaCommunicator"  # noqa
396

397
398
399
400
401
    @classmethod
    def supports_mx(cls) -> bool:
        gcn_arch = torch.cuda.get_device_properties(0).gcnArchName
        return any(gfx in gcn_arch for gfx in ["gfx95"])

402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
    @classmethod
    def supports_fp8(cls) -> bool:
        gcn_arch = torch.cuda.get_device_properties(0).gcnArchName
        return any(gfx in gcn_arch for gfx in ['gfx94', 'gfx95', 'gfx12'])

    @classmethod
    def is_fp8_fnuz(cls) -> bool:
        # only device 0 is checked, this assumes MI300 platforms are homogeneous
        return 'gfx94' in torch.cuda.get_device_properties(0).gcnArchName

    @classmethod
    def fp8_dtype(cls) -> torch.dtype:
        if cls.is_fp8_fnuz():
            return torch.float8_e4m3fnuz
        else:
            return torch.float8_e4m3fn
418

419
420
421
422
    @classmethod
    def use_custom_allreduce(cls) -> bool:
        # We only enable custom allreduce for MI300 series
        gcn_arch = torch.cuda.get_device_properties(0).gcnArchName
423
        supported_archs = ['gfx94', 'gfx95']
424
        return any(gfx in gcn_arch for gfx in supported_archs)
425

426
427
428
429
    @classmethod
    def opaque_attention_op(cls) -> bool:
        return True

430
431
432
    @classmethod
    def get_cu_count(cls, device_id: int = 0) -> int:
        return torch.cuda.get_device_properties(
433
            device_id).multi_processor_count
434
435
436
437

    @classmethod
    def is_navi(cls) -> bool:
        return 'gfx1' in torch.cuda.get_device_properties(0).gcnArchName
438
439

    @classmethod
440
441
    def get_static_graph_wrapper_cls(cls) -> str:
        return "vllm.compilation.cuda_graph.CUDAGraphWrapper"
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471

    @classmethod
    def stateless_init_device_torch_dist_pg(
        cls,
        backend: str,
        prefix_store: PrefixStore,
        group_rank: int,
        group_size: int,
        timeout: timedelta,
    ) -> ProcessGroup:
        assert is_nccl_available()
        pg: ProcessGroup = ProcessGroup(
            prefix_store,
            group_rank,
            group_size,
        )
        from torch.distributed.distributed_c10d import ProcessGroupNCCL

        backend_options = ProcessGroupNCCL.Options()
        backend_options._timeout = timeout

        backend_class = ProcessGroupNCCL(prefix_store, group_rank, group_size,
                                         backend_options)
        backend_type = ProcessGroup.BackendType.NCCL
        device = torch.device("cuda")
        pg._set_default_backend(backend_type)
        backend_class._set_sequence_number_for_group()

        pg._register_backend(device, backend_type, backend_class)
        return pg
472
473
474
475

    @classmethod
    def device_count(cls) -> int:
        return cuda_device_count_stateless()
476
477

    @classmethod
478
479
    def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str,
                                    model_config: "ModelConfig") -> bool:
480
        return True
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500

    @classmethod
    def check_if_supports_dtype(cls, torch_dtype: torch.dtype):
        if torch_dtype == torch.bfloat16:  # noqa: SIM102
            if not cls.has_device_capability(80):
                capability = cls.get_device_capability()
                gpu_name = cls.get_device_name()

                if capability is None:
                    compute_str = "does not have a compute capability"
                else:
                    version_str = capability.as_version_str()
                    compute_str = f"has compute capability {version_str}"

                raise ValueError(
                    "Bfloat16 is only supported on GPUs "
                    "with compute capability of at least 8.0. "
                    f"Your {gpu_name} GPU {compute_str}. "
                    "You can use float16 instead by explicitly setting the "
                    "`dtype` flag in CLI, for example: --dtype=half.")
501
502
503
504

    @classmethod
    def support_hybrid_kv_cache(cls) -> bool:
        return True
505
506
507
508

    @classmethod
    def support_static_graph_mode(cls) -> bool:
        return True