cuda.py 25.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
"""Code inside this file can safely assume cuda platform, e.g. importing
pynvml. However, it should not initialize cuda context.
"""

7
8
from __future__ import annotations

9
import os
10
from collections.abc import Callable
11
from datetime import timedelta
12
from functools import cache, lru_cache, wraps
13
from typing import TYPE_CHECKING, TypeVar
14

15
import torch
16
17
from torch.distributed import PrefixStore, ProcessGroup
from torch.distributed.distributed_c10d import is_nccl_available
18
from typing_extensions import ParamSpec
19

20
21
# import custom ops, trigger op registration
import vllm._C  # noqa
22
import vllm._C_stable_libtorch  # noqa
23
import vllm.envs as envs
24
from vllm.logger import init_logger
25
from vllm.utils.import_utils import import_pynvml
26
from vllm.v1.attention.backends.registry import AttentionBackendEnum
27

28
from .interface import DeviceCapability, Platform, PlatformEnum
29

30
if TYPE_CHECKING:
31
    from vllm.config import VllmConfig
32
    from vllm.config.cache import CacheDType
33
    from vllm.config.kernel import IrOpPriorityConfig
34
    from vllm.v1.attention.selector import AttentionSelectorConfig
35
else:
36
37
    VllmConfig = None
    CacheDType = None
38

39
40
logger = init_logger(__name__)

41
42
43
_P = ParamSpec("_P")
_R = TypeVar("_R")

44
pynvml = import_pynvml()
45

46
47
48
49
# pytorch 2.5 uses cudnn sdpa by default, which will cause crash on some models
# see https://github.com/huggingface/diffusers/issues/9704 for details
torch.backends.cuda.enable_cudnn_sdp(False)

50

51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
@lru_cache(maxsize=8)
def _cuda_device_count_stateless(cuda_visible_devices: str | None = None) -> int:
    """Get number of CUDA devices, caching based on the value of CUDA_VISIBLE_DEVICES
    at the time of call.

    This should be used instead of torch.accelerator.device_count() unless
    CUDA_VISIBLE_DEVICES has already been set to the desired value.

    # This can be removed and simply replaced with torch.cuda.get_device_count
    # after https://github.com/pytorch/pytorch/pull/122815 is released."""
    # Note: cuda_visible_devices is not used, but we keep it as an argument for
    # LRU Cache purposes.

    # Code below is based on
    # https://github.com/pytorch/pytorch/blob/
    # c1cd946818442aca8c7f812b16d187ce1586c3bc/
    # torch/cuda/__init__.py#L831C1-L831C17
    import torch.cuda

    if not torch.cuda._is_compiled():
        return 0
    raw_count = torch.cuda._device_count_nvml()
    r = torch._C._cuda_getDeviceCount() if raw_count < 0 else raw_count
    return r


77
78
79
80
@cache
def _get_backend_priorities(
    use_mla: bool,
    device_capability: DeviceCapability,
81
    num_heads: int | None = None,
82
    kv_cache_dtype: CacheDType | None = None,
83
84
85
86
) -> list[AttentionBackendEnum]:
    """Get backend priorities with lazy import to avoid circular dependency."""
    if use_mla:
        if device_capability.major == 10:
87
88
89
90
91
            # Sparse MLA backend priorities
            # See https://github.com/vllm-project/vllm/issues/35807 for
            # benchmark results
            if kv_cache_dtype is not None and kv_cache_dtype.startswith("fp8"):
                # Prefer FlashInfer for fp8 kv cache
92
93
94
95
96
                sparse_backends = [
                    AttentionBackendEnum.FLASHINFER_MLA_SPARSE,
                    AttentionBackendEnum.FLASHMLA_SPARSE,
                ]
            else:
97
98
99
100
101
102
103
104
105
106
107
108
109
                # BF16 KV Cache
                # Prefer FlashInfer at low head counts (FlashMLA uses padding)
                if num_heads is not None and num_heads <= 16:
                    sparse_backends = [
                        AttentionBackendEnum.FLASHINFER_MLA_SPARSE,
                        AttentionBackendEnum.FLASHMLA_SPARSE,
                    ]
                else:
                    sparse_backends = [
                        AttentionBackendEnum.FLASHMLA_SPARSE,
                        AttentionBackendEnum.FLASHINFER_MLA_SPARSE,
                    ]

110
            return [
111
                AttentionBackendEnum.FLASHINFER_MLA,
112
                AttentionBackendEnum.CUTLASS_MLA,
113
                AttentionBackendEnum.FLASH_ATTN_MLA,
114
                AttentionBackendEnum.FLASHMLA,
115
                AttentionBackendEnum.TRITON_MLA,
116
                *sparse_backends,
117
118
119
120
            ]
        else:
            return [
                AttentionBackendEnum.FLASH_ATTN_MLA,
121
                AttentionBackendEnum.FLASHMLA,
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
                AttentionBackendEnum.FLASHINFER_MLA,
                AttentionBackendEnum.TRITON_MLA,
                AttentionBackendEnum.FLASHMLA_SPARSE,
            ]
    else:
        if device_capability.major == 10:
            return [
                AttentionBackendEnum.FLASHINFER,
                AttentionBackendEnum.FLASH_ATTN,
                AttentionBackendEnum.TRITON_ATTN,
                AttentionBackendEnum.FLEX_ATTENTION,
            ]
        else:
            return [
                AttentionBackendEnum.FLASH_ATTN,
                AttentionBackendEnum.FLASHINFER,
                AttentionBackendEnum.TRITON_ATTN,
                AttentionBackendEnum.FLEX_ATTENTION,
            ]


143
def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]:
144
    @wraps(fn)
145
    def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
146
147
148
149
150
151
152
153
154
        pynvml.nvmlInit()
        try:
            return fn(*args, **kwargs)
        finally:
            pynvml.nvmlShutdown()

    return wrapper


155
156
class CudaPlatformBase(Platform):
    _enum = PlatformEnum.CUDA
157
    device_name: str = "cuda"
158
159
    device_type: str = "cuda"
    dispatch_key: str = "CUDA"
160
    ray_device_key: str = "GPU"
161
    dist_backend: str = "nccl"
162
    device_control_env_var: str = "CUDA_VISIBLE_DEVICES"
163
164
165
    ray_noset_device_env_vars: list[str] = [
        "RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES",
    ]
166

167
    @property
168
    def supported_dtypes(self) -> list[torch.dtype]:
169
170
171
        if self.has_device_capability(80):
            # Ampere and Hopper or later NVIDIA GPUs.
            return [torch.bfloat16, torch.float16, torch.float32]
172
        if self.has_device_capability(60):
173
174
175
176
177
178
            # Pascal, Volta and Turing NVIDIA GPUs, BF16 is not supported
            return [torch.float16, torch.float32]
        # Kepler and Maxwell NVIDIA GPUs, only FP32 is supported,
        # though vLLM doesn't support these GPUs.
        return [torch.float32]

179
180
181
182
183
    @classmethod
    def set_device(cls, device: torch.device) -> None:
        """
        Set the device for the current platform.
        """
184
        torch.cuda.set_device(device)
185
186
187
188
189
        # With this trick we can force the device to be set eagerly
        # see https://github.com/pytorch/pytorch/issues/155668
        # for why and when it is needed
        _ = torch.zeros(1, device=device)

190
    @classmethod
191
    def get_device_capability(cls, device_id: int = 0) -> DeviceCapability | None:
192
        raise NotImplementedError
193

194
195
196
    @classmethod
    def get_device_name(cls, device_id: int = 0) -> str:
        raise NotImplementedError
197

198
199
200
    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        raise NotImplementedError
201

202
    @classmethod
203
    def is_fully_connected(cls, device_ids: list[int]) -> bool:
204
        raise NotImplementedError
205

206
207
208
    @classmethod
    def log_warnings(cls):
        pass
209

210
    @classmethod
211
    def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
212
        parallel_config = vllm_config.parallel_config
213
        model_config = vllm_config.model_config
214

215
        if parallel_config.worker_cls == "auto":
216
            parallel_config.worker_cls = "vllm.v1.worker.gpu_worker.Worker"
217

218
219
220
221
222
223
224
225
226
227
228
229
230
231
        scheduler_config = vllm_config.scheduler_config
        # Note: model_config may be None during testing
        if (
            model_config is not None
            and model_config.is_mm_prefix_lm
            and scheduler_config.is_multimodal_model
            and not scheduler_config.disable_chunked_mm_input
        ):
            logger.warning(
                "Forcing --disable_chunked_mm_input for models "
                "with multimodal-bidirectional attention."
            )
            scheduler_config.disable_chunked_mm_input = True

232
    @classmethod
233
    def get_current_memory_usage(
234
        cls, device: torch.types.Device | None = None
235
    ) -> float:
236
        torch.cuda.empty_cache()
237
238
239
        torch.cuda.reset_peak_memory_stats(device)
        return torch.cuda.max_memory_allocated(device)

240
    @classmethod
241
    def get_valid_backends(
242
        cls,
243
        device_capability: DeviceCapability,
244
        attn_selector_config: AttentionSelectorConfig,
245
        num_heads: int | None = None,
246
    ) -> tuple[
247
248
        list[tuple[AttentionBackendEnum, int]],
        dict[AttentionBackendEnum, tuple[int, list[str]]],
249
250
    ]:
        valid_backends_priorities = []
251
        invalid_reasons: dict[AttentionBackendEnum, tuple[int, list[str]]] = {}
252

253
        backend_priorities = _get_backend_priorities(
254
255
256
            attn_selector_config.use_mla,
            device_capability,
            num_heads,
257
            attn_selector_config.kv_cache_dtype,
258
        )
259
260
261
262
        for priority, backend in enumerate(backend_priorities):
            try:
                backend_class = backend.get_class()
                invalid_reasons_i = backend_class.validate_configuration(
263
264
                    device_capability=device_capability,
                    **attn_selector_config._asdict(),
265
                )
266
267
268
            except ImportError:
                invalid_reasons_i = ["ImportError"]
            if invalid_reasons_i:
269
                invalid_reasons[backend] = (priority, invalid_reasons_i)
270
271
            else:
                valid_backends_priorities.append((backend, priority))
272

273
        return valid_backends_priorities, invalid_reasons
274

275
    @classmethod
276
    def get_attn_backend_cls(
277
        cls,
278
279
        selected_backend: AttentionBackendEnum | None,
        attn_selector_config: AttentionSelectorConfig,
280
        num_heads: int | None = None,
281
282
283
284
    ) -> str:
        device_capability = cls.get_device_capability()
        assert device_capability is not None

285
286
287
288
        # First try checking just the selected backend, if there is one.
        if selected_backend is not None:
            try:
                backend_class = selected_backend.get_class()
289
                invalid_reasons = backend_class.validate_configuration(
290
291
                    device_capability=device_capability,
                    **attn_selector_config._asdict(),
292
                )
293
            except ImportError:
294
295
296
297
298
299
300
301
302
                invalid_reasons = ["ImportError"]
            if invalid_reasons:
                raise ValueError(
                    f"Selected backend {selected_backend} is not valid for "
                    f"this configuration. Reason: {invalid_reasons}"
                )
            else:
                logger.info("Using %s backend.", selected_backend)
                return selected_backend.get_path()
303

304
305
        # No selected backend or the selected backend is invalid,
        # so we try finding a valid backend.
306
        valid_backends_priorities, all_invalid_reasons = cls.get_valid_backends(
307
308
            device_capability=device_capability,
            attn_selector_config=attn_selector_config,
309
            num_heads=num_heads,
310
        )
311
312
313
314
        reasons_str = (
            "{"
            + ", ".join(
                f"{backend.name}: [{', '.join(reasons)}]"
315
                for backend, (_, reasons) in all_invalid_reasons.items()
316
317
318
319
320
321
322
323
            )
            + "}"
        )
        config_str = attn_selector_config.__repr__()
        logger.debug_once(
            f"Some attention backends are not valid for {cls.device_name} with "
            f"{config_str}. Reasons: {reasons_str}."
        )
324
        if len(valid_backends_priorities) == 0:
325
326
327
328
            raise ValueError(
                f"No valid attention backend found for {cls.device_name} "
                f"with {config_str}. Reasons: {reasons_str}."
            )
329

330
331
332
333
334
335
336
337
        # We have found some valid backends. Select the one with the
        # highest priority.
        sorted_indices = sorted(
            range(len(valid_backends_priorities)),
            key=lambda i: valid_backends_priorities[i][1],
        )
        selected_index = sorted_indices[0]
        selected_backend = valid_backends_priorities[selected_index][0]
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
        selected_priority = valid_backends_priorities[selected_index][1]

        # If the user specified --block-size (but not --attention-backend),
        # check whether that constraint precluded any higher-priority backends.
        if attn_selector_config.block_size is not None:
            excluded = [
                backend
                for backend, (priority, reasons) in all_invalid_reasons.items()
                if priority < selected_priority
                and reasons == ["block_size not supported"]
            ]
            if excluded:
                names = ", ".join(b.name for b in excluded)
                logger.warning(
                    "--block-size %d precluded higher-priority backend(s) "
                    "%s. Using %s instead, which may result in reduced "
                    "performance. Consider removing --block-size to "
                    "auto-select the optimal block size.",
                    attn_selector_config.block_size,
                    names,
                    selected_backend.name,
                )

361
362
363
364
365
        logger.info_once(
            "Using %s attention backend out of potential backends: %s.",
            selected_backend.name,
            "[" + ", ".join(f"'{b[0].name}'" for b in valid_backends_priorities) + "]",
            scope="local",
366
367
        )

368
        return selected_backend.get_path()
369

370
    @classmethod
371
    def get_supported_vit_attn_backends(cls) -> list[AttentionBackendEnum]:
372
373
374
375
376
377
378
379
380
381
382
383
384
385
        if cls.has_device_capability(80):
            return [
                AttentionBackendEnum.FLASH_ATTN,
                AttentionBackendEnum.TRITON_ATTN,
                AttentionBackendEnum.TORCH_SDPA,
                AttentionBackendEnum.FLASHINFER,
            ]
        else:
            return [
                AttentionBackendEnum.FLASH_ATTN,
                AttentionBackendEnum.TORCH_SDPA,
                AttentionBackendEnum.TRITON_ATTN,
                AttentionBackendEnum.FLASHINFER,
            ]
386
387
388
389
390
391

    @classmethod
    def get_vit_attn_backend(
        cls,
        head_size: int,
        dtype: torch.dtype,
392
393
        backend: AttentionBackendEnum | None = None,
    ) -> AttentionBackendEnum:
394
395
396
397
398
399
400
401
        if backend is not None:
            assert backend in cls.get_supported_vit_attn_backends(), (
                f"Backend {backend} is not supported for vit attention. "
                f"Supported backends are: {cls.get_supported_vit_attn_backends()}"
            )
            logger.info_once(f"Using backend {backend} for vit attention")
            return backend

402
403
404
        cc = cls.get_device_capability()
        for vit_attn_backend in cls.get_supported_vit_attn_backends():
            if vit_attn_backend == AttentionBackendEnum.TORCH_SDPA:
405
                return vit_attn_backend
406
            try:
407
408
                backend_class = vit_attn_backend.get_class()
                is_backend_supported = backend_class.supports_head_size(
409
                    head_size
410
411
412
413
414
415
416
417
                ) and backend_class.supports_dtype(dtype)
                if cc is not None:
                    is_backend_supported = (
                        is_backend_supported
                        and backend_class.supports_compute_capability(cc)
                    )
                if is_backend_supported:
                    logger.info_once(
418
419
                        f"Using backend {vit_attn_backend} for vit attention",
                        scope="local",
420
421
                    )
                    return vit_attn_backend
422
423
424
425
426
            except ImportError:
                pass

        return AttentionBackendEnum.TORCH_SDPA

427
428
429
430
    @classmethod
    def get_punica_wrapper(cls) -> str:
        return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU"

431
432
    @classmethod
    def get_device_communicator_cls(cls) -> str:
433
434
435
        return (
            "vllm.distributed.device_communicators.cuda_communicator.CudaCommunicator"  # noqa
        )
436

437
438
439
440
    @classmethod
    def supports_fp8(cls) -> bool:
        return cls.has_device_capability(89)

441
442
443
444
    @classmethod
    def use_custom_allreduce(cls) -> bool:
        return True

445
446
447
448
    @classmethod
    def opaque_attention_op(cls) -> bool:
        return True

449
    @classmethod
450
451
    def get_static_graph_wrapper_cls(cls) -> str:
        return "vllm.compilation.cuda_graph.CUDAGraphWrapper"
452

453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
    @classmethod
    def stateless_init_device_torch_dist_pg(
        cls,
        backend: str,
        prefix_store: PrefixStore,
        group_rank: int,
        group_size: int,
        timeout: timedelta,
    ) -> ProcessGroup:
        assert is_nccl_available()
        pg: ProcessGroup = ProcessGroup(
            prefix_store,
            group_rank,
            group_size,
        )
        from torch.distributed.distributed_c10d import ProcessGroupNCCL

        backend_options = ProcessGroupNCCL.Options()
        backend_options._timeout = timeout

        backend_class = ProcessGroupNCCL(
            prefix_store, group_rank, group_size, backend_options
        )
        backend_type = ProcessGroup.BackendType.NCCL
        device = torch.device("cuda")
        pg._set_default_backend(backend_type)
        backend_class._set_sequence_number_for_group()

        pg._register_backend(device, backend_type, backend_class)
        return pg

484
485
    @classmethod
    def device_count(cls) -> int:
486
        return _cuda_device_count_stateless(envs.CUDA_VISIBLE_DEVICES)
487

488
    @classmethod
489
490
    def check_if_supports_dtype(cls, dtype: torch.dtype):
        if dtype == torch.bfloat16:  # noqa: SIM102
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
            if not cls.has_device_capability(80):
                capability = cls.get_device_capability()
                gpu_name = cls.get_device_name()

                if capability is None:
                    compute_str = "does not have a compute capability"
                else:
                    version_str = capability.as_version_str()
                    compute_str = f"has compute capability {version_str}"

                raise ValueError(
                    "Bfloat16 is only supported on GPUs "
                    "with compute capability of at least 8.0. "
                    f"Your {gpu_name} GPU {compute_str}. "
                    "You can use float16 instead by explicitly setting the "
506
507
                    "`dtype` flag in CLI, for example: --dtype=half."
                )
508

509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
    @classmethod
    def insert_blocks_to_device(
        cls,
        src_cache: torch.Tensor,
        dst_cache: torch.Tensor,
        src_block_indices: torch.Tensor,
        dst_block_indices: torch.Tensor,
    ) -> None:
        """Copy blocks from src_cache to dst_cache on GPU."""
        _src_cache = src_cache[:, src_block_indices]
        dst_cache[:, dst_block_indices] = _src_cache.to(dst_cache.device)

    @classmethod
    def swap_out_blocks_to_host(
        cls,
        src_cache: torch.Tensor,
        dst_cache: torch.Tensor,
        src_block_indices: torch.Tensor,
        dst_block_indices: torch.Tensor,
    ) -> None:
        """Copy blocks from GPU to host (CPU)."""
        _src_cache = src_cache[:, src_block_indices]
        dst_cache[:, dst_block_indices] = _src_cache.cpu()

533
534
535
536
    @classmethod
    def support_hybrid_kv_cache(cls) -> bool:
        return True

537
538
539
540
    @classmethod
    def support_static_graph_mode(cls) -> bool:
        return True

541
542
543
544
545
    @classmethod
    def support_deep_gemm(cls) -> bool:
        """Currently, only Hopper and Blackwell GPUs are supported."""
        return cls.is_device_capability(90) or cls.is_device_capability_family(100)

546
    @classmethod
547
    def num_compute_units(cls, device_id: int = 0) -> int:
548
549
        return torch.cuda.get_device_properties(device_id).multi_processor_count

550
551
552
553
    @classmethod
    def use_custom_op_collectives(cls) -> bool:
        return True

554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
    @classmethod
    def get_default_ir_op_priority(cls, vllm_config: VllmConfig) -> IrOpPriorityConfig:
        from vllm.config.compilation import CompilationMode
        from vllm.config.kernel import IrOpPriorityConfig

        # Native used by default when compiling,
        # use vllm_c kernels where available when no codegen
        cc = vllm_config.compilation_config
        using_inductor = cc.backend == "inductor" and cc.mode != CompilationMode.NONE
        default = ["native"] if using_inductor else ["vllm_c", "native"]

        # Use oink if enabled for rms_norm
        # TODO(Laurawly/luka): remove this env var,
        #  users can just use IR op priority directly
        rms_norm = default
        if envs.VLLM_USE_OINK_OPS:
            rms_norm = ["oink"] + default

        return IrOpPriorityConfig.with_default(default, rms_norm=rms_norm)

574

575
576
577
578
579
# NVML utils
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
# all the related functions work on real physical device ids.
# the major benefit of using NVML is that it will not initialize CUDA
class NvmlCudaPlatform(CudaPlatformBase):
580
    @classmethod
581
    @cache
582
    @with_nvml_context
583
    def get_device_capability(cls, device_id: int = 0) -> DeviceCapability | None:
584
        try:
585
            physical_device_id = cls.device_id_to_physical_device_id(device_id)
586
587
588
589
590
591
592
593
594
595
            handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
            major, minor = pynvml.nvmlDeviceGetCudaComputeCapability(handle)
            return DeviceCapability(major=major, minor=minor)
        except RuntimeError:
            return None

    @classmethod
    @with_nvml_context
    def has_device_capability(
        cls,
596
        capability: tuple[int, int] | int,
597
598
599
600
601
602
        device_id: int = 0,
    ) -> bool:
        try:
            return super().has_device_capability(capability, device_id)
        except RuntimeError:
            return False
603

604
    @classmethod
605
    @with_nvml_context
606
    def get_device_name(cls, device_id: int = 0) -> str:
607
        physical_device_id = cls.device_id_to_physical_device_id(device_id)
608
        return cls._get_physical_device_name(physical_device_id)
609

610
611
612
    @classmethod
    @with_nvml_context
    def get_device_uuid(cls, device_id: int = 0) -> str:
613
        physical_device_id = cls.device_id_to_physical_device_id(device_id)
614
615
616
        handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
        return pynvml.nvmlDeviceGetUUID(handle)

617
    @classmethod
618
    @with_nvml_context
619
    def get_device_total_memory(cls, device_id: int = 0) -> int:
620
        physical_device_id = cls.device_id_to_physical_device_id(device_id)
621
622
        handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
        return int(pynvml.nvmlDeviceGetMemoryInfo(handle).total)
623

624
    @classmethod
625
    @with_nvml_context
626
    def is_fully_connected(cls, physical_device_ids: list[int]) -> bool:
627
628
629
        """
        query if the set of gpus are fully connected by nvlink (1 hop)
        """
630
        handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in physical_device_ids]
631
632
633
634
635
        for i, handle in enumerate(handles):
            for j, peer_handle in enumerate(handles):
                if i < j:
                    try:
                        p2p_status = pynvml.nvmlDeviceGetP2PStatus(
636
637
638
639
                            handle,
                            peer_handle,
                            pynvml.NVML_P2P_CAPS_INDEX_NVLINK,
                        )
640
641
                        if p2p_status != pynvml.NVML_P2P_STATUS_OK:
                            return False
642
643
                    except pynvml.NVMLError:
                        logger.exception(
644
                            "NVLink detection failed. This is normal if"
645
646
                            " your machine has no NVLink equipped."
                        )
647
648
                        return False
        return True
649
650

    @classmethod
651
652
653
654
655
656
657
658
659
    def _get_physical_device_name(cls, device_id: int = 0) -> str:
        handle = pynvml.nvmlDeviceGetHandleByIndex(device_id)
        return pynvml.nvmlDeviceGetName(handle)

    @classmethod
    @with_nvml_context
    def log_warnings(cls):
        device_ids: int = pynvml.nvmlDeviceGetCount()
        if device_ids > 1:
660
661
662
663
664
            device_names = [cls._get_physical_device_name(i) for i in range(device_ids)]
            if (
                len(set(device_names)) > 1
                and os.environ.get("CUDA_DEVICE_ORDER") != "PCI_BUS_ID"
            ):
665
                logger.warning(
666
                    "Detected different devices in the system: %s. Please"
667
668
                    " make sure to set `CUDA_DEVICE_ORDER=PCI_BUS_ID` to "
                    "avoid unexpected behavior.",
669
                    ", ".join(device_names),
670
671
672
673
674
                )


class NonNvmlCudaPlatform(CudaPlatformBase):
    @classmethod
675
    @cache
676
677
678
679
680
681
682
683
684
685
686
687
688
689
    def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:
        major, minor = torch.cuda.get_device_capability(device_id)
        return DeviceCapability(major=major, minor=minor)

    @classmethod
    def get_device_name(cls, device_id: int = 0) -> str:
        return torch.cuda.get_device_name(device_id)

    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        device_props = torch.cuda.get_device_properties(device_id)
        return device_props.total_memory

    @classmethod
690
    def is_fully_connected(cls, physical_device_ids: list[int]) -> bool:
691
692
        logger.exception(
            "NVLink detection not possible, as context support was"
693
694
            " not found. Assuming no NVLink available."
        )
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
        return False


# Autodetect either NVML-enabled or non-NVML platform
# based on whether NVML is available.
nvml_available = False
try:
    try:
        pynvml.nvmlInit()
        nvml_available = True
    except Exception:
        # On Jetson, NVML is not supported.
        nvml_available = False
finally:
    if nvml_available:
        pynvml.nvmlShutdown()

CudaPlatform = NvmlCudaPlatform if nvml_available else NonNvmlCudaPlatform

714
CudaPlatform.log_warnings()