rocm.py 25.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import os
5
from functools import cache, lru_cache, wraps
6
from typing import TYPE_CHECKING, Optional
7
8
9

import torch

10
import vllm.envs as envs
11
from vllm.logger import init_logger
12
from vllm.utils.torch_utils import cuda_device_count_stateless
13
from vllm.v1.attention.backends.registry import AttentionBackendEnum
14

15
from .interface import DeviceCapability, Platform, PlatformEnum
16

zhuwenwen's avatar
zhuwenwen committed
17

18
# from vllm.utils import SUPPORT_MOE_MARLIN_W16A16
19

20
21
22
# if SUPPORT_MOE_MARLIN_W16A16:
#     os.environ['VLLM_USE_MARLIN_W16A16_MOE'] = '1'
#     os.environ['MOE_NN'] = '0'
23

24
if TYPE_CHECKING:
25
    from vllm.config import VllmConfig
26
    from vllm.v1.attention.selector import AttentionSelectorConfig
27

28
29
logger = init_logger(__name__)

30
try:
31
32
33
34
35
36
37
38
    from amdsmi import (
        AmdSmiException,
        amdsmi_get_gpu_asic_info,
        amdsmi_get_processor_handles,
        amdsmi_init,
        amdsmi_shut_down,
        amdsmi_topo_get_link_type,
    )
39
40
41
except ImportError as e:
    logger.warning("Failed to import from amdsmi with %r", e)

42
43
44
45
46
47
try:
    import vllm._C  # noqa: F401
except ImportError as e:
    logger.warning("Failed to import from vllm._C with %r", e)

# import custom ops, trigger op registration
zhuwenwen's avatar
zhuwenwen committed
48
49
50
51
# try:
#     import vllm._rocm_C  # noqa: F401
# except ImportError as e:
#     logger.warning("Failed to import from vllm._rocm_C with %r", e)
52

53

54
# Models not supported by ROCm.
55
_ROCM_UNSUPPORTED_MODELS: list[str] = []
56
57
58

# Models partially supported by ROCm.
# Architecture -> Reason.
59
_ROCM_PARTIALLY_SUPPORTED_MODELS: dict[str, str] = {}
60
_ROCM_DEVICE_ID_NAME_MAP: dict[str, str] = {
61
62
63
    "0x74a0": "AMD_Instinct_MI300A",
    "0x74a1": "AMD_Instinct_MI300X",
    "0x74b5": "AMD_Instinct_MI300X",  # MI300X VF
64
    "0x74a2": "AMD_Instinct_MI308X",
65
66
67
68
    "0x74a5": "AMD_Instinct_MI325X",
    "0x74b9": "AMD_Instinct_MI325X",  # MI325X VF
    "0x74a9": "AMD_Instinct_MI300X_HF",
    "0x74bd": "AMD_Instinct_MI300X_HF",
69
    "0x744c": "AMD_Radeon_RX7900XTX",
70
}
71

72
# Prevent use of clashing `{CUDA/HIP}_VISIBLE_DEVICES`
zhuwenwen's avatar
zhuwenwen committed
73
74
75
76
77
78
# if "HIP_VISIBLE_DEVICES" in os.environ:
#     val = os.environ["HIP_VISIBLE_DEVICES"]
#     if cuda_val := os.environ.get("CUDA_VISIBLE_DEVICES", None):
#         assert val == cuda_val
#     else:
#         os.environ["CUDA_VISIBLE_DEVICES"] = val
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97

# AMDSMI utils
# Note that NVML is not affected by `{CUDA/HIP}_VISIBLE_DEVICES`,
# all the related functions work on real physical device ids.
# the major benefit of using AMDSMI is that it will not initialize CUDA


def with_amdsmi_context(fn):
    @wraps(fn)
    def wrapper(*args, **kwargs):
        amdsmi_init()
        try:
            return fn(*args, **kwargs)
        finally:
            amdsmi_shut_down()

    return wrapper


98
99
100
101
102
103
104
105
106
def device_id_to_physical_device_id(device_id: int) -> int:
    if "CUDA_VISIBLE_DEVICES" in os.environ:
        device_ids = os.environ["CUDA_VISIBLE_DEVICES"].split(",")
        physical_device_id = device_ids[device_id]
        return int(physical_device_id)
    else:
        return device_id


107
108
109
110
@cache
def on_gfx1x() -> bool:
    GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
    return any(arch in GPU_ARCH for arch in ["gfx11", "gfx12"])
111

112

113
@cache
114
def on_mi3xx() -> bool:
115
    GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
116
    return any(arch in GPU_ARCH for arch in ["gfx942", "gfx950"])
117
118


119
@cache
120
121
def on_gfx9() -> bool:
    GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
122
    return any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950", "gfx928", "gfx936", "gfx938"])
123

124
125
126
127
@cache
def get_gcn_arch_name() -> str:
    GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
    return GPU_ARCH.split(':')[0]
128

129
130
131
132
133
134
@cache
def on_gfx942() -> bool:
    GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
    return any(arch in GPU_ARCH for arch in ["gfx942"])


135
136
137
138
@cache
def on_gfx950() -> bool:
    GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
    return any(arch in GPU_ARCH for arch in ["gfx950"])
139
140


141
@cache
142
def use_rocm_custom_paged_attention(
143
144
145
146
147
148
149
    qtype: torch.dtype,
    head_size: int,
    block_size: int,
    gqa_ratio: int,
    max_seq_len: int,
    sliding_window: int,
    kv_cache_dtype: str,
150
151
    alibi_slopes: torch.Tensor | None = None,
    sinks: torch.Tensor | None = None,
152
) -> bool:
153
    GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
154
    ON_GFX9 = any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950", "gfx928", "gfx936"])
155
    ON_GFX11_GFX12 = any(arch in GPU_ARCH for arch in ["gfx11", "gfx12"])
156

157
158
    # custom paged attn always supported on V0. On V1, requires sliding window
    # disabled due to observed numerical discrepancy.
zhuwenwen's avatar
zhuwenwen committed
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
    # if ON_GFX9:
    #     return (
    #         (sliding_window == 0 or sliding_window == (-1, -1))
    #         and (qtype == torch.half or qtype == torch.bfloat16)
    #         and (head_size == 64 or head_size == 128)
    #         and (block_size == 16 or block_size == 32)
    #         and (gqa_ratio >= 1 and gqa_ratio <= 16)
    #         and max_seq_len <= 128 * 1024
    #         and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN)
    #         and sinks is None
    #     )

    # else:
    #     return (
    #         ON_GFX11_GFX12
    #         and (sliding_window == 0 or sliding_window == (-1, -1))
    #         and (qtype == torch.half or qtype == torch.bfloat16)
    #         and head_size == 128
    #         and block_size == 16
    #         and (gqa_ratio >= 3 and gqa_ratio <= 16)
    #         and max_seq_len <= 128 * 1024
    #         and alibi_slopes is None
    #         and kv_cache_dtype == "auto"
    #         and envs.VLLM_ROCM_CUSTOM_PAGED_ATTN
    #         and sinks is None
    #     )
185
    return False
186
187


188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
@cache
def flash_attn_triton_available() -> bool:
    if not on_gfx1x():
        return False
    try:
        from importlib.util import find_spec

        if find_spec("flash_attn") is None:
            return False
        if find_spec("flash_attn.flash_attn_triton_amd") is None:
            return False
        if os.environ.get("FLASH_ATTENTION_TRITON_AMD_ENABLE") != "TRUE":
            logger.info_once(
                "Set FLASH_ATTENTION_TRITON_AMD_ENABLE=TRUE to enable "
                "Flash Attention Triton backend on RDNA."
            )
            return False
        return True
    except ImportError:
        return False


210
211
class RocmPlatform(Platform):
    _enum = PlatformEnum.ROCM
212
    device_name: str = "rocm"
213
    device_type: str = "cuda"
214
    dispatch_key: str = "CUDA"
215
    ray_device_key: str = "GPU"
216
    dist_backend: str = "nccl"
217
218
    # rocm shares the same device control env var as CUDA
    device_control_env_var: str = "CUDA_VISIBLE_DEVICES"
219

220
    supported_quantization: list[str] = [
221
        "awq",
222
        "awq_marlin",  # will be overwritten with awq
223
        "gptq",
224
        "gptq_marlin",  # will be overwritten with gptq
225
226
227
228
229
230
231
232
233
        "fp8",
        "compressed-tensors",
        "fbgemm_fp8",
        "gguf",
        "quark",
        "ptpc_fp8",
        "mxfp4",
        "petit_nvfp4",
        "torchao",
234
235
236
237
238
239
        "moe_wna16", 
        "slimquant_w4a8", 
        "w8a8_int8", 
        "awq_marlin", 
        "slimquant_w4a8_marlin", 
        "slimquant_compressed_tensors_marlin"
240
    ]
241
242
243
    # bitsandbytes not supported on gfx9 (warp size 64 limitation)
    if not on_gfx9():
        supported_quantization += ["bitsandbytes"]
244

245
246
247
248
249
250
251
252
253
254
255
    @classmethod
    def import_kernels(cls) -> None:
        """Import ROCm-specific kernels."""
        super().import_kernels()

        import contextlib

        # Import ROCm-specific extension
        with contextlib.suppress(ImportError):
            import vllm._rocm_C  # noqa: F401

256
    @classmethod
257
258
    def get_attn_backend_cls(
        cls,
259
260
        selected_backend: "AttentionBackendEnum",
        attn_selector_config: "AttentionSelectorConfig",
261
    ) -> str:
262
        from vllm._aiter_ops import rocm_aiter_ops
263

264
265
266
267
        block_size = attn_selector_config.block_size
        kv_cache_dtype = attn_selector_config.kv_cache_dtype

        if attn_selector_config.use_sparse:
zhuwenwen's avatar
zhuwenwen committed
268
269
270
271
272
273
274
            # if kv_cache_dtype and kv_cache_dtype.startswith("fp8"):
            #     raise ValueError(
            #         "ROCMAiterMLASparseBackend doesn't support fp8 kv_cache_dtype."
            #     )
            # assert block_size == 1, (
            #     "Sparse MLA backend on ROCm only supports block size 1 for now."
            # )
275
            logger.info_once("Using Sparse MLA backend.")
zhuwenwen's avatar
zhuwenwen committed
276
277
            # return AttentionBackendEnum.ROCM_AITER_MLA_SPARSE.get_path()
            return AttentionBackendEnum.FLASHMLA_SPARSE.get_path()
278
                
zhuwenwen's avatar
zhuwenwen committed
279
280
        if attn_selector_config.use_mla:
            # if attn_selector_config.use_sparse:
281
            #     logger.info_once("Using Sparse MLA backend on V1 engine.")
zhuwenwen's avatar
zhuwenwen committed
282
            #     return AttentionBackendEnum.FLASHMLA_SPARSE.get_path()
283
                
284
285
            use_flashmla = selected_backend == AttentionBackendEnum.FLASHMLA or envs.VLLM_USE_FLASH_MLA 
            use_triton = selected_backend == AttentionBackendEnum.TRITON_MLA or (
286
287
288
289
290
291
292
293
                selected_backend is None)
            
            if use_flashmla: 
                if block_size != 64:
                    logger.warning(
                        "FlashMLA backend is not supported for block size %d"
                        " (currently only supports block size 64).",
                        block_size)
294
                else:
295
                    logger.info_once("Using FlashMLA backend on V1 engine.")
296
                    return AttentionBackendEnum.FLASHMLA.get_path()
297
298
                    
            if use_triton:
299
300
                logger.info_once("Using Triton MLA backend.")
                return AttentionBackendEnum.TRITON_MLA.get_path()
301

302
303
            raise ValueError(
                f" The selected backend, {selected_backend.name},"
304
305
                f"is not MLA type while requested for MLA backend."
            )
306
            
307
308
309
        
        if envs.VLLM_USE_FLASH_ATTN_PA and block_size == 64:
            logger.info_once("Using Flash Attention backend on V1 engine. (only supports block size 64)")
310
311
312
313
314
            return AttentionBackendEnum.FLASH_ATTN.get_path()
        else:
            os.environ['VLLM_USE_FLASH_ATTN_PA'] = '0'
            logger.info_once("Using Triton backend on V1 engine.")
            return AttentionBackendEnum.TRITON_ATTN.get_path()
zhuwenwen's avatar
zhuwenwen committed
315
            
316
            
317
318
319
320
        if selected_backend == AttentionBackendEnum.FLEX_ATTENTION:
            logger.info("Using FlexAttention backend.")
            return AttentionBackendEnum.FLEX_ATTENTION.get_path()

321
        if selected_backend == AttentionBackendEnum.TRITON_ATTN:
322
            logger.info("Using Triton Attention backend.")
323
324
325
            return AttentionBackendEnum.TRITON_ATTN.get_path()

        if selected_backend == AttentionBackendEnum.ROCM_ATTN:
326
            logger.info("Using Rocm Attention backend.")
327
            return AttentionBackendEnum.ROCM_ATTN.get_path()
328
329
330

        if selected_backend == AttentionBackendEnum.ROCM_AITER_FA:
            if on_gfx9():
331
                logger.info("Using Aiter Flash Attention backend.")
332
333
334
335
336
337
338
339
                return AttentionBackendEnum.ROCM_AITER_FA.get_path()
            else:
                raise ValueError(
                    f"The selected backend, {selected_backend.name}, "
                    "is only supported on gfx9 architectures."
                )

        if selected_backend == AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN:
340
            logger.info("Using Aiter Unified Attention backend.")
341
342
343
344
345
            return AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN.get_path()

        # Handle automatic backend selection based on environment variables
        if selected_backend is None:
            # Priority 1: Check for AITER Unified Attention (must check before MHA)
346
347
348
            # if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION:
            #     logger.info("Using Aiter Unified Attention backend.")
            #     return AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN.get_path()
349
350
351

            # Priority 2: Check for AITER MHA (Flash Attention)
            # Only use if explicitly enabled (not just VLLM_ROCM_USE_AITER=1)
352
353
354
            # if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9():
            #     logger.info("Using Aiter Flash Attention backend.")
            #     return AttentionBackendEnum.ROCM_AITER_FA.get_path()
355
356

            # Priority 3: Check for ROCM_ATTN (prefill-decode split)
357
            from vllm.config import get_current_vllm_config_or_none
358

359
360
361
362
363
            vllm_config = get_current_vllm_config_or_none()
            if (
                vllm_config is not None
                and vllm_config.attention_config.use_prefill_decode_attention
            ):
364
                logger.info("Using Rocm Attention backend.")
365
366
367
368
                return AttentionBackendEnum.ROCM_ATTN.get_path()

            # Priority 4: Check for AITER enabled without specific flags
            # This defaults to AITER FA only if MHA is not explicitly disabled
369
370
371
372
373
374
375
            # if (
            #     envs.VLLM_ROCM_USE_AITER
            #     and on_gfx9()
            #     and envs.VLLM_ROCM_USE_AITER_MHA is not False
            # ):
            #     logger.info("Using Aiter Flash Attention backend.")
            #     return AttentionBackendEnum.ROCM_AITER_FA.get_path()
376
377

            # Default: Triton Unified Attention
378
            logger.info("Using Triton Attention backend.")
379
380
            return AttentionBackendEnum.TRITON_ATTN.get_path()

381
        raise RuntimeError(
382
383
            f"Attention backend {selected_backend.name} is not supported on "
            "ROCm. Note that V0 attention backends have been removed."
384
        )
385

386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
    @classmethod
    def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]:
        return [
            AttentionBackendEnum.FLASH_ATTN,
            AttentionBackendEnum.ROCM_AITER_FA,
            AttentionBackendEnum.TORCH_SDPA,
        ]

    @classmethod
    def get_vit_attn_backend(
        cls,
        head_size: int,
        dtype: torch.dtype,
        backend: Optional["AttentionBackendEnum"] = None,
    ) -> "AttentionBackendEnum":
        if backend is not None:
            assert backend in cls.get_supported_vit_attn_backends(), (
                f"Backend {backend} is not supported for vit attention. "
                f"Supported backends are: {cls.get_supported_vit_attn_backends()}"
            )
            logger.info_once(f"Using backend {backend} for vit attention")
            return backend

        from importlib.util import find_spec

        from vllm._aiter_ops import rocm_aiter_ops

413
        if rocm_aiter_ops.is_enabled() and on_gfx9():
414
            logger.info_once("Using AITER Flash Attention backend for ViT model.")
415
416
            return AttentionBackendEnum.ROCM_AITER_FA

417
418
419
420
421
422
        if (
            on_gfx9()
            and find_spec("flash_attn") is not None
            and (dtype == torch.float16 or dtype == torch.bfloat16)
        ):
            logger.info_once("Using Flash Attention backend for ViT model.")
423
424
            return AttentionBackendEnum.FLASH_ATTN

425
426
427
428
429
430
431
432
433
434
435
        # RDNA3/RDNA4 (gfx11xx/gfx12xx): Use Flash Attention Triton backend
        if (
            on_gfx1x()
            and flash_attn_triton_available()
            and (dtype == torch.float16 or dtype == torch.bfloat16)
        ):
            logger.info_once(
                "Using Flash Attention (Triton backend) for ViT model on RDNA."
            )
            return AttentionBackendEnum.FLASH_ATTN

436
        logger.info_once("Using Torch SDPA backend for ViT model.")
437
438
        return AttentionBackendEnum.TORCH_SDPA

439
440
441
442
443
444
445
    @classmethod
    def set_device(cls, device: torch.device) -> None:
        """
        Set the device for the current platform.
        """
        torch.cuda.set_device(device)

446
    @classmethod
447
    @lru_cache(maxsize=8)
448
    def get_device_capability(cls, device_id: int = 0) -> DeviceCapability | None:
449
450
        major, minor = torch.cuda.get_device_capability(device_id)
        return DeviceCapability(major=major, minor=minor)
451

452
    @classmethod
453
    @with_amdsmi_context
454
    def is_fully_connected(cls, physical_device_ids: list[int]) -> bool:
zhuwenwen's avatar
zhuwenwen committed
455
456
457
        """
        Query if the set of gpus are fully connected by xgmi (1 hop)
        """
458
        handles = [amdsmi_get_processor_handles()[i] for i in physical_device_ids]
zhuwenwen's avatar
zhuwenwen committed
459
460
461
462
        for i, handle in enumerate(handles):
            for j, peer_handle in enumerate(handles):
                if i < j:
                    try:
463
                        link_type = amdsmi_topo_get_link_type(handle, peer_handle)
zhuwenwen's avatar
zhuwenwen committed
464
465
466
467
                        # type is 2 for XGMI
                        if link_type["hops"] != 1 or link_type["type"] != 2:
                            return False
                    except AmdSmiException as error:
468
                        logger.error("AMD 1 hop XGMI detection failed.", exc_info=error)
zhuwenwen's avatar
zhuwenwen committed
469
470
                        return False
        return True
471

472
    @classmethod
473
    @with_amdsmi_context
474
    @lru_cache(maxsize=8)
475
    def get_device_name(cls, device_id: int = 0) -> str:
zhuwenwen's avatar
zhuwenwen committed
476
        # physical_device_id = cls.device_id_to_physical_device_id(device_id)
477
        physical_device_id = device_id_to_physical_device_id(device_id)
478
        handle = amdsmi_get_processor_handles()[physical_device_id]
zhuwenwen's avatar
zhuwenwen committed
479
480
481
482
483
        # asic_info = amdsmi_get_gpu_asic_info(handle)
        # device_name: str = asic_info["device_id"]
        # if device_name in _ROCM_DEVICE_ID_NAME_MAP:
        #     return _ROCM_DEVICE_ID_NAME_MAP[device_name]
        # return asic_info["market_name"]
484
        return torch.cuda.get_device_name(device_id)
485

zhuwenwen's avatar
zhuwenwen committed
486
487
488
489
    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        device_props = torch.cuda.get_device_properties(device_id)
        return device_props.total_memory
490
491

    @classmethod
492
    def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
493
        from vllm._aiter_ops import rocm_aiter_ops
494
495
        from vllm.config.compilation import CUDAGraphMode

496
        cache_config = vllm_config.cache_config
497
498
499
        compilation_config = vllm_config.compilation_config
        parallel_config = vllm_config.parallel_config
        is_eager_execution = compilation_config == CUDAGraphMode.NONE
500
        use_aiter_fused_moe = rocm_aiter_ops.is_fused_moe_enabled()
501
        use_aiter_rms_norm = rocm_aiter_ops.is_rmsnorm_enabled()
502
        use_aiter_fp8_linear = rocm_aiter_ops.is_linear_fp8_enabled()
503
        use_aiter_fused_se = rocm_aiter_ops.is_fusion_moe_shared_experts_enabled()
504

505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
        if compilation_config.cudagraph_mode.has_full_cudagraphs():
            # decode context parallel does not support full cudagraphs
            if parallel_config.decode_context_parallel_size > 1:
                logger.warning_once(
                    "Decode context parallel (DCP) is enabled, which is "
                    "incompatible with full CUDA graphs. "
                    "Overriding cudagraph_mode to PIECEWISE."
                )
                compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
            # prefill context parallel do not support full cudagraphs
            elif parallel_config.prefill_context_parallel_size > 1:
                logger.warning_once(
                    "Prefill context parallel (PCP) is enabled, which is "
                    "incompatible with full CUDA graphs. "
                    "Overriding cudagraph_mode to PIECEWISE."
                )
                compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
522

523
        if cache_config and cache_config.block_size is None:
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
            if (
                envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION and envs.VLLM_ROCM_USE_AITER
                # NOTE: This block has been deprecated
                # or get_env_variable_attn_backend()
                # == AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN
                # TODO: monitor https://github.com/vllm-project/vllm/pull/30396
                # to see how we can transition to the new way of selecting
                # attention backends
            ):
                cache_config.block_size = 64
                logger.warning(
                    "[ROCM_AITER_UNIFIED_ATTN]: Setting kv cache block size to 64."
                )
            else:
                cache_config.block_size = 16
539

540
        if parallel_config.worker_cls == "auto":
541
            parallel_config.worker_cls = "vllm.v1.worker.gpu_worker.Worker"
542
        #  Aiter rms norm perform best when CUDA Graph capture is enabled.
543
        if (
544
            use_aiter_rms_norm
545
546
547
            and not is_eager_execution
            and "-rms_norm" not in compilation_config.custom_ops
        ):
548
            compilation_config.custom_ops.append("+rms_norm")
549

vllmellm's avatar
vllmellm committed
550
551
552
        if use_aiter_fp8_linear and "-quant_fp8" not in compilation_config.custom_ops:
            compilation_config.custom_ops.append("+quant_fp8")

553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
        if use_aiter_fused_se and "-grouped_topk" in compilation_config.custom_ops:
            logger.warning_once(
                "VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS is enabled, which "
                "requires the 'grouped_topk' custom op. Overriding the "
                "user-provided '-grouped_topk'."
            )
            compilation_config.custom_ops.remove("-grouped_topk")
        # Ensure grouped_topk is always enabled when using AITER if
        # its not disabled by user
        if (
            use_aiter_fused_moe
            and "+grouped_topk" not in compilation_config.custom_ops
            and "-grouped_topk" not in compilation_config.custom_ops
        ):
            compilation_config.custom_ops.append("+grouped_topk")

569
570
571
        # Default dispatch to rocm's sparse_attn_indexer implementation
        compilation_config.custom_ops.append("+sparse_attn_indexer")

572
573
574
    @classmethod
    def verify_model_arch(cls, model_arch: str) -> None:
        if model_arch in _ROCM_UNSUPPORTED_MODELS:
575
576
577
            raise ValueError(
                f"Model architecture '{model_arch}' is not supported by ROCm for now."
            )
578
579
580
581

        if model_arch in _ROCM_PARTIALLY_SUPPORTED_MODELS:
            msg = _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch]
            logger.warning(
582
583
584
585
                "Model architecture '%s' is partially supported by ROCm: %s",
                model_arch,
                msg,
            )
586

587
588
589
590
591
592
    @classmethod
    def verify_quantization(cls, quant: str) -> None:
        super().verify_quantization(quant)
        if quant == "awq" and not envs.VLLM_USE_TRITON_AWQ:
            logger.warning(
                "Using AWQ quantization with ROCm, but VLLM_USE_TRITON_AWQ"
593
594
                " is not set, enabling VLLM_USE_TRITON_AWQ."
            )
595
            envs.VLLM_USE_TRITON_AWQ = False
596
        # os.environ["VLLM_USE_TRITON_AWQ"] = "1"
597
598
599
600

    @classmethod
    def get_punica_wrapper(cls) -> str:
        return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU"
601
602

    @classmethod
603
    def get_current_memory_usage(
604
        cls, device: torch.types.Device | None = None
605
    ) -> float:
606
        torch.cuda.reset_peak_memory_stats(device)
607
608
        # free_mem, total_mem = torch.cuda.mem_get_info(device)
        # return total_mem - free_mem
zhuwenwen's avatar
zhuwenwen committed
609
        return torch.cuda.max_memory_allocated(device)
610
611
612

    @classmethod
    def get_device_communicator_cls(cls) -> str:
613
614
615
        return (
            "vllm.distributed.device_communicators.cuda_communicator.CudaCommunicator"  # noqa
        )
616

617
618
619
620
621
    @classmethod
    def supports_mx(cls) -> bool:
        gcn_arch = torch.cuda.get_device_properties(0).gcnArchName
        return any(gfx in gcn_arch for gfx in ["gfx95"])

622
623
624
    @classmethod
    def supports_fp8(cls) -> bool:
        gcn_arch = torch.cuda.get_device_properties(0).gcnArchName
625
        return any(gfx in gcn_arch for gfx in ["gfx94", "gfx95", "gfx12"])
626
627
628
629

    @classmethod
    def is_fp8_fnuz(cls) -> bool:
        # only device 0 is checked, this assumes MI300 platforms are homogeneous
630
        return "gfx94" in torch.cuda.get_device_properties(0).gcnArchName
631
632
633
634
635
636
637

    @classmethod
    def fp8_dtype(cls) -> torch.dtype:
        if cls.is_fp8_fnuz():
            return torch.float8_e4m3fnuz
        else:
            return torch.float8_e4m3fn
638

639
640
641
642
    @classmethod
    def use_custom_allreduce(cls) -> bool:
        # We only enable custom allreduce for MI300 series
        gcn_arch = torch.cuda.get_device_properties(0).gcnArchName
643
        supported_archs = ["gfx94", "gfx95"]
644
        return any(gfx in gcn_arch for gfx in supported_archs)
645

646
647
648
649
    @classmethod
    def opaque_attention_op(cls) -> bool:
        return True

650
651
    @classmethod
    def is_navi(cls) -> bool:
652
        return "gfx1" in torch.cuda.get_device_properties(0).gcnArchName
653
654

    @classmethod
655
656
    def get_static_graph_wrapper_cls(cls) -> str:
        return "vllm.compilation.cuda_graph.CUDAGraphWrapper"
657

658
659
660
    @classmethod
    def device_count(cls) -> int:
        return cuda_device_count_stateless()
661
662

    @classmethod
663
664
    def check_if_supports_dtype(cls, dtype: torch.dtype):
        if dtype == torch.bfloat16:  # noqa: SIM102
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
            if not cls.has_device_capability(80):
                capability = cls.get_device_capability()
                gpu_name = cls.get_device_name()

                if capability is None:
                    compute_str = "does not have a compute capability"
                else:
                    version_str = capability.as_version_str()
                    compute_str = f"has compute capability {version_str}"

                raise ValueError(
                    "Bfloat16 is only supported on GPUs "
                    "with compute capability of at least 8.0. "
                    f"Your {gpu_name} GPU {compute_str}. "
                    "You can use float16 instead by explicitly setting the "
680
681
                    "`dtype` flag in CLI, for example: --dtype=half."
                )
682
683
684
685

    @classmethod
    def support_hybrid_kv_cache(cls) -> bool:
        return True
686
687
688
689

    @classmethod
    def support_static_graph_mode(cls) -> bool:
        return True