cuda.py 26.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
"""Code inside this file can safely assume cuda platform, e.g. importing
pynvml. However, it should not initialize cuda context.
"""

7
import os
8
from datetime import timedelta
9
from functools import cache, wraps
10
from typing import TYPE_CHECKING, Callable, Optional, TypeVar, Union
11

12
import torch
13
14
from torch.distributed import PrefixStore, ProcessGroup
from torch.distributed.distributed_c10d import is_nccl_available
15
from typing_extensions import ParamSpec
16

17
18
# import custom ops, trigger op registration
import vllm._C  # noqa
19
import vllm.envs as envs
20
from vllm.logger import init_logger
21
from vllm.utils import cuda_device_count_stateless, import_pynvml
22

23
from .interface import DeviceCapability, Platform, PlatformEnum
24

25
if TYPE_CHECKING:
26
    from vllm.attention.backends.registry import _Backend
27
    from vllm.config import ModelConfig, VllmConfig
28
29
else:
    _Backend = None
30

31
32
logger = init_logger(__name__)

33
34
35
_P = ParamSpec("_P")
_R = TypeVar("_R")

36
pynvml = import_pynvml()
37

38
39
40
41
# pytorch 2.5 uses cudnn sdpa by default, which will cause crash on some models
# see https://github.com/huggingface/diffusers/issues/9704 for details
torch.backends.cuda.enable_cudnn_sdp(False)

42

43
def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]:
44
    @wraps(fn)
45
    def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
46
47
48
49
50
51
52
53
54
        pynvml.nvmlInit()
        try:
            return fn(*args, **kwargs)
        finally:
            pynvml.nvmlShutdown()

    return wrapper


55
56
class CudaPlatformBase(Platform):
    _enum = PlatformEnum.CUDA
57
    device_name: str = "cuda"
58
59
    device_type: str = "cuda"
    dispatch_key: str = "CUDA"
60
    ray_device_key: str = "GPU"
61
    dist_backend: str = "nccl"
62
    device_control_env_var: str = "CUDA_VISIBLE_DEVICES"
63

64
    @property
65
    def supported_dtypes(self) -> list[torch.dtype]:
66
67
68
        if self.has_device_capability(80):
            # Ampere and Hopper or later NVIDIA GPUs.
            return [torch.bfloat16, torch.float16, torch.float32]
69
        if self.has_device_capability(60):
70
71
72
73
74
75
            # Pascal, Volta and Turing NVIDIA GPUs, BF16 is not supported
            return [torch.float16, torch.float32]
        # Kepler and Maxwell NVIDIA GPUs, only FP32 is supported,
        # though vLLM doesn't support these GPUs.
        return [torch.float32]

76
77
78
79
80
    @classmethod
    def set_device(cls, device: torch.device) -> None:
        """
        Set the device for the current platform.
        """
81
        torch.cuda.set_device(device)
82
83
84
85
86
        # With this trick we can force the device to be set eagerly
        # see https://github.com/pytorch/pytorch/issues/155668
        # for why and when it is needed
        _ = torch.zeros(1, device=device)

87
    @classmethod
88
    def get_device_capability(cls, device_id: int = 0) -> Optional[DeviceCapability]:
89
        raise NotImplementedError
90

91
92
93
    @classmethod
    def get_device_name(cls, device_id: int = 0) -> str:
        raise NotImplementedError
94

95
96
97
    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        raise NotImplementedError
98

99
    @classmethod
100
    def is_fully_connected(cls, device_ids: list[int]) -> bool:
101
        raise NotImplementedError
102

103
104
105
    @classmethod
    def log_warnings(cls):
        pass
106

107
    @classmethod
108
    def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
109
        parallel_config = vllm_config.parallel_config
110
        model_config = vllm_config.model_config
111

112
        if parallel_config.worker_cls == "auto":
113
            parallel_config.worker_cls = "vllm.v1.worker.gpu_worker.Worker"
114

115
116
117
        cache_config = vllm_config.cache_config
        if cache_config and cache_config.block_size is None:
            cache_config.block_size = 16
118

119
        # TODO(lucas): handle this more gracefully
120
121
        # Note: model_config may be None during testing
        if model_config is not None and model_config.use_mla:
122
            use_sparse = hasattr(vllm_config.model_config.hf_config, "index_topk")
123
124
125
126
127
128
            # If `VLLM_ATTENTION_BACKEND` is not set and we are using MLA,
            # then we default to FlashMLA backend for non-blackwell GPUs,
            # else we default to CutlassMLA. For each case, we force the
            # required block_size.
            use_flashmla = False
            use_cutlass_mla = False
129
            use_flashinfer_mla = False
130
131
132
133
134
135

            if envs.VLLM_ATTENTION_BACKEND is None:
                # Default case
                if cls.is_device_capability(100):
                    # Blackwell => Force CutlassMLA.
                    use_cutlass_mla = True
136
137
138
                    # TODO: This does not work, because the
                    # global_force_attn_backend_context_manager is not set.
                    # See vllm/attention/selector.py:_cached_get_attn_backend
139
                    envs.VLLM_ATTENTION_BACKEND = "CUTLASS_MLA"
140
141
142
143
144
                else:
                    # Not Blackwell
                    use_flashmla = True
            else:
                # Forced case
145
146
147
                use_flashmla = envs.VLLM_ATTENTION_BACKEND == "FLASHMLA"
                use_cutlass_mla = envs.VLLM_ATTENTION_BACKEND == "CUTLASS_MLA"
                use_flashinfer_mla = envs.VLLM_ATTENTION_BACKEND == "FLASHINFER_MLA"
148

149
            from vllm.attention.ops.flashmla import is_flashmla_supported
150
151
152
153
154
155

            if (
                use_flashmla
                and is_flashmla_supported()[0]
                and cache_config.block_size != 64
            ):
156
                cache_config.block_size = 64
157
                logger.info("Forcing kv cache block size to 64 for FlashMLA backend.")
158

159
160
            if use_cutlass_mla and cache_config.block_size != 128:
                cache_config.block_size = 128
161
162
163
                logger.info(
                    "Forcing kv cache block size to 128 for CUTLASS_MLA backend."
                )
164

165
166
167
            if use_flashinfer_mla and cache_config.block_size not in [32, 64]:
                cache_config.block_size = 64
                logger.info(
168
169
                    "Forcing kv cache block size to 64 for FlashInferMLA backend."
                )
170

171
172
173
174
            # TODO(Chen): remove this hacky code
            if use_sparse and cache_config.block_size != 64:
                cache_config.block_size = 64
                logger.info(
175
176
                    "Forcing kv cache block size to 64 for FlashMLASparse backend."
                )
177
178
179
        # lazy import to avoid circular import
        from vllm.config import CUDAGraphMode

180
        compilation_config = vllm_config.compilation_config
181
182
183
184
185
        if (
            envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput"
            and parallel_config.data_parallel_size > 1
            and compilation_config.cudagraph_mode != CUDAGraphMode.NONE
        ):
186
187
188
            # TODO: Piecewise Cuda graph might be enabled
            # if torch compile cache key issue fixed
            # See https://github.com/vllm-project/vllm/pull/25093
189
            logger.info(
190
191
192
193
194
                "WideEP: Disabling CUDA Graphs since DeepEP high-throughput "
                "kernels are optimized for prefill and are incompatible with "
                "CUDA Graphs. "
                "In order to use CUDA Graphs for decode-optimized workloads, "
                "set VLLM_ALL2ALL_BACKEND to another option, such as "
195
196
                "deepep_low_latency, pplx, or allgather_reducescatter."
            )
197
            compilation_config.cudagraph_mode = CUDAGraphMode.NONE
198

199
    @classmethod
200
201
202
    def get_current_memory_usage(
        cls, device: Optional[torch.types.Device] = None
    ) -> float:
203
        torch.cuda.empty_cache()
204
205
206
        torch.cuda.reset_peak_memory_stats(device)
        return torch.cuda.max_memory_allocated(device)

207
    @classmethod
208
    def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> "_Backend":
209
        from vllm.attention.backends.registry import _Backend
210
211
212
213
214
215

        # For Blackwell GPUs, force TORCH_SDPA for now.
        # See https://github.com/facebookresearch/xformers/issues/1317#issuecomment-3199392579 # noqa: E501
        if cls.has_device_capability(100):
            return _Backend.TORCH_SDPA

216
217
218
219
        if dtype not in (torch.float16, torch.bfloat16):
            return _Backend.XFORMERS

        if cls.has_device_capability(80):
220
221
222
            FLASH_ATTN_V1 = (
                "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"  # noqa: E501
            )
223
            from vllm.attention.selector import is_attn_backend_supported
224

225
            is_default_fa_supported = is_attn_backend_supported(
226
227
                FLASH_ATTN_V1, head_size, dtype, allow_import_error=False
            )
228
            if is_default_fa_supported:
229
                return _Backend.FLASH_ATTN
230
231
232
233
234
235
            else:
                # Fallback to XFORMERS
                return _Backend.XFORMERS
        else:
            # Fallback for Volta/Turing GPUs or FA not supported
            return _Backend.XFORMERS
236

237
    @classmethod
238
239
240
241
242
243
244
245
246
247
248
249
    def get_attn_backend_cls(
        cls,
        selected_backend,
        head_size,
        dtype,
        kv_cache_dtype,
        block_size,
        use_v1,
        use_mla,
        has_sink,
        use_sparse,
    ) -> str:
250
        from vllm.attention.backends.registry import _Backend
251

252
        if use_mla:
253
254
255
            if not use_v1:
                raise RuntimeError(
                    "MLA attention backends require the V1 engine. "
256
257
                    "Set VLLM_USE_V1=1 to enable them."
                )
258
259
260
261

            from vllm.attention.ops.flashmla import is_flashmla_supported
            from vllm.attention.utils.fa_utils import flash_attn_supports_mla

262
263
            if use_sparse:
                logger.info_once("Using Sparse MLA backend on V1 engine.")
264
265
266
267
                return (
                    "vllm.v1.attention.backends.mla.flashmla_sparse."
                    "FlashMLASparseBackend"
                )
268

269
            use_cutlassmla = selected_backend == _Backend.CUTLASS_MLA or (
270
271
272
273
                selected_backend is None
                and cls.is_device_capability(100)
                and block_size == 128
            )
274
            use_flashinfermla = selected_backend == _Backend.FLASHINFER_MLA or (
275
276
277
278
                selected_backend is None
                and cls.is_device_capability(100)
                and block_size in [32, 64]
            )
279
            use_flashmla = selected_backend == _Backend.FLASHMLA or (
280
281
                selected_backend is None and is_flashmla_supported()[0]
            )
282
            use_flashattn = selected_backend == _Backend.FLASH_ATTN_MLA or (
283
284
                selected_backend is None and flash_attn_supports_mla()
            )
285
            use_triton = selected_backend == _Backend.TRITON_MLA or (
286
287
                selected_backend is None
            )
288
289

            if use_cutlassmla:
290
                logger.info_once("Using Cutlass MLA backend on V1 engine.")
291
                return "vllm.v1.attention.backends.mla.cutlass_mla.CutlassMLABackend"
292
            if use_flashinfermla:
293
294
                from vllm.v1.attention.backends.utils import set_kv_cache_layout

295
296
                set_kv_cache_layout("HND")
                logger.info_once("Using FlashInfer MLA backend on V1 engine.")
297
298
299
                return (
                    "vllm.v1.attention.backends.mla.flashinfer_mla.FlashInferMLABackend"
                )
300
301
            if use_flashmla:
                if block_size != 64:
302
303
304
                    logger.warning(
                        "FlashMLA backend is not supported for block size %d"
                        " (currently only supports block size 64).",
305
306
                        block_size,
                    )
307
                else:
308
                    logger.info_once("Using FlashMLA backend on V1 engine.")
309
                    return "vllm.v1.attention.backends.mla.flashmla.FlashMLABackend"
310
            if use_flashattn:
311
312
313
314
                logger.info_once("Using FlashAttention MLA backend on V1 engine.")
                return (
                    "vllm.v1.attention.backends.mla.flashattn_mla.FlashAttnMLABackend"
                )
315
            if use_triton:
316
                logger.info_once("Using Triton MLA backend on V1 engine.")
317
                return "vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend"
318
        if use_v1:
319
            FLASHINFER_V1 = "vllm.v1.attention.backends.flashinfer.FlashInferBackend"  # noqa: E501
320
321
322
323
324
325
326
327
328
            FLEX_ATTENTION_V1 = (
                "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend"  # noqa: E501
            )
            TRITON_ATTN = (
                "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend"  # noqa: E501
            )
            FLASH_ATTN_V1 = (
                "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"  # noqa: E501
            )
329
            TREE_ATTN_V1 = "vllm.v1.attention.backends.tree_attn.TreeAttentionBackend"  # noqa: E501
330
            XFORMERS_V1 = "vllm.v1.attention.backends.xformers.XFormersAttentionBackend"  # noqa: E501
331

332
333
334
            use_fp8_kv_cache = kv_cache_dtype is not None and kv_cache_dtype.startswith(
                "fp8"
            )
335

336
337
            if selected_backend == _Backend.FLASHINFER:
                logger.info_once("Using FlashInfer backend on V1 engine.")
338
                if cls.has_device_capability(100):
339
340
                    from vllm.v1.attention.backends.utils import set_kv_cache_layout

341
                    set_kv_cache_layout("HND")
342
                return FLASHINFER_V1
343
            elif selected_backend == _Backend.FLEX_ATTENTION:
344
345
                logger.info_once("Using FlexAttention backend on V1 engine.")
                return FLEX_ATTENTION_V1
346
            elif selected_backend == _Backend.TRITON_ATTN:
347
                logger.info_once("Using Triton backend on V1 engine.")
348
                return TRITON_ATTN
349
350
            elif selected_backend == _Backend.FLASH_ATTN:
                logger.info_once("Using Flash Attention backend on V1 engine.")
351
                return FLASH_ATTN_V1
352
353
354
            elif selected_backend == _Backend.TREE_ATTN:
                logger.info_once("Using Tree Attention backend on V1 engine.")
                return TREE_ATTN_V1
355
            elif selected_backend == _Backend.XFORMERS:
356
357
                logger.info_once("Using XFormers backend on V1 engine.")
                return XFORMERS_V1
358

359
            from vllm.attention.selector import is_attn_backend_supported
360
361

            # Default backends for V1 engine
362
            # Prefer FlashInfer for Blackwell GPUs if installed
363
364
            if cls.is_device_capability(100):
                if is_default_backend_supported := is_attn_backend_supported(
365
366
367
                    FLASHINFER_V1, head_size, dtype
                ):
                    from vllm.v1.attention.backends.utils import set_kv_cache_layout
368

369
                    logger.info_once(
370
                        "Using FlashInfer backend with HND KV cache layout on "
371
372
                        "V1 engine by default for Blackwell (SM 10.0) GPUs."
                    )
373
                    set_kv_cache_layout("HND")
374

375
                    return FLASHINFER_V1
376
377
378

                if not is_default_backend_supported.can_import:
                    logger.warning_once(
379
380
                        "FlashInfer failed to import for V1 engine on "
                        "Blackwell (SM 10.0) GPUs; it is recommended to "
381
382
                        "install FlashInfer for better performance."
                    )
383

384
            # FlashAttention is the default for SM 8.0+ GPUs
385
            if cls.has_device_capability(80):
386
                if (has_sink or use_fp8_kv_cache) and not cls.is_device_capability(90):
387
                    logger.info_once("Using Triton backend on V1 engine.")
388
                    return TRITON_ATTN
389
                elif is_default_backend_supported := is_attn_backend_supported(
390
391
392
                    FLASH_ATTN_V1, head_size, dtype, allow_import_error=False
                ):
                    logger.info_once("Using Flash Attention backend on V1 engine.")
393
394
395
396
397
398
399
400
401
402
403
404
405
406
                    return FLASH_ATTN_V1

            # FlexAttention is the default for older GPUs
            else:
                logger.info_once("Using FlexAttention backend on V1 engine.")
                return FLEX_ATTENTION_V1

            assert not is_default_backend_supported

            use_flex_attention_reason = {}
            if not is_default_backend_supported.head_size:
                use_flex_attention_reason["head_size"] = head_size
            if not is_default_backend_supported.dtype:
                use_flex_attention_reason["dtype"] = dtype
407

408
409
            logger.info_once(
                "Using FlexAttention backend for %s on V1 engine.",
410
                ", ".join(f"{k}={v}" for k, v in use_flex_attention_reason.items()),
411
            )
412
            return FLEX_ATTENTION_V1
413

414
415
        raise RuntimeError(
            "V0 attention backends have been removed. Set VLLM_USE_V1=1 "
416
417
            "to select a supported backend."
        )
418

419
420
421
422
    @classmethod
    def get_punica_wrapper(cls) -> str:
        return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU"

423
424
    @classmethod
    def get_device_communicator_cls(cls) -> str:
425
426
427
        return (
            "vllm.distributed.device_communicators.cuda_communicator.CudaCommunicator"  # noqa
        )
428

429
430
431
432
    @classmethod
    def supports_fp8(cls) -> bool:
        return cls.has_device_capability(89)

433
434
435
436
    @classmethod
    def use_custom_allreduce(cls) -> bool:
        return True

437
438
439
440
    @classmethod
    def opaque_attention_op(cls) -> bool:
        return True

441
    @classmethod
442
443
    def get_static_graph_wrapper_cls(cls) -> str:
        return "vllm.compilation.cuda_graph.CUDAGraphWrapper"
444

445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
    @classmethod
    def stateless_init_device_torch_dist_pg(
        cls,
        backend: str,
        prefix_store: PrefixStore,
        group_rank: int,
        group_size: int,
        timeout: timedelta,
    ) -> ProcessGroup:
        assert is_nccl_available()
        pg: ProcessGroup = ProcessGroup(
            prefix_store,
            group_rank,
            group_size,
        )
        from torch.distributed.distributed_c10d import ProcessGroupNCCL

        backend_options = ProcessGroupNCCL.Options()
        backend_options._timeout = timeout

465
466
467
        backend_class = ProcessGroupNCCL(
            prefix_store, group_rank, group_size, backend_options
        )
468
469
470
471
472
473
474
475
        backend_type = ProcessGroup.BackendType.NCCL
        device = torch.device("cuda")
        pg._set_default_backend(backend_type)
        backend_class._set_sequence_number_for_group()

        pg._register_backend(device, backend_type, backend_class)
        return pg

476
477
478
479
    @classmethod
    def device_count(cls) -> int:
        return cuda_device_count_stateless()

480
    @classmethod
481
482
483
    def is_kv_cache_dtype_supported(
        cls, kv_cache_dtype: str, model_config: "ModelConfig"
    ) -> bool:
484
        fp8_attention = kv_cache_dtype.startswith("fp8")
485
486
        attention_backend = envs.VLLM_ATTENTION_BACKEND

487
        supported = False
488
489
490
491
492
493
494
495
496
        if model_config is not None and model_config.use_mla:
            # Default to CutlassMLA for blackwell,
            # FlashMLA otherwise
            if attention_backend is None:
                if cls.is_device_capability(100):
                    attention_backend = "CUTLASS_MLA"
                else:
                    attention_backend = "FLASHMLA"

497
            # Only FlashMLA and CUTLASS_MLA support fp8
498
            if attention_backend in ["FLASHMLA", "CUTLASS_MLA", "FLASHINFER_MLA"]:
499
500
                supported = True
            else:
501
                supported = not fp8_attention
502
503
504
        else:
            # Default to FlashAttention
            if attention_backend is None:
505
                attention_backend = "FLASH_ATTN"
506
507
508
509

            # All Blackwell backends support fp8
            if cls.is_device_capability(100):
                supported = True
510
            elif attention_backend == "FLASH_ATTN":
511
                if fp8_attention:
512
513
                    from vllm.attention.utils.fa_utils import flash_attn_supports_fp8

514
515
516
                    supported = flash_attn_supports_fp8()
                else:
                    supported = True
517
518
            elif attention_backend == "FLASHINFER":
                supported = True
519
            elif attention_backend == "TRITON_ATTN":
520
                supported = cls.supports_fp8()
521
522
        return supported

523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
    @classmethod
    def check_if_supports_dtype(cls, torch_dtype: torch.dtype):
        if torch_dtype == torch.bfloat16:  # noqa: SIM102
            if not cls.has_device_capability(80):
                capability = cls.get_device_capability()
                gpu_name = cls.get_device_name()

                if capability is None:
                    compute_str = "does not have a compute capability"
                else:
                    version_str = capability.as_version_str()
                    compute_str = f"has compute capability {version_str}"

                raise ValueError(
                    "Bfloat16 is only supported on GPUs "
                    "with compute capability of at least 8.0. "
                    f"Your {gpu_name} GPU {compute_str}. "
                    "You can use float16 instead by explicitly setting the "
541
542
                    "`dtype` flag in CLI, for example: --dtype=half."
                )
543

544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
    @classmethod
    def insert_blocks_to_device(
        cls,
        src_cache: torch.Tensor,
        dst_cache: torch.Tensor,
        src_block_indices: torch.Tensor,
        dst_block_indices: torch.Tensor,
    ) -> None:
        """Copy blocks from src_cache to dst_cache on GPU."""
        _src_cache = src_cache[:, src_block_indices]
        dst_cache[:, dst_block_indices] = _src_cache.to(dst_cache.device)

    @classmethod
    def swap_out_blocks_to_host(
        cls,
        src_cache: torch.Tensor,
        dst_cache: torch.Tensor,
        src_block_indices: torch.Tensor,
        dst_block_indices: torch.Tensor,
    ) -> None:
        """Copy blocks from GPU to host (CPU)."""
        _src_cache = src_cache[:, src_block_indices]
        dst_cache[:, dst_block_indices] = _src_cache.cpu()

568
569
570
571
    @classmethod
    def support_hybrid_kv_cache(cls) -> bool:
        return True

572
573
574
575
    @classmethod
    def support_static_graph_mode(cls) -> bool:
        return True

576

577
578
579
580
581
# NVML utils
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
# all the related functions work on real physical device ids.
# the major benefit of using NVML is that it will not initialize CUDA
class NvmlCudaPlatform(CudaPlatformBase):
582
    @classmethod
583
    @cache
584
    @with_nvml_context
585
    def get_device_capability(cls, device_id: int = 0) -> Optional[DeviceCapability]:
586
        try:
587
            physical_device_id = cls.device_id_to_physical_device_id(device_id)
588
589
590
591
592
593
594
595
596
597
            handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
            major, minor = pynvml.nvmlDeviceGetCudaComputeCapability(handle)
            return DeviceCapability(major=major, minor=minor)
        except RuntimeError:
            return None

    @classmethod
    @with_nvml_context
    def has_device_capability(
        cls,
598
        capability: Union[tuple[int, int], int],
599
600
601
602
603
604
        device_id: int = 0,
    ) -> bool:
        try:
            return super().has_device_capability(capability, device_id)
        except RuntimeError:
            return False
605

606
    @classmethod
607
    @with_nvml_context
608
    def get_device_name(cls, device_id: int = 0) -> str:
609
        physical_device_id = cls.device_id_to_physical_device_id(device_id)
610
        return cls._get_physical_device_name(physical_device_id)
611

612
613
614
    @classmethod
    @with_nvml_context
    def get_device_uuid(cls, device_id: int = 0) -> str:
615
        physical_device_id = cls.device_id_to_physical_device_id(device_id)
616
617
618
        handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
        return pynvml.nvmlDeviceGetUUID(handle)

619
    @classmethod
620
    @with_nvml_context
621
    def get_device_total_memory(cls, device_id: int = 0) -> int:
622
        physical_device_id = cls.device_id_to_physical_device_id(device_id)
623
624
        handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
        return int(pynvml.nvmlDeviceGetMemoryInfo(handle).total)
625

626
    @classmethod
627
    @with_nvml_context
628
    def is_fully_connected(cls, physical_device_ids: list[int]) -> bool:
629
630
631
        """
        query if the set of gpus are fully connected by nvlink (1 hop)
        """
632
        handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in physical_device_ids]
633
634
635
636
637
        for i, handle in enumerate(handles):
            for j, peer_handle in enumerate(handles):
                if i < j:
                    try:
                        p2p_status = pynvml.nvmlDeviceGetP2PStatus(
638
639
640
641
                            handle,
                            peer_handle,
                            pynvml.NVML_P2P_CAPS_INDEX_NVLINK,
                        )
642
643
                        if p2p_status != pynvml.NVML_P2P_STATUS_OK:
                            return False
644
645
                    except pynvml.NVMLError:
                        logger.exception(
646
                            "NVLink detection failed. This is normal if"
647
648
                            " your machine has no NVLink equipped."
                        )
649
650
                        return False
        return True
651
652

    @classmethod
653
654
655
656
657
658
659
660
661
    def _get_physical_device_name(cls, device_id: int = 0) -> str:
        handle = pynvml.nvmlDeviceGetHandleByIndex(device_id)
        return pynvml.nvmlDeviceGetName(handle)

    @classmethod
    @with_nvml_context
    def log_warnings(cls):
        device_ids: int = pynvml.nvmlDeviceGetCount()
        if device_ids > 1:
662
663
664
665
666
            device_names = [cls._get_physical_device_name(i) for i in range(device_ids)]
            if (
                len(set(device_names)) > 1
                and os.environ.get("CUDA_DEVICE_ORDER") != "PCI_BUS_ID"
            ):
667
                logger.warning(
668
                    "Detected different devices in the system: %s. Please"
669
670
                    " make sure to set `CUDA_DEVICE_ORDER=PCI_BUS_ID` to "
                    "avoid unexpected behavior.",
671
                    ", ".join(device_names),
672
673
674
675
676
                )


class NonNvmlCudaPlatform(CudaPlatformBase):
    @classmethod
677
    @cache
678
679
680
681
682
683
684
685
686
687
688
689
690
691
    def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:
        major, minor = torch.cuda.get_device_capability(device_id)
        return DeviceCapability(major=major, minor=minor)

    @classmethod
    def get_device_name(cls, device_id: int = 0) -> str:
        return torch.cuda.get_device_name(device_id)

    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        device_props = torch.cuda.get_device_properties(device_id)
        return device_props.total_memory

    @classmethod
692
    def is_fully_connected(cls, physical_device_ids: list[int]) -> bool:
693
694
        logger.exception(
            "NVLink detection not possible, as context support was"
695
696
            " not found. Assuming no NVLink available."
        )
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
        return False


# Autodetect either NVML-enabled or non-NVML platform
# based on whether NVML is available.
nvml_available = False
try:
    try:
        pynvml.nvmlInit()
        nvml_available = True
    except Exception:
        # On Jetson, NVML is not supported.
        nvml_available = False
finally:
    if nvml_available:
        pynvml.nvmlShutdown()

CudaPlatform = NvmlCudaPlatform if nvml_available else NonNvmlCudaPlatform

716
CudaPlatform.log_warnings()