cuda.py 26.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
"""Code inside this file can safely assume cuda platform, e.g. importing
pynvml. However, it should not initialize cuda context.
"""

7
import os
8
from datetime import timedelta
9
from functools import cache, wraps
10
from typing import TYPE_CHECKING, Callable, Optional, TypeVar, Union
11

12
import torch
13
14
from torch.distributed import PrefixStore, ProcessGroup
from torch.distributed.distributed_c10d import is_nccl_available
15
from typing_extensions import ParamSpec
16

17
18
# import custom ops, trigger op registration
import vllm._C  # noqa
19
import vllm.envs as envs
20
from vllm.logger import init_logger
21
from vllm.utils import cuda_device_count_stateless, import_pynvml
22

23
from .interface import DeviceCapability, Platform, PlatformEnum, _Backend
24

25
if TYPE_CHECKING:
26
    from vllm.config import ModelConfig, VllmConfig
27

28
29
logger = init_logger(__name__)

30
31
32
_P = ParamSpec("_P")
_R = TypeVar("_R")

33
pynvml = import_pynvml()
34

35
36
37
38
# pytorch 2.5 uses cudnn sdpa by default, which will cause crash on some models
# see https://github.com/huggingface/diffusers/issues/9704 for details
torch.backends.cuda.enable_cudnn_sdp(False)

39

40
def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]:
41
42

    @wraps(fn)
43
    def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
44
45
46
47
48
49
50
51
52
        pynvml.nvmlInit()
        try:
            return fn(*args, **kwargs)
        finally:
            pynvml.nvmlShutdown()

    return wrapper


53
54
class CudaPlatformBase(Platform):
    _enum = PlatformEnum.CUDA
55
    device_name: str = "cuda"
56
57
    device_type: str = "cuda"
    dispatch_key: str = "CUDA"
58
    ray_device_key: str = "GPU"
59
    dist_backend: str = "nccl"
60
    device_control_env_var: str = "CUDA_VISIBLE_DEVICES"
61

62
    @property
63
    def supported_dtypes(self) -> list[torch.dtype]:
64
65
66
        if self.has_device_capability(80):
            # Ampere and Hopper or later NVIDIA GPUs.
            return [torch.bfloat16, torch.float16, torch.float32]
67
        if self.has_device_capability(60):
68
69
70
71
72
73
            # Pascal, Volta and Turing NVIDIA GPUs, BF16 is not supported
            return [torch.float16, torch.float32]
        # Kepler and Maxwell NVIDIA GPUs, only FP32 is supported,
        # though vLLM doesn't support these GPUs.
        return [torch.float32]

74
75
76
77
78
    @classmethod
    def set_device(cls, device: torch.device) -> None:
        """
        Set the device for the current platform.
        """
79
        torch.cuda.set_device(device)
80
81
82
83
84
        # With this trick we can force the device to be set eagerly
        # see https://github.com/pytorch/pytorch/issues/155668
        # for why and when it is needed
        _ = torch.zeros(1, device=device)

85
    @classmethod
86
87
88
    def get_device_capability(cls,
                              device_id: int = 0
                              ) -> Optional[DeviceCapability]:
89
        raise NotImplementedError
90

91
92
93
    @classmethod
    def get_device_name(cls, device_id: int = 0) -> str:
        raise NotImplementedError
94

95
96
97
    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        raise NotImplementedError
98

99
    @classmethod
100
    def is_fully_connected(cls, device_ids: list[int]) -> bool:
101
        raise NotImplementedError
102

103
104
105
    @classmethod
    def log_warnings(cls):
        pass
106

107
    @classmethod
108
    def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
109
        parallel_config = vllm_config.parallel_config
110
        model_config = vllm_config.model_config
111

112
        if parallel_config.worker_cls == "auto":
113
            parallel_config.worker_cls = "vllm.v1.worker.gpu_worker.Worker"
114

115
116
117
        cache_config = vllm_config.cache_config
        if cache_config and cache_config.block_size is None:
            cache_config.block_size = 16
118

119
        # TODO(lucas): handle this more gracefully
120
121
        # Note: model_config may be None during testing
        if model_config is not None and model_config.use_mla:
122
123
            use_sparse = hasattr(vllm_config.model_config.hf_config,
                                 "index_topk")
124
125
126
127
128
129
            # If `VLLM_ATTENTION_BACKEND` is not set and we are using MLA,
            # then we default to FlashMLA backend for non-blackwell GPUs,
            # else we default to CutlassMLA. For each case, we force the
            # required block_size.
            use_flashmla = False
            use_cutlass_mla = False
130
            use_flashinfer_mla = False
131
132
133
134
135
136

            if envs.VLLM_ATTENTION_BACKEND is None:
                # Default case
                if cls.is_device_capability(100):
                    # Blackwell => Force CutlassMLA.
                    use_cutlass_mla = True
137
138
139
                    # TODO: This does not work, because the
                    # global_force_attn_backend_context_manager is not set.
                    # See vllm/attention/selector.py:_cached_get_attn_backend
140
                    envs.VLLM_ATTENTION_BACKEND = "CUTLASS_MLA"
141
142
143
144
145
146
147
                else:
                    # Not Blackwell
                    use_flashmla = True
            else:
                # Forced case
                use_flashmla = (envs.VLLM_ATTENTION_BACKEND == "FLASHMLA")
                use_cutlass_mla = (
148
                    envs.VLLM_ATTENTION_BACKEND == "CUTLASS_MLA")
149
150
                use_flashinfer_mla = (
                    envs.VLLM_ATTENTION_BACKEND == "FLASHINFER_MLA")
151

152
            from vllm.attention.ops.flashmla import is_flashmla_supported
153
154
155
156
157
            if use_flashmla and is_flashmla_supported()[0] \
                and cache_config.block_size != 64:
                cache_config.block_size = 64
                logger.info(
                    "Forcing kv cache block size to 64 for FlashMLA backend.")
158

159
160
161
            if use_cutlass_mla and cache_config.block_size != 128:
                cache_config.block_size = 128
                logger.info("Forcing kv cache block size to 128 for "
162
                            "CUTLASS_MLA backend.")
163

164
165
166
167
168
            if use_flashinfer_mla and cache_config.block_size not in [32, 64]:
                cache_config.block_size = 64
                logger.info(
                    "Forcing kv cache block size to 64 for FlashInferMLA "
                    "backend.")
169

170
171
172
173
174
175
            # TODO(Chen): remove this hacky code
            if use_sparse and cache_config.block_size != 64:
                cache_config.block_size = 64
                logger.info(
                    "Forcing kv cache block size to 64 for FlashMLASparse "
                    "backend.")
176
177
178
        # lazy import to avoid circular import
        from vllm.config import CUDAGraphMode

179
        compilation_config = vllm_config.compilation_config
180
181
        if (envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput"
                and parallel_config.data_parallel_size > 1
182
183
184
185
                and compilation_config.cudagraph_mode != CUDAGraphMode.NONE):
            # TODO: Piecewise Cuda graph might be enabled
            # if torch compile cache key issue fixed
            # See https://github.com/vllm-project/vllm/pull/25093
186
            logger.info(
187
188
189
190
191
192
                "WideEP: Disabling CUDA Graphs since DeepEP high-throughput "
                "kernels are optimized for prefill and are incompatible with "
                "CUDA Graphs. "
                "In order to use CUDA Graphs for decode-optimized workloads, "
                "set VLLM_ALL2ALL_BACKEND to another option, such as "
                "deepep_low_latency, pplx, or allgather_reducescatter.")
193
            compilation_config.cudagraph_mode = CUDAGraphMode.NONE
194

195
196
197
198
    @classmethod
    def get_current_memory_usage(cls,
                                 device: Optional[torch.types.Device] = None
                                 ) -> float:
199
        torch.cuda.empty_cache()
200
201
202
        torch.cuda.reset_peak_memory_stats(device)
        return torch.cuda.max_memory_allocated(device)

203
    @classmethod
204
205
    def get_vit_attn_backend(cls, head_size: int,
                             dtype: torch.dtype) -> _Backend:
206
207
208
209
210
211

        # For Blackwell GPUs, force TORCH_SDPA for now.
        # See https://github.com/facebookresearch/xformers/issues/1317#issuecomment-3199392579 # noqa: E501
        if cls.has_device_capability(100):
            return _Backend.TORCH_SDPA

212
213
214
215
216
217
218
219
220
        if dtype not in (torch.float16, torch.bfloat16):
            return _Backend.XFORMERS

        if cls.has_device_capability(80):
            FLASH_ATTN_V1 = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"  # noqa: E501
            from vllm.attention.selector import is_attn_backend_supported
            is_default_fa_supported = is_attn_backend_supported(
                FLASH_ATTN_V1, head_size, dtype, allow_import_error=False)
            if is_default_fa_supported:
221
                return _Backend.FLASH_ATTN
222
223
224
225
226
227
            else:
                # Fallback to XFORMERS
                return _Backend.XFORMERS
        else:
            # Fallback for Volta/Turing GPUs or FA not supported
            return _Backend.XFORMERS
228

229
230
    @classmethod
    def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
231
                             kv_cache_dtype, block_size, use_v1, use_mla,
232
                             has_sink, use_sparse) -> str:
233
        if use_mla:
234
235
236
237
            if not use_v1:
                raise RuntimeError(
                    "MLA attention backends require the V1 engine. "
                    "Set VLLM_USE_V1=1 to enable them.")
238
239
240
241

            from vllm.attention.ops.flashmla import is_flashmla_supported
            from vllm.attention.utils.fa_utils import flash_attn_supports_mla

242
243
244
245
246
            if use_sparse:
                logger.info_once("Using Sparse MLA backend on V1 engine.")
                return ("vllm.v1.attention.backends.mla.flashmla_sparse."
                        "FlashMLASparseBackend")

247
248
249
            use_cutlassmla = selected_backend == _Backend.CUTLASS_MLA or (
                selected_backend is None and cls.is_device_capability(100)
                and block_size == 128)
250
251
252
            use_flashinfermla = selected_backend == _Backend.FLASHINFER_MLA or (
                selected_backend is None and cls.is_device_capability(100)
                and block_size in [32, 64])
253
254
            use_flashmla = selected_backend == _Backend.FLASHMLA or (
                selected_backend is None and is_flashmla_supported()[0])
255
256
257
258
259
260
            use_flashattn = selected_backend == _Backend.FLASH_ATTN_MLA or (
                selected_backend is None and flash_attn_supports_mla())
            use_triton = selected_backend == _Backend.TRITON_MLA or (
                selected_backend is None)

            if use_cutlassmla:
261
262
263
                logger.info_once("Using Cutlass MLA backend on V1 engine.")
                return ("vllm.v1.attention.backends.mla."
                        "cutlass_mla.CutlassMLABackend")
264
            if use_flashinfermla:
265
266
267
268
269
270
                from vllm.v1.attention.backends.utils import (
                    set_kv_cache_layout)
                set_kv_cache_layout("HND")
                logger.info_once("Using FlashInfer MLA backend on V1 engine.")
                return ("vllm.v1.attention.backends.mla."
                        "flashinfer_mla.FlashInferMLABackend")
271
272
            if use_flashmla:
                if block_size != 64:
273
274
275
276
277
                    logger.warning(
                        "FlashMLA backend is not supported for block size %d"
                        " (currently only supports block size 64).",
                        block_size)
                else:
278
                    logger.info_once("Using FlashMLA backend on V1 engine.")
279
                    return ("vllm.v1.attention.backends.mla."
280
281
282
283
284
285
                            "flashmla.FlashMLABackend")
            if use_flashattn:
                logger.info_once(
                    "Using FlashAttention MLA backend on V1 engine.")
                return ("vllm.v1.attention.backends.mla."
                        "flashattn_mla.FlashAttnMLABackend")
286
            if use_triton:
287
288
289
                logger.info_once("Using Triton MLA backend on V1 engine.")
                return ("vllm.v1.attention.backends.mla."
                        "triton_mla.TritonMLABackend")
290
        if use_v1:
291
292
            FLASHINFER_V1 = "vllm.v1.attention.backends.flashinfer.FlashInferBackend"  # noqa: E501
            FLEX_ATTENTION_V1 = "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend"  # noqa: E501
293
            TRITON_ATTN = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend"  # noqa: E501
294
            FLASH_ATTN_V1 = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"  # noqa: E501
295
            TREE_ATTN_V1 = "vllm.v1.attention.backends.tree_attn.TreeAttentionBackend"  # noqa: E501
296
            XFORMERS_V1 = "vllm.v1.attention.backends.xformers.XFormersAttentionBackend"  # noqa: E501
297

298
299
300
            use_fp8_kv_cache = (kv_cache_dtype is not None
                                and kv_cache_dtype.startswith("fp8"))

301
302
            if selected_backend == _Backend.FLASHINFER:
                logger.info_once("Using FlashInfer backend on V1 engine.")
303
304
305
306
                if cls.has_device_capability(100):
                    from vllm.v1.attention.backends.utils import (
                        set_kv_cache_layout)
                    set_kv_cache_layout("HND")
307
                return FLASHINFER_V1
308
            elif selected_backend == _Backend.FLEX_ATTENTION:
309
310
                logger.info_once("Using FlexAttention backend on V1 engine.")
                return FLEX_ATTENTION_V1
311
            elif selected_backend == _Backend.TRITON_ATTN:
312
                logger.info_once("Using Triton backend on V1 engine.")
313
                return TRITON_ATTN
314
315
            elif selected_backend == _Backend.FLASH_ATTN:
                logger.info_once("Using Flash Attention backend on V1 engine.")
316
                return FLASH_ATTN_V1
317
318
319
            elif selected_backend == _Backend.TREE_ATTN:
                logger.info_once("Using Tree Attention backend on V1 engine.")
                return TREE_ATTN_V1
320
            elif selected_backend == _Backend.XFORMERS:
321
322
                logger.info_once("Using XFormers backend on V1 engine.")
                return XFORMERS_V1
323

324
            from vllm.attention.selector import is_attn_backend_supported
325
326

            # Default backends for V1 engine
327
            # Prefer FlashInfer for Blackwell GPUs if installed
328
329
330
            if cls.is_device_capability(100):
                if is_default_backend_supported := is_attn_backend_supported(
                        FLASHINFER_V1, head_size, dtype):
331
332
                    from vllm.v1.attention.backends.utils import (
                        set_kv_cache_layout)
333

334
                    logger.info_once(
335
336
337
                        "Using FlashInfer backend with HND KV cache layout on "
                        "V1 engine by default for Blackwell (SM 10.0) GPUs.")
                    set_kv_cache_layout("HND")
338

339
                    return FLASHINFER_V1
340
341
342

                if not is_default_backend_supported.can_import:
                    logger.warning_once(
343
344
345
                        "FlashInfer failed to import for V1 engine on "
                        "Blackwell (SM 10.0) GPUs; it is recommended to "
                        "install FlashInfer for better performance.")
346

347
            # FlashAttention is the default for SM 8.0+ GPUs
348
            if cls.has_device_capability(80):
349
350
                if (has_sink or
                        use_fp8_kv_cache) and not cls.is_device_capability(90):
351
                    logger.info_once("Using Triton backend on V1 engine.")
352
                    return TRITON_ATTN
353
                elif is_default_backend_supported := is_attn_backend_supported(
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
                        FLASH_ATTN_V1, head_size, dtype,
                        allow_import_error=False):
                    logger.info_once("Using Flash Attention backend on "
                                     "V1 engine.")
                    return FLASH_ATTN_V1

            # FlexAttention is the default for older GPUs
            else:
                logger.info_once("Using FlexAttention backend on V1 engine.")
                return FLEX_ATTENTION_V1

            assert not is_default_backend_supported

            use_flex_attention_reason = {}
            if not is_default_backend_supported.head_size:
                use_flex_attention_reason["head_size"] = head_size
            if not is_default_backend_supported.dtype:
                use_flex_attention_reason["dtype"] = dtype
372

373
374
375
376
377
            logger.info_once(
                "Using FlexAttention backend for %s on V1 engine.",
                ", ".join(f"{k}={v}"
                          for k, v in use_flex_attention_reason.items()),
            )
378
            return FLEX_ATTENTION_V1
379

380
381
382
        raise RuntimeError(
            "V0 attention backends have been removed. Set VLLM_USE_V1=1 "
            "to select a supported backend.")
383

384
385
386
387
    @classmethod
    def get_punica_wrapper(cls) -> str:
        return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU"

388
389
390
391
    @classmethod
    def get_device_communicator_cls(cls) -> str:
        return "vllm.distributed.device_communicators.cuda_communicator.CudaCommunicator"  # noqa

392
393
394
395
    @classmethod
    def supports_fp8(cls) -> bool:
        return cls.has_device_capability(89)

396
397
398
399
    @classmethod
    def use_custom_allreduce(cls) -> bool:
        return True

400
401
402
403
    @classmethod
    def opaque_attention_op(cls) -> bool:
        return True

404
    @classmethod
405
406
    def get_static_graph_wrapper_cls(cls) -> str:
        return "vllm.compilation.cuda_graph.CUDAGraphWrapper"
407

408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
    @classmethod
    def stateless_init_device_torch_dist_pg(
        cls,
        backend: str,
        prefix_store: PrefixStore,
        group_rank: int,
        group_size: int,
        timeout: timedelta,
    ) -> ProcessGroup:
        assert is_nccl_available()
        pg: ProcessGroup = ProcessGroup(
            prefix_store,
            group_rank,
            group_size,
        )
        from torch.distributed.distributed_c10d import ProcessGroupNCCL

        backend_options = ProcessGroupNCCL.Options()
        backend_options._timeout = timeout

        backend_class = ProcessGroupNCCL(prefix_store, group_rank, group_size,
                                         backend_options)
        backend_type = ProcessGroup.BackendType.NCCL
        device = torch.device("cuda")
        pg._set_default_backend(backend_type)
        backend_class._set_sequence_number_for_group()

        pg._register_backend(device, backend_type, backend_class)
        return pg

438
439
440
441
    @classmethod
    def device_count(cls) -> int:
        return cuda_device_count_stateless()

442
    @classmethod
443
444
    def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str,
                                    model_config: "ModelConfig") -> bool:
445
        fp8_attention = kv_cache_dtype.startswith("fp8")
446
447
        attention_backend = envs.VLLM_ATTENTION_BACKEND

448
        supported = False
449
450
451
452
453
454
455
456
457
        if model_config is not None and model_config.use_mla:
            # Default to CutlassMLA for blackwell,
            # FlashMLA otherwise
            if attention_backend is None:
                if cls.is_device_capability(100):
                    attention_backend = "CUTLASS_MLA"
                else:
                    attention_backend = "FLASHMLA"

458
            # Only FlashMLA and CUTLASS_MLA support fp8
459
460
461
            if attention_backend in [
                    "FLASHMLA", "CUTLASS_MLA", "FLASHINFER_MLA"
            ]:
462
463
464
465
466
467
                supported = True
            else:
                supported = (not fp8_attention)
        else:
            # Default to FlashAttention
            if attention_backend is None:
468
                attention_backend = "FLASH_ATTN"
469
470
471
472

            # All Blackwell backends support fp8
            if cls.is_device_capability(100):
                supported = True
473
            elif attention_backend == "FLASH_ATTN":
474
475
476
477
478
479
                if fp8_attention:
                    from vllm.attention.utils.fa_utils import (
                        flash_attn_supports_fp8)
                    supported = flash_attn_supports_fp8()
                else:
                    supported = True
480
481
            elif attention_backend == "FLASHINFER":
                supported = True
482
            elif attention_backend == "TRITON_ATTN":
483
                supported = cls.supports_fp8()
484
485
        return supported

486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
    @classmethod
    def check_if_supports_dtype(cls, torch_dtype: torch.dtype):
        if torch_dtype == torch.bfloat16:  # noqa: SIM102
            if not cls.has_device_capability(80):
                capability = cls.get_device_capability()
                gpu_name = cls.get_device_name()

                if capability is None:
                    compute_str = "does not have a compute capability"
                else:
                    version_str = capability.as_version_str()
                    compute_str = f"has compute capability {version_str}"

                raise ValueError(
                    "Bfloat16 is only supported on GPUs "
                    "with compute capability of at least 8.0. "
                    f"Your {gpu_name} GPU {compute_str}. "
                    "You can use float16 instead by explicitly setting the "
                    "`dtype` flag in CLI, for example: --dtype=half.")

506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
    @classmethod
    def insert_blocks_to_device(
        cls,
        src_cache: torch.Tensor,
        dst_cache: torch.Tensor,
        src_block_indices: torch.Tensor,
        dst_block_indices: torch.Tensor,
    ) -> None:
        """Copy blocks from src_cache to dst_cache on GPU."""
        _src_cache = src_cache[:, src_block_indices]
        dst_cache[:, dst_block_indices] = _src_cache.to(dst_cache.device)

    @classmethod
    def swap_out_blocks_to_host(
        cls,
        src_cache: torch.Tensor,
        dst_cache: torch.Tensor,
        src_block_indices: torch.Tensor,
        dst_block_indices: torch.Tensor,
    ) -> None:
        """Copy blocks from GPU to host (CPU)."""
        _src_cache = src_cache[:, src_block_indices]
        dst_cache[:, dst_block_indices] = _src_cache.cpu()

530
531
532
533
    @classmethod
    def support_hybrid_kv_cache(cls) -> bool:
        return True

534
535
536
537
    @classmethod
    def support_static_graph_mode(cls) -> bool:
        return True

538

539
540
541
542
543
# NVML utils
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
# all the related functions work on real physical device ids.
# the major benefit of using NVML is that it will not initialize CUDA
class NvmlCudaPlatform(CudaPlatformBase):
544

545
    @classmethod
546
    @cache
547
    @with_nvml_context
548
549
550
551
    def get_device_capability(cls,
                              device_id: int = 0
                              ) -> Optional[DeviceCapability]:
        try:
552
            physical_device_id = cls.device_id_to_physical_device_id(device_id)
553
554
555
556
557
558
559
560
561
562
            handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
            major, minor = pynvml.nvmlDeviceGetCudaComputeCapability(handle)
            return DeviceCapability(major=major, minor=minor)
        except RuntimeError:
            return None

    @classmethod
    @with_nvml_context
    def has_device_capability(
        cls,
563
        capability: Union[tuple[int, int], int],
564
565
566
567
568
569
        device_id: int = 0,
    ) -> bool:
        try:
            return super().has_device_capability(capability, device_id)
        except RuntimeError:
            return False
570

571
    @classmethod
572
    @with_nvml_context
573
    def get_device_name(cls, device_id: int = 0) -> str:
574
        physical_device_id = cls.device_id_to_physical_device_id(device_id)
575
        return cls._get_physical_device_name(physical_device_id)
576

577
578
579
    @classmethod
    @with_nvml_context
    def get_device_uuid(cls, device_id: int = 0) -> str:
580
        physical_device_id = cls.device_id_to_physical_device_id(device_id)
581
582
583
        handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
        return pynvml.nvmlDeviceGetUUID(handle)

584
    @classmethod
585
    @with_nvml_context
586
    def get_device_total_memory(cls, device_id: int = 0) -> int:
587
        physical_device_id = cls.device_id_to_physical_device_id(device_id)
588
589
        handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
        return int(pynvml.nvmlDeviceGetMemoryInfo(handle).total)
590

591
    @classmethod
592
    @with_nvml_context
593
    def is_fully_connected(cls, physical_device_ids: list[int]) -> bool:
594
595
596
597
598
599
600
601
602
603
604
        """
        query if the set of gpus are fully connected by nvlink (1 hop)
        """
        handles = [
            pynvml.nvmlDeviceGetHandleByIndex(i) for i in physical_device_ids
        ]
        for i, handle in enumerate(handles):
            for j, peer_handle in enumerate(handles):
                if i < j:
                    try:
                        p2p_status = pynvml.nvmlDeviceGetP2PStatus(
605
606
607
608
                            handle,
                            peer_handle,
                            pynvml.NVML_P2P_CAPS_INDEX_NVLINK,
                        )
609
610
                        if p2p_status != pynvml.NVML_P2P_STATUS_OK:
                            return False
611
612
                    except pynvml.NVMLError:
                        logger.exception(
613
614
                            "NVLink detection failed. This is normal if"
                            " your machine has no NVLink equipped.")
615
616
                        return False
        return True
617
618

    @classmethod
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
    def _get_physical_device_name(cls, device_id: int = 0) -> str:
        handle = pynvml.nvmlDeviceGetHandleByIndex(device_id)
        return pynvml.nvmlDeviceGetName(handle)

    @classmethod
    @with_nvml_context
    def log_warnings(cls):
        device_ids: int = pynvml.nvmlDeviceGetCount()
        if device_ids > 1:
            device_names = [
                cls._get_physical_device_name(i) for i in range(device_ids)
            ]
            if (len(set(device_names)) > 1
                    and os.environ.get("CUDA_DEVICE_ORDER") != "PCI_BUS_ID"):
                logger.warning(
634
                    "Detected different devices in the system: %s. Please"
635
636
                    " make sure to set `CUDA_DEVICE_ORDER=PCI_BUS_ID` to "
                    "avoid unexpected behavior.",
637
                    ", ".join(device_names),
638
639
640
641
642
643
                )


class NonNvmlCudaPlatform(CudaPlatformBase):

    @classmethod
644
    @cache
645
646
647
648
649
650
651
652
653
654
655
656
657
658
    def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:
        major, minor = torch.cuda.get_device_capability(device_id)
        return DeviceCapability(major=major, minor=minor)

    @classmethod
    def get_device_name(cls, device_id: int = 0) -> str:
        return torch.cuda.get_device_name(device_id)

    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        device_props = torch.cuda.get_device_properties(device_id)
        return device_props.total_memory

    @classmethod
659
    def is_fully_connected(cls, physical_device_ids: list[int]) -> bool:
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
        logger.exception(
            "NVLink detection not possible, as context support was"
            " not found. Assuming no NVLink available.")
        return False


# Autodetect either NVML-enabled or non-NVML platform
# based on whether NVML is available.
nvml_available = False
try:
    try:
        pynvml.nvmlInit()
        nvml_available = True
    except Exception:
        # On Jetson, NVML is not supported.
        nvml_available = False
finally:
    if nvml_available:
        pynvml.nvmlShutdown()

CudaPlatform = NvmlCudaPlatform if nvml_available else NonNvmlCudaPlatform

682
CudaPlatform.log_warnings()