cuda.py 27.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
"""Code inside this file can safely assume cuda platform, e.g. importing
pynvml. However, it should not initialize cuda context.
"""

7
import os
8
from collections.abc import Callable
9
from datetime import timedelta
10
from functools import cache, wraps
11
from typing import TYPE_CHECKING, TypeVar
12

13
import torch
14
15
from torch.distributed import PrefixStore, ProcessGroup
from torch.distributed.distributed_c10d import is_nccl_available
16
from typing_extensions import ParamSpec
17

18
19
# import custom ops, trigger op registration
import vllm._C  # noqa
20
import vllm.envs as envs
21
from vllm.logger import init_logger
22
from vllm.utils import cuda_device_count_stateless, import_pynvml
23

24
from .interface import DeviceCapability, Platform, PlatformEnum
25

26
if TYPE_CHECKING:
27
    from vllm.attention.backends.registry import _Backend
28
    from vllm.config import ModelConfig, VllmConfig
29
30
else:
    _Backend = None
31

32
33
logger = init_logger(__name__)

34
35
36
_P = ParamSpec("_P")
_R = TypeVar("_R")

37
pynvml = import_pynvml()
38

39
40
41
42
# pytorch 2.5 uses cudnn sdpa by default, which will cause crash on some models
# see https://github.com/huggingface/diffusers/issues/9704 for details
torch.backends.cuda.enable_cudnn_sdp(False)

43

44
def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]:
45
    @wraps(fn)
46
    def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
47
48
49
50
51
52
53
54
55
        pynvml.nvmlInit()
        try:
            return fn(*args, **kwargs)
        finally:
            pynvml.nvmlShutdown()

    return wrapper


56
57
class CudaPlatformBase(Platform):
    _enum = PlatformEnum.CUDA
58
    device_name: str = "cuda"
59
60
    device_type: str = "cuda"
    dispatch_key: str = "CUDA"
61
    ray_device_key: str = "GPU"
62
    dist_backend: str = "nccl"
63
    device_control_env_var: str = "CUDA_VISIBLE_DEVICES"
64

65
    @property
66
    def supported_dtypes(self) -> list[torch.dtype]:
67
68
69
        if self.has_device_capability(80):
            # Ampere and Hopper or later NVIDIA GPUs.
            return [torch.bfloat16, torch.float16, torch.float32]
70
        if self.has_device_capability(60):
71
72
73
74
75
76
            # Pascal, Volta and Turing NVIDIA GPUs, BF16 is not supported
            return [torch.float16, torch.float32]
        # Kepler and Maxwell NVIDIA GPUs, only FP32 is supported,
        # though vLLM doesn't support these GPUs.
        return [torch.float32]

77
78
79
80
81
    @classmethod
    def set_device(cls, device: torch.device) -> None:
        """
        Set the device for the current platform.
        """
82
        torch.cuda.set_device(device)
83
84
85
86
87
        # With this trick we can force the device to be set eagerly
        # see https://github.com/pytorch/pytorch/issues/155668
        # for why and when it is needed
        _ = torch.zeros(1, device=device)

88
    @classmethod
89
    def get_device_capability(cls, device_id: int = 0) -> DeviceCapability | None:
90
        raise NotImplementedError
91

92
93
94
    @classmethod
    def get_device_name(cls, device_id: int = 0) -> str:
        raise NotImplementedError
95

96
97
98
    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        raise NotImplementedError
99

100
    @classmethod
101
    def is_fully_connected(cls, device_ids: list[int]) -> bool:
102
        raise NotImplementedError
103

104
105
106
    @classmethod
    def log_warnings(cls):
        pass
107

108
    @classmethod
109
    def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
110
        parallel_config = vllm_config.parallel_config
111
        model_config = vllm_config.model_config
112

113
        if parallel_config.worker_cls == "auto":
114
            parallel_config.worker_cls = "vllm.v1.worker.gpu_worker.Worker"
115

116
117
118
        cache_config = vllm_config.cache_config
        if cache_config and cache_config.block_size is None:
            cache_config.block_size = 16
119

120
        # TODO(lucas): handle this more gracefully
121
        # Note: model_config may be None during testing
122
123
124
125
126
127
128
129
130
        # Note: block_size is initialized in
        # HybridAttentionMambaModelConfig.verify_and_update_config
        # for models with both attention and mamba,
        # and doesn't need to be reinitialized here
        if (
            model_config is not None
            and model_config.use_mla
            and cache_config.block_size is not None
        ):
131
            use_sparse = hasattr(vllm_config.model_config.hf_config, "index_topk")
132
133
134
135
136
137
            # If `VLLM_ATTENTION_BACKEND` is not set and we are using MLA,
            # then we default to FlashMLA backend for non-blackwell GPUs,
            # else we default to CutlassMLA. For each case, we force the
            # required block_size.
            use_flashmla = False
            use_cutlass_mla = False
138
            use_flashinfer_mla = False
139
140
141
142
143
144

            if envs.VLLM_ATTENTION_BACKEND is None:
                # Default case
                if cls.is_device_capability(100):
                    # Blackwell => Force CutlassMLA.
                    use_cutlass_mla = True
145
146
147
                    # TODO: This does not work, because the
                    # global_force_attn_backend_context_manager is not set.
                    # See vllm/attention/selector.py:_cached_get_attn_backend
148
                    envs.VLLM_ATTENTION_BACKEND = "CUTLASS_MLA"
149
150
151
152
153
                else:
                    # Not Blackwell
                    use_flashmla = True
            else:
                # Forced case
154
155
156
                use_flashmla = envs.VLLM_ATTENTION_BACKEND == "FLASHMLA"
                use_cutlass_mla = envs.VLLM_ATTENTION_BACKEND == "CUTLASS_MLA"
                use_flashinfer_mla = envs.VLLM_ATTENTION_BACKEND == "FLASHINFER_MLA"
157

158
            from vllm.attention.ops.flashmla import is_flashmla_dense_supported
159
160
161

            if (
                use_flashmla
162
                and is_flashmla_dense_supported()[0]
163
                and cache_config.block_size % 64 != 0
164
            ):
165
                cache_config.block_size = 64
166
                logger.info("Forcing kv cache block size to 64 for FlashMLA backend.")
167

168
            if use_cutlass_mla and cache_config.block_size % 128 != 0:
169
                cache_config.block_size = 128
170
171
172
                logger.info(
                    "Forcing kv cache block size to 128 for CUTLASS_MLA backend."
                )
173

174
175
176
177
178
            if (
                use_flashinfer_mla
                and cache_config.block_size != 32
                and cache_config.block_size % 64 != 0
            ):
179
180
                cache_config.block_size = 64
                logger.info(
181
182
                    "Forcing kv cache block size to 64 for FlashInferMLA backend."
                )
183

184
185
186
187
            # TODO(Chen): remove this hacky code
            if use_sparse and cache_config.block_size != 64:
                cache_config.block_size = 64
                logger.info(
188
189
                    "Forcing kv cache block size to 64 for FlashMLASparse backend."
                )
190
191
192
        # lazy import to avoid circular import
        from vllm.config import CUDAGraphMode

193
        compilation_config = vllm_config.compilation_config
194
195
196
197
198
        if (
            envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput"
            and parallel_config.data_parallel_size > 1
            and compilation_config.cudagraph_mode != CUDAGraphMode.NONE
        ):
199
200
201
            # TODO: Piecewise Cuda graph might be enabled
            # if torch compile cache key issue fixed
            # See https://github.com/vllm-project/vllm/pull/25093
202
            logger.info(
203
204
205
206
207
                "WideEP: Disabling CUDA Graphs since DeepEP high-throughput "
                "kernels are optimized for prefill and are incompatible with "
                "CUDA Graphs. "
                "In order to use CUDA Graphs for decode-optimized workloads, "
                "set VLLM_ALL2ALL_BACKEND to another option, such as "
208
209
                "deepep_low_latency, pplx, or allgather_reducescatter."
            )
210
            compilation_config.cudagraph_mode = CUDAGraphMode.NONE
211

212
    @classmethod
213
    def get_current_memory_usage(
214
        cls, device: torch.types.Device | None = None
215
    ) -> float:
216
        torch.cuda.empty_cache()
217
218
219
        torch.cuda.reset_peak_memory_stats(device)
        return torch.cuda.max_memory_allocated(device)

220
    @classmethod
221
    def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> "_Backend":
222
        from vllm.attention.backends.registry import _Backend
223
224
225
226
227
228

        # For Blackwell GPUs, force TORCH_SDPA for now.
        # See https://github.com/facebookresearch/xformers/issues/1317#issuecomment-3199392579 # noqa: E501
        if cls.has_device_capability(100):
            return _Backend.TORCH_SDPA

229
230
231
232
        if dtype not in (torch.float16, torch.bfloat16):
            return _Backend.XFORMERS

        if cls.has_device_capability(80):
233
234
235
            FLASH_ATTN_V1 = (
                "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"  # noqa: E501
            )
236
            from vllm.attention.selector import is_attn_backend_supported
237

238
            is_default_fa_supported = is_attn_backend_supported(
239
240
                FLASH_ATTN_V1, head_size, dtype, allow_import_error=False
            )
241
            if is_default_fa_supported:
242
                return _Backend.FLASH_ATTN
243
244
245
246
247
248
            else:
                # Fallback to XFORMERS
                return _Backend.XFORMERS
        else:
            # Fallback for Volta/Turing GPUs or FA not supported
            return _Backend.XFORMERS
249

250
    @classmethod
251
252
253
254
255
256
257
258
259
260
261
262
    def get_attn_backend_cls(
        cls,
        selected_backend,
        head_size,
        dtype,
        kv_cache_dtype,
        block_size,
        use_v1,
        use_mla,
        has_sink,
        use_sparse,
    ) -> str:
263
        from vllm.attention.backends.registry import _Backend
264

265
        if use_mla:
266
267
268
            if not use_v1:
                raise RuntimeError(
                    "MLA attention backends require the V1 engine. "
269
270
                    "Set VLLM_USE_V1=1 to enable them."
                )
271

272
            from vllm.attention.ops.flashmla import is_flashmla_dense_supported
273
274
            from vllm.attention.utils.fa_utils import flash_attn_supports_mla

275
276
            if use_sparse:
                logger.info_once("Using Sparse MLA backend on V1 engine.")
277
278
279
280
                return (
                    "vllm.v1.attention.backends.mla.flashmla_sparse."
                    "FlashMLASparseBackend"
                )
281

282
            use_cutlassmla = selected_backend == _Backend.CUTLASS_MLA or (
283
284
                selected_backend is None
                and cls.is_device_capability(100)
285
                and block_size % 128 == 0
286
            )
287
            use_flashinfermla = selected_backend == _Backend.FLASHINFER_MLA or (
288
289
                selected_backend is None
                and cls.is_device_capability(100)
290
                and (block_size == 32 or block_size % 64 == 0)
291
            )
292
            use_flashmla = selected_backend == _Backend.FLASHMLA or (
293
                selected_backend is None and is_flashmla_dense_supported()[0]
294
            )
295
            use_flashattn = selected_backend == _Backend.FLASH_ATTN_MLA or (
296
297
                selected_backend is None and flash_attn_supports_mla()
            )
298
            use_triton = selected_backend == _Backend.TRITON_MLA or (
299
300
                selected_backend is None
            )
301
302

            if use_cutlassmla:
303
                logger.info_once("Using Cutlass MLA backend on V1 engine.")
304
                return "vllm.v1.attention.backends.mla.cutlass_mla.CutlassMLABackend"
305
            if use_flashinfermla:
306
307
                from vllm.v1.attention.backends.utils import set_kv_cache_layout

308
309
                set_kv_cache_layout("HND")
                logger.info_once("Using FlashInfer MLA backend on V1 engine.")
310
311
312
                return (
                    "vllm.v1.attention.backends.mla.flashinfer_mla.FlashInferMLABackend"
                )
313
            if use_flashmla:
314
                if block_size % 64 != 0:
315
316
317
                    logger.warning(
                        "FlashMLA backend is not supported for block size %d"
                        " (currently only supports block size 64).",
318
319
                        block_size,
                    )
320
                else:
321
                    logger.info_once("Using FlashMLA backend on V1 engine.")
322
                    return "vllm.v1.attention.backends.mla.flashmla.FlashMLABackend"
323
            if use_flashattn:
324
325
326
327
                logger.info_once("Using FlashAttention MLA backend on V1 engine.")
                return (
                    "vllm.v1.attention.backends.mla.flashattn_mla.FlashAttnMLABackend"
                )
328
            if use_triton:
329
                logger.info_once("Using Triton MLA backend on V1 engine.")
330
                return "vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend"
331
        if use_v1:
332
            FLASHINFER_V1 = "vllm.v1.attention.backends.flashinfer.FlashInferBackend"  # noqa: E501
333
334
335
336
337
338
339
340
341
            FLEX_ATTENTION_V1 = (
                "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend"  # noqa: E501
            )
            TRITON_ATTN = (
                "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend"  # noqa: E501
            )
            FLASH_ATTN_V1 = (
                "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"  # noqa: E501
            )
342
            TREE_ATTN_V1 = "vllm.v1.attention.backends.tree_attn.TreeAttentionBackend"  # noqa: E501
343
            XFORMERS_V1 = "vllm.v1.attention.backends.xformers.XFormersAttentionBackend"  # noqa: E501
344

345
346
347
            use_fp8_kv_cache = kv_cache_dtype is not None and kv_cache_dtype.startswith(
                "fp8"
            )
348

349
350
            if selected_backend == _Backend.FLASHINFER:
                logger.info_once("Using FlashInfer backend on V1 engine.")
351
                if cls.has_device_capability(100):
352
353
                    from vllm.v1.attention.backends.utils import set_kv_cache_layout

354
                    set_kv_cache_layout("HND")
355
                return FLASHINFER_V1
356
            elif selected_backend == _Backend.FLEX_ATTENTION:
357
358
                logger.info_once("Using FlexAttention backend on V1 engine.")
                return FLEX_ATTENTION_V1
359
            elif selected_backend == _Backend.TRITON_ATTN:
360
                logger.info_once("Using Triton backend on V1 engine.")
361
                return TRITON_ATTN
362
363
            elif selected_backend == _Backend.FLASH_ATTN:
                logger.info_once("Using Flash Attention backend on V1 engine.")
364
                return FLASH_ATTN_V1
365
366
367
            elif selected_backend == _Backend.TREE_ATTN:
                logger.info_once("Using Tree Attention backend on V1 engine.")
                return TREE_ATTN_V1
368
            elif selected_backend == _Backend.XFORMERS:
369
370
                logger.info_once("Using XFormers backend on V1 engine.")
                return XFORMERS_V1
371

372
            from vllm.attention.selector import is_attn_backend_supported
373
374

            # Default backends for V1 engine
375
            # Prefer FlashInfer for Blackwell GPUs if installed
376
377
            if cls.is_device_capability(100):
                if is_default_backend_supported := is_attn_backend_supported(
378
379
380
                    FLASHINFER_V1, head_size, dtype
                ):
                    from vllm.v1.attention.backends.utils import set_kv_cache_layout
381

382
                    logger.info_once(
383
                        "Using FlashInfer backend with HND KV cache layout on "
384
385
                        "V1 engine by default for Blackwell (SM 10.0) GPUs."
                    )
386
                    set_kv_cache_layout("HND")
387

388
                    return FLASHINFER_V1
389
390
391

                if not is_default_backend_supported.can_import:
                    logger.warning_once(
392
393
                        "FlashInfer failed to import for V1 engine on "
                        "Blackwell (SM 10.0) GPUs; it is recommended to "
394
395
                        "install FlashInfer for better performance."
                    )
396

397
            # FlashAttention is the default for SM 8.0+ GPUs
398
            if cls.has_device_capability(80):
399
                if (has_sink or use_fp8_kv_cache) and not cls.is_device_capability(90):
400
                    logger.info_once("Using Triton backend on V1 engine.")
401
                    return TRITON_ATTN
402
                elif is_default_backend_supported := is_attn_backend_supported(
403
404
405
                    FLASH_ATTN_V1, head_size, dtype, allow_import_error=False
                ):
                    logger.info_once("Using Flash Attention backend on V1 engine.")
406
407
408
409
410
411
412
413
414
415
416
417
418
419
                    return FLASH_ATTN_V1

            # FlexAttention is the default for older GPUs
            else:
                logger.info_once("Using FlexAttention backend on V1 engine.")
                return FLEX_ATTENTION_V1

            assert not is_default_backend_supported

            use_flex_attention_reason = {}
            if not is_default_backend_supported.head_size:
                use_flex_attention_reason["head_size"] = head_size
            if not is_default_backend_supported.dtype:
                use_flex_attention_reason["dtype"] = dtype
420

421
422
            logger.info_once(
                "Using FlexAttention backend for %s on V1 engine.",
423
                ", ".join(f"{k}={v}" for k, v in use_flex_attention_reason.items()),
424
            )
425
            return FLEX_ATTENTION_V1
426

427
428
        raise RuntimeError(
            "V0 attention backends have been removed. Set VLLM_USE_V1=1 "
429
430
            "to select a supported backend."
        )
431

432
433
434
435
    @classmethod
    def get_punica_wrapper(cls) -> str:
        return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU"

436
437
    @classmethod
    def get_device_communicator_cls(cls) -> str:
438
439
440
        return (
            "vllm.distributed.device_communicators.cuda_communicator.CudaCommunicator"  # noqa
        )
441

442
443
444
445
    @classmethod
    def supports_fp8(cls) -> bool:
        return cls.has_device_capability(89)

446
447
448
449
    @classmethod
    def use_custom_allreduce(cls) -> bool:
        return True

450
451
452
453
    @classmethod
    def opaque_attention_op(cls) -> bool:
        return True

454
    @classmethod
455
456
    def get_static_graph_wrapper_cls(cls) -> str:
        return "vllm.compilation.cuda_graph.CUDAGraphWrapper"
457

458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
    @classmethod
    def stateless_init_device_torch_dist_pg(
        cls,
        backend: str,
        prefix_store: PrefixStore,
        group_rank: int,
        group_size: int,
        timeout: timedelta,
    ) -> ProcessGroup:
        assert is_nccl_available()
        pg: ProcessGroup = ProcessGroup(
            prefix_store,
            group_rank,
            group_size,
        )
        from torch.distributed.distributed_c10d import ProcessGroupNCCL

        backend_options = ProcessGroupNCCL.Options()
        backend_options._timeout = timeout

478
479
480
        backend_class = ProcessGroupNCCL(
            prefix_store, group_rank, group_size, backend_options
        )
481
482
483
484
485
486
487
488
        backend_type = ProcessGroup.BackendType.NCCL
        device = torch.device("cuda")
        pg._set_default_backend(backend_type)
        backend_class._set_sequence_number_for_group()

        pg._register_backend(device, backend_type, backend_class)
        return pg

489
490
491
492
    @classmethod
    def device_count(cls) -> int:
        return cuda_device_count_stateless()

493
    @classmethod
494
495
496
    def is_kv_cache_dtype_supported(
        cls, kv_cache_dtype: str, model_config: "ModelConfig"
    ) -> bool:
497
        fp8_attention = kv_cache_dtype.startswith("fp8")
498
499
        attention_backend = envs.VLLM_ATTENTION_BACKEND

500
        supported = False
501
502
503
504
505
506
507
508
509
        if model_config is not None and model_config.use_mla:
            # Default to CutlassMLA for blackwell,
            # FlashMLA otherwise
            if attention_backend is None:
                if cls.is_device_capability(100):
                    attention_backend = "CUTLASS_MLA"
                else:
                    attention_backend = "FLASHMLA"

510
            # Only FlashMLA and CUTLASS_MLA support fp8
511
            if attention_backend in ["FLASHMLA", "CUTLASS_MLA", "FLASHINFER_MLA"]:
512
513
                supported = True
            else:
514
                supported = not fp8_attention
515
516
517
        else:
            # Default to FlashAttention
            if attention_backend is None:
518
                attention_backend = "FLASH_ATTN"
519
520
521
522

            # All Blackwell backends support fp8
            if cls.is_device_capability(100):
                supported = True
523
            elif attention_backend == "FLASH_ATTN":
524
                if fp8_attention:
525
526
                    from vllm.attention.utils.fa_utils import flash_attn_supports_fp8

527
528
529
                    supported = flash_attn_supports_fp8()
                else:
                    supported = True
530
531
            elif attention_backend == "FLASHINFER":
                supported = True
532
            elif attention_backend == "TRITON_ATTN":
533
                supported = cls.supports_fp8()
534
535
        return supported

536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
    @classmethod
    def check_if_supports_dtype(cls, torch_dtype: torch.dtype):
        if torch_dtype == torch.bfloat16:  # noqa: SIM102
            if not cls.has_device_capability(80):
                capability = cls.get_device_capability()
                gpu_name = cls.get_device_name()

                if capability is None:
                    compute_str = "does not have a compute capability"
                else:
                    version_str = capability.as_version_str()
                    compute_str = f"has compute capability {version_str}"

                raise ValueError(
                    "Bfloat16 is only supported on GPUs "
                    "with compute capability of at least 8.0. "
                    f"Your {gpu_name} GPU {compute_str}. "
                    "You can use float16 instead by explicitly setting the "
554
555
                    "`dtype` flag in CLI, for example: --dtype=half."
                )
556

557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
    @classmethod
    def insert_blocks_to_device(
        cls,
        src_cache: torch.Tensor,
        dst_cache: torch.Tensor,
        src_block_indices: torch.Tensor,
        dst_block_indices: torch.Tensor,
    ) -> None:
        """Copy blocks from src_cache to dst_cache on GPU."""
        _src_cache = src_cache[:, src_block_indices]
        dst_cache[:, dst_block_indices] = _src_cache.to(dst_cache.device)

    @classmethod
    def swap_out_blocks_to_host(
        cls,
        src_cache: torch.Tensor,
        dst_cache: torch.Tensor,
        src_block_indices: torch.Tensor,
        dst_block_indices: torch.Tensor,
    ) -> None:
        """Copy blocks from GPU to host (CPU)."""
        _src_cache = src_cache[:, src_block_indices]
        dst_cache[:, dst_block_indices] = _src_cache.cpu()

581
582
583
584
    @classmethod
    def support_hybrid_kv_cache(cls) -> bool:
        return True

585
586
587
588
    @classmethod
    def support_static_graph_mode(cls) -> bool:
        return True

589

590
591
592
593
594
# NVML utils
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
# all the related functions work on real physical device ids.
# the major benefit of using NVML is that it will not initialize CUDA
class NvmlCudaPlatform(CudaPlatformBase):
595
    @classmethod
596
    @cache
597
    @with_nvml_context
598
    def get_device_capability(cls, device_id: int = 0) -> DeviceCapability | None:
599
        try:
600
            physical_device_id = cls.device_id_to_physical_device_id(device_id)
601
602
603
604
605
606
607
608
609
610
            handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
            major, minor = pynvml.nvmlDeviceGetCudaComputeCapability(handle)
            return DeviceCapability(major=major, minor=minor)
        except RuntimeError:
            return None

    @classmethod
    @with_nvml_context
    def has_device_capability(
        cls,
611
        capability: tuple[int, int] | int,
612
613
614
615
616
617
        device_id: int = 0,
    ) -> bool:
        try:
            return super().has_device_capability(capability, device_id)
        except RuntimeError:
            return False
618

619
    @classmethod
620
    @with_nvml_context
621
    def get_device_name(cls, device_id: int = 0) -> str:
622
        physical_device_id = cls.device_id_to_physical_device_id(device_id)
623
        return cls._get_physical_device_name(physical_device_id)
624

625
626
627
    @classmethod
    @with_nvml_context
    def get_device_uuid(cls, device_id: int = 0) -> str:
628
        physical_device_id = cls.device_id_to_physical_device_id(device_id)
629
630
631
        handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
        return pynvml.nvmlDeviceGetUUID(handle)

632
    @classmethod
633
    @with_nvml_context
634
    def get_device_total_memory(cls, device_id: int = 0) -> int:
635
        physical_device_id = cls.device_id_to_physical_device_id(device_id)
636
637
        handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
        return int(pynvml.nvmlDeviceGetMemoryInfo(handle).total)
638

639
    @classmethod
640
    @with_nvml_context
641
    def is_fully_connected(cls, physical_device_ids: list[int]) -> bool:
642
643
644
        """
        query if the set of gpus are fully connected by nvlink (1 hop)
        """
645
        handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in physical_device_ids]
646
647
648
649
650
        for i, handle in enumerate(handles):
            for j, peer_handle in enumerate(handles):
                if i < j:
                    try:
                        p2p_status = pynvml.nvmlDeviceGetP2PStatus(
651
652
653
654
                            handle,
                            peer_handle,
                            pynvml.NVML_P2P_CAPS_INDEX_NVLINK,
                        )
655
656
                        if p2p_status != pynvml.NVML_P2P_STATUS_OK:
                            return False
657
658
                    except pynvml.NVMLError:
                        logger.exception(
659
                            "NVLink detection failed. This is normal if"
660
661
                            " your machine has no NVLink equipped."
                        )
662
663
                        return False
        return True
664
665

    @classmethod
666
667
668
669
670
671
672
673
674
    def _get_physical_device_name(cls, device_id: int = 0) -> str:
        handle = pynvml.nvmlDeviceGetHandleByIndex(device_id)
        return pynvml.nvmlDeviceGetName(handle)

    @classmethod
    @with_nvml_context
    def log_warnings(cls):
        device_ids: int = pynvml.nvmlDeviceGetCount()
        if device_ids > 1:
675
676
677
678
679
            device_names = [cls._get_physical_device_name(i) for i in range(device_ids)]
            if (
                len(set(device_names)) > 1
                and os.environ.get("CUDA_DEVICE_ORDER") != "PCI_BUS_ID"
            ):
680
                logger.warning(
681
                    "Detected different devices in the system: %s. Please"
682
683
                    " make sure to set `CUDA_DEVICE_ORDER=PCI_BUS_ID` to "
                    "avoid unexpected behavior.",
684
                    ", ".join(device_names),
685
686
687
688
689
                )


class NonNvmlCudaPlatform(CudaPlatformBase):
    @classmethod
690
    @cache
691
692
693
694
695
696
697
698
699
700
701
702
703
704
    def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:
        major, minor = torch.cuda.get_device_capability(device_id)
        return DeviceCapability(major=major, minor=minor)

    @classmethod
    def get_device_name(cls, device_id: int = 0) -> str:
        return torch.cuda.get_device_name(device_id)

    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        device_props = torch.cuda.get_device_properties(device_id)
        return device_props.total_memory

    @classmethod
705
    def is_fully_connected(cls, physical_device_ids: list[int]) -> bool:
706
707
        logger.exception(
            "NVLink detection not possible, as context support was"
708
709
            " not found. Assuming no NVLink available."
        )
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
        return False


# Autodetect either NVML-enabled or non-NVML platform
# based on whether NVML is available.
nvml_available = False
try:
    try:
        pynvml.nvmlInit()
        nvml_available = True
    except Exception:
        # On Jetson, NVML is not supported.
        nvml_available = False
finally:
    if nvml_available:
        pynvml.nvmlShutdown()

CudaPlatform = NvmlCudaPlatform if nvml_available else NonNvmlCudaPlatform

729
CudaPlatform.log_warnings()