cuda.py 27.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
"""Code inside this file can safely assume cuda platform, e.g. importing
pynvml. However, it should not initialize cuda context.
"""

7
import os
8
from datetime import timedelta
9
from functools import cache, wraps
10
from typing import TYPE_CHECKING, Callable, Optional, TypeVar, Union
11

12
import torch
13
14
from torch.distributed import PrefixStore, ProcessGroup
from torch.distributed.distributed_c10d import is_nccl_available
15
from typing_extensions import ParamSpec
16

17
18
# import custom ops, trigger op registration
import vllm._C  # noqa
19
import vllm.envs as envs
20
from vllm.logger import init_logger
21
from vllm.utils import cuda_device_count_stateless, import_pynvml
22

23
from .interface import DeviceCapability, Platform, PlatformEnum
24

25
if TYPE_CHECKING:
26
    from vllm.attention.backends.registry import _Backend
27
    from vllm.config import ModelConfig, VllmConfig
28
29
else:
    _Backend = None
30

31
32
logger = init_logger(__name__)

33
34
35
_P = ParamSpec("_P")
_R = TypeVar("_R")

36
pynvml = import_pynvml()
37

38
39
40
41
# pytorch 2.5 uses cudnn sdpa by default, which will cause crash on some models
# see https://github.com/huggingface/diffusers/issues/9704 for details
torch.backends.cuda.enable_cudnn_sdp(False)

42

43
def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]:
44
    @wraps(fn)
45
    def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
46
47
48
49
50
51
52
53
54
        pynvml.nvmlInit()
        try:
            return fn(*args, **kwargs)
        finally:
            pynvml.nvmlShutdown()

    return wrapper


55
56
class CudaPlatformBase(Platform):
    _enum = PlatformEnum.CUDA
57
    device_name: str = "cuda"
58
59
    device_type: str = "cuda"
    dispatch_key: str = "CUDA"
60
    ray_device_key: str = "GPU"
61
    dist_backend: str = "nccl"
62
    device_control_env_var: str = "CUDA_VISIBLE_DEVICES"
63

64
    @property
65
    def supported_dtypes(self) -> list[torch.dtype]:
66
67
68
        if self.has_device_capability(80):
            # Ampere and Hopper or later NVIDIA GPUs.
            return [torch.bfloat16, torch.float16, torch.float32]
69
        if self.has_device_capability(60):
70
71
72
73
74
75
            # Pascal, Volta and Turing NVIDIA GPUs, BF16 is not supported
            return [torch.float16, torch.float32]
        # Kepler and Maxwell NVIDIA GPUs, only FP32 is supported,
        # though vLLM doesn't support these GPUs.
        return [torch.float32]

76
77
78
79
80
    @classmethod
    def set_device(cls, device: torch.device) -> None:
        """
        Set the device for the current platform.
        """
81
        torch.cuda.set_device(device)
82
83
84
85
86
        # With this trick we can force the device to be set eagerly
        # see https://github.com/pytorch/pytorch/issues/155668
        # for why and when it is needed
        _ = torch.zeros(1, device=device)

87
    @classmethod
88
    def get_device_capability(cls, device_id: int = 0) -> Optional[DeviceCapability]:
89
        raise NotImplementedError
90

91
92
93
    @classmethod
    def get_device_name(cls, device_id: int = 0) -> str:
        raise NotImplementedError
94

95
96
97
    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        raise NotImplementedError
98

99
    @classmethod
100
    def is_fully_connected(cls, device_ids: list[int]) -> bool:
101
        raise NotImplementedError
102

103
104
105
    @classmethod
    def log_warnings(cls):
        pass
106

107
    @classmethod
108
    def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
109
        parallel_config = vllm_config.parallel_config
110
        model_config = vllm_config.model_config
111

112
        if parallel_config.worker_cls == "auto":
113
            parallel_config.worker_cls = "vllm.v1.worker.gpu_worker.Worker"
114

115
116
117
        cache_config = vllm_config.cache_config
        if cache_config and cache_config.block_size is None:
            cache_config.block_size = 16
118

119
        # TODO(lucas): handle this more gracefully
120
        # Note: model_config may be None during testing
121
122
123
124
125
126
127
128
129
        # Note: block_size is initialized in
        # HybridAttentionMambaModelConfig.verify_and_update_config
        # for models with both attention and mamba,
        # and doesn't need to be reinitialized here
        if (
            model_config is not None
            and model_config.use_mla
            and cache_config.block_size is not None
        ):
130
            use_sparse = hasattr(vllm_config.model_config.hf_config, "index_topk")
131
132
133
134
135
136
            # If `VLLM_ATTENTION_BACKEND` is not set and we are using MLA,
            # then we default to FlashMLA backend for non-blackwell GPUs,
            # else we default to CutlassMLA. For each case, we force the
            # required block_size.
            use_flashmla = False
            use_cutlass_mla = False
137
            use_flashinfer_mla = False
138
139
140
141
142
143

            if envs.VLLM_ATTENTION_BACKEND is None:
                # Default case
                if cls.is_device_capability(100):
                    # Blackwell => Force CutlassMLA.
                    use_cutlass_mla = True
144
145
146
                    # TODO: This does not work, because the
                    # global_force_attn_backend_context_manager is not set.
                    # See vllm/attention/selector.py:_cached_get_attn_backend
147
                    envs.VLLM_ATTENTION_BACKEND = "CUTLASS_MLA"
148
149
150
151
152
                else:
                    # Not Blackwell
                    use_flashmla = True
            else:
                # Forced case
153
154
155
                use_flashmla = envs.VLLM_ATTENTION_BACKEND == "FLASHMLA"
                use_cutlass_mla = envs.VLLM_ATTENTION_BACKEND == "CUTLASS_MLA"
                use_flashinfer_mla = envs.VLLM_ATTENTION_BACKEND == "FLASHINFER_MLA"
156

157
            from vllm.attention.ops.flashmla import is_flashmla_dense_supported
158
159
160

            if (
                use_flashmla
161
                and is_flashmla_dense_supported()[0]
162
                and cache_config.block_size % 64 != 0
163
            ):
164
                cache_config.block_size = 64
165
                logger.info("Forcing kv cache block size to 64 for FlashMLA backend.")
166

167
            if use_cutlass_mla and cache_config.block_size % 128 != 0:
168
                cache_config.block_size = 128
169
170
171
                logger.info(
                    "Forcing kv cache block size to 128 for CUTLASS_MLA backend."
                )
172

173
174
175
176
177
            if (
                use_flashinfer_mla
                and cache_config.block_size != 32
                and cache_config.block_size % 64 != 0
            ):
178
179
                cache_config.block_size = 64
                logger.info(
180
181
                    "Forcing kv cache block size to 64 for FlashInferMLA backend."
                )
182

183
184
185
186
            # TODO(Chen): remove this hacky code
            if use_sparse and cache_config.block_size != 64:
                cache_config.block_size = 64
                logger.info(
187
188
                    "Forcing kv cache block size to 64 for FlashMLASparse backend."
                )
189
190
191
        # lazy import to avoid circular import
        from vllm.config import CUDAGraphMode

192
        compilation_config = vllm_config.compilation_config
193
194
195
196
197
        if (
            envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput"
            and parallel_config.data_parallel_size > 1
            and compilation_config.cudagraph_mode != CUDAGraphMode.NONE
        ):
198
199
200
            # TODO: Piecewise Cuda graph might be enabled
            # if torch compile cache key issue fixed
            # See https://github.com/vllm-project/vllm/pull/25093
201
            logger.info(
202
203
204
205
206
                "WideEP: Disabling CUDA Graphs since DeepEP high-throughput "
                "kernels are optimized for prefill and are incompatible with "
                "CUDA Graphs. "
                "In order to use CUDA Graphs for decode-optimized workloads, "
                "set VLLM_ALL2ALL_BACKEND to another option, such as "
207
208
                "deepep_low_latency, pplx, or allgather_reducescatter."
            )
209
            compilation_config.cudagraph_mode = CUDAGraphMode.NONE
210

211
    @classmethod
212
213
214
    def get_current_memory_usage(
        cls, device: Optional[torch.types.Device] = None
    ) -> float:
215
        torch.cuda.empty_cache()
216
217
218
        torch.cuda.reset_peak_memory_stats(device)
        return torch.cuda.max_memory_allocated(device)

219
    @classmethod
220
    def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> "_Backend":
221
        from vllm.attention.backends.registry import _Backend
222
223
224
225
226
227

        # For Blackwell GPUs, force TORCH_SDPA for now.
        # See https://github.com/facebookresearch/xformers/issues/1317#issuecomment-3199392579 # noqa: E501
        if cls.has_device_capability(100):
            return _Backend.TORCH_SDPA

228
229
230
231
        if dtype not in (torch.float16, torch.bfloat16):
            return _Backend.XFORMERS

        if cls.has_device_capability(80):
232
233
234
            FLASH_ATTN_V1 = (
                "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"  # noqa: E501
            )
235
            from vllm.attention.selector import is_attn_backend_supported
236

237
            is_default_fa_supported = is_attn_backend_supported(
238
239
                FLASH_ATTN_V1, head_size, dtype, allow_import_error=False
            )
240
            if is_default_fa_supported:
241
                return _Backend.FLASH_ATTN
242
243
244
245
246
247
            else:
                # Fallback to XFORMERS
                return _Backend.XFORMERS
        else:
            # Fallback for Volta/Turing GPUs or FA not supported
            return _Backend.XFORMERS
248

249
    @classmethod
250
251
252
253
254
255
256
257
258
259
260
261
    def get_attn_backend_cls(
        cls,
        selected_backend,
        head_size,
        dtype,
        kv_cache_dtype,
        block_size,
        use_v1,
        use_mla,
        has_sink,
        use_sparse,
    ) -> str:
262
        from vllm.attention.backends.registry import _Backend
263

264
        if use_mla:
265
266
267
            if not use_v1:
                raise RuntimeError(
                    "MLA attention backends require the V1 engine. "
268
269
                    "Set VLLM_USE_V1=1 to enable them."
                )
270

271
            from vllm.attention.ops.flashmla import is_flashmla_dense_supported
272
273
            from vllm.attention.utils.fa_utils import flash_attn_supports_mla

274
275
            if use_sparse:
                logger.info_once("Using Sparse MLA backend on V1 engine.")
276
277
278
279
                return (
                    "vllm.v1.attention.backends.mla.flashmla_sparse."
                    "FlashMLASparseBackend"
                )
280

281
            use_cutlassmla = selected_backend == _Backend.CUTLASS_MLA or (
282
283
                selected_backend is None
                and cls.is_device_capability(100)
284
                and block_size % 128 == 0
285
            )
286
            use_flashinfermla = selected_backend == _Backend.FLASHINFER_MLA or (
287
288
                selected_backend is None
                and cls.is_device_capability(100)
289
                and (block_size == 32 or block_size % 64 == 0)
290
            )
291
            use_flashmla = selected_backend == _Backend.FLASHMLA or (
292
                selected_backend is None and is_flashmla_dense_supported()[0]
293
            )
294
            use_flashattn = selected_backend == _Backend.FLASH_ATTN_MLA or (
295
296
                selected_backend is None and flash_attn_supports_mla()
            )
297
            use_triton = selected_backend == _Backend.TRITON_MLA or (
298
299
                selected_backend is None
            )
300
301

            if use_cutlassmla:
302
                logger.info_once("Using Cutlass MLA backend on V1 engine.")
303
                return "vllm.v1.attention.backends.mla.cutlass_mla.CutlassMLABackend"
304
            if use_flashinfermla:
305
306
                from vllm.v1.attention.backends.utils import set_kv_cache_layout

307
308
                set_kv_cache_layout("HND")
                logger.info_once("Using FlashInfer MLA backend on V1 engine.")
309
310
311
                return (
                    "vllm.v1.attention.backends.mla.flashinfer_mla.FlashInferMLABackend"
                )
312
            if use_flashmla:
313
                if block_size % 64 != 0:
314
315
316
                    logger.warning(
                        "FlashMLA backend is not supported for block size %d"
                        " (currently only supports block size 64).",
317
318
                        block_size,
                    )
319
                else:
320
                    logger.info_once("Using FlashMLA backend on V1 engine.")
321
                    return "vllm.v1.attention.backends.mla.flashmla.FlashMLABackend"
322
            if use_flashattn:
323
324
325
326
                logger.info_once("Using FlashAttention MLA backend on V1 engine.")
                return (
                    "vllm.v1.attention.backends.mla.flashattn_mla.FlashAttnMLABackend"
                )
327
            if use_triton:
328
                logger.info_once("Using Triton MLA backend on V1 engine.")
329
                return "vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend"
330
        if use_v1:
331
            FLASHINFER_V1 = "vllm.v1.attention.backends.flashinfer.FlashInferBackend"  # noqa: E501
332
333
334
335
336
337
338
339
340
            FLEX_ATTENTION_V1 = (
                "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend"  # noqa: E501
            )
            TRITON_ATTN = (
                "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend"  # noqa: E501
            )
            FLASH_ATTN_V1 = (
                "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"  # noqa: E501
            )
341
            TREE_ATTN_V1 = "vllm.v1.attention.backends.tree_attn.TreeAttentionBackend"  # noqa: E501
342
            XFORMERS_V1 = "vllm.v1.attention.backends.xformers.XFormersAttentionBackend"  # noqa: E501
343

344
345
346
            use_fp8_kv_cache = kv_cache_dtype is not None and kv_cache_dtype.startswith(
                "fp8"
            )
347

348
349
            if selected_backend == _Backend.FLASHINFER:
                logger.info_once("Using FlashInfer backend on V1 engine.")
350
                if cls.has_device_capability(100):
351
352
                    from vllm.v1.attention.backends.utils import set_kv_cache_layout

353
                    set_kv_cache_layout("HND")
354
                return FLASHINFER_V1
355
            elif selected_backend == _Backend.FLEX_ATTENTION:
356
357
                logger.info_once("Using FlexAttention backend on V1 engine.")
                return FLEX_ATTENTION_V1
358
            elif selected_backend == _Backend.TRITON_ATTN:
359
                logger.info_once("Using Triton backend on V1 engine.")
360
                return TRITON_ATTN
361
362
            elif selected_backend == _Backend.FLASH_ATTN:
                logger.info_once("Using Flash Attention backend on V1 engine.")
363
                return FLASH_ATTN_V1
364
365
366
            elif selected_backend == _Backend.TREE_ATTN:
                logger.info_once("Using Tree Attention backend on V1 engine.")
                return TREE_ATTN_V1
367
            elif selected_backend == _Backend.XFORMERS:
368
369
                logger.info_once("Using XFormers backend on V1 engine.")
                return XFORMERS_V1
370

371
            from vllm.attention.selector import is_attn_backend_supported
372
373

            # Default backends for V1 engine
374
            # Prefer FlashInfer for Blackwell GPUs if installed
375
376
            if cls.is_device_capability(100):
                if is_default_backend_supported := is_attn_backend_supported(
377
378
379
                    FLASHINFER_V1, head_size, dtype
                ):
                    from vllm.v1.attention.backends.utils import set_kv_cache_layout
380

381
                    logger.info_once(
382
                        "Using FlashInfer backend with HND KV cache layout on "
383
384
                        "V1 engine by default for Blackwell (SM 10.0) GPUs."
                    )
385
                    set_kv_cache_layout("HND")
386

387
                    return FLASHINFER_V1
388
389
390

                if not is_default_backend_supported.can_import:
                    logger.warning_once(
391
392
                        "FlashInfer failed to import for V1 engine on "
                        "Blackwell (SM 10.0) GPUs; it is recommended to "
393
394
                        "install FlashInfer for better performance."
                    )
395

396
            # FlashAttention is the default for SM 8.0+ GPUs
397
            if cls.has_device_capability(80):
398
                if (has_sink or use_fp8_kv_cache) and not cls.is_device_capability(90):
399
                    logger.info_once("Using Triton backend on V1 engine.")
400
                    return TRITON_ATTN
401
                elif is_default_backend_supported := is_attn_backend_supported(
402
403
404
                    FLASH_ATTN_V1, head_size, dtype, allow_import_error=False
                ):
                    logger.info_once("Using Flash Attention backend on V1 engine.")
405
406
407
408
409
410
411
412
413
414
415
416
417
418
                    return FLASH_ATTN_V1

            # FlexAttention is the default for older GPUs
            else:
                logger.info_once("Using FlexAttention backend on V1 engine.")
                return FLEX_ATTENTION_V1

            assert not is_default_backend_supported

            use_flex_attention_reason = {}
            if not is_default_backend_supported.head_size:
                use_flex_attention_reason["head_size"] = head_size
            if not is_default_backend_supported.dtype:
                use_flex_attention_reason["dtype"] = dtype
419

420
421
            logger.info_once(
                "Using FlexAttention backend for %s on V1 engine.",
422
                ", ".join(f"{k}={v}" for k, v in use_flex_attention_reason.items()),
423
            )
424
            return FLEX_ATTENTION_V1
425

426
427
        raise RuntimeError(
            "V0 attention backends have been removed. Set VLLM_USE_V1=1 "
428
429
            "to select a supported backend."
        )
430

431
432
433
434
    @classmethod
    def get_punica_wrapper(cls) -> str:
        return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU"

435
436
    @classmethod
    def get_device_communicator_cls(cls) -> str:
437
438
439
        return (
            "vllm.distributed.device_communicators.cuda_communicator.CudaCommunicator"  # noqa
        )
440

441
442
443
444
    @classmethod
    def supports_fp8(cls) -> bool:
        return cls.has_device_capability(89)

445
446
447
448
    @classmethod
    def use_custom_allreduce(cls) -> bool:
        return True

449
450
451
452
    @classmethod
    def opaque_attention_op(cls) -> bool:
        return True

453
    @classmethod
454
455
    def get_static_graph_wrapper_cls(cls) -> str:
        return "vllm.compilation.cuda_graph.CUDAGraphWrapper"
456

457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
    @classmethod
    def stateless_init_device_torch_dist_pg(
        cls,
        backend: str,
        prefix_store: PrefixStore,
        group_rank: int,
        group_size: int,
        timeout: timedelta,
    ) -> ProcessGroup:
        assert is_nccl_available()
        pg: ProcessGroup = ProcessGroup(
            prefix_store,
            group_rank,
            group_size,
        )
        from torch.distributed.distributed_c10d import ProcessGroupNCCL

        backend_options = ProcessGroupNCCL.Options()
        backend_options._timeout = timeout

477
478
479
        backend_class = ProcessGroupNCCL(
            prefix_store, group_rank, group_size, backend_options
        )
480
481
482
483
484
485
486
487
        backend_type = ProcessGroup.BackendType.NCCL
        device = torch.device("cuda")
        pg._set_default_backend(backend_type)
        backend_class._set_sequence_number_for_group()

        pg._register_backend(device, backend_type, backend_class)
        return pg

488
489
490
491
    @classmethod
    def device_count(cls) -> int:
        return cuda_device_count_stateless()

492
    @classmethod
493
494
495
    def is_kv_cache_dtype_supported(
        cls, kv_cache_dtype: str, model_config: "ModelConfig"
    ) -> bool:
496
        fp8_attention = kv_cache_dtype.startswith("fp8")
497
498
        attention_backend = envs.VLLM_ATTENTION_BACKEND

499
        supported = False
500
501
502
503
504
505
506
507
508
        if model_config is not None and model_config.use_mla:
            # Default to CutlassMLA for blackwell,
            # FlashMLA otherwise
            if attention_backend is None:
                if cls.is_device_capability(100):
                    attention_backend = "CUTLASS_MLA"
                else:
                    attention_backend = "FLASHMLA"

509
            # Only FlashMLA and CUTLASS_MLA support fp8
510
            if attention_backend in ["FLASHMLA", "CUTLASS_MLA", "FLASHINFER_MLA"]:
511
512
                supported = True
            else:
513
                supported = not fp8_attention
514
515
516
        else:
            # Default to FlashAttention
            if attention_backend is None:
517
                attention_backend = "FLASH_ATTN"
518
519
520
521

            # All Blackwell backends support fp8
            if cls.is_device_capability(100):
                supported = True
522
            elif attention_backend == "FLASH_ATTN":
523
                if fp8_attention:
524
525
                    from vllm.attention.utils.fa_utils import flash_attn_supports_fp8

526
527
528
                    supported = flash_attn_supports_fp8()
                else:
                    supported = True
529
530
            elif attention_backend == "FLASHINFER":
                supported = True
531
            elif attention_backend == "TRITON_ATTN":
532
                supported = cls.supports_fp8()
533
534
        return supported

535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
    @classmethod
    def check_if_supports_dtype(cls, torch_dtype: torch.dtype):
        if torch_dtype == torch.bfloat16:  # noqa: SIM102
            if not cls.has_device_capability(80):
                capability = cls.get_device_capability()
                gpu_name = cls.get_device_name()

                if capability is None:
                    compute_str = "does not have a compute capability"
                else:
                    version_str = capability.as_version_str()
                    compute_str = f"has compute capability {version_str}"

                raise ValueError(
                    "Bfloat16 is only supported on GPUs "
                    "with compute capability of at least 8.0. "
                    f"Your {gpu_name} GPU {compute_str}. "
                    "You can use float16 instead by explicitly setting the "
553
554
                    "`dtype` flag in CLI, for example: --dtype=half."
                )
555

556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
    @classmethod
    def insert_blocks_to_device(
        cls,
        src_cache: torch.Tensor,
        dst_cache: torch.Tensor,
        src_block_indices: torch.Tensor,
        dst_block_indices: torch.Tensor,
    ) -> None:
        """Copy blocks from src_cache to dst_cache on GPU."""
        _src_cache = src_cache[:, src_block_indices]
        dst_cache[:, dst_block_indices] = _src_cache.to(dst_cache.device)

    @classmethod
    def swap_out_blocks_to_host(
        cls,
        src_cache: torch.Tensor,
        dst_cache: torch.Tensor,
        src_block_indices: torch.Tensor,
        dst_block_indices: torch.Tensor,
    ) -> None:
        """Copy blocks from GPU to host (CPU)."""
        _src_cache = src_cache[:, src_block_indices]
        dst_cache[:, dst_block_indices] = _src_cache.cpu()

580
581
582
583
    @classmethod
    def support_hybrid_kv_cache(cls) -> bool:
        return True

584
585
586
587
    @classmethod
    def support_static_graph_mode(cls) -> bool:
        return True

588

589
590
591
592
593
# NVML utils
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
# all the related functions work on real physical device ids.
# the major benefit of using NVML is that it will not initialize CUDA
class NvmlCudaPlatform(CudaPlatformBase):
594
    @classmethod
595
    @cache
596
    @with_nvml_context
597
    def get_device_capability(cls, device_id: int = 0) -> Optional[DeviceCapability]:
598
        try:
599
            physical_device_id = cls.device_id_to_physical_device_id(device_id)
600
601
602
603
604
605
606
607
608
609
            handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
            major, minor = pynvml.nvmlDeviceGetCudaComputeCapability(handle)
            return DeviceCapability(major=major, minor=minor)
        except RuntimeError:
            return None

    @classmethod
    @with_nvml_context
    def has_device_capability(
        cls,
610
        capability: Union[tuple[int, int], int],
611
612
613
614
615
616
        device_id: int = 0,
    ) -> bool:
        try:
            return super().has_device_capability(capability, device_id)
        except RuntimeError:
            return False
617

618
    @classmethod
619
    @with_nvml_context
620
    def get_device_name(cls, device_id: int = 0) -> str:
621
        physical_device_id = cls.device_id_to_physical_device_id(device_id)
622
        return cls._get_physical_device_name(physical_device_id)
623

624
625
626
    @classmethod
    @with_nvml_context
    def get_device_uuid(cls, device_id: int = 0) -> str:
627
        physical_device_id = cls.device_id_to_physical_device_id(device_id)
628
629
630
        handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
        return pynvml.nvmlDeviceGetUUID(handle)

631
    @classmethod
632
    @with_nvml_context
633
    def get_device_total_memory(cls, device_id: int = 0) -> int:
634
        physical_device_id = cls.device_id_to_physical_device_id(device_id)
635
636
        handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
        return int(pynvml.nvmlDeviceGetMemoryInfo(handle).total)
637

638
    @classmethod
639
    @with_nvml_context
640
    def is_fully_connected(cls, physical_device_ids: list[int]) -> bool:
641
642
643
        """
        query if the set of gpus are fully connected by nvlink (1 hop)
        """
644
        handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in physical_device_ids]
645
646
647
648
649
        for i, handle in enumerate(handles):
            for j, peer_handle in enumerate(handles):
                if i < j:
                    try:
                        p2p_status = pynvml.nvmlDeviceGetP2PStatus(
650
651
652
653
                            handle,
                            peer_handle,
                            pynvml.NVML_P2P_CAPS_INDEX_NVLINK,
                        )
654
655
                        if p2p_status != pynvml.NVML_P2P_STATUS_OK:
                            return False
656
657
                    except pynvml.NVMLError:
                        logger.exception(
658
                            "NVLink detection failed. This is normal if"
659
660
                            " your machine has no NVLink equipped."
                        )
661
662
                        return False
        return True
663
664

    @classmethod
665
666
667
668
669
670
671
672
673
    def _get_physical_device_name(cls, device_id: int = 0) -> str:
        handle = pynvml.nvmlDeviceGetHandleByIndex(device_id)
        return pynvml.nvmlDeviceGetName(handle)

    @classmethod
    @with_nvml_context
    def log_warnings(cls):
        device_ids: int = pynvml.nvmlDeviceGetCount()
        if device_ids > 1:
674
675
676
677
678
            device_names = [cls._get_physical_device_name(i) for i in range(device_ids)]
            if (
                len(set(device_names)) > 1
                and os.environ.get("CUDA_DEVICE_ORDER") != "PCI_BUS_ID"
            ):
679
                logger.warning(
680
                    "Detected different devices in the system: %s. Please"
681
682
                    " make sure to set `CUDA_DEVICE_ORDER=PCI_BUS_ID` to "
                    "avoid unexpected behavior.",
683
                    ", ".join(device_names),
684
685
686
687
688
                )


class NonNvmlCudaPlatform(CudaPlatformBase):
    @classmethod
689
    @cache
690
691
692
693
694
695
696
697
698
699
700
701
702
703
    def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:
        major, minor = torch.cuda.get_device_capability(device_id)
        return DeviceCapability(major=major, minor=minor)

    @classmethod
    def get_device_name(cls, device_id: int = 0) -> str:
        return torch.cuda.get_device_name(device_id)

    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        device_props = torch.cuda.get_device_properties(device_id)
        return device_props.total_memory

    @classmethod
704
    def is_fully_connected(cls, physical_device_ids: list[int]) -> bool:
705
706
        logger.exception(
            "NVLink detection not possible, as context support was"
707
708
            " not found. Assuming no NVLink available."
        )
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
        return False


# Autodetect either NVML-enabled or non-NVML platform
# based on whether NVML is available.
nvml_available = False
try:
    try:
        pynvml.nvmlInit()
        nvml_available = True
    except Exception:
        # On Jetson, NVML is not supported.
        nvml_available = False
finally:
    if nvml_available:
        pynvml.nvmlShutdown()

CudaPlatform = NvmlCudaPlatform if nvml_available else NonNvmlCudaPlatform

728
CudaPlatform.log_warnings()