rocm.py 21.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import os
5
from functools import cache, lru_cache, wraps
6
from typing import TYPE_CHECKING, Optional
7
8
9

import torch

10
import vllm.envs as envs
11
from vllm.attention.backends.registry import AttentionBackendEnum
12
from vllm.logger import init_logger
13
from vllm.utils.torch_utils import cuda_device_count_stateless
14

15
from .interface import DeviceCapability, Platform, PlatformEnum
16

zhuwenwen's avatar
zhuwenwen committed
17

18
if TYPE_CHECKING:
19
    from vllm.attention.selector import AttentionSelectorConfig
20
    from vllm.config import VllmConfig
21

22
23
logger = init_logger(__name__)

24
try:
25
26
27
28
29
30
31
32
    from amdsmi import (
        AmdSmiException,
        amdsmi_get_gpu_asic_info,
        amdsmi_get_processor_handles,
        amdsmi_init,
        amdsmi_shut_down,
        amdsmi_topo_get_link_type,
    )
33
34
35
except ImportError as e:
    logger.warning("Failed to import from amdsmi with %r", e)

36
37
38
39
40
41
try:
    import vllm._C  # noqa: F401
except ImportError as e:
    logger.warning("Failed to import from vllm._C with %r", e)

# import custom ops, trigger op registration
zhuwenwen's avatar
zhuwenwen committed
42
43
44
45
# try:
#     import vllm._rocm_C  # noqa: F401
# except ImportError as e:
#     logger.warning("Failed to import from vllm._rocm_C with %r", e)
46

47

48
# Models not supported by ROCm.
49
_ROCM_UNSUPPORTED_MODELS: list[str] = []
50
51
52

# Models partially supported by ROCm.
# Architecture -> Reason.
53
54
_ROCM_SWA_REASON = ()
_ROCM_PARTIALLY_SUPPORTED_MODELS: dict[str, str] = {}
55
_ROCM_DEVICE_ID_NAME_MAP: dict[str, str] = {
56
57
58
    "0x74a0": "AMD_Instinct_MI300A",
    "0x74a1": "AMD_Instinct_MI300X",
    "0x74b5": "AMD_Instinct_MI300X",  # MI300X VF
59
    "0x74a2": "AMD_Instinct_MI308X",
60
61
62
63
    "0x74a5": "AMD_Instinct_MI325X",
    "0x74b9": "AMD_Instinct_MI325X",  # MI325X VF
    "0x74a9": "AMD_Instinct_MI300X_HF",
    "0x74bd": "AMD_Instinct_MI300X_HF",
64
    "0x744c": "AMD_Radeon_RX7900XTX",
65
}
66

67
# Prevent use of clashing `{CUDA/HIP}_VISIBLE_DEVICES`
zhuwenwen's avatar
zhuwenwen committed
68
69
70
71
72
73
# if "HIP_VISIBLE_DEVICES" in os.environ:
#     val = os.environ["HIP_VISIBLE_DEVICES"]
#     if cuda_val := os.environ.get("CUDA_VISIBLE_DEVICES", None):
#         assert val == cuda_val
#     else:
#         os.environ["CUDA_VISIBLE_DEVICES"] = val
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92

# AMDSMI utils
# Note that NVML is not affected by `{CUDA/HIP}_VISIBLE_DEVICES`,
# all the related functions work on real physical device ids.
# the major benefit of using AMDSMI is that it will not initialize CUDA


def with_amdsmi_context(fn):
    @wraps(fn)
    def wrapper(*args, **kwargs):
        amdsmi_init()
        try:
            return fn(*args, **kwargs)
        finally:
            amdsmi_shut_down()

    return wrapper


93
94
95
96
97
98
99
100
101
def device_id_to_physical_device_id(device_id: int) -> int:
    if "CUDA_VISIBLE_DEVICES" in os.environ:
        device_ids = os.environ["CUDA_VISIBLE_DEVICES"].split(",")
        physical_device_id = device_ids[device_id]
        return int(physical_device_id)
    else:
        return device_id


102
103
104
105
@cache
def on_gfx1x() -> bool:
    GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
    return any(arch in GPU_ARCH for arch in ["gfx11", "gfx12"])
106

107

108
@cache
109
def on_mi3xx() -> bool:
110
    GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
111
    return any(arch in GPU_ARCH for arch in ["gfx942", "gfx950"])
112
113


114
@cache
115
116
def on_gfx9() -> bool:
    GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
117
    return any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950", "gfx928", "gfx936", "gfx938"])
118
119


120
121
122
123
@cache
def on_gfx950() -> bool:
    GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
    return any(arch in GPU_ARCH for arch in ["gfx950"])
124
125


126
@cache
127
def use_rocm_custom_paged_attention(
128
129
130
131
132
133
134
    qtype: torch.dtype,
    head_size: int,
    block_size: int,
    gqa_ratio: int,
    max_seq_len: int,
    sliding_window: int,
    kv_cache_dtype: str,
135
136
    alibi_slopes: torch.Tensor | None = None,
    sinks: torch.Tensor | None = None,
137
) -> bool:
138
    from vllm._aiter_ops import rocm_aiter_ops
139

140
    GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
141
    ON_GFX9 = any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950", "gfx928", "gfx936"])
142
    ON_GFX11_GFX12 = any(arch in GPU_ARCH for arch in ["gfx11", "gfx12"])
143

144
145
    # custom paged attn always supported on V0. On V1, requires sliding window
    # disabled due to observed numerical discrepancy.
zhuwenwen's avatar
zhuwenwen committed
146
    # if ON_GFX9:
147
148
149
150
151
152
153
154
155
156
157
    #     return (
    #         (sliding_window == 0 or sliding_window == (-1, -1))
    #         and (qtype == torch.half or qtype == torch.bfloat16)
    #         and (head_size == 64 or head_size == 128)
    #         and (block_size == 16 or block_size == 32)
    #         and (gqa_ratio >= 1 and gqa_ratio <= 16)
    #         and max_seq_len <= 128 * 1024
    #         and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN)
    #         and not (rocm_aiter_ops.is_pa_attn_enabled())
    #         and sinks is None
    #     )
zhuwenwen's avatar
zhuwenwen committed
158
159

    # else:
160
161
162
163
164
165
166
167
168
169
170
171
172
    #     return (
    #         ON_GFX11_GFX12
    #         and (sliding_window == 0 or sliding_window == (-1, -1))
    #         and (qtype == torch.half or qtype == torch.bfloat16)
    #         and head_size == 128
    #         and block_size == 16
    #         and (gqa_ratio >= 3 and gqa_ratio <= 16)
    #         and max_seq_len <= 128 * 1024
    #         and alibi_slopes is None
    #         and kv_cache_dtype == "auto"
    #         and envs.VLLM_ROCM_CUSTOM_PAGED_ATTN
    #         and sinks is None
    #     )
173
    return False
174
175


176
177
class RocmPlatform(Platform):
    _enum = PlatformEnum.ROCM
178
    device_name: str = "rocm"
179
    device_type: str = "cuda"
180
    dispatch_key: str = "CUDA"
181
    ray_device_key: str = "GPU"
182
    dist_backend: str = "nccl"
183
184
    # rocm shares the same device control env var as CUDA
    device_control_env_var: str = "CUDA_VISIBLE_DEVICES"
185

186
    supported_quantization: list[str] = [
187
188
189
190
191
192
193
194
195
196
197
        "awq",
        "gptq",
        "fp8",
        "compressed-tensors",
        "fbgemm_fp8",
        "gguf",
        "quark",
        "ptpc_fp8",
        "mxfp4",
        "petit_nvfp4",
        "torchao",
198
199
200
201
202
203
        "moe_wna16", 
        "slimquant_w4a8", 
        "w8a8_int8", 
        "awq_marlin", 
        "slimquant_w4a8_marlin", 
        "slimquant_compressed_tensors_marlin"
204
    ]
205
206
207
    # bitsandbytes not supported on gfx9 (warp size 64 limitation)
    if not on_gfx9():
        supported_quantization += ["bitsandbytes"]
208

209
    @classmethod
210
211
    def get_attn_backend_cls(
        cls,
212
213
        selected_backend: "AttentionBackendEnum",
        attn_selector_config: "AttentionSelectorConfig",
214
    ) -> str:
215
        from vllm._aiter_ops import rocm_aiter_ops
216

217
218
219
220
221
        block_size = attn_selector_config.block_size
        kv_cache_dtype = attn_selector_config.kv_cache_dtype

        if attn_selector_config.use_sparse:
            if kv_cache_dtype and kv_cache_dtype.startswith("fp8"):
222
223
224
225
226
227
228
                raise ValueError(
                    "ROCMAiterMLASparseBackend doesn't support fp8 kv_cache_dtype."
                )
            assert block_size == 1, (
                "Sparse MLA backend on ROCm only supports block size 1 for now."
            )
            logger.info_once("Using Sparse MLA backend on V1 engine.")
229
            return AttentionBackendEnum.ROCM_AITER_MLA_SPARSE.get_path()
230
                
zhuwenwen's avatar
zhuwenwen committed
231
232
        if attn_selector_config.use_mla:
            # if attn_selector_config.use_sparse:
233
            #     logger.info_once("Using Sparse MLA backend on V1 engine.")
zhuwenwen's avatar
zhuwenwen committed
234
            #     return AttentionBackendEnum.FLASHMLA_SPARSE.get_path()
235
                
236
237
            use_flashmla = selected_backend == AttentionBackendEnum.FLASHMLA or envs.VLLM_USE_FLASH_MLA 
            use_triton = selected_backend == AttentionBackendEnum.TRITON_MLA or (
238
239
240
241
242
243
244
245
                selected_backend is None)
            
            if use_flashmla: 
                if block_size != 64:
                    logger.warning(
                        "FlashMLA backend is not supported for block size %d"
                        " (currently only supports block size 64).",
                        block_size)
246
                else:
247
                    logger.info_once("Using FlashMLA backend on V1 engine.")
248
                    return AttentionBackendEnum.FLASHMLA.get_path()
249
250
                    
            if use_triton:
251
252
                logger.info_once("Using Triton MLA backend.")
                return AttentionBackendEnum.TRITON_MLA.get_path()
253

254
255
            raise ValueError(
                f" The selected backend, {selected_backend.name},"
256
257
                f"is not MLA type while requested for MLA backend."
            )
258
            
259
        
260
261
262
263
264
265
266
        if envs.VLLM_USE_FLASH_ATTN_PA and block_size == 64:
            logger.info_once("Using Flash Attention backend on V1 engine. (only supports block size 64)")
            return AttentionBackendEnum.FLASH_ATTN.get_path()
        else:
            os.environ['VLLM_USE_FLASH_ATTN_PA'] = '0'
            logger.info_once("Using Triton backend on V1 engine.")
            return AttentionBackendEnum.TRITON_ATTN.get_path()
zhuwenwen's avatar
zhuwenwen committed
267
            
268
            
269
270
271
272
        if selected_backend == AttentionBackendEnum.FLEX_ATTENTION:
            logger.info("Using FlexAttention backend.")
            return AttentionBackendEnum.FLEX_ATTENTION.get_path()

273
274
275
276
277
278
        if selected_backend == AttentionBackendEnum.TRITON_ATTN:
            logger.info("Using Triton Attention backend on V1 engine.")
            return AttentionBackendEnum.TRITON_ATTN.get_path()

        if selected_backend == AttentionBackendEnum.ROCM_ATTN:
            logger.info("Using Rocm Attention backend on V1 engine.")
279
            return AttentionBackendEnum.ROCM_ATTN.get_path()
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305

        if selected_backend == AttentionBackendEnum.ROCM_AITER_FA:
            if on_gfx9():
                logger.info("Using Aiter Flash Attention backend on V1 engine.")
                return AttentionBackendEnum.ROCM_AITER_FA.get_path()
            else:
                raise ValueError(
                    f"The selected backend, {selected_backend.name}, "
                    "is only supported on gfx9 architectures."
                )

        if selected_backend == AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN:
            logger.info("Using Aiter Unified Attention backend on V1 engine.")
            return AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN.get_path()

        # Handle automatic backend selection based on environment variables
        if selected_backend is None:
            # Priority 3: Check for ROCM_ATTN (prefill-decode split)
            if envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION:
                logger.info("Using Rocm Attention backend on V1 engine.")
                return AttentionBackendEnum.ROCM_ATTN.get_path()

            # Default: Triton Unified Attention
            logger.info("Using Triton Attention backend on V1 engine.")
            return AttentionBackendEnum.TRITON_ATTN.get_path()

306
        raise RuntimeError(
307
308
            f"Attention backend {selected_backend.name} is not supported on "
            "ROCm. Note that V0 attention backends have been removed."
309
        )
310

311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
    @classmethod
    def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]:
        return [
            AttentionBackendEnum.FLASH_ATTN,
            AttentionBackendEnum.ROCM_AITER_FA,
            AttentionBackendEnum.TORCH_SDPA,
        ]

    @classmethod
    def get_vit_attn_backend(
        cls,
        head_size: int,
        dtype: torch.dtype,
        backend: Optional["AttentionBackendEnum"] = None,
    ) -> "AttentionBackendEnum":
        if backend is not None:
            assert backend in cls.get_supported_vit_attn_backends(), (
                f"Backend {backend} is not supported for vit attention. "
                f"Supported backends are: {cls.get_supported_vit_attn_backends()}"
            )
            logger.info_once(f"Using backend {backend} for vit attention")
            return backend

        from importlib.util import find_spec

        from vllm._aiter_ops import rocm_aiter_ops

        if rocm_aiter_ops.is_mha_enabled():
            # Note: AITER FA is only supported for Qwen-VL models.
            # TODO: Add support for other VL models in their model class.
            return AttentionBackendEnum.ROCM_AITER_FA

        if on_gfx9() and find_spec("flash_attn") is not None:
            return AttentionBackendEnum.FLASH_ATTN

        return AttentionBackendEnum.TORCH_SDPA

348
349
350
351
352
353
354
    @classmethod
    def set_device(cls, device: torch.device) -> None:
        """
        Set the device for the current platform.
        """
        torch.cuda.set_device(device)

355
    @classmethod
356
    @lru_cache(maxsize=8)
357
    def get_device_capability(cls, device_id: int = 0) -> DeviceCapability | None:
358
359
        major, minor = torch.cuda.get_device_capability(device_id)
        return DeviceCapability(major=major, minor=minor)
360

361
    @classmethod
362
    @with_amdsmi_context
363
    def is_fully_connected(cls, physical_device_ids: list[int]) -> bool:
zhuwenwen's avatar
zhuwenwen committed
364
365
366
        """
        Query if the set of gpus are fully connected by xgmi (1 hop)
        """
367
        handles = [amdsmi_get_processor_handles()[i] for i in physical_device_ids]
zhuwenwen's avatar
zhuwenwen committed
368
369
370
371
        for i, handle in enumerate(handles):
            for j, peer_handle in enumerate(handles):
                if i < j:
                    try:
372
                        link_type = amdsmi_topo_get_link_type(handle, peer_handle)
zhuwenwen's avatar
zhuwenwen committed
373
374
375
376
                        # type is 2 for XGMI
                        if link_type["hops"] != 1 or link_type["type"] != 2:
                            return False
                    except AmdSmiException as error:
377
                        logger.error("AMD 1 hop XGMI detection failed.", exc_info=error)
zhuwenwen's avatar
zhuwenwen committed
378
379
                        return False
        return True
380

381
    @classmethod
382
    @with_amdsmi_context
383
    @lru_cache(maxsize=8)
384
    def get_device_name(cls, device_id: int = 0) -> str:
zhuwenwen's avatar
zhuwenwen committed
385
        # physical_device_id = cls.device_id_to_physical_device_id(device_id)
386
        physical_device_id = device_id_to_physical_device_id(device_id)
387
        handle = amdsmi_get_processor_handles()[physical_device_id]
zhuwenwen's avatar
zhuwenwen committed
388
389
390
391
392
        # asic_info = amdsmi_get_gpu_asic_info(handle)
        # device_name: str = asic_info["device_id"]
        # if device_name in _ROCM_DEVICE_ID_NAME_MAP:
        #     return _ROCM_DEVICE_ID_NAME_MAP[device_name]
        # return asic_info["market_name"]
393
        return torch.cuda.get_device_name(device_id)
394

zhuwenwen's avatar
zhuwenwen committed
395
396
397
398
    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        device_props = torch.cuda.get_device_properties(device_id)
        return device_props.total_memory
399
400

    @classmethod
401
    def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
402
        from vllm._aiter_ops import rocm_aiter_ops
403
404
        from vllm.config.compilation import CUDAGraphMode

405
        cache_config = vllm_config.cache_config
406
407
408
        compilation_config = vllm_config.compilation_config
        parallel_config = vllm_config.parallel_config
        is_eager_execution = compilation_config == CUDAGraphMode.NONE
409
        use_aiter_rms_norm = rocm_aiter_ops.is_rmsnorm_enabled()
vllmellm's avatar
vllmellm committed
410
        use_aiter_fp8_linear = rocm_aiter_ops.is_linear_fp8_enaled()
411

412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
        if compilation_config.cudagraph_mode.has_full_cudagraphs():
            # decode context parallel does not support full cudagraphs
            if parallel_config.decode_context_parallel_size > 1:
                logger.warning_once(
                    "Decode context parallel (DCP) is enabled, which is "
                    "incompatible with full CUDA graphs. "
                    "Overriding cudagraph_mode to PIECEWISE."
                )
                compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
            # prefill context parallel do not support full cudagraphs
            elif parallel_config.prefill_context_parallel_size > 1:
                logger.warning_once(
                    "Prefill context parallel (PCP) is enabled, which is "
                    "incompatible with full CUDA graphs. "
                    "Overriding cudagraph_mode to PIECEWISE."
                )
                compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
429

430
        if cache_config and cache_config.block_size is None:
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
            if (
                envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION and envs.VLLM_ROCM_USE_AITER
                # NOTE: This block has been deprecated
                # or get_env_variable_attn_backend()
                # == AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN
                # TODO: monitor https://github.com/vllm-project/vllm/pull/30396
                # to see how we can transition to the new way of selecting
                # attention backends
            ):
                cache_config.block_size = 64
                logger.warning(
                    "[ROCM_AITER_UNIFIED_ATTN]: Setting kv cache block size to 64."
                )
            else:
                cache_config.block_size = 16
446

447
        if parallel_config.worker_cls == "auto":
448
            parallel_config.worker_cls = "vllm.v1.worker.gpu_worker.Worker"
449
        #  Aiter rms norm perform best when CUDA Graph capture is enabled.
450
        if (
451
            use_aiter_rms_norm
452
453
454
            and not is_eager_execution
            and "-rms_norm" not in compilation_config.custom_ops
        ):
455
            compilation_config.custom_ops.append("+rms_norm")
456

vllmellm's avatar
vllmellm committed
457
458
459
        if use_aiter_fp8_linear and "-quant_fp8" not in compilation_config.custom_ops:
            compilation_config.custom_ops.append("+quant_fp8")

460
461
462
    @classmethod
    def verify_model_arch(cls, model_arch: str) -> None:
        if model_arch in _ROCM_UNSUPPORTED_MODELS:
463
464
465
            raise ValueError(
                f"Model architecture '{model_arch}' is not supported by ROCm for now."
            )
466
467
468
469

        if model_arch in _ROCM_PARTIALLY_SUPPORTED_MODELS:
            msg = _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch]
            logger.warning(
470
471
472
473
                "Model architecture '%s' is partially supported by ROCm: %s",
                model_arch,
                msg,
            )
474

475
476
477
478
479
480
    @classmethod
    def verify_quantization(cls, quant: str) -> None:
        super().verify_quantization(quant)
        if quant == "awq" and not envs.VLLM_USE_TRITON_AWQ:
            logger.warning(
                "Using AWQ quantization with ROCm, but VLLM_USE_TRITON_AWQ"
481
482
                " is not set, enabling VLLM_USE_TRITON_AWQ."
            )
483
            envs.VLLM_USE_TRITON_AWQ = False
484
        os.environ["VLLM_USE_TRITON_AWQ"] = "1"
485
486
487
488

    @classmethod
    def get_punica_wrapper(cls) -> str:
        return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU"
489
490

    @classmethod
491
    def get_current_memory_usage(
492
        cls, device: torch.types.Device | None = None
493
    ) -> float:
494
        torch.cuda.reset_peak_memory_stats(device)
495
        # return torch.cuda.mem_get_info(device)[1] - torch.cuda.mem_get_info(device)[0]
zhuwenwen's avatar
zhuwenwen committed
496
        return torch.cuda.max_memory_allocated(device)
497
498
499

    @classmethod
    def get_device_communicator_cls(cls) -> str:
500
501
502
        return (
            "vllm.distributed.device_communicators.cuda_communicator.CudaCommunicator"  # noqa
        )
503

504
505
506
507
508
    @classmethod
    def supports_mx(cls) -> bool:
        gcn_arch = torch.cuda.get_device_properties(0).gcnArchName
        return any(gfx in gcn_arch for gfx in ["gfx95"])

509
510
511
    @classmethod
    def supports_fp8(cls) -> bool:
        gcn_arch = torch.cuda.get_device_properties(0).gcnArchName
512
        return any(gfx in gcn_arch for gfx in ["gfx94", "gfx95", "gfx12"])
513
514
515
516

    @classmethod
    def is_fp8_fnuz(cls) -> bool:
        # only device 0 is checked, this assumes MI300 platforms are homogeneous
517
        return "gfx94" in torch.cuda.get_device_properties(0).gcnArchName
518
519
520
521
522
523
524

    @classmethod
    def fp8_dtype(cls) -> torch.dtype:
        if cls.is_fp8_fnuz():
            return torch.float8_e4m3fnuz
        else:
            return torch.float8_e4m3fn
525

526
527
528
529
    @classmethod
    def use_custom_allreduce(cls) -> bool:
        # We only enable custom allreduce for MI300 series
        gcn_arch = torch.cuda.get_device_properties(0).gcnArchName
530
        supported_archs = ["gfx94", "gfx95"]
531
        return any(gfx in gcn_arch for gfx in supported_archs)
532

533
534
535
536
    @classmethod
    def opaque_attention_op(cls) -> bool:
        return True

537
538
    @classmethod
    def is_navi(cls) -> bool:
539
        return "gfx1" in torch.cuda.get_device_properties(0).gcnArchName
540
541

    @classmethod
542
543
    def get_static_graph_wrapper_cls(cls) -> str:
        return "vllm.compilation.cuda_graph.CUDAGraphWrapper"
544

545
546
547
    @classmethod
    def device_count(cls) -> int:
        return cuda_device_count_stateless()
548
549

    @classmethod
550
551
    def check_if_supports_dtype(cls, dtype: torch.dtype):
        if dtype == torch.bfloat16:  # noqa: SIM102
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
            if not cls.has_device_capability(80):
                capability = cls.get_device_capability()
                gpu_name = cls.get_device_name()

                if capability is None:
                    compute_str = "does not have a compute capability"
                else:
                    version_str = capability.as_version_str()
                    compute_str = f"has compute capability {version_str}"

                raise ValueError(
                    "Bfloat16 is only supported on GPUs "
                    "with compute capability of at least 8.0. "
                    f"Your {gpu_name} GPU {compute_str}. "
                    "You can use float16 instead by explicitly setting the "
567
568
                    "`dtype` flag in CLI, for example: --dtype=half."
                )
569
570
571
572

    @classmethod
    def support_hybrid_kv_cache(cls) -> bool:
        return True
573
574
575
576

    @classmethod
    def support_static_graph_mode(cls) -> bool:
        return True