rocm.py 23.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import os
5
from functools import cache, lru_cache, wraps
6
from typing import TYPE_CHECKING, Optional
7
8
9

import torch

10
import vllm.envs as envs
11
from vllm.logger import init_logger
12
from vllm.utils.torch_utils import cuda_device_count_stateless
13
from vllm.v1.attention.backends.registry import AttentionBackendEnum
14

15
from .interface import DeviceCapability, Platform, PlatformEnum
16

zhuwenwen's avatar
zhuwenwen committed
17

18
# from vllm.utils import SUPPORT_MOE_MARLIN_W16A16
19

20
21
22
# if SUPPORT_MOE_MARLIN_W16A16:
#     os.environ['VLLM_USE_MARLIN_W16A16_MOE'] = '1'
#     os.environ['MOE_NN'] = '0'
23

24
if TYPE_CHECKING:
25
    from vllm.config import VllmConfig
26
    from vllm.v1.attention.selector import AttentionSelectorConfig
27

28
29
logger = init_logger(__name__)

30
try:
31
32
33
34
35
36
37
38
    from amdsmi import (
        AmdSmiException,
        amdsmi_get_gpu_asic_info,
        amdsmi_get_processor_handles,
        amdsmi_init,
        amdsmi_shut_down,
        amdsmi_topo_get_link_type,
    )
39
40
41
except ImportError as e:
    logger.warning("Failed to import from amdsmi with %r", e)

42
43
44
45
46
47
try:
    import vllm._C  # noqa: F401
except ImportError as e:
    logger.warning("Failed to import from vllm._C with %r", e)

# import custom ops, trigger op registration
zhuwenwen's avatar
zhuwenwen committed
48
49
50
51
# try:
#     import vllm._rocm_C  # noqa: F401
# except ImportError as e:
#     logger.warning("Failed to import from vllm._rocm_C with %r", e)
52

53

54
# Models not supported by ROCm.
55
_ROCM_UNSUPPORTED_MODELS: list[str] = []
56
57
58

# Models partially supported by ROCm.
# Architecture -> Reason.
59
60
_ROCM_SWA_REASON = ()
_ROCM_PARTIALLY_SUPPORTED_MODELS: dict[str, str] = {}
61
_ROCM_DEVICE_ID_NAME_MAP: dict[str, str] = {
62
63
64
    "0x74a0": "AMD_Instinct_MI300A",
    "0x74a1": "AMD_Instinct_MI300X",
    "0x74b5": "AMD_Instinct_MI300X",  # MI300X VF
65
    "0x74a2": "AMD_Instinct_MI308X",
66
67
68
69
    "0x74a5": "AMD_Instinct_MI325X",
    "0x74b9": "AMD_Instinct_MI325X",  # MI325X VF
    "0x74a9": "AMD_Instinct_MI300X_HF",
    "0x74bd": "AMD_Instinct_MI300X_HF",
70
    "0x744c": "AMD_Radeon_RX7900XTX",
71
}
72

73
# Prevent use of clashing `{CUDA/HIP}_VISIBLE_DEVICES`
zhuwenwen's avatar
zhuwenwen committed
74
75
76
77
78
79
# if "HIP_VISIBLE_DEVICES" in os.environ:
#     val = os.environ["HIP_VISIBLE_DEVICES"]
#     if cuda_val := os.environ.get("CUDA_VISIBLE_DEVICES", None):
#         assert val == cuda_val
#     else:
#         os.environ["CUDA_VISIBLE_DEVICES"] = val
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98

# AMDSMI utils
# Note that NVML is not affected by `{CUDA/HIP}_VISIBLE_DEVICES`,
# all the related functions work on real physical device ids.
# the major benefit of using AMDSMI is that it will not initialize CUDA


def with_amdsmi_context(fn):
    @wraps(fn)
    def wrapper(*args, **kwargs):
        amdsmi_init()
        try:
            return fn(*args, **kwargs)
        finally:
            amdsmi_shut_down()

    return wrapper


99
100
101
102
103
104
105
106
107
def device_id_to_physical_device_id(device_id: int) -> int:
    if "CUDA_VISIBLE_DEVICES" in os.environ:
        device_ids = os.environ["CUDA_VISIBLE_DEVICES"].split(",")
        physical_device_id = device_ids[device_id]
        return int(physical_device_id)
    else:
        return device_id


108
109
110
111
@cache
def on_gfx1x() -> bool:
    GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
    return any(arch in GPU_ARCH for arch in ["gfx11", "gfx12"])
112

113

114
@cache
115
def on_mi3xx() -> bool:
116
    GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
117
    return any(arch in GPU_ARCH for arch in ["gfx942", "gfx950"])
118
119


120
@cache
121
122
def on_gfx9() -> bool:
    GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
123
    return any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950", "gfx928", "gfx936", "gfx938"])
124
125


126
127
128
129
@cache
def on_gfx950() -> bool:
    GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
    return any(arch in GPU_ARCH for arch in ["gfx950"])
130
131


132
@cache
133
def use_rocm_custom_paged_attention(
134
135
136
137
138
139
140
    qtype: torch.dtype,
    head_size: int,
    block_size: int,
    gqa_ratio: int,
    max_seq_len: int,
    sliding_window: int,
    kv_cache_dtype: str,
141
142
    alibi_slopes: torch.Tensor | None = None,
    sinks: torch.Tensor | None = None,
143
) -> bool:
144
    GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
145
    ON_GFX9 = any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950", "gfx928", "gfx936"])
146
    ON_GFX11_GFX12 = any(arch in GPU_ARCH for arch in ["gfx11", "gfx12"])
147

148
149
    # custom paged attn always supported on V0. On V1, requires sliding window
    # disabled due to observed numerical discrepancy.
150
    if ON_GFX9:
151
        return (
152
            (sliding_window == 0 or sliding_window == (-1, -1))
153
154
155
156
157
158
159
160
            and (qtype == torch.half or qtype == torch.bfloat16)
            and (head_size == 64 or head_size == 128)
            and (block_size == 16 or block_size == 32)
            and (gqa_ratio >= 1 and gqa_ratio <= 16)
            and max_seq_len <= 128 * 1024
            and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN)
            and sinks is None
        )
161
162

    else:
163
164
        return (
            ON_GFX11_GFX12
165
            and (sliding_window == 0 or sliding_window == (-1, -1))
166
167
168
169
170
171
172
173
174
175
            and (qtype == torch.half or qtype == torch.bfloat16)
            and head_size == 128
            and block_size == 16
            and (gqa_ratio >= 3 and gqa_ratio <= 16)
            and max_seq_len <= 128 * 1024
            and alibi_slopes is None
            and kv_cache_dtype == "auto"
            and envs.VLLM_ROCM_CUSTOM_PAGED_ATTN
            and sinks is None
        )
176
    return False
177
178


179
180
class RocmPlatform(Platform):
    _enum = PlatformEnum.ROCM
181
    device_name: str = "rocm"
182
    device_type: str = "cuda"
183
    dispatch_key: str = "CUDA"
184
    ray_device_key: str = "GPU"
185
    dist_backend: str = "nccl"
186
187
    # rocm shares the same device control env var as CUDA
    device_control_env_var: str = "CUDA_VISIBLE_DEVICES"
188

189
    supported_quantization: list[str] = [
190
        "awq",
191
        "awq_marlin",  # will be overwritten with awq
192
        "gptq",
193
        "gptq_marlin",  # will be overwritten with gptq
194
195
196
197
198
199
200
201
202
        "fp8",
        "compressed-tensors",
        "fbgemm_fp8",
        "gguf",
        "quark",
        "ptpc_fp8",
        "mxfp4",
        "petit_nvfp4",
        "torchao",
203
204
205
206
207
208
        "moe_wna16", 
        "slimquant_w4a8", 
        "w8a8_int8", 
        "awq_marlin", 
        "slimquant_w4a8_marlin", 
        "slimquant_compressed_tensors_marlin"
209
    ]
210
211
212
    # bitsandbytes not supported on gfx9 (warp size 64 limitation)
    if not on_gfx9():
        supported_quantization += ["bitsandbytes"]
213

214
    @classmethod
215
216
    def get_attn_backend_cls(
        cls,
217
218
        selected_backend: "AttentionBackendEnum",
        attn_selector_config: "AttentionSelectorConfig",
219
    ) -> str:
220
        from vllm._aiter_ops import rocm_aiter_ops
221

222
223
224
225
226
        block_size = attn_selector_config.block_size
        kv_cache_dtype = attn_selector_config.kv_cache_dtype

        if attn_selector_config.use_sparse:
            if kv_cache_dtype and kv_cache_dtype.startswith("fp8"):
227
228
229
230
231
232
                raise ValueError(
                    "ROCMAiterMLASparseBackend doesn't support fp8 kv_cache_dtype."
                )
            assert block_size == 1, (
                "Sparse MLA backend on ROCm only supports block size 1 for now."
            )
233
            logger.info_once("Using Sparse MLA backend.")
234
            return AttentionBackendEnum.ROCM_AITER_MLA_SPARSE.get_path()
235
                
zhuwenwen's avatar
zhuwenwen committed
236
237
        if attn_selector_config.use_mla:
            # if attn_selector_config.use_sparse:
238
            #     logger.info_once("Using Sparse MLA backend on V1 engine.")
zhuwenwen's avatar
zhuwenwen committed
239
            #     return AttentionBackendEnum.FLASHMLA_SPARSE.get_path()
240
                
241
242
            use_flashmla = selected_backend == AttentionBackendEnum.FLASHMLA or envs.VLLM_USE_FLASH_MLA 
            use_triton = selected_backend == AttentionBackendEnum.TRITON_MLA or (
243
244
245
246
247
248
249
250
                selected_backend is None)
            
            if use_flashmla: 
                if block_size != 64:
                    logger.warning(
                        "FlashMLA backend is not supported for block size %d"
                        " (currently only supports block size 64).",
                        block_size)
251
                else:
252
                    logger.info_once("Using FlashMLA backend on V1 engine.")
253
                    return AttentionBackendEnum.FLASHMLA.get_path()
254
255
                    
            if use_triton:
256
257
                logger.info_once("Using Triton MLA backend.")
                return AttentionBackendEnum.TRITON_MLA.get_path()
258

259
260
            raise ValueError(
                f" The selected backend, {selected_backend.name},"
261
262
                f"is not MLA type while requested for MLA backend."
            )
263
            
264
        
265
266
267
268
269
270
271
        if envs.VLLM_USE_FLASH_ATTN_PA and block_size == 64:
            logger.info_once("Using Flash Attention backend on V1 engine. (only supports block size 64)")
            return AttentionBackendEnum.FLASH_ATTN.get_path()
        else:
            os.environ['VLLM_USE_FLASH_ATTN_PA'] = '0'
            logger.info_once("Using Triton backend on V1 engine.")
            return AttentionBackendEnum.TRITON_ATTN.get_path()
zhuwenwen's avatar
zhuwenwen committed
272
            
273
            
274
275
276
277
        if selected_backend == AttentionBackendEnum.FLEX_ATTENTION:
            logger.info("Using FlexAttention backend.")
            return AttentionBackendEnum.FLEX_ATTENTION.get_path()

278
        if selected_backend == AttentionBackendEnum.TRITON_ATTN:
279
            logger.info("Using Triton Attention backend.")
280
281
282
            return AttentionBackendEnum.TRITON_ATTN.get_path()

        if selected_backend == AttentionBackendEnum.ROCM_ATTN:
283
            logger.info("Using Rocm Attention backend.")
284
            return AttentionBackendEnum.ROCM_ATTN.get_path()
285
286
287

        if selected_backend == AttentionBackendEnum.ROCM_AITER_FA:
            if on_gfx9():
288
                logger.info("Using Aiter Flash Attention backend.")
289
290
291
292
293
294
295
296
                return AttentionBackendEnum.ROCM_AITER_FA.get_path()
            else:
                raise ValueError(
                    f"The selected backend, {selected_backend.name}, "
                    "is only supported on gfx9 architectures."
                )

        if selected_backend == AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN:
297
            logger.info("Using Aiter Unified Attention backend.")
298
299
300
301
302
            return AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN.get_path()

        # Handle automatic backend selection based on environment variables
        if selected_backend is None:
            # Priority 1: Check for AITER Unified Attention (must check before MHA)
303
304
305
            # if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION:
            #     logger.info("Using Aiter Unified Attention backend.")
            #     return AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN.get_path()
306
307
308

            # Priority 2: Check for AITER MHA (Flash Attention)
            # Only use if explicitly enabled (not just VLLM_ROCM_USE_AITER=1)
309
310
311
            # if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9():
            #     logger.info("Using Aiter Flash Attention backend.")
            #     return AttentionBackendEnum.ROCM_AITER_FA.get_path()
312
313
314

            # Priority 3: Check for ROCM_ATTN (prefill-decode split)
            if envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION:
315
                logger.info("Using Rocm Attention backend.")
316
317
318
319
                return AttentionBackendEnum.ROCM_ATTN.get_path()

            # Priority 4: Check for AITER enabled without specific flags
            # This defaults to AITER FA only if MHA is not explicitly disabled
320
321
322
323
324
325
326
            # if (
            #     envs.VLLM_ROCM_USE_AITER
            #     and on_gfx9()
            #     and envs.VLLM_ROCM_USE_AITER_MHA is not False
            # ):
            #     logger.info("Using Aiter Flash Attention backend.")
            #     return AttentionBackendEnum.ROCM_AITER_FA.get_path()
327
328

            # Default: Triton Unified Attention
329
            logger.info("Using Triton Attention backend.")
330
331
            return AttentionBackendEnum.TRITON_ATTN.get_path()

332
        raise RuntimeError(
333
334
            f"Attention backend {selected_backend.name} is not supported on "
            "ROCm. Note that V0 attention backends have been removed."
335
        )
336

337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
    @classmethod
    def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]:
        return [
            AttentionBackendEnum.FLASH_ATTN,
            AttentionBackendEnum.ROCM_AITER_FA,
            AttentionBackendEnum.TORCH_SDPA,
        ]

    @classmethod
    def get_vit_attn_backend(
        cls,
        head_size: int,
        dtype: torch.dtype,
        backend: Optional["AttentionBackendEnum"] = None,
    ) -> "AttentionBackendEnum":
        if backend is not None:
            assert backend in cls.get_supported_vit_attn_backends(), (
                f"Backend {backend} is not supported for vit attention. "
                f"Supported backends are: {cls.get_supported_vit_attn_backends()}"
            )
            logger.info_once(f"Using backend {backend} for vit attention")
            return backend

        from importlib.util import find_spec

        from vllm._aiter_ops import rocm_aiter_ops

364
365
        if rocm_aiter_ops.is_enabled():
            logger.info_once("Using AITER Flash Attention backend for ViT model.")
366
367
            return AttentionBackendEnum.ROCM_AITER_FA

368
369
370
371
372
373
        if (
            on_gfx9()
            and find_spec("flash_attn") is not None
            and (dtype == torch.float16 or dtype == torch.bfloat16)
        ):
            logger.info_once("Using Flash Attention backend for ViT model.")
374
375
            return AttentionBackendEnum.FLASH_ATTN

376
        logger.info_once("Using Torch SDPA backend for ViT model.")
377
378
        return AttentionBackendEnum.TORCH_SDPA

379
380
381
382
383
384
385
    @classmethod
    def set_device(cls, device: torch.device) -> None:
        """
        Set the device for the current platform.
        """
        torch.cuda.set_device(device)

386
    @classmethod
387
    @lru_cache(maxsize=8)
388
    def get_device_capability(cls, device_id: int = 0) -> DeviceCapability | None:
389
390
        major, minor = torch.cuda.get_device_capability(device_id)
        return DeviceCapability(major=major, minor=minor)
391

392
    @classmethod
393
    @with_amdsmi_context
394
    def is_fully_connected(cls, physical_device_ids: list[int]) -> bool:
zhuwenwen's avatar
zhuwenwen committed
395
396
397
        """
        Query if the set of gpus are fully connected by xgmi (1 hop)
        """
398
        handles = [amdsmi_get_processor_handles()[i] for i in physical_device_ids]
zhuwenwen's avatar
zhuwenwen committed
399
400
401
402
        for i, handle in enumerate(handles):
            for j, peer_handle in enumerate(handles):
                if i < j:
                    try:
403
                        link_type = amdsmi_topo_get_link_type(handle, peer_handle)
zhuwenwen's avatar
zhuwenwen committed
404
405
406
407
                        # type is 2 for XGMI
                        if link_type["hops"] != 1 or link_type["type"] != 2:
                            return False
                    except AmdSmiException as error:
408
                        logger.error("AMD 1 hop XGMI detection failed.", exc_info=error)
zhuwenwen's avatar
zhuwenwen committed
409
410
                        return False
        return True
411

412
    @classmethod
413
    @with_amdsmi_context
414
    @lru_cache(maxsize=8)
415
    def get_device_name(cls, device_id: int = 0) -> str:
zhuwenwen's avatar
zhuwenwen committed
416
        # physical_device_id = cls.device_id_to_physical_device_id(device_id)
417
        physical_device_id = device_id_to_physical_device_id(device_id)
418
        handle = amdsmi_get_processor_handles()[physical_device_id]
zhuwenwen's avatar
zhuwenwen committed
419
420
421
422
423
        # asic_info = amdsmi_get_gpu_asic_info(handle)
        # device_name: str = asic_info["device_id"]
        # if device_name in _ROCM_DEVICE_ID_NAME_MAP:
        #     return _ROCM_DEVICE_ID_NAME_MAP[device_name]
        # return asic_info["market_name"]
424
        return torch.cuda.get_device_name(device_id)
425

zhuwenwen's avatar
zhuwenwen committed
426
427
428
429
    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        device_props = torch.cuda.get_device_properties(device_id)
        return device_props.total_memory
430
431

    @classmethod
432
    def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
433
        from vllm._aiter_ops import rocm_aiter_ops
434
435
        from vllm.config.compilation import CUDAGraphMode

436
        cache_config = vllm_config.cache_config
437
438
439
        compilation_config = vllm_config.compilation_config
        parallel_config = vllm_config.parallel_config
        is_eager_execution = compilation_config == CUDAGraphMode.NONE
440
        use_aiter_fused_moe = rocm_aiter_ops.is_fused_moe_enabled()
441
        use_aiter_rms_norm = rocm_aiter_ops.is_rmsnorm_enabled()
442
        use_aiter_fp8_linear = rocm_aiter_ops.is_linear_fp8_enabled()
443
        use_aiter_fused_se = rocm_aiter_ops.is_fusion_moe_shared_experts_enabled()
444

445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
        if compilation_config.cudagraph_mode.has_full_cudagraphs():
            # decode context parallel does not support full cudagraphs
            if parallel_config.decode_context_parallel_size > 1:
                logger.warning_once(
                    "Decode context parallel (DCP) is enabled, which is "
                    "incompatible with full CUDA graphs. "
                    "Overriding cudagraph_mode to PIECEWISE."
                )
                compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
            # prefill context parallel do not support full cudagraphs
            elif parallel_config.prefill_context_parallel_size > 1:
                logger.warning_once(
                    "Prefill context parallel (PCP) is enabled, which is "
                    "incompatible with full CUDA graphs. "
                    "Overriding cudagraph_mode to PIECEWISE."
                )
                compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
462

463
        if cache_config and cache_config.block_size is None:
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
            if (
                envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION and envs.VLLM_ROCM_USE_AITER
                # NOTE: This block has been deprecated
                # or get_env_variable_attn_backend()
                # == AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN
                # TODO: monitor https://github.com/vllm-project/vllm/pull/30396
                # to see how we can transition to the new way of selecting
                # attention backends
            ):
                cache_config.block_size = 64
                logger.warning(
                    "[ROCM_AITER_UNIFIED_ATTN]: Setting kv cache block size to 64."
                )
            else:
                cache_config.block_size = 16
479

480
        if parallel_config.worker_cls == "auto":
481
            parallel_config.worker_cls = "vllm.v1.worker.gpu_worker.Worker"
482
        #  Aiter rms norm perform best when CUDA Graph capture is enabled.
483
        if (
484
            use_aiter_rms_norm
485
486
487
            and not is_eager_execution
            and "-rms_norm" not in compilation_config.custom_ops
        ):
488
            compilation_config.custom_ops.append("+rms_norm")
489

vllmellm's avatar
vllmellm committed
490
491
492
        if use_aiter_fp8_linear and "-quant_fp8" not in compilation_config.custom_ops:
            compilation_config.custom_ops.append("+quant_fp8")

493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
        if use_aiter_fused_se and "-grouped_topk" in compilation_config.custom_ops:
            logger.warning_once(
                "VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS is enabled, which "
                "requires the 'grouped_topk' custom op. Overriding the "
                "user-provided '-grouped_topk'."
            )
            compilation_config.custom_ops.remove("-grouped_topk")
        # Ensure grouped_topk is always enabled when using AITER if
        # its not disabled by user
        if (
            use_aiter_fused_moe
            and "+grouped_topk" not in compilation_config.custom_ops
            and "-grouped_topk" not in compilation_config.custom_ops
        ):
            compilation_config.custom_ops.append("+grouped_topk")

509
510
511
    @classmethod
    def verify_model_arch(cls, model_arch: str) -> None:
        if model_arch in _ROCM_UNSUPPORTED_MODELS:
512
513
514
            raise ValueError(
                f"Model architecture '{model_arch}' is not supported by ROCm for now."
            )
515
516
517
518

        if model_arch in _ROCM_PARTIALLY_SUPPORTED_MODELS:
            msg = _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch]
            logger.warning(
519
520
521
522
                "Model architecture '%s' is partially supported by ROCm: %s",
                model_arch,
                msg,
            )
523

524
525
526
527
528
529
    @classmethod
    def verify_quantization(cls, quant: str) -> None:
        super().verify_quantization(quant)
        if quant == "awq" and not envs.VLLM_USE_TRITON_AWQ:
            logger.warning(
                "Using AWQ quantization with ROCm, but VLLM_USE_TRITON_AWQ"
530
531
                " is not set, enabling VLLM_USE_TRITON_AWQ."
            )
532
            envs.VLLM_USE_TRITON_AWQ = False
533
        # os.environ["VLLM_USE_TRITON_AWQ"] = "1"
534
535
536
537

    @classmethod
    def get_punica_wrapper(cls) -> str:
        return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU"
538
539

    @classmethod
540
    def get_current_memory_usage(
541
        cls, device: torch.types.Device | None = None
542
    ) -> float:
543
        torch.cuda.reset_peak_memory_stats(device)
544
        # return torch.cuda.mem_get_info(device)[1] - torch.cuda.mem_get_info(device)[0]
zhuwenwen's avatar
zhuwenwen committed
545
        return torch.cuda.max_memory_allocated(device)
546
547
548

    @classmethod
    def get_device_communicator_cls(cls) -> str:
549
550
551
        return (
            "vllm.distributed.device_communicators.cuda_communicator.CudaCommunicator"  # noqa
        )
552

553
554
555
556
557
    @classmethod
    def supports_mx(cls) -> bool:
        gcn_arch = torch.cuda.get_device_properties(0).gcnArchName
        return any(gfx in gcn_arch for gfx in ["gfx95"])

558
559
560
    @classmethod
    def supports_fp8(cls) -> bool:
        gcn_arch = torch.cuda.get_device_properties(0).gcnArchName
561
        return any(gfx in gcn_arch for gfx in ["gfx94", "gfx95", "gfx12"])
562
563
564
565

    @classmethod
    def is_fp8_fnuz(cls) -> bool:
        # only device 0 is checked, this assumes MI300 platforms are homogeneous
566
        return "gfx94" in torch.cuda.get_device_properties(0).gcnArchName
567
568
569
570
571
572
573

    @classmethod
    def fp8_dtype(cls) -> torch.dtype:
        if cls.is_fp8_fnuz():
            return torch.float8_e4m3fnuz
        else:
            return torch.float8_e4m3fn
574

575
576
577
578
    @classmethod
    def use_custom_allreduce(cls) -> bool:
        # We only enable custom allreduce for MI300 series
        gcn_arch = torch.cuda.get_device_properties(0).gcnArchName
579
        supported_archs = ["gfx94", "gfx95"]
580
        return any(gfx in gcn_arch for gfx in supported_archs)
581

582
583
584
585
    @classmethod
    def opaque_attention_op(cls) -> bool:
        return True

586
587
    @classmethod
    def is_navi(cls) -> bool:
588
        return "gfx1" in torch.cuda.get_device_properties(0).gcnArchName
589
590

    @classmethod
591
592
    def get_static_graph_wrapper_cls(cls) -> str:
        return "vllm.compilation.cuda_graph.CUDAGraphWrapper"
593

594
595
596
    @classmethod
    def device_count(cls) -> int:
        return cuda_device_count_stateless()
597
598

    @classmethod
599
600
    def check_if_supports_dtype(cls, dtype: torch.dtype):
        if dtype == torch.bfloat16:  # noqa: SIM102
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
            if not cls.has_device_capability(80):
                capability = cls.get_device_capability()
                gpu_name = cls.get_device_name()

                if capability is None:
                    compute_str = "does not have a compute capability"
                else:
                    version_str = capability.as_version_str()
                    compute_str = f"has compute capability {version_str}"

                raise ValueError(
                    "Bfloat16 is only supported on GPUs "
                    "with compute capability of at least 8.0. "
                    f"Your {gpu_name} GPU {compute_str}. "
                    "You can use float16 instead by explicitly setting the "
616
617
                    "`dtype` flag in CLI, for example: --dtype=half."
                )
618
619
620
621

    @classmethod
    def support_hybrid_kv_cache(cls) -> bool:
        return True
622
623
624
625

    @classmethod
    def support_static_graph_mode(cls) -> bool:
        return True