rocm.py 21.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import os
5
from functools import cache, lru_cache, wraps
6
from typing import TYPE_CHECKING, Optional
7
8
9

import torch

10
import vllm.envs as envs
11
from vllm.attention.backends.registry import AttentionBackendEnum
12
from vllm.logger import init_logger
13
from vllm.utils.torch_utils import cuda_device_count_stateless
14

15
from .interface import DeviceCapability, Platform, PlatformEnum
16

17
if TYPE_CHECKING:
18
    from vllm.attention.selector import AttentionSelectorConfig
19
    from vllm.config import VllmConfig
20

21
22
logger = init_logger(__name__)

23
try:
24
25
26
27
28
29
30
31
    from amdsmi import (
        AmdSmiException,
        amdsmi_get_gpu_asic_info,
        amdsmi_get_processor_handles,
        amdsmi_init,
        amdsmi_shut_down,
        amdsmi_topo_get_link_type,
    )
32
33
34
except ImportError as e:
    logger.warning("Failed to import from amdsmi with %r", e)

35
36
37
38
39
40
try:
    import vllm._C  # noqa: F401
except ImportError as e:
    logger.warning("Failed to import from vllm._C with %r", e)

# import custom ops, trigger op registration
zhuwenwen's avatar
zhuwenwen committed
41
42
43
44
# try:
#     import vllm._rocm_C  # noqa: F401
# except ImportError as e:
#     logger.warning("Failed to import from vllm._rocm_C with %r", e)
45

46
# Models not supported by ROCm.
47
_ROCM_UNSUPPORTED_MODELS: list[str] = []
48
49
50

# Models partially supported by ROCm.
# Architecture -> Reason.
51
52
_ROCM_SWA_REASON = ()
_ROCM_PARTIALLY_SUPPORTED_MODELS: dict[str, str] = {}
53
_ROCM_DEVICE_ID_NAME_MAP: dict[str, str] = {
54
55
56
    "0x74a0": "AMD_Instinct_MI300A",
    "0x74a1": "AMD_Instinct_MI300X",
    "0x74b5": "AMD_Instinct_MI300X",  # MI300X VF
57
    "0x74a2": "AMD_Instinct_MI308X",
58
59
60
61
    "0x74a5": "AMD_Instinct_MI325X",
    "0x74b9": "AMD_Instinct_MI325X",  # MI325X VF
    "0x74a9": "AMD_Instinct_MI300X_HF",
    "0x74bd": "AMD_Instinct_MI300X_HF",
62
    "0x744c": "AMD_Radeon_RX7900XTX",
63
}
64

65
# Prevent use of clashing `{CUDA/HIP}_VISIBLE_DEVICES`
66
67
68
69
70
71
# if "HIP_VISIBLE_DEVICES" in os.environ:
#     val = os.environ["HIP_VISIBLE_DEVICES"]
#     if cuda_val := os.environ.get("CUDA_VISIBLE_DEVICES", None):
#         assert val == cuda_val
#     else:
#         os.environ["CUDA_VISIBLE_DEVICES"] = val
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90

# AMDSMI utils
# Note that NVML is not affected by `{CUDA/HIP}_VISIBLE_DEVICES`,
# all the related functions work on real physical device ids.
# the major benefit of using AMDSMI is that it will not initialize CUDA


def with_amdsmi_context(fn):
    @wraps(fn)
    def wrapper(*args, **kwargs):
        amdsmi_init()
        try:
            return fn(*args, **kwargs)
        finally:
            amdsmi_shut_down()

    return wrapper


91
@cache
92
def on_gfx1x() -> bool:
93
    GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
94
95
96
    return any(arch in GPU_ARCH for arch in ["gfx11", "gfx12"])


97
@cache
98
def on_mi3xx() -> bool:
99
    GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
100
101
102
103
104
105
106
    return any(arch in GPU_ARCH for arch in ["gfx942", "gfx950"])


@cache
def on_gfx9() -> bool:
    GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
    return any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"])
107
108


109
110
111
112
113
114
@cache
def on_gfx950() -> bool:
    GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
    return any(arch in GPU_ARCH for arch in ["gfx950"])


115
@cache
116
def use_rocm_custom_paged_attention(
117
118
119
120
121
122
123
    qtype: torch.dtype,
    head_size: int,
    block_size: int,
    gqa_ratio: int,
    max_seq_len: int,
    sliding_window: int,
    kv_cache_dtype: str,
124
125
    alibi_slopes: torch.Tensor | None = None,
    sinks: torch.Tensor | None = None,
126
) -> bool:
127
    from vllm._aiter_ops import rocm_aiter_ops
128

zhuwenwen's avatar
zhuwenwen committed
129
130
131
    # GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
    # ON_GFX9 = any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"])
    # ON_GFX11_GFX12 = any(arch in GPU_ARCH for arch in ["gfx11", "gfx12"])
132

133
134
    # custom paged attn always supported on V0. On V1, requires sliding window
    # disabled due to observed numerical discrepancy.
zhuwenwen's avatar
zhuwenwen committed
135
    # if ON_GFX9:
136
137
138
139
140
141
142
143
144
145
146
    #     return (
    #         (sliding_window == 0 or sliding_window == (-1, -1))
    #         and (qtype == torch.half or qtype == torch.bfloat16)
    #         and (head_size == 64 or head_size == 128)
    #         and (block_size == 16 or block_size == 32)
    #         and (gqa_ratio >= 1 and gqa_ratio <= 16)
    #         and max_seq_len <= 128 * 1024
    #         and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN)
    #         and not (rocm_aiter_ops.is_pa_attn_enabled())
    #         and sinks is None
    #     )
zhuwenwen's avatar
zhuwenwen committed
147
148

    # else:
149
150
151
152
153
154
155
156
157
158
159
160
161
    #     return (
    #         ON_GFX11_GFX12
    #         and (sliding_window == 0 or sliding_window == (-1, -1))
    #         and (qtype == torch.half or qtype == torch.bfloat16)
    #         and head_size == 128
    #         and block_size == 16
    #         and (gqa_ratio >= 3 and gqa_ratio <= 16)
    #         and max_seq_len <= 128 * 1024
    #         and alibi_slopes is None
    #         and kv_cache_dtype == "auto"
    #         and envs.VLLM_ROCM_CUSTOM_PAGED_ATTN
    #         and sinks is None
    #     )
zhuwenwen's avatar
zhuwenwen committed
162
    return False
163
164


165
166
class RocmPlatform(Platform):
    _enum = PlatformEnum.ROCM
167
    device_name: str = "rocm"
168
    device_type: str = "cuda"
169
    dispatch_key: str = "CUDA"
170
    ray_device_key: str = "GPU"
171
    dist_backend: str = "nccl"
172
173
    # rocm shares the same device control env var as CUDA
    device_control_env_var: str = "CUDA_VISIBLE_DEVICES"
174

175
    supported_quantization: list[str] = [
176
177
178
179
180
181
182
183
184
185
186
        "awq",
        "gptq",
        "fp8",
        "compressed-tensors",
        "fbgemm_fp8",
        "gguf",
        "quark",
        "ptpc_fp8",
        "mxfp4",
        "petit_nvfp4",
        "torchao",
187
    ]
188
189
190
    # bitsandbytes not supported on gfx9 (warp size 64 limitation)
    if not on_gfx9():
        supported_quantization += ["bitsandbytes"]
191

192
    @classmethod
193
194
    def get_attn_backend_cls(
        cls,
195
196
        selected_backend: "AttentionBackendEnum",
        attn_selector_config: "AttentionSelectorConfig",
197
    ) -> str:
198
        from vllm._aiter_ops import rocm_aiter_ops
199

200
201
202
203
204
        block_size = attn_selector_config.block_size
        kv_cache_dtype = attn_selector_config.kv_cache_dtype

        if attn_selector_config.use_sparse:
            if kv_cache_dtype and kv_cache_dtype.startswith("fp8"):
205
206
207
208
209
210
211
                raise ValueError(
                    "ROCMAiterMLASparseBackend doesn't support fp8 kv_cache_dtype."
                )
            assert block_size == 1, (
                "Sparse MLA backend on ROCm only supports block size 1 for now."
            )
            logger.info_once("Using Sparse MLA backend on V1 engine.")
212
            return AttentionBackendEnum.ROCM_AITER_MLA_SPARSE.get_path()
213

214
        if attn_selector_config.use_mla:
215
            if selected_backend is None:
216
                selected_backend = (
217
218
219
220
                    # AttentionBackendEnum.ROCM_AITER_MLA
                    # if rocm_aiter_ops.is_mla_enabled() or block_size == 1
                    # else AttentionBackendEnum.TRITON_MLA
                    AttentionBackendEnum.TRITON_MLA
221
                )
222
            if selected_backend == AttentionBackendEnum.TRITON_MLA:
223
                if block_size != 1:
224
                    logger.info_once("Using Triton MLA backend.")
225
                    return AttentionBackendEnum.TRITON_MLA.get_path()
226
227
                raise ValueError(
                    f" The selected backend, {selected_backend.name},"
228
229
                    f"does not support block size {block_size}."
                )
230
231
232
            # if selected_backend == AttentionBackendEnum.ROCM_AITER_MLA:
            #     logger.info("Using AITER MLA backend.")
            #     return AttentionBackendEnum.ROCM_AITER_MLA.get_path()
233
234
235
            # if selected_backend == AttentionBackendEnum.ROCM_AITER_TRITON_MLA:
            #     logger.info("Using AITER TRITON MLA backend.")
            #     return AttentionBackendEnum.ROCM_AITER_TRITON_MLA.get_path()
236

237
238
            raise ValueError(
                f" The selected backend, {selected_backend.name},"
239
240
                f"is not MLA type while requested for MLA backend."
            )
241

242
        if selected_backend == AttentionBackendEnum.FLEX_ATTENTION:
243
            logger.info("Using FlexAttention backend.")
244
245
            return AttentionBackendEnum.FLEX_ATTENTION.get_path()

246
247
248
249
250
251
        if selected_backend == AttentionBackendEnum.TRITON_ATTN:
            logger.info("Using Triton Attention backend on V1 engine.")
            return AttentionBackendEnum.TRITON_ATTN.get_path()

        if selected_backend == AttentionBackendEnum.ROCM_ATTN:
            logger.info("Using Rocm Attention backend on V1 engine.")
252
            return AttentionBackendEnum.ROCM_ATTN.get_path()
253

254
255
256
257
258
259
260
261
262
        # if selected_backend == AttentionBackendEnum.ROCM_AITER_FA:
        #     if on_gfx9():
        #         logger.info("Using Aiter Flash Attention backend on V1 engine.")
        #         return AttentionBackendEnum.ROCM_AITER_FA.get_path()
        #     else:
        #         raise ValueError(
        #             f"The selected backend, {selected_backend.name}, "
        #             "is only supported on gfx9 architectures."
        #         )
263

264
265
266
        # if selected_backend == AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN:
        #     logger.info("Using Aiter Unified Attention backend on V1 engine.")
        #     return AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN.get_path()
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300

        # Handle automatic backend selection based on environment variables
        if selected_backend is None:
            # Priority 1: Check for AITER Unified Attention (must check before MHA)
            if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION:
                logger.info("Using Aiter Unified Attention backend on V1 engine.")
                return AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN.get_path()

            # Priority 2: Check for AITER MHA (Flash Attention)
            # Only use if explicitly enabled (not just VLLM_ROCM_USE_AITER=1)
            if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9():
                logger.info("Using Aiter Flash Attention backend on V1 engine.")
                return AttentionBackendEnum.ROCM_AITER_FA.get_path()

            # Priority 3: Check for ROCM_ATTN (prefill-decode split)
            if envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION:
                logger.info("Using Rocm Attention backend on V1 engine.")
                return AttentionBackendEnum.ROCM_ATTN.get_path()

            # Priority 4: Check for AITER enabled without specific flags
            # This defaults to AITER FA only if MHA is not explicitly disabled
            if (
                envs.VLLM_ROCM_USE_AITER
                and on_gfx9()
                and envs.VLLM_ROCM_USE_AITER_MHA is not False
            ):
                logger.info("Using Aiter Flash Attention backend on V1 engine.")
                return AttentionBackendEnum.ROCM_AITER_FA.get_path()

            # Default: Triton Unified Attention
            logger.info("Using Triton Attention backend on V1 engine.")
            return AttentionBackendEnum.TRITON_ATTN.get_path()

        raise RuntimeError(
301
302
            f"Attention backend {selected_backend.name} is not supported on "
            "ROCm. Note that V0 attention backends have been removed."
303
        )
304

305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
    @classmethod
    def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]:
        return [
            AttentionBackendEnum.FLASH_ATTN,
            AttentionBackendEnum.ROCM_AITER_FA,
            AttentionBackendEnum.TORCH_SDPA,
        ]

    @classmethod
    def get_vit_attn_backend(
        cls,
        head_size: int,
        dtype: torch.dtype,
        backend: Optional["AttentionBackendEnum"] = None,
    ) -> "AttentionBackendEnum":
        if backend is not None:
            assert backend in cls.get_supported_vit_attn_backends(), (
                f"Backend {backend} is not supported for vit attention. "
                f"Supported backends are: {cls.get_supported_vit_attn_backends()}"
            )
            logger.info_once(f"Using backend {backend} for vit attention")
            return backend

        from importlib.util import find_spec

        from vllm._aiter_ops import rocm_aiter_ops

        if rocm_aiter_ops.is_mha_enabled():
            # Note: AITER FA is only supported for Qwen-VL models.
            # TODO: Add support for other VL models in their model class.
            return AttentionBackendEnum.ROCM_AITER_FA

        if on_gfx9() and find_spec("flash_attn") is not None:
            return AttentionBackendEnum.FLASH_ATTN

        return AttentionBackendEnum.TORCH_SDPA

342
343
344
345
346
347
348
    @classmethod
    def set_device(cls, device: torch.device) -> None:
        """
        Set the device for the current platform.
        """
        torch.cuda.set_device(device)

349
    @classmethod
350
    @lru_cache(maxsize=8)
351
    def get_device_capability(cls, device_id: int = 0) -> DeviceCapability | None:
352
353
        major, minor = torch.cuda.get_device_capability(device_id)
        return DeviceCapability(major=major, minor=minor)
354

355
    @classmethod
356
    @with_amdsmi_context
357
    def is_fully_connected(cls, physical_device_ids: list[int]) -> bool:
358
359
360
        """
        Query if the set of gpus are fully connected by xgmi (1 hop)
        """
361
        handles = [amdsmi_get_processor_handles()[i] for i in physical_device_ids]
362
363
364
365
        for i, handle in enumerate(handles):
            for j, peer_handle in enumerate(handles):
                if i < j:
                    try:
366
                        link_type = amdsmi_topo_get_link_type(handle, peer_handle)
367
368
369
370
                        # type is 2 for XGMI
                        if link_type["hops"] != 1 or link_type["type"] != 2:
                            return False
                    except AmdSmiException as error:
371
                        logger.error("AMD 1 hop XGMI detection failed.", exc_info=error)
372
373
374
                        return False
        return True

375
    @classmethod
376
    @with_amdsmi_context
377
    @lru_cache(maxsize=8)
378
    def get_device_name(cls, device_id: int = 0) -> str:
379
        physical_device_id = cls.device_id_to_physical_device_id(device_id)
380
        handle = amdsmi_get_processor_handles()[physical_device_id]
381
382
383
384
385
        asic_info = amdsmi_get_gpu_asic_info(handle)
        device_name: str = asic_info["device_id"]
        if device_name in _ROCM_DEVICE_ID_NAME_MAP:
            return _ROCM_DEVICE_ID_NAME_MAP[device_name]
        return asic_info["market_name"]
386
387
388
389
390

    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        device_props = torch.cuda.get_device_properties(device_id)
        return device_props.total_memory
391
392

    @classmethod
393
    def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
394
        from vllm._aiter_ops import rocm_aiter_ops
395
396
        from vllm.config.compilation import CUDAGraphMode

397
        cache_config = vllm_config.cache_config
398
399
400
        compilation_config = vllm_config.compilation_config
        parallel_config = vllm_config.parallel_config
        is_eager_execution = compilation_config == CUDAGraphMode.NONE
401
        use_aiter_rms_norm = rocm_aiter_ops.is_rmsnorm_enabled()
vllmellm's avatar
vllmellm committed
402
        use_aiter_fp8_linear = rocm_aiter_ops.is_linear_fp8_enaled()
403

404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
        if compilation_config.cudagraph_mode.has_full_cudagraphs():
            # decode context parallel does not support full cudagraphs
            if parallel_config.decode_context_parallel_size > 1:
                logger.warning_once(
                    "Decode context parallel (DCP) is enabled, which is "
                    "incompatible with full CUDA graphs. "
                    "Overriding cudagraph_mode to PIECEWISE."
                )
                compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
            # prefill context parallel do not support full cudagraphs
            elif parallel_config.prefill_context_parallel_size > 1:
                logger.warning_once(
                    "Prefill context parallel (PCP) is enabled, which is "
                    "incompatible with full CUDA graphs. "
                    "Overriding cudagraph_mode to PIECEWISE."
                )
                compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
421

422
        if cache_config and cache_config.block_size is None:
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
            if (
                envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION and envs.VLLM_ROCM_USE_AITER
                # NOTE: This block has been deprecated
                # or get_env_variable_attn_backend()
                # == AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN
                # TODO: monitor https://github.com/vllm-project/vllm/pull/30396
                # to see how we can transition to the new way of selecting
                # attention backends
            ):
                cache_config.block_size = 64
                logger.warning(
                    "[ROCM_AITER_UNIFIED_ATTN]: Setting kv cache block size to 64."
                )
            else:
                cache_config.block_size = 16
438

439
        if parallel_config.worker_cls == "auto":
440
            parallel_config.worker_cls = "vllm.v1.worker.gpu_worker.Worker"
441
        #  Aiter rms norm perform best when CUDA Graph capture is enabled.
442
        if (
443
            use_aiter_rms_norm
444
445
446
            and not is_eager_execution
            and "-rms_norm" not in compilation_config.custom_ops
        ):
447
            compilation_config.custom_ops.append("+rms_norm")
448

vllmellm's avatar
vllmellm committed
449
450
451
        if use_aiter_fp8_linear and "-quant_fp8" not in compilation_config.custom_ops:
            compilation_config.custom_ops.append("+quant_fp8")

452
453
454
    @classmethod
    def verify_model_arch(cls, model_arch: str) -> None:
        if model_arch in _ROCM_UNSUPPORTED_MODELS:
455
456
457
            raise ValueError(
                f"Model architecture '{model_arch}' is not supported by ROCm for now."
            )
458
459
460
461

        if model_arch in _ROCM_PARTIALLY_SUPPORTED_MODELS:
            msg = _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch]
            logger.warning(
462
463
464
465
                "Model architecture '%s' is partially supported by ROCm: %s",
                model_arch,
                msg,
            )
466

467
468
469
470
471
472
    @classmethod
    def verify_quantization(cls, quant: str) -> None:
        super().verify_quantization(quant)
        if quant == "awq" and not envs.VLLM_USE_TRITON_AWQ:
            logger.warning(
                "Using AWQ quantization with ROCm, but VLLM_USE_TRITON_AWQ"
473
474
                " is not set, enabling VLLM_USE_TRITON_AWQ."
            )
475
        os.environ["VLLM_USE_TRITON_AWQ"] = "1"
476
477
478
479

    @classmethod
    def get_punica_wrapper(cls) -> str:
        return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU"
480
481

    @classmethod
482
    def get_current_memory_usage(
483
        cls, device: torch.types.Device | None = None
484
    ) -> float:
485
        torch.cuda.reset_peak_memory_stats(device)
486
        # return torch.cuda.mem_get_info(device)[1] - torch.cuda.mem_get_info(device)[0]
zhuwenwen's avatar
zhuwenwen committed
487
        return torch.cuda.max_memory_allocated(device)
488
489
490

    @classmethod
    def get_device_communicator_cls(cls) -> str:
491
492
493
        return (
            "vllm.distributed.device_communicators.cuda_communicator.CudaCommunicator"  # noqa
        )
494

495
496
497
498
499
    @classmethod
    def supports_mx(cls) -> bool:
        gcn_arch = torch.cuda.get_device_properties(0).gcnArchName
        return any(gfx in gcn_arch for gfx in ["gfx95"])

500
501
502
    @classmethod
    def supports_fp8(cls) -> bool:
        gcn_arch = torch.cuda.get_device_properties(0).gcnArchName
503
        return any(gfx in gcn_arch for gfx in ["gfx94", "gfx95", "gfx12"])
504
505
506
507

    @classmethod
    def is_fp8_fnuz(cls) -> bool:
        # only device 0 is checked, this assumes MI300 platforms are homogeneous
508
        return "gfx94" in torch.cuda.get_device_properties(0).gcnArchName
509
510
511
512
513
514
515

    @classmethod
    def fp8_dtype(cls) -> torch.dtype:
        if cls.is_fp8_fnuz():
            return torch.float8_e4m3fnuz
        else:
            return torch.float8_e4m3fn
516

517
518
519
520
    @classmethod
    def use_custom_allreduce(cls) -> bool:
        # We only enable custom allreduce for MI300 series
        gcn_arch = torch.cuda.get_device_properties(0).gcnArchName
521
        supported_archs = ["gfx94", "gfx95"]
522
        return any(gfx in gcn_arch for gfx in supported_archs)
523

524
525
526
527
    @classmethod
    def opaque_attention_op(cls) -> bool:
        return True

528
529
    @classmethod
    def is_navi(cls) -> bool:
530
        return "gfx1" in torch.cuda.get_device_properties(0).gcnArchName
531
532

    @classmethod
533
534
    def get_static_graph_wrapper_cls(cls) -> str:
        return "vllm.compilation.cuda_graph.CUDAGraphWrapper"
535

536
537
538
    @classmethod
    def device_count(cls) -> int:
        return cuda_device_count_stateless()
539
540

    @classmethod
541
542
    def check_if_supports_dtype(cls, dtype: torch.dtype):
        if dtype == torch.bfloat16:  # noqa: SIM102
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
            if not cls.has_device_capability(80):
                capability = cls.get_device_capability()
                gpu_name = cls.get_device_name()

                if capability is None:
                    compute_str = "does not have a compute capability"
                else:
                    version_str = capability.as_version_str()
                    compute_str = f"has compute capability {version_str}"

                raise ValueError(
                    "Bfloat16 is only supported on GPUs "
                    "with compute capability of at least 8.0. "
                    f"Your {gpu_name} GPU {compute_str}. "
                    "You can use float16 instead by explicitly setting the "
558
559
                    "`dtype` flag in CLI, for example: --dtype=half."
                )
560
561
562
563

    @classmethod
    def support_hybrid_kv_cache(cls) -> bool:
        return True
564
565
566
567

    @classmethod
    def support_static_graph_mode(cls) -> bool:
        return True