rocm.py 19.9 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import os
4
from functools import cache, lru_cache, wraps
5
from typing import TYPE_CHECKING, Dict, List, Optional
6
7
8

import torch

9
import vllm.envs as envs
10
11
from vllm.logger import init_logger

12
from .interface import DeviceCapability, Platform, PlatformEnum, _Backend
13

14
if TYPE_CHECKING:
15
    from vllm.config import ModelConfig, VllmConfig
16

17
18
logger = init_logger(__name__)

19
try:
20
21
22
    from amdsmi import (AmdSmiException, amdsmi_get_gpu_asic_info,
                        amdsmi_get_processor_handles, amdsmi_init,
                        amdsmi_shut_down, amdsmi_topo_get_link_type)
23
24
25
except ImportError as e:
    logger.warning("Failed to import from amdsmi with %r", e)

26
27
28
29
30
31
try:
    import vllm._C  # noqa: F401
except ImportError as e:
    logger.warning("Failed to import from vllm._C with %r", e)

# import custom ops, trigger op registration
zhuwenwen's avatar
zhuwenwen committed
32
33
34
35
# try:
#     import vllm._rocm_C  # noqa: F401
# except ImportError as e:
#     logger.warning("Failed to import from vllm._rocm_C with %r", e)
36

37

38
39
40
41
42
# Models not supported by ROCm.
_ROCM_UNSUPPORTED_MODELS: List[str] = []

# Models partially supported by ROCm.
# Architecture -> Reason.
zhuwenwen's avatar
zhuwenwen committed
43
44
45
46
# _ROCM_SWA_REASON = ("Sliding window attention (SWA) is not yet supported in "
#                     "Triton flash attention. For half-precision SWA support, "
#                     "please use CK flash attention by setting "
#                     "`VLLM_USE_TRITON_FLASH_ATTN=0`")
47
_ROCM_PARTIALLY_SUPPORTED_MODELS: Dict[str, str] = {
zhuwenwen's avatar
zhuwenwen committed
48
49
50
51
52
53
    # "Qwen2ForCausalLM":
    # _ROCM_SWA_REASON,
    # "MistralForCausalLM":
    # _ROCM_SWA_REASON,
    # "MixtralForCausalLM":
    # _ROCM_SWA_REASON,
54
55
56
    "PaliGemmaForConditionalGeneration":
    ("ROCm flash attention does not yet "
     "fully support 32-bit precision on PaliGemma"),
zhuwenwen's avatar
zhuwenwen committed
57
58
59
60
    # "Phi3VForCausalLM":
    # ("ROCm Triton flash attention may run into compilation errors due to "
    #  "excessive use of shared memory. If this happens, disable Triton FA "
    #  "by setting `VLLM_USE_TRITON_FLASH_ATTN=0`")
61
62
}

63
# Prevent use of clashing `{CUDA/HIP}_VISIBLE_DEVICES``
64
65
66
67
68
69
# if "HIP_VISIBLE_DEVICES" in os.environ:
#     val = os.environ["HIP_VISIBLE_DEVICES"]
#     if cuda_val := os.environ.get("CUDA_VISIBLE_DEVICES", None):
#         assert val == cuda_val
#     else:
#         os.environ["CUDA_VISIBLE_DEVICES"] = val
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97

# AMDSMI utils
# Note that NVML is not affected by `{CUDA/HIP}_VISIBLE_DEVICES`,
# all the related functions work on real physical device ids.
# the major benefit of using AMDSMI is that it will not initialize CUDA


def with_amdsmi_context(fn):

    @wraps(fn)
    def wrapper(*args, **kwargs):
        amdsmi_init()
        try:
            return fn(*args, **kwargs)
        finally:
            amdsmi_shut_down()

    return wrapper


def device_id_to_physical_device_id(device_id: int) -> int:
    if "CUDA_VISIBLE_DEVICES" in os.environ:
        device_ids = os.environ["CUDA_VISIBLE_DEVICES"].split(",")
        physical_device_id = device_ids[device_id]
        return int(physical_device_id)
    else:
        return device_id

98

99
100
101
102
103
def on_mi250_mi300() -> bool:
    GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
    return any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942"])


104
105
106
107
108
109
@cache
def use_rocm_custom_paged_attention(qtype: torch.dtype, head_size: int,
                                    block_size: int, gqa_ratio: int,
                                    max_seq_len: int,
                                    sliding_window: int) -> bool:

110
    # rocm custom page attention not support on gfx1*
111
112
    # custom paged attn always supported on V0. On V1, requires sliding window
    # disabled due to observed numerical discrepancy.
113
114
115
116
117
118
119
120
121
122
    return False
    # return (on_mi250_mi300() and (not envs.VLLM_USE_V1 or sliding_window == 0
    #                               or sliding_window == (-1, -1))
    #         and (qtype == torch.half or qtype == torch.bfloat16)
    #         and (head_size == 64 or head_size == 128)
    #         and (block_size == 16 or block_size == 32)
    #         and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 32768
    #         and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN)
    #         and not (envs.VLLM_ROCM_USE_AITER_PAGED_ATTN
    #                  and envs.VLLM_ROCM_USE_AITER))
123
124


125
126
class RocmPlatform(Platform):
    _enum = PlatformEnum.ROCM
127
    device_name: str = "rocm"
128
    device_type: str = "cuda"
129
    dispatch_key: str = "CUDA"
130
    ray_device_key: str = "GPU"
131
132
    # rocm shares the same device control env var as CUDA
    device_control_env_var: str = "CUDA_VISIBLE_DEVICES"
133

134
    supported_quantization: list[str] = [
135
        "awq", "gptq", "fp8", "compressed-tensors", "fbgemm_fp8", "gguf",
zhuwenwen's avatar
zhuwenwen committed
136
        "quark", "ptpc_fp8", "moe_wna16", "blockwise_int8","w8a8_int8"
137
    ]
138

139
    @classmethod
140
    def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
141
142
                             kv_cache_dtype, block_size, use_v1,
                             use_mla) -> str:
143
        if use_mla:
zhuwenwen's avatar
zhuwenwen committed
144
145
146
147
148
149
150
            if selected_backend == _Backend.TRITON_MLA or block_size != 64:
                if use_v1:
                    logger.info_once("Using Triton MLA backend on V1 engine.")
                    return ("vllm.v1.attention.backends.mla."
                            "triton_mla.TritonMLABackend")
                else:
                    logger.info("Using Triton MLA backend.")
151
                    return "vllm.attention.backends.triton_mla.TritonMLABackend"  
zhuwenwen's avatar
zhuwenwen committed
152
            else:
153
154
155
156
157
158
159
160
161
162
163
164
                if envs.VLLM_USE_FLASH_MLA:
                    from vllm.attention.backends.flashmla import (
                        is_flashmla_supported)
                    if not is_flashmla_supported()[0]:
                        logger.warning(
                            "FlashMLA backend is not supported due to %s",
                            is_flashmla_supported()[1])
                    elif block_size != 64:
                        logger.warning(
                            "FlashMLA backend is not supported for block size %d"
                            " (currently only supports block size 64).",
                            block_size)
zhuwenwen's avatar
zhuwenwen committed
165
                    else:
166
167
168
169
170
171
172
173
174
175
176
                        if use_v1:
                            logger.info_once(
                                "Using FlashMLA backend on V1 engine.")
                            return ("vllm.v1.attention.backends.mla."
                                    "flashmla.FlashMLABackend")
                        else:
                            logger.info("Using FlashMLA backend.")
                            return ("vllm.attention.backends."
                                    "flashmla.FlashMLABackend")
                else:
                    logger.info("Using Triton MLA backend (block size 64).")
zhuwenwen's avatar
zhuwenwen committed
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
                    return "vllm.attention.backends.triton_mla.TritonMLABackend"              

            # from vllm.attention.backends.rocm_aiter_mla import (
            #     is_aiter_mla_enabled)

            # if selected_backend is None:
            #     selected_backend = (_Backend.ROCM_AITER_MLA if
            #                         is_aiter_mla_enabled() or block_size == 1
            #                         else _Backend.TRITON_MLA)

            # if selected_backend == _Backend.TRITON_MLA:
            #     if block_size != 1:
            #         logger.info("Using Triton MLA backend.")
            #         return "vllm.attention.backends.triton_mla.TritonMLABackend"  # noqa: E501
            #     else:
            #         raise ValueError(
            #             f" The selected backend, {selected_backend.name},"
            #             f"does not support block size {block_size}.")
            # elif selected_backend == _Backend.ROCM_AITER_MLA:
            #     if block_size == 1:
            #         logger.info("Using AITER MLA backend.")
            #         return "vllm.attention.backends.rocm_aiter_mla.AiterMLABackend"  # noqa: E501
            #     else:
            #         raise ValueError(
            #             f" The selected backend, {selected_backend.name},"
            #             f"does not support block size {block_size}."
            #             "(currently only supports block size 1)")
            # else:
zhuwenwen's avatar
zhuwenwen committed
205
206
207
            #     raise ValueError(
            #         f" The selected backend, {selected_backend.name},"
            #         f"is not MLA type while requested for MLA backend.")
208

209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
        if envs.VLLM_FLASH_ATTN_BACKEND:
            if use_v1:
                if selected_backend == _Backend.FLASHINFER:
                    raise ValueError("FlashInfer backend on V1 engine is not supported")
                if selected_backend == _Backend.TRITON_ATTN_VLLM_V1:
                    logger.info_once("Using Triton backend on V1 engine.")
                    return ("vllm.v1.attention.backends."
                            "triton_attn.TritonAttentionBackend")
                if cls.has_device_capability(80):
                    logger.info_once("Using Flash Attention backend on V1 engine.")
                    return ("vllm.v1.attention.backends."
                            "flash_attn.FlashAttentionBackend")
            if selected_backend == _Backend.FLASHINFER:
                raise ValueError("FlashInfer backend is not supported")
            elif selected_backend == _Backend.XFORMERS:
                raise ValueError("XFormers backend is not supported")
            elif selected_backend == _Backend.FLASH_ATTN:
                pass
            elif selected_backend:
                raise ValueError(
                    f"Invalid attention backend for {cls.device_name}, "
                    f"with use_v1: {use_v1} use_mla: {use_mla}")

            target_backend = _Backend.FLASH_ATTN
            if not cls.has_device_capability(80):
                # Volta and Turing NVIDIA GPUs.
                logger.info(
                    "Cannot use FlashAttention-2 backend for Volta and Turing "
                    "GPUs.")
                raise ValueError("XFormers backend is not supported")
            elif dtype not in (torch.float16, torch.bfloat16):
                logger.info(
                    "Cannot use FlashAttention-2 backend for dtype other than "
                    "torch.float16 or torch.bfloat16.")
243
244
                # raise ValueError("XFormers backend is not supported")
                pass
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
            elif block_size % 16 != 0:
                logger.info(
                    "Cannot use FlashAttention-2 backend for block size not "
                    "divisible by 16.")
                raise ValueError("XFormers backend is not supported")

            # FlashAttn is valid for the model, checking if the package is
            # installed.
            if target_backend == _Backend.FLASH_ATTN:
                try:
                    import flash_attn  # noqa: F401
                    from vllm.attention.backends.flash_attn import (  # noqa: F401
                        FlashAttentionBackend, flash_attn_supports_fp8)

                    supported_sizes = \
                        FlashAttentionBackend.get_supported_head_sizes()
                    if head_size not in supported_sizes:
                        logger.info(
                            "Cannot use FlashAttention-2 backend for head size %d.",
                            head_size)
                        raise ValueError("XFormers backend is not supported")
                    fp8_kv_cache = (kv_cache_dtype is not None
                                    and kv_cache_dtype.startswith("fp8"))
                    if (fp8_kv_cache and not flash_attn_supports_fp8()):
                        logger.info(
                            "Cannot use FlashAttention backend for FP8 KV cache.")
                        logger.warning(
                            "Please use FlashInfer backend with FP8 KV Cache for "
                            "better performance by setting environment variable "
                            "VLLM_ATTENTION_BACKEND=FLASHINFER")
                        raise ValueError("XFormers backend is not supported")
                except ImportError:
                    logger.info(
                        "Cannot use FlashAttention-2 backend because the "
                        "flash_attn package is not found. "
                        "Make sure that flash_attn was built and installed "
                        "(on by default).")
                    raise ValueError("XFormers backend is not supported")

            if target_backend == _Backend.XFORMERS:
                raise ValueError("XFormers backend is not supported")

            logger.info("Using Flash Attention backend.")
            return "vllm.attention.backends.flash_attn.FlashAttentionBackend"
    
290
        else:
291
292
293
294
295
296
297
298
299
300
301
302
303
304
            selected_backend = (_Backend.ROCM_FLASH if selected_backend
                                == _Backend.FLASH_ATTN else selected_backend)
            if envs.VLLM_USE_V1:
                logger.info("Using Triton Attention backend on V1 engine.")
                return ("vllm.v1.attention.backends."
                        "triton_attn.TritonAttentionBackend")
            if selected_backend == _Backend.ROCM_FLASH:
                if not cls.has_device_capability(90):
                    # not Instinct series GPUs.
                    logger.info("flash_attn is not supported on NAVI GPUs.")
            else:
                logger.info("%s is not supported in AMD GPUs.", selected_backend)
            logger.info("Using ROCmFlashAttention backend.")
            return "vllm.attention.backends.rocm_flash_attn.ROCmFlashAttentionBackend"  # noqa: E501
305

306
    @classmethod
307
    @lru_cache(maxsize=8)
308
309
310
    def get_device_capability(cls,
                              device_id: int = 0
                              ) -> Optional[DeviceCapability]:
311
312
        major, minor = torch.cuda.get_device_capability(device_id)
        return DeviceCapability(major=major, minor=minor)
313

zhuwenwen's avatar
zhuwenwen committed
314
    @staticmethod
315
316
    @with_amdsmi_context
    def is_fully_connected(physical_device_ids: List[int]) -> bool:
zhuwenwen's avatar
zhuwenwen committed
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
        """
        Query if the set of gpus are fully connected by xgmi (1 hop)
        """
        handles = [
            amdsmi_get_processor_handles()[i] for i in physical_device_ids
        ]
        for i, handle in enumerate(handles):
            for j, peer_handle in enumerate(handles):
                if i < j:
                    try:
                        link_type = amdsmi_topo_get_link_type(
                            handle, peer_handle)
                        # type is 2 for XGMI
                        if link_type["hops"] != 1 or link_type["type"] != 2:
                            return False
                    except AmdSmiException as error:
                        logger.error("AMD 1 hop XGMI detection failed.",
                                     exc_info=error)
                        return False
        return True
337

338
    @classmethod
339
    @with_amdsmi_context
340
    @lru_cache(maxsize=8)
341
    def get_device_name(cls, device_id: int = 0) -> str:
342
343
        physical_device_id = device_id_to_physical_device_id(device_id)
        handle = amdsmi_get_processor_handles()[physical_device_id]
zhuwenwen's avatar
zhuwenwen committed
344
345
        # return amdsmi_get_gpu_asic_info(handle)["market_name"]
        return torch.cuda.get_device_name(device_id)
346

zhuwenwen's avatar
zhuwenwen committed
347
348
349
350
    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        device_props = torch.cuda.get_device_properties(device_id)
        return device_props.total_memory
351

352
353
354
355
356
357
358
359
360
361
    @classmethod
    def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
        if enforce_eager:
            logger.warning(
                "To see benefits of async output processing, enable CUDA "
                "graph. Since, enforce-eager is enabled, async output "
                "processor cannot be used")
            return False
        return True

362
    @classmethod
363
    def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
364
365
366
367
        cache_config = vllm_config.cache_config
        if cache_config and cache_config.block_size is None:
            cache_config.block_size = 16

368
369
370
371
        parallel_config = vllm_config.parallel_config
        scheduler_config = vllm_config.scheduler_config
        if parallel_config.worker_cls == "auto":
            if scheduler_config.is_multi_step:
372
373
374
                if envs.VLLM_USE_V1:
                    raise NotImplementedError(
                        "Multi-step scheduling is not supported (and not "
375
                        "needed) on vLLM V1. Please launch without "
376
377
378
379
                        "--num-scheduler-steps.")
                else:
                    parallel_config.worker_cls = \
                        "vllm.worker.multi_step_worker.MultiStepWorker"
380
            elif vllm_config.speculative_config:
381
382
                if envs.VLLM_USE_V1:
                    raise NotImplementedError(
383
                        "Speculative decoding is not yet supported on vLLM V1."
384
385
386
387
388
389
                    )
                else:
                    parallel_config.worker_cls = \
                        "vllm.spec_decode.spec_decode_worker.create_spec_worker"
                    parallel_config.sd_worker_cls = \
                        "vllm.worker.worker.Worker"
390
            else:
391
392
393
394
395
                if envs.VLLM_USE_V1:
                    parallel_config.worker_cls = \
                            "vllm.v1.worker.gpu_worker.Worker"
                else:
                    parallel_config.worker_cls = "vllm.worker.worker.Worker"
396

397
398
399
400
401
402
403
404
405
406
407
408
    @classmethod
    def verify_model_arch(cls, model_arch: str) -> None:
        if model_arch in _ROCM_UNSUPPORTED_MODELS:
            raise ValueError(f"Model architecture '{model_arch}' is not "
                             "supported by ROCm for now.")

        if model_arch in _ROCM_PARTIALLY_SUPPORTED_MODELS:
            msg = _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch]
            logger.warning(
                "Model architecture '%s' is partially "
                "supported by ROCm: %s", model_arch, msg)

409
410
411
412
413
414
    @classmethod
    def verify_quantization(cls, quant: str) -> None:
        super().verify_quantization(quant)
        if quant == "awq" and not envs.VLLM_USE_TRITON_AWQ:
            logger.warning(
                "Using AWQ quantization with ROCm, but VLLM_USE_TRITON_AWQ"
415
416
                " is not set, disabling VLLM_USE_TRITON_AWQ.")
            envs.VLLM_USE_TRITON_AWQ = False
417
418
419
420

    @classmethod
    def get_punica_wrapper(cls) -> str:
        return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU"
421
422
423
424
425
426

    @classmethod
    def get_current_memory_usage(cls,
                                 device: Optional[torch.types.Device] = None
                                 ) -> float:
        torch.cuda.reset_peak_memory_stats(device)
zhuwenwen's avatar
zhuwenwen committed
427
428
429
        # return torch.cuda.mem_get_info(device)[1] - torch.cuda.mem_get_info(
        #     device)[0]
        return torch.cuda.max_memory_allocated(device)
430
431
432
433

    @classmethod
    def get_device_communicator_cls(cls) -> str:
        return "vllm.distributed.device_communicators.cuda_communicator.CudaCommunicator"  # noqa
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450

    @classmethod
    def supports_fp8(cls) -> bool:
        gcn_arch = torch.cuda.get_device_properties(0).gcnArchName
        return any(gfx in gcn_arch for gfx in ['gfx94', 'gfx95', 'gfx12'])

    @classmethod
    def is_fp8_fnuz(cls) -> bool:
        # only device 0 is checked, this assumes MI300 platforms are homogeneous
        return 'gfx94' in torch.cuda.get_device_properties(0).gcnArchName

    @classmethod
    def fp8_dtype(cls) -> torch.dtype:
        if cls.is_fp8_fnuz():
            return torch.float8_e4m3fnuz
        else:
            return torch.float8_e4m3fn
451
452

    @classmethod
453
    def supports_v1(cls, model_config: "ModelConfig") -> bool:
454
455
        # V1 support on AMD gpus is experimental
        return True
456
457
458
459
460
461
462

    @classmethod
    def use_custom_allreduce(cls) -> bool:
        # We only enable custom allreduce for MI300 series
        gcn_arch = torch.cuda.get_device_properties(0).gcnArchName
        supported_archs = ['gfx94']
        return any(gfx in gcn_arch for gfx in supported_archs)
463
464
465
466

    @classmethod
    def get_cu_count(cls, device_id: int = 0) -> int:
        return torch.cuda.get_device_properties(
zhuwenwen's avatar
zhuwenwen committed
467
            device_id).multi_processor_count