"tests/vscode:/vscode.git/clone" did not exist on "582bbe6bd708d01d74d6d02d6ef59b4c3c34a7b1"
cuda.py 29.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
"""Code inside this file can safely assume cuda platform, e.g. importing
pynvml. However, it should not initialize cuda context.
"""

7
import os
8
from datetime import timedelta
9
from functools import cache, wraps
10
from typing import TYPE_CHECKING, Callable, Optional, TypeVar, Union
11

12
import torch
13
14
from torch.distributed import PrefixStore, ProcessGroup
from torch.distributed.distributed_c10d import is_nccl_available
15
from typing_extensions import ParamSpec
16

17
18
# import custom ops, trigger op registration
import vllm._C  # noqa
19
import vllm.envs as envs
20
from vllm.logger import init_logger
21
from vllm.utils import cuda_device_count_stateless, import_pynvml
22

23
from .interface import DeviceCapability, Platform, PlatformEnum, _Backend
24

25
if TYPE_CHECKING:
26
    from vllm.config import ModelConfig, VllmConfig
27

28
29
logger = init_logger(__name__)

30
31
32
_P = ParamSpec("_P")
_R = TypeVar("_R")

33
pynvml = import_pynvml()
34

35
36
37
38
# pytorch 2.5 uses cudnn sdpa by default, which will cause crash on some models
# see https://github.com/huggingface/diffusers/issues/9704 for details
torch.backends.cuda.enable_cudnn_sdp(False)

39

40
def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]:
41
42

    @wraps(fn)
43
    def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
44
45
46
47
48
49
50
51
52
        pynvml.nvmlInit()
        try:
            return fn(*args, **kwargs)
        finally:
            pynvml.nvmlShutdown()

    return wrapper


53
54
class CudaPlatformBase(Platform):
    _enum = PlatformEnum.CUDA
55
    device_name: str = "cuda"
56
57
    device_type: str = "cuda"
    dispatch_key: str = "CUDA"
58
    ray_device_key: str = "GPU"
59
    dist_backend: str = "nccl"
60
    device_control_env_var: str = "CUDA_VISIBLE_DEVICES"
61

62
    @property
63
    def supported_dtypes(self) -> list[torch.dtype]:
64
65
66
        if self.has_device_capability(80):
            # Ampere and Hopper or later NVIDIA GPUs.
            return [torch.bfloat16, torch.float16, torch.float32]
67
        if self.has_device_capability(60):
68
69
70
71
72
73
            # Pascal, Volta and Turing NVIDIA GPUs, BF16 is not supported
            return [torch.float16, torch.float32]
        # Kepler and Maxwell NVIDIA GPUs, only FP32 is supported,
        # though vLLM doesn't support these GPUs.
        return [torch.float32]

74
75
76
77
78
    @classmethod
    def set_device(cls, device: torch.device) -> None:
        """
        Set the device for the current platform.
        """
79
        torch.cuda.set_device(device)
80
81
82
83
84
        # With this trick we can force the device to be set eagerly
        # see https://github.com/pytorch/pytorch/issues/155668
        # for why and when it is needed
        _ = torch.zeros(1, device=device)

85
    @classmethod
86
87
88
    def get_device_capability(cls,
                              device_id: int = 0
                              ) -> Optional[DeviceCapability]:
89
        raise NotImplementedError
90

91
92
93
    @classmethod
    def get_device_name(cls, device_id: int = 0) -> str:
        raise NotImplementedError
94

95
96
97
    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        raise NotImplementedError
98

99
    @classmethod
100
    def is_fully_connected(cls, device_ids: list[int]) -> bool:
101
        raise NotImplementedError
102

103
104
105
    @classmethod
    def log_warnings(cls):
        pass
106

107
    @classmethod
108
    def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
109
        parallel_config = vllm_config.parallel_config
110
        model_config = vllm_config.model_config
111

112
        if parallel_config.worker_cls == "auto":
113
            if vllm_config.speculative_config:
114
115
116
117
                if not envs.VLLM_USE_V1:
                    raise NotImplementedError(
                        "Speculative decoding is not supported on vLLM V0.")
                parallel_config.worker_cls = "vllm.v1.worker.gpu_worker.Worker"
118
            else:
119
120
                if envs.VLLM_USE_V1:
                    parallel_config.worker_cls = \
121
                        "vllm.v1.worker.gpu_worker.Worker"
122
123
                else:
                    parallel_config.worker_cls = "vllm.worker.worker.Worker"
124

125
126
127
        cache_config = vllm_config.cache_config
        if cache_config and cache_config.block_size is None:
            cache_config.block_size = 16
128

129
        # TODO(lucas): handle this more gracefully
130
131
        # Note: model_config may be None during testing
        if model_config is not None and model_config.use_mla:
132
133
134
135
136
137
            # If `VLLM_ATTENTION_BACKEND` is not set and we are using MLA,
            # then we default to FlashMLA backend for non-blackwell GPUs,
            # else we default to CutlassMLA. For each case, we force the
            # required block_size.
            use_flashmla = False
            use_cutlass_mla = False
138
            use_flashinfer_mla = False
139
140
141
142
143
144

            if envs.VLLM_ATTENTION_BACKEND is None:
                # Default case
                if cls.is_device_capability(100):
                    # Blackwell => Force CutlassMLA.
                    use_cutlass_mla = True
145
146
147
                    # TODO: This does not work, because the
                    # global_force_attn_backend_context_manager is not set.
                    # See vllm/attention/selector.py:_cached_get_attn_backend
148
                    envs.VLLM_ATTENTION_BACKEND = "CUTLASS_MLA"
149
150
151
152
153
154
155
                else:
                    # Not Blackwell
                    use_flashmla = True
            else:
                # Forced case
                use_flashmla = (envs.VLLM_ATTENTION_BACKEND == "FLASHMLA")
                use_cutlass_mla = (
156
                    envs.VLLM_ATTENTION_BACKEND == "CUTLASS_MLA")
157
158
                use_flashinfer_mla = (
                    envs.VLLM_ATTENTION_BACKEND == "FLASHINFER_MLA")
159

160
            from vllm.attention.ops.flashmla import is_flashmla_supported
161
162
163
164
165
            if use_flashmla and is_flashmla_supported()[0] \
                and cache_config.block_size != 64:
                cache_config.block_size = 64
                logger.info(
                    "Forcing kv cache block size to 64 for FlashMLA backend.")
166

167
168
169
            if use_cutlass_mla and cache_config.block_size != 128:
                cache_config.block_size = 128
                logger.info("Forcing kv cache block size to 128 for "
170
                            "CUTLASS_MLA backend.")
171

172
173
174
175
176
            if use_flashinfer_mla and cache_config.block_size not in [32, 64]:
                cache_config.block_size = 64
                logger.info(
                    "Forcing kv cache block size to 64 for FlashInferMLA "
                    "backend.")
177

178
179
180
        # lazy import to avoid circular import
        from vllm.config import CUDAGraphMode

181
        compilation_config = vllm_config.compilation_config
182
183
        if (envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput"
                and parallel_config.data_parallel_size > 1
184
185
186
187
                and compilation_config.cudagraph_mode != CUDAGraphMode.NONE):
            # TODO: Piecewise Cuda graph might be enabled
            # if torch compile cache key issue fixed
            # See https://github.com/vllm-project/vllm/pull/25093
188
            logger.info(
189
190
191
192
193
194
                "Data Parallel: disabling cudagraphs since DP "
                "with DeepEP high-throughput kernels are not CUDA Graph "
                "compatible. The DeepEP low-latency kernels are CUDA Graph "
                "compatible. Set the all_to_all backend to deepep_low_latency "
                "to use those kernels instead.")
            compilation_config.cudagraph_mode = CUDAGraphMode.NONE
195

196
197
198
199
    @classmethod
    def get_current_memory_usage(cls,
                                 device: Optional[torch.types.Device] = None
                                 ) -> float:
200
        torch.cuda.empty_cache()
201
202
203
        torch.cuda.reset_peak_memory_stats(device)
        return torch.cuda.max_memory_allocated(device)

204
    @classmethod
205
206
207
208
209
210
211
212
213
214
215
    def get_vit_attn_backend(cls, head_size: int,
                             dtype: torch.dtype) -> _Backend:
        if dtype not in (torch.float16, torch.bfloat16):
            return _Backend.XFORMERS

        if cls.has_device_capability(80):
            FLASH_ATTN_V1 = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"  # noqa: E501
            from vllm.attention.selector import is_attn_backend_supported
            is_default_fa_supported = is_attn_backend_supported(
                FLASH_ATTN_V1, head_size, dtype, allow_import_error=False)
            if is_default_fa_supported:
216
                return _Backend.FLASH_ATTN
217
218
219
220
221
222
            else:
                # Fallback to XFORMERS
                return _Backend.XFORMERS
        else:
            # Fallback for Volta/Turing GPUs or FA not supported
            return _Backend.XFORMERS
223

224
225
    @classmethod
    def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
226
227
                             kv_cache_dtype, block_size, use_v1, use_mla,
                             has_sink) -> str:
228
        if use_mla:
229
            # TODO(lucas): refactor to be more concise
230
            #  we should probably consider factoring out V1 here
231
232
233
234
235
236
237

            from vllm.attention.ops.flashmla import is_flashmla_supported
            from vllm.attention.utils.fa_utils import flash_attn_supports_mla

            use_cutlassmla = selected_backend == _Backend.CUTLASS_MLA or (
                selected_backend is None and cls.is_device_capability(100)
                and block_size == 128)
238
239
240
            use_flashinfermla = selected_backend == _Backend.FLASHINFER_MLA or (
                selected_backend is None and cls.is_device_capability(100)
                and block_size in [32, 64])
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
            use_flashmla = selected_backend in [
                _Backend.FLASHMLA, _Backend.FLASHMLA_VLLM_V1
            ] or (selected_backend is None and is_flashmla_supported()[0])
            use_flashattn = selected_backend == _Backend.FLASH_ATTN_MLA or (
                selected_backend is None and flash_attn_supports_mla())
            use_triton = selected_backend == _Backend.TRITON_MLA or (
                selected_backend is None)

            def _get_version(name, import_suffix) -> str:
                if use_v1:
                    logger.info_once(f"Using {name} backend on V1 engine.")
                    return f"vllm.v1.attention.backends.mla.{import_suffix}"
                else:
                    logger.info_once(f"Using {name} backend.")
                    return f"vllm.attention.backends.{import_suffix}"

            if use_cutlassmla:
258
259
260
261
262
263
264
                if use_v1:
                    logger.info_once("Using Cutlass MLA backend on V1 engine.")
                    return ("vllm.v1.attention.backends.mla."
                            "cutlass_mla.CutlassMLABackend")
                else:
                    logger.warning(
                        "Cutlass MLA backend is only supported on V1 engine")
265
266
267
268
269
270
271
272
273
274
275
276
277
            if use_flashinfermla:
                if use_v1:
                    from vllm.v1.attention.backends.utils import (
                        set_kv_cache_layout)
                    set_kv_cache_layout("HND")
                    logger.info_once(
                        "Using FlashInfer MLA backend on V1 engine.")
                    return ("vllm.v1.attention.backends.mla."
                            "flashinfer_mla.FlashInferMLABackend")
                else:
                    logger.warning(
                        "FlashInfer MLA backend is only supported on V1 engine"
                    )
278
279
            if use_flashmla:
                if block_size != 64:
280
281
282
283
284
                    logger.warning(
                        "FlashMLA backend is not supported for block size %d"
                        " (currently only supports block size 64).",
                        block_size)
                else:
285
286
287
288
289
290
291
292
293
294
295
296
297
298
                    return _get_version("FlashMLA", "flashmla.FlashMLABackend")
            if use_flashattn:
                if use_v1:
                    logger.info_once(
                        "Using FlashAttention MLA backend on V1 engine.")
                    return ("vllm.v1.attention.backends.mla."
                            "flashattn_mla.FlashAttnMLABackend")
                else:
                    logger.warning(
                        "FlashAttention MLA backend is only supported on V1 "
                        "engine.")
            if use_triton:
                return _get_version("Triton MLA",
                                    "triton_mla.TritonMLABackend")
299
        if use_v1:
300
301
302
303
            FLASHINFER_V1 = "vllm.v1.attention.backends.flashinfer.FlashInferBackend"  # noqa: E501
            FLEX_ATTENTION_V1 = "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend"  # noqa: E501
            TRITON_ATTN_VLLM_V1 = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend"  # noqa: E501
            FLASH_ATTN_V1 = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"  # noqa: E501
304
            TREE_ATTN_V1 = "vllm.v1.attention.backends.tree_attn.TreeAttentionBackend"  # noqa: E501
305
            XFORMERS_V1 = "vllm.v1.attention.backends.xformers.XFormersAttentionBackend"  # noqa: E501
306

307
308
            if selected_backend == _Backend.FLASHINFER:
                logger.info_once("Using FlashInfer backend on V1 engine.")
309
310
311
312
                if cls.has_device_capability(100):
                    from vllm.v1.attention.backends.utils import (
                        set_kv_cache_layout)
                    set_kv_cache_layout("HND")
313
                return FLASHINFER_V1
314
            elif selected_backend == _Backend.FLEX_ATTENTION:
315
316
                logger.info_once("Using FlexAttention backend on V1 engine.")
                return FLEX_ATTENTION_V1
317
            elif selected_backend == _Backend.TRITON_ATTN_VLLM_V1:
318
                logger.info_once("Using Triton backend on V1 engine.")
319
                return TRITON_ATTN_VLLM_V1
320
321
            elif selected_backend == _Backend.FLASH_ATTN:
                logger.info_once("Using Flash Attention backend on V1 engine.")
322
                return FLASH_ATTN_V1
323
324
325
            elif selected_backend == _Backend.TREE_ATTN:
                logger.info_once("Using Tree Attention backend on V1 engine.")
                return TREE_ATTN_V1
326
327
328
            elif selected_backend == _Backend.XFORMERS_VLLM_V1:
                logger.info_once("Using XFormers backend on V1 engine.")
                return XFORMERS_V1
329

330
            from vllm.attention.selector import is_attn_backend_supported
331
332

            # Default backends for V1 engine
333
            # Prefer FlashInfer for Blackwell GPUs if installed
334
335
336
            if cls.is_device_capability(100):
                if is_default_backend_supported := is_attn_backend_supported(
                        FLASHINFER_V1, head_size, dtype):
337
338
                    from vllm.v1.attention.backends.utils import (
                        set_kv_cache_layout)
339

340
                    logger.info_once(
341
342
343
                        "Using FlashInfer backend with HND KV cache layout on "
                        "V1 engine by default for Blackwell (SM 10.0) GPUs.")
                    set_kv_cache_layout("HND")
344

345
                    return FLASHINFER_V1
346
347
348

                if not is_default_backend_supported.can_import:
                    logger.warning_once(
349
350
351
                        "FlashInfer failed to import for V1 engine on "
                        "Blackwell (SM 10.0) GPUs; it is recommended to "
                        "install FlashInfer for better performance.")
352

353
            # FlashAttention is the default for SM 8.0+ GPUs
354
            if cls.has_device_capability(80):
355
                if has_sink and not cls.is_device_capability(90):
356
357
                    logger.info_once("Using Triton backend on V1 engine.")
                    return TRITON_ATTN_VLLM_V1
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
                if is_default_backend_supported := is_attn_backend_supported(
                        FLASH_ATTN_V1, head_size, dtype,
                        allow_import_error=False):
                    logger.info_once("Using Flash Attention backend on "
                                     "V1 engine.")
                    return FLASH_ATTN_V1

            # FlexAttention is the default for older GPUs
            else:
                logger.info_once("Using FlexAttention backend on V1 engine.")
                return FLEX_ATTENTION_V1

            assert not is_default_backend_supported

            use_flex_attention_reason = {}
            if not is_default_backend_supported.head_size:
                use_flex_attention_reason["head_size"] = head_size
            if not is_default_backend_supported.dtype:
                use_flex_attention_reason["dtype"] = dtype
377

378
379
380
381
382
            logger.info_once(
                "Using FlexAttention backend for %s on V1 engine.",
                ", ".join(f"{k}={v}"
                          for k, v in use_flex_attention_reason.items()),
            )
383
            return FLEX_ATTENTION_V1
384
385

        # Backends for V0 engine
386
        if selected_backend == _Backend.XFORMERS:
387
388
            logger.info("Using XFormers backend.")
            return "vllm.attention.backends.xformers.XFormersBackend"
389
390
391
392
        elif selected_backend == _Backend.DUAL_CHUNK_FLASH_ATTN:
            logger.info("Using DualChunkFlashAttention backend.")
            return ("vllm.attention.backends.dual_chunk_flash_attn."
                    "DualChunkFlashAttentionBackend")
393
394
395
396
        elif selected_backend == _Backend.DIFFERENTIAL_FLASH_ATTN:
            logger.info("Using DifferentialFlashAttention backend.")
            return ("vllm.attention.backends.differential_flash_attn."
                    "DifferentialFlashAttentionBackend")
397
398
399
400
        elif selected_backend == _Backend.FLASH_ATTN:
            pass
        elif selected_backend:
            raise ValueError(
401
402
                f"Invalid attention backend for {cls.device_name}, "
                f"with use_v1: {use_v1} use_mla: {use_mla}")
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427

        target_backend = _Backend.FLASH_ATTN
        if not cls.has_device_capability(80):
            # Volta and Turing NVIDIA GPUs.
            logger.info(
                "Cannot use FlashAttention-2 backend for Volta and Turing "
                "GPUs.")
            target_backend = _Backend.XFORMERS
        elif dtype not in (torch.float16, torch.bfloat16):
            logger.info(
                "Cannot use FlashAttention-2 backend for dtype other than "
                "torch.float16 or torch.bfloat16.")
            target_backend = _Backend.XFORMERS
        elif block_size % 16 != 0:
            logger.info(
                "Cannot use FlashAttention-2 backend for block size not "
                "divisible by 16.")
            target_backend = _Backend.XFORMERS

        # FlashAttn is valid for the model, checking if the package is
        # installed.
        if target_backend == _Backend.FLASH_ATTN:
            try:
                import vllm.vllm_flash_attn  # noqa: F401
                from vllm.attention.backends.flash_attn import (  # noqa: F401
428
                    FlashAttentionBackend, flash_attn_supports_fp8)
429
430
431
432
433
434
435
436

                supported_sizes = \
                    FlashAttentionBackend.get_supported_head_sizes()
                if head_size not in supported_sizes:
                    logger.info(
                        "Cannot use FlashAttention-2 backend for head size %d.",
                        head_size)
                    target_backend = _Backend.XFORMERS
437
438
                fp8_kv_cache = (kv_cache_dtype is not None
                                and kv_cache_dtype.startswith("fp8"))
439
                if (fp8_kv_cache and not flash_attn_supports_fp8()):
440
                    logger.info(
441
                        "Cannot use FlashAttention backend for FP8 KV cache.")
442
                    target_backend = _Backend.XFORMERS
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
            except ImportError:
                logger.info(
                    "Cannot use FlashAttention-2 backend because the "
                    "vllm.vllm_flash_attn package is not found. "
                    "Make sure that vllm_flash_attn was built and installed "
                    "(on by default).")
                target_backend = _Backend.XFORMERS

        if target_backend == _Backend.XFORMERS:
            logger.info("Using XFormers backend.")
            return "vllm.attention.backends.xformers.XFormersBackend"

        logger.info("Using Flash Attention backend.")
        return "vllm.attention.backends.flash_attn.FlashAttentionBackend"

458
459
460
461
    @classmethod
    def get_punica_wrapper(cls) -> str:
        return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU"

462
463
464
465
    @classmethod
    def get_device_communicator_cls(cls) -> str:
        return "vllm.distributed.device_communicators.cuda_communicator.CudaCommunicator"  # noqa

466
467
468
469
    @classmethod
    def supports_fp8(cls) -> bool:
        return cls.has_device_capability(89)

470
    @classmethod
471
    def supports_v1(cls, model_config: "ModelConfig") -> bool:
472
473
        return True

474
475
476
477
    @classmethod
    def use_custom_allreduce(cls) -> bool:
        return True

478
479
480
481
    @classmethod
    def opaque_attention_op(cls) -> bool:
        return True

482
    @classmethod
483
484
    def get_static_graph_wrapper_cls(cls) -> str:
        return "vllm.compilation.cuda_graph.CUDAGraphWrapper"
485

486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
    @classmethod
    def stateless_init_device_torch_dist_pg(
        cls,
        backend: str,
        prefix_store: PrefixStore,
        group_rank: int,
        group_size: int,
        timeout: timedelta,
    ) -> ProcessGroup:
        assert is_nccl_available()
        pg: ProcessGroup = ProcessGroup(
            prefix_store,
            group_rank,
            group_size,
        )
        from torch.distributed.distributed_c10d import ProcessGroupNCCL

        backend_options = ProcessGroupNCCL.Options()
        backend_options._timeout = timeout

        backend_class = ProcessGroupNCCL(prefix_store, group_rank, group_size,
                                         backend_options)
        backend_type = ProcessGroup.BackendType.NCCL
        device = torch.device("cuda")
        pg._set_default_backend(backend_type)
        backend_class._set_sequence_number_for_group()

        pg._register_backend(device, backend_type, backend_class)
        return pg

516
517
518
519
    @classmethod
    def device_count(cls) -> int:
        return cuda_device_count_stateless()

520
    @classmethod
521
522
    def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str,
                                    model_config: "ModelConfig") -> bool:
523
        fp8_attention = kv_cache_dtype.startswith("fp8")
524
525
        attention_backend = envs.VLLM_ATTENTION_BACKEND

526
        supported = False
527
528
529
530
531
532
533
534
535
        if model_config is not None and model_config.use_mla:
            # Default to CutlassMLA for blackwell,
            # FlashMLA otherwise
            if attention_backend is None:
                if cls.is_device_capability(100):
                    attention_backend = "CUTLASS_MLA"
                else:
                    attention_backend = "FLASHMLA"

536
            # Only FlashMLA and CUTLASS_MLA support fp8
537
538
539
            if attention_backend in [
                    "FLASHMLA", "CUTLASS_MLA", "FLASHINFER_MLA"
            ]:
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
                supported = True
            else:
                supported = (not fp8_attention)
        else:
            # Default to FlashAttention
            if attention_backend is None:
                attention_backend = "FLASH_ATTN_VLLM_V1"

            # All Blackwell backends support fp8
            if cls.is_device_capability(100):
                supported = True
            elif attention_backend == "FLASH_ATTN_VLLM_V1":
                if fp8_attention:
                    from vllm.attention.utils.fa_utils import (
                        flash_attn_supports_fp8)
                    supported = flash_attn_supports_fp8()
                else:
                    supported = True
558
559
560
561
            elif attention_backend == "FLASHINFER":
                supported = True
            elif attention_backend == "TRITON_ATTN_VLLM_V1":
                supported = cls.supports_fp8()
562
563
        return supported

564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
    @classmethod
    def check_if_supports_dtype(cls, torch_dtype: torch.dtype):
        if torch_dtype == torch.bfloat16:  # noqa: SIM102
            if not cls.has_device_capability(80):
                capability = cls.get_device_capability()
                gpu_name = cls.get_device_name()

                if capability is None:
                    compute_str = "does not have a compute capability"
                else:
                    version_str = capability.as_version_str()
                    compute_str = f"has compute capability {version_str}"

                raise ValueError(
                    "Bfloat16 is only supported on GPUs "
                    "with compute capability of at least 8.0. "
                    f"Your {gpu_name} GPU {compute_str}. "
                    "You can use float16 instead by explicitly setting the "
                    "`dtype` flag in CLI, for example: --dtype=half.")

584
585
586
587
    @classmethod
    def support_hybrid_kv_cache(cls) -> bool:
        return True

588

589
590
591
592
593
# NVML utils
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
# all the related functions work on real physical device ids.
# the major benefit of using NVML is that it will not initialize CUDA
class NvmlCudaPlatform(CudaPlatformBase):
594

595
    @classmethod
596
    @cache
597
    @with_nvml_context
598
599
600
601
    def get_device_capability(cls,
                              device_id: int = 0
                              ) -> Optional[DeviceCapability]:
        try:
602
            physical_device_id = cls.device_id_to_physical_device_id(device_id)
603
604
605
606
607
608
609
610
611
612
            handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
            major, minor = pynvml.nvmlDeviceGetCudaComputeCapability(handle)
            return DeviceCapability(major=major, minor=minor)
        except RuntimeError:
            return None

    @classmethod
    @with_nvml_context
    def has_device_capability(
        cls,
613
        capability: Union[tuple[int, int], int],
614
615
616
617
618
619
        device_id: int = 0,
    ) -> bool:
        try:
            return super().has_device_capability(capability, device_id)
        except RuntimeError:
            return False
620

621
    @classmethod
622
    @with_nvml_context
623
    def get_device_name(cls, device_id: int = 0) -> str:
624
        physical_device_id = cls.device_id_to_physical_device_id(device_id)
625
        return cls._get_physical_device_name(physical_device_id)
626

627
628
629
    @classmethod
    @with_nvml_context
    def get_device_uuid(cls, device_id: int = 0) -> str:
630
        physical_device_id = cls.device_id_to_physical_device_id(device_id)
631
632
633
        handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
        return pynvml.nvmlDeviceGetUUID(handle)

634
    @classmethod
635
    @with_nvml_context
636
    def get_device_total_memory(cls, device_id: int = 0) -> int:
637
        physical_device_id = cls.device_id_to_physical_device_id(device_id)
638
639
        handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
        return int(pynvml.nvmlDeviceGetMemoryInfo(handle).total)
640

641
    @classmethod
642
    @with_nvml_context
643
    def is_fully_connected(cls, physical_device_ids: list[int]) -> bool:
644
645
646
647
648
649
650
651
652
653
654
        """
        query if the set of gpus are fully connected by nvlink (1 hop)
        """
        handles = [
            pynvml.nvmlDeviceGetHandleByIndex(i) for i in physical_device_ids
        ]
        for i, handle in enumerate(handles):
            for j, peer_handle in enumerate(handles):
                if i < j:
                    try:
                        p2p_status = pynvml.nvmlDeviceGetP2PStatus(
655
656
657
658
                            handle,
                            peer_handle,
                            pynvml.NVML_P2P_CAPS_INDEX_NVLINK,
                        )
659
660
                        if p2p_status != pynvml.NVML_P2P_STATUS_OK:
                            return False
661
662
                    except pynvml.NVMLError:
                        logger.exception(
663
664
                            "NVLink detection failed. This is normal if"
                            " your machine has no NVLink equipped.")
665
666
                        return False
        return True
667
668

    @classmethod
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
    def _get_physical_device_name(cls, device_id: int = 0) -> str:
        handle = pynvml.nvmlDeviceGetHandleByIndex(device_id)
        return pynvml.nvmlDeviceGetName(handle)

    @classmethod
    @with_nvml_context
    def log_warnings(cls):
        device_ids: int = pynvml.nvmlDeviceGetCount()
        if device_ids > 1:
            device_names = [
                cls._get_physical_device_name(i) for i in range(device_ids)
            ]
            if (len(set(device_names)) > 1
                    and os.environ.get("CUDA_DEVICE_ORDER") != "PCI_BUS_ID"):
                logger.warning(
684
                    "Detected different devices in the system: %s. Please"
685
686
                    " make sure to set `CUDA_DEVICE_ORDER=PCI_BUS_ID` to "
                    "avoid unexpected behavior.",
687
                    ", ".join(device_names),
688
689
690
691
692
693
                )


class NonNvmlCudaPlatform(CudaPlatformBase):

    @classmethod
694
    @cache
695
696
697
698
699
700
701
702
703
704
705
706
707
708
    def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:
        major, minor = torch.cuda.get_device_capability(device_id)
        return DeviceCapability(major=major, minor=minor)

    @classmethod
    def get_device_name(cls, device_id: int = 0) -> str:
        return torch.cuda.get_device_name(device_id)

    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        device_props = torch.cuda.get_device_properties(device_id)
        return device_props.total_memory

    @classmethod
709
    def is_fully_connected(cls, physical_device_ids: list[int]) -> bool:
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
        logger.exception(
            "NVLink detection not possible, as context support was"
            " not found. Assuming no NVLink available.")
        return False


# Autodetect either NVML-enabled or non-NVML platform
# based on whether NVML is available.
nvml_available = False
try:
    try:
        pynvml.nvmlInit()
        nvml_available = True
    except Exception:
        # On Jetson, NVML is not supported.
        nvml_available = False
finally:
    if nvml_available:
        pynvml.nvmlShutdown()

CudaPlatform = NvmlCudaPlatform if nvml_available else NonNvmlCudaPlatform

732
CudaPlatform.log_warnings()