rocm.py 22.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import os
5
from functools import cache, lru_cache, wraps
6
from typing import TYPE_CHECKING, Optional
7
8
9

import torch

10
import vllm.envs as envs
11
from vllm.logger import init_logger
12
from vllm.utils.torch_utils import cuda_device_count_stateless
13
from vllm.v1.attention.backends.registry import AttentionBackendEnum
14

15
from .interface import DeviceCapability, Platform, PlatformEnum
16

17
if TYPE_CHECKING:
18
    from vllm.config import VllmConfig
19
    from vllm.v1.attention.selector import AttentionSelectorConfig
20

21
22
logger = init_logger(__name__)

23
try:
24
25
26
27
28
29
30
31
    from amdsmi import (
        AmdSmiException,
        amdsmi_get_gpu_asic_info,
        amdsmi_get_processor_handles,
        amdsmi_init,
        amdsmi_shut_down,
        amdsmi_topo_get_link_type,
    )
32
33
34
except ImportError as e:
    logger.warning("Failed to import from amdsmi with %r", e)

35
36
37
38
39
40
41
42
43
44
45
try:
    import vllm._C  # noqa: F401
except ImportError as e:
    logger.warning("Failed to import from vllm._C with %r", e)

# import custom ops, trigger op registration
try:
    import vllm._rocm_C  # noqa: F401
except ImportError as e:
    logger.warning("Failed to import from vllm._rocm_C with %r", e)

46
# Models not supported by ROCm.
47
_ROCM_UNSUPPORTED_MODELS: list[str] = []
48
49
50

# Models partially supported by ROCm.
# Architecture -> Reason.
51
52
_ROCM_SWA_REASON = ()
_ROCM_PARTIALLY_SUPPORTED_MODELS: dict[str, str] = {}
53
_ROCM_DEVICE_ID_NAME_MAP: dict[str, str] = {
54
55
56
    "0x74a0": "AMD_Instinct_MI300A",
    "0x74a1": "AMD_Instinct_MI300X",
    "0x74b5": "AMD_Instinct_MI300X",  # MI300X VF
57
    "0x74a2": "AMD_Instinct_MI308X",
58
59
60
61
    "0x74a5": "AMD_Instinct_MI325X",
    "0x74b9": "AMD_Instinct_MI325X",  # MI325X VF
    "0x74a9": "AMD_Instinct_MI300X_HF",
    "0x74bd": "AMD_Instinct_MI300X_HF",
62
    "0x744c": "AMD_Radeon_RX7900XTX",
63
}
64

65
# Prevent use of clashing `{CUDA/HIP}_VISIBLE_DEVICES`
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
if "HIP_VISIBLE_DEVICES" in os.environ:
    val = os.environ["HIP_VISIBLE_DEVICES"]
    if cuda_val := os.environ.get("CUDA_VISIBLE_DEVICES", None):
        assert val == cuda_val
    else:
        os.environ["CUDA_VISIBLE_DEVICES"] = val

# AMDSMI utils
# Note that NVML is not affected by `{CUDA/HIP}_VISIBLE_DEVICES`,
# all the related functions work on real physical device ids.
# the major benefit of using AMDSMI is that it will not initialize CUDA


def with_amdsmi_context(fn):
    @wraps(fn)
    def wrapper(*args, **kwargs):
        amdsmi_init()
        try:
            return fn(*args, **kwargs)
        finally:
            amdsmi_shut_down()

    return wrapper


91
92
93
94
95
96
@cache
def on_gfx1x() -> bool:
    GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
    return any(arch in GPU_ARCH for arch in ["gfx11", "gfx12"])


97
@cache
98
def on_mi3xx() -> bool:
99
    GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
100
101
102
103
104
105
106
    return any(arch in GPU_ARCH for arch in ["gfx942", "gfx950"])


@cache
def on_gfx9() -> bool:
    GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
    return any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"])
107
108


109
110
111
112
113
114
@cache
def on_gfx950() -> bool:
    GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
    return any(arch in GPU_ARCH for arch in ["gfx950"])


115
@cache
116
def use_rocm_custom_paged_attention(
117
118
119
120
121
122
123
    qtype: torch.dtype,
    head_size: int,
    block_size: int,
    gqa_ratio: int,
    max_seq_len: int,
    sliding_window: int,
    kv_cache_dtype: str,
124
125
    alibi_slopes: torch.Tensor | None = None,
    sinks: torch.Tensor | None = None,
126
) -> bool:
127
128
    GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
    ON_GFX9 = any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"])
129
    ON_GFX11_GFX12 = any(arch in GPU_ARCH for arch in ["gfx11", "gfx12"])
130

131
132
    # custom paged attn always supported on V0. On V1, requires sliding window
    # disabled due to observed numerical discrepancy.
133
    if ON_GFX9:
134
        return (
135
            (sliding_window == 0 or sliding_window == (-1, -1))
136
137
138
139
140
141
142
143
            and (qtype == torch.half or qtype == torch.bfloat16)
            and (head_size == 64 or head_size == 128)
            and (block_size == 16 or block_size == 32)
            and (gqa_ratio >= 1 and gqa_ratio <= 16)
            and max_seq_len <= 128 * 1024
            and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN)
            and sinks is None
        )
144
145

    else:
146
147
        return (
            ON_GFX11_GFX12
148
            and (sliding_window == 0 or sliding_window == (-1, -1))
149
150
151
152
153
154
155
156
157
158
            and (qtype == torch.half or qtype == torch.bfloat16)
            and head_size == 128
            and block_size == 16
            and (gqa_ratio >= 3 and gqa_ratio <= 16)
            and max_seq_len <= 128 * 1024
            and alibi_slopes is None
            and kv_cache_dtype == "auto"
            and envs.VLLM_ROCM_CUSTOM_PAGED_ATTN
            and sinks is None
        )
159
160


161
162
class RocmPlatform(Platform):
    _enum = PlatformEnum.ROCM
163
    device_name: str = "rocm"
164
    device_type: str = "cuda"
165
    dispatch_key: str = "CUDA"
166
    ray_device_key: str = "GPU"
167
    dist_backend: str = "nccl"
168
169
    # rocm shares the same device control env var as CUDA
    device_control_env_var: str = "CUDA_VISIBLE_DEVICES"
170

171
    supported_quantization: list[str] = [
172
        "awq",
173
        "awq_marlin",  # will be overwritten with awq
174
        "gptq",
175
        "gptq_marlin",  # will be overwritten with gptq
176
177
178
179
180
181
182
183
184
        "fp8",
        "compressed-tensors",
        "fbgemm_fp8",
        "gguf",
        "quark",
        "ptpc_fp8",
        "mxfp4",
        "petit_nvfp4",
        "torchao",
185
    ]
186
187
188
    # bitsandbytes not supported on gfx9 (warp size 64 limitation)
    if not on_gfx9():
        supported_quantization += ["bitsandbytes"]
189

190
191
192
193
194
195
196
197
198
199
200
    @classmethod
    def import_kernels(cls) -> None:
        """Import ROCm-specific kernels."""
        super().import_kernels()

        import contextlib

        # Import ROCm-specific extension
        with contextlib.suppress(ImportError):
            import vllm._rocm_C  # noqa: F401

201
    @classmethod
202
203
    def get_attn_backend_cls(
        cls,
204
205
        selected_backend: "AttentionBackendEnum",
        attn_selector_config: "AttentionSelectorConfig",
206
    ) -> str:
207
        from vllm._aiter_ops import rocm_aiter_ops
208

209
210
211
212
213
        block_size = attn_selector_config.block_size
        kv_cache_dtype = attn_selector_config.kv_cache_dtype

        if attn_selector_config.use_sparse:
            if kv_cache_dtype and kv_cache_dtype.startswith("fp8"):
214
215
216
217
218
219
                raise ValueError(
                    "ROCMAiterMLASparseBackend doesn't support fp8 kv_cache_dtype."
                )
            assert block_size == 1, (
                "Sparse MLA backend on ROCm only supports block size 1 for now."
            )
220
            logger.info_once("Using Sparse MLA backend.")
221
            return AttentionBackendEnum.ROCM_AITER_MLA_SPARSE.get_path()
222

223
        if attn_selector_config.use_mla:
224
            if selected_backend is None:
225
                selected_backend = (
226
                    AttentionBackendEnum.ROCM_AITER_MLA
227
                    if rocm_aiter_ops.is_mla_enabled() or block_size == 1
228
                    else AttentionBackendEnum.TRITON_MLA
229
                )
230
            if selected_backend == AttentionBackendEnum.TRITON_MLA:
231
                if block_size != 1:
232
                    logger.info_once("Using Triton MLA backend.")
233
                    return AttentionBackendEnum.TRITON_MLA.get_path()
234
235
                raise ValueError(
                    f" The selected backend, {selected_backend.name},"
236
237
                    f"does not support block size {block_size}."
                )
238
            if selected_backend == AttentionBackendEnum.ROCM_AITER_MLA:
239
                logger.info("Using AITER MLA backend.")
240
                return AttentionBackendEnum.ROCM_AITER_MLA.get_path()
241
242
243
            if selected_backend == AttentionBackendEnum.ROCM_AITER_TRITON_MLA:
                logger.info("Using AITER TRITON MLA backend.")
                return AttentionBackendEnum.ROCM_AITER_TRITON_MLA.get_path()
244

245
246
            raise ValueError(
                f" The selected backend, {selected_backend.name},"
247
248
                f"is not MLA type while requested for MLA backend."
            )
249

250
251
252
253
        if selected_backend == AttentionBackendEnum.FLEX_ATTENTION:
            logger.info("Using FlexAttention backend.")
            return AttentionBackendEnum.FLEX_ATTENTION.get_path()

254
        if selected_backend == AttentionBackendEnum.TRITON_ATTN:
255
            logger.info("Using Triton Attention backend.")
256
257
258
            return AttentionBackendEnum.TRITON_ATTN.get_path()

        if selected_backend == AttentionBackendEnum.ROCM_ATTN:
259
            logger.info("Using Rocm Attention backend.")
260
            return AttentionBackendEnum.ROCM_ATTN.get_path()
261
262
263

        if selected_backend == AttentionBackendEnum.ROCM_AITER_FA:
            if on_gfx9():
264
                logger.info("Using Aiter Flash Attention backend.")
265
266
267
268
269
270
271
272
                return AttentionBackendEnum.ROCM_AITER_FA.get_path()
            else:
                raise ValueError(
                    f"The selected backend, {selected_backend.name}, "
                    "is only supported on gfx9 architectures."
                )

        if selected_backend == AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN:
273
            logger.info("Using Aiter Unified Attention backend.")
274
275
276
277
278
279
            return AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN.get_path()

        # Handle automatic backend selection based on environment variables
        if selected_backend is None:
            # Priority 1: Check for AITER Unified Attention (must check before MHA)
            if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION:
280
                logger.info("Using Aiter Unified Attention backend.")
281
282
283
284
285
                return AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN.get_path()

            # Priority 2: Check for AITER MHA (Flash Attention)
            # Only use if explicitly enabled (not just VLLM_ROCM_USE_AITER=1)
            if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9():
286
                logger.info("Using Aiter Flash Attention backend.")
287
288
289
290
                return AttentionBackendEnum.ROCM_AITER_FA.get_path()

            # Priority 3: Check for ROCM_ATTN (prefill-decode split)
            if envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION:
291
                logger.info("Using Rocm Attention backend.")
292
293
294
295
296
297
298
299
300
                return AttentionBackendEnum.ROCM_ATTN.get_path()

            # Priority 4: Check for AITER enabled without specific flags
            # This defaults to AITER FA only if MHA is not explicitly disabled
            if (
                envs.VLLM_ROCM_USE_AITER
                and on_gfx9()
                and envs.VLLM_ROCM_USE_AITER_MHA is not False
            ):
301
                logger.info("Using Aiter Flash Attention backend.")
302
303
304
                return AttentionBackendEnum.ROCM_AITER_FA.get_path()

            # Default: Triton Unified Attention
305
            logger.info("Using Triton Attention backend.")
306
307
308
            return AttentionBackendEnum.TRITON_ATTN.get_path()

        raise RuntimeError(
309
310
            f"Attention backend {selected_backend.name} is not supported on "
            "ROCm. Note that V0 attention backends have been removed."
311
        )
312

313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
    @classmethod
    def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]:
        return [
            AttentionBackendEnum.FLASH_ATTN,
            AttentionBackendEnum.ROCM_AITER_FA,
            AttentionBackendEnum.TORCH_SDPA,
        ]

    @classmethod
    def get_vit_attn_backend(
        cls,
        head_size: int,
        dtype: torch.dtype,
        backend: Optional["AttentionBackendEnum"] = None,
    ) -> "AttentionBackendEnum":
        if backend is not None:
            assert backend in cls.get_supported_vit_attn_backends(), (
                f"Backend {backend} is not supported for vit attention. "
                f"Supported backends are: {cls.get_supported_vit_attn_backends()}"
            )
            logger.info_once(f"Using backend {backend} for vit attention")
            return backend

        from importlib.util import find_spec

        from vllm._aiter_ops import rocm_aiter_ops

340
341
        if rocm_aiter_ops.is_enabled():
            logger.info_once("Using AITER Flash Attention backend for ViT model.")
342
343
            return AttentionBackendEnum.ROCM_AITER_FA

344
345
346
347
348
349
        if (
            on_gfx9()
            and find_spec("flash_attn") is not None
            and (dtype == torch.float16 or dtype == torch.bfloat16)
        ):
            logger.info_once("Using Flash Attention backend for ViT model.")
350
351
            return AttentionBackendEnum.FLASH_ATTN

352
        logger.info_once("Using Torch SDPA backend for ViT model.")
353
354
        return AttentionBackendEnum.TORCH_SDPA

355
356
357
358
359
360
361
    @classmethod
    def set_device(cls, device: torch.device) -> None:
        """
        Set the device for the current platform.
        """
        torch.cuda.set_device(device)

362
    @classmethod
363
    @lru_cache(maxsize=8)
364
    def get_device_capability(cls, device_id: int = 0) -> DeviceCapability | None:
365
366
        major, minor = torch.cuda.get_device_capability(device_id)
        return DeviceCapability(major=major, minor=minor)
367

368
    @classmethod
369
    @with_amdsmi_context
370
    def is_fully_connected(cls, physical_device_ids: list[int]) -> bool:
371
372
373
        """
        Query if the set of gpus are fully connected by xgmi (1 hop)
        """
374
        handles = [amdsmi_get_processor_handles()[i] for i in physical_device_ids]
375
376
377
378
        for i, handle in enumerate(handles):
            for j, peer_handle in enumerate(handles):
                if i < j:
                    try:
379
                        link_type = amdsmi_topo_get_link_type(handle, peer_handle)
380
381
382
383
                        # type is 2 for XGMI
                        if link_type["hops"] != 1 or link_type["type"] != 2:
                            return False
                    except AmdSmiException as error:
384
                        logger.error("AMD 1 hop XGMI detection failed.", exc_info=error)
385
386
387
                        return False
        return True

388
    @classmethod
389
    @with_amdsmi_context
390
    @lru_cache(maxsize=8)
391
    def get_device_name(cls, device_id: int = 0) -> str:
392
        physical_device_id = cls.device_id_to_physical_device_id(device_id)
393
        handle = amdsmi_get_processor_handles()[physical_device_id]
394
395
396
397
398
        asic_info = amdsmi_get_gpu_asic_info(handle)
        device_name: str = asic_info["device_id"]
        if device_name in _ROCM_DEVICE_ID_NAME_MAP:
            return _ROCM_DEVICE_ID_NAME_MAP[device_name]
        return asic_info["market_name"]
399
400
401
402
403

    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        device_props = torch.cuda.get_device_properties(device_id)
        return device_props.total_memory
404
405

    @classmethod
406
    def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
407
        from vllm._aiter_ops import rocm_aiter_ops
408
409
        from vllm.config.compilation import CUDAGraphMode

410
        cache_config = vllm_config.cache_config
411
412
413
        compilation_config = vllm_config.compilation_config
        parallel_config = vllm_config.parallel_config
        is_eager_execution = compilation_config == CUDAGraphMode.NONE
414
        use_aiter_fused_moe = rocm_aiter_ops.is_fused_moe_enabled()
vllmellm's avatar
vllmellm committed
415
        use_aiter_rms_norm = rocm_aiter_ops.is_rmsnorm_enabled()
416
        use_aiter_fp8_linear = rocm_aiter_ops.is_linear_fp8_enabled()
417
        use_aiter_fused_se = rocm_aiter_ops.is_fusion_moe_shared_experts_enabled()
418

419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
        if compilation_config.cudagraph_mode.has_full_cudagraphs():
            # decode context parallel does not support full cudagraphs
            if parallel_config.decode_context_parallel_size > 1:
                logger.warning_once(
                    "Decode context parallel (DCP) is enabled, which is "
                    "incompatible with full CUDA graphs. "
                    "Overriding cudagraph_mode to PIECEWISE."
                )
                compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
            # prefill context parallel do not support full cudagraphs
            elif parallel_config.prefill_context_parallel_size > 1:
                logger.warning_once(
                    "Prefill context parallel (PCP) is enabled, which is "
                    "incompatible with full CUDA graphs. "
                    "Overriding cudagraph_mode to PIECEWISE."
                )
                compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE

437
        if cache_config and cache_config.block_size is None:
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
            if (
                envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION and envs.VLLM_ROCM_USE_AITER
                # NOTE: This block has been deprecated
                # or get_env_variable_attn_backend()
                # == AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN
                # TODO: monitor https://github.com/vllm-project/vllm/pull/30396
                # to see how we can transition to the new way of selecting
                # attention backends
            ):
                cache_config.block_size = 64
                logger.warning(
                    "[ROCM_AITER_UNIFIED_ATTN]: Setting kv cache block size to 64."
                )
            else:
                cache_config.block_size = 16
453

454
        if parallel_config.worker_cls == "auto":
455
            parallel_config.worker_cls = "vllm.v1.worker.gpu_worker.Worker"
456
        #  Aiter rms norm perform best when CUDA Graph capture is enabled.
457
        if (
458
            use_aiter_rms_norm
459
460
461
            and not is_eager_execution
            and "-rms_norm" not in compilation_config.custom_ops
        ):
462
            compilation_config.custom_ops.append("+rms_norm")
463

vllmellm's avatar
vllmellm committed
464
465
466
        if use_aiter_fp8_linear and "-quant_fp8" not in compilation_config.custom_ops:
            compilation_config.custom_ops.append("+quant_fp8")

467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
        if use_aiter_fused_se and "-grouped_topk" in compilation_config.custom_ops:
            logger.warning_once(
                "VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS is enabled, which "
                "requires the 'grouped_topk' custom op. Overriding the "
                "user-provided '-grouped_topk'."
            )
            compilation_config.custom_ops.remove("-grouped_topk")
        # Ensure grouped_topk is always enabled when using AITER if
        # its not disabled by user
        if (
            use_aiter_fused_moe
            and "+grouped_topk" not in compilation_config.custom_ops
            and "-grouped_topk" not in compilation_config.custom_ops
        ):
            compilation_config.custom_ops.append("+grouped_topk")

483
484
485
    @classmethod
    def verify_model_arch(cls, model_arch: str) -> None:
        if model_arch in _ROCM_UNSUPPORTED_MODELS:
486
487
488
            raise ValueError(
                f"Model architecture '{model_arch}' is not supported by ROCm for now."
            )
489
490
491
492

        if model_arch in _ROCM_PARTIALLY_SUPPORTED_MODELS:
            msg = _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch]
            logger.warning(
493
494
495
496
                "Model architecture '%s' is partially supported by ROCm: %s",
                model_arch,
                msg,
            )
497

498
499
500
501
502
503
    @classmethod
    def verify_quantization(cls, quant: str) -> None:
        super().verify_quantization(quant)
        if quant == "awq" and not envs.VLLM_USE_TRITON_AWQ:
            logger.warning(
                "Using AWQ quantization with ROCm, but VLLM_USE_TRITON_AWQ"
504
505
                " is not set, enabling VLLM_USE_TRITON_AWQ."
            )
506
        os.environ["VLLM_USE_TRITON_AWQ"] = "1"
507
508
509
510

    @classmethod
    def get_punica_wrapper(cls) -> str:
        return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU"
511
512

    @classmethod
513
    def get_current_memory_usage(
514
        cls, device: torch.types.Device | None = None
515
    ) -> float:
516
        torch.cuda.reset_peak_memory_stats(device)
517
        return torch.cuda.mem_get_info(device)[1] - torch.cuda.mem_get_info(device)[0]
518
519
520

    @classmethod
    def get_device_communicator_cls(cls) -> str:
521
522
523
        return (
            "vllm.distributed.device_communicators.cuda_communicator.CudaCommunicator"  # noqa
        )
524

525
526
527
528
529
    @classmethod
    def supports_mx(cls) -> bool:
        gcn_arch = torch.cuda.get_device_properties(0).gcnArchName
        return any(gfx in gcn_arch for gfx in ["gfx95"])

530
531
532
    @classmethod
    def supports_fp8(cls) -> bool:
        gcn_arch = torch.cuda.get_device_properties(0).gcnArchName
533
        return any(gfx in gcn_arch for gfx in ["gfx94", "gfx95", "gfx12"])
534
535
536
537

    @classmethod
    def is_fp8_fnuz(cls) -> bool:
        # only device 0 is checked, this assumes MI300 platforms are homogeneous
538
        return "gfx94" in torch.cuda.get_device_properties(0).gcnArchName
539
540
541
542
543
544
545

    @classmethod
    def fp8_dtype(cls) -> torch.dtype:
        if cls.is_fp8_fnuz():
            return torch.float8_e4m3fnuz
        else:
            return torch.float8_e4m3fn
546

547
548
549
550
    @classmethod
    def use_custom_allreduce(cls) -> bool:
        # We only enable custom allreduce for MI300 series
        gcn_arch = torch.cuda.get_device_properties(0).gcnArchName
551
        supported_archs = ["gfx94", "gfx95"]
552
        return any(gfx in gcn_arch for gfx in supported_archs)
553

554
555
556
557
    @classmethod
    def opaque_attention_op(cls) -> bool:
        return True

558
559
    @classmethod
    def is_navi(cls) -> bool:
560
        return "gfx1" in torch.cuda.get_device_properties(0).gcnArchName
561
562

    @classmethod
563
564
    def get_static_graph_wrapper_cls(cls) -> str:
        return "vllm.compilation.cuda_graph.CUDAGraphWrapper"
565

566
567
568
    @classmethod
    def device_count(cls) -> int:
        return cuda_device_count_stateless()
569

570
    @classmethod
571
572
    def check_if_supports_dtype(cls, dtype: torch.dtype):
        if dtype == torch.bfloat16:  # noqa: SIM102
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
            if not cls.has_device_capability(80):
                capability = cls.get_device_capability()
                gpu_name = cls.get_device_name()

                if capability is None:
                    compute_str = "does not have a compute capability"
                else:
                    version_str = capability.as_version_str()
                    compute_str = f"has compute capability {version_str}"

                raise ValueError(
                    "Bfloat16 is only supported on GPUs "
                    "with compute capability of at least 8.0. "
                    f"Your {gpu_name} GPU {compute_str}. "
                    "You can use float16 instead by explicitly setting the "
588
589
                    "`dtype` flag in CLI, for example: --dtype=half."
                )
590
591
592
593

    @classmethod
    def support_hybrid_kv_cache(cls) -> bool:
        return True
594
595
596
597

    @classmethod
    def support_static_graph_mode(cls) -> bool:
        return True