cuda.py 24.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
"""Code inside this file can safely assume cuda platform, e.g. importing
pynvml. However, it should not initialize cuda context.
"""

7
import os
8
from collections.abc import Callable
9
from functools import cache, wraps
10
from typing import TYPE_CHECKING, TypeVar
11

12
import torch
13
from typing_extensions import ParamSpec
14

15
16
# import custom ops, trigger op registration
import vllm._C  # noqa
17
from vllm.logger import init_logger
18
from vllm.utils.import_utils import import_pynvml
19
from vllm.utils.torch_utils import cuda_device_count_stateless
20
from vllm.v1.attention.backends.registry import AttentionBackendEnum
21

22
from .interface import DeviceCapability, Platform, PlatformEnum
23

24
if TYPE_CHECKING:
25
    from vllm.config import VllmConfig
26
    from vllm.config.cache import CacheDType
27
    from vllm.v1.attention.selector import AttentionSelectorConfig
28
else:
29
30
    VllmConfig = None
    CacheDType = None
31

32
33
logger = init_logger(__name__)

34
35
36
_P = ParamSpec("_P")
_R = TypeVar("_R")

37
pynvml = import_pynvml()
38

39
40
41
42
# pytorch 2.5 uses cudnn sdpa by default, which will cause crash on some models
# see https://github.com/huggingface/diffusers/issues/9704 for details
torch.backends.cuda.enable_cudnn_sdp(False)

43

44
45
46
47
@cache
def _get_backend_priorities(
    use_mla: bool,
    device_capability: DeviceCapability,
48
    num_heads: int | None = None,
49
50
51
52
) -> list[AttentionBackendEnum]:
    """Get backend priorities with lazy import to avoid circular dependency."""
    if use_mla:
        if device_capability.major == 10:
53
54
55
56
57
58
59
60
61
62
63
            # Prefer FlashInfer at low head counts (FlashMLA uses padding)
            if num_heads is not None and num_heads <= 16:
                sparse_backends = [
                    AttentionBackendEnum.FLASHINFER_MLA_SPARSE,
                    AttentionBackendEnum.FLASHMLA_SPARSE,
                ]
            else:
                sparse_backends = [
                    AttentionBackendEnum.FLASHMLA_SPARSE,
                    AttentionBackendEnum.FLASHINFER_MLA_SPARSE,
                ]
64
            return [
65
                AttentionBackendEnum.FLASHINFER_MLA,
66
                AttentionBackendEnum.CUTLASS_MLA,
67
                AttentionBackendEnum.FLASH_ATTN_MLA,
68
                AttentionBackendEnum.FLASHMLA,
69
                AttentionBackendEnum.TRITON_MLA,
70
                *sparse_backends,
71
72
73
74
            ]
        else:
            return [
                AttentionBackendEnum.FLASH_ATTN_MLA,
75
                AttentionBackendEnum.FLASHMLA,
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
                AttentionBackendEnum.FLASHINFER_MLA,
                AttentionBackendEnum.TRITON_MLA,
                AttentionBackendEnum.FLASHMLA_SPARSE,
            ]
    else:
        if device_capability.major == 10:
            return [
                AttentionBackendEnum.FLASHINFER,
                AttentionBackendEnum.FLASH_ATTN,
                AttentionBackendEnum.TRITON_ATTN,
                AttentionBackendEnum.FLEX_ATTENTION,
            ]
        else:
            return [
                AttentionBackendEnum.FLASH_ATTN,
                AttentionBackendEnum.FLASHINFER,
                AttentionBackendEnum.TRITON_ATTN,
                AttentionBackendEnum.FLEX_ATTENTION,
            ]


97
def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]:
98
    @wraps(fn)
99
    def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
100
101
102
103
104
105
106
107
108
        pynvml.nvmlInit()
        try:
            return fn(*args, **kwargs)
        finally:
            pynvml.nvmlShutdown()

    return wrapper


109
110
class CudaPlatformBase(Platform):
    _enum = PlatformEnum.CUDA
111
    device_name: str = "cuda"
112
113
    device_type: str = "cuda"
    dispatch_key: str = "CUDA"
114
    ray_device_key: str = "GPU"
115
    dist_backend: str = "nccl"
116
    device_control_env_var: str = "CUDA_VISIBLE_DEVICES"
117
118
119
    ray_noset_device_env_vars: list[str] = [
        "RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES",
    ]
120

121
    @property
122
    def supported_dtypes(self) -> list[torch.dtype]:
123
124
125
        if self.has_device_capability(80):
            # Ampere and Hopper or later NVIDIA GPUs.
            return [torch.bfloat16, torch.float16, torch.float32]
126
        if self.has_device_capability(60):
127
128
129
130
131
132
            # Pascal, Volta and Turing NVIDIA GPUs, BF16 is not supported
            return [torch.float16, torch.float32]
        # Kepler and Maxwell NVIDIA GPUs, only FP32 is supported,
        # though vLLM doesn't support these GPUs.
        return [torch.float32]

133
134
135
136
137
    @classmethod
    def set_device(cls, device: torch.device) -> None:
        """
        Set the device for the current platform.
        """
138
        torch.cuda.set_device(device)
139
140
141
142
143
        # With this trick we can force the device to be set eagerly
        # see https://github.com/pytorch/pytorch/issues/155668
        # for why and when it is needed
        _ = torch.zeros(1, device=device)

144
    @classmethod
145
    def get_device_capability(cls, device_id: int = 0) -> DeviceCapability | None:
146
        raise NotImplementedError
147

148
149
150
    @classmethod
    def get_device_name(cls, device_id: int = 0) -> str:
        raise NotImplementedError
151

152
153
154
    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        raise NotImplementedError
155

156
    @classmethod
157
    def is_fully_connected(cls, device_ids: list[int]) -> bool:
158
        raise NotImplementedError
159

160
161
162
    @classmethod
    def log_warnings(cls):
        pass
163

164
    @classmethod
165
    def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
166
        parallel_config = vllm_config.parallel_config
167
        model_config = vllm_config.model_config
168

169
        if parallel_config.worker_cls == "auto":
170
            parallel_config.worker_cls = "vllm.v1.worker.gpu_worker.Worker"
171

172
173
174
175
176
177
178
179
180
181
182
183
184
185
        scheduler_config = vllm_config.scheduler_config
        # Note: model_config may be None during testing
        if (
            model_config is not None
            and model_config.is_mm_prefix_lm
            and scheduler_config.is_multimodal_model
            and not scheduler_config.disable_chunked_mm_input
        ):
            logger.warning(
                "Forcing --disable_chunked_mm_input for models "
                "with multimodal-bidirectional attention."
            )
            scheduler_config.disable_chunked_mm_input = True

186
    @classmethod
187
    def update_block_size_for_backend(cls, vllm_config: "VllmConfig") -> None:
188
        cache_config = vllm_config.cache_config
189
190
191
        if cache_config.block_size is not None:
            # User specified --block-size; keep it.
            return
192

193
194
195
196
197
198
        model_config = vllm_config.model_config
        # model_config may be None during testing.
        # Skip hybrid models — their block_size is managed by
        # HybridAttentionMambaModelConfig.
        if model_config is None or model_config.is_hybrid:
            cache_config.block_size = 16
199
200
            return

201
202
203
204
205
206
        from vllm.config.vllm import (
            get_layers_from_vllm_config,
            set_current_vllm_config,
        )
        from vllm.model_executor.layers.attention_layer_base import (
            AttentionLayerBase,
207
208
        )

209
210
211
        attn_layers = get_layers_from_vllm_config(
            vllm_config,
            AttentionLayerBase,
212
        )
213
214
215
216
217
218
        if not attn_layers:
            cache_config.block_size = 16
            return

        first_layer = next(iter(attn_layers.values()))
        backend_cls = first_layer.get_attn_backend()
219
        with set_current_vllm_config(vllm_config):
220
221
222
223
224
225
            preferred = backend_cls.get_preferred_block_size(16)
        if preferred != 16:
            logger.info(
                "Setting kv cache block size to %d for %s backend.",
                preferred,
                backend_cls.get_name(),
226
            )
227
        cache_config.block_size = preferred
228

229
    @classmethod
230
    def get_current_memory_usage(
231
        cls, device: torch.types.Device | None = None
232
    ) -> float:
233
        torch.cuda.empty_cache()
234
235
236
        torch.cuda.reset_peak_memory_stats(device)
        return torch.cuda.max_memory_allocated(device)

237
    @classmethod
238
    def get_valid_backends(
239
        cls,
240
241
        device_capability: DeviceCapability,
        attn_selector_config: "AttentionSelectorConfig",
242
        num_heads: int | None = None,
243
244
    ) -> tuple[
        list[tuple["AttentionBackendEnum", int]],
245
        dict["AttentionBackendEnum", tuple[int, list[str]]],
246
247
    ]:
        valid_backends_priorities = []
248
        invalid_reasons: dict[AttentionBackendEnum, tuple[int, list[str]]] = {}
249

250
        backend_priorities = _get_backend_priorities(
251
252
253
            attn_selector_config.use_mla,
            device_capability,
            num_heads,
254
        )
255
256
257
258
        for priority, backend in enumerate(backend_priorities):
            try:
                backend_class = backend.get_class()
                invalid_reasons_i = backend_class.validate_configuration(
259
260
                    device_capability=device_capability,
                    **attn_selector_config._asdict(),
261
                )
262
263
264
            except ImportError:
                invalid_reasons_i = ["ImportError"]
            if invalid_reasons_i:
265
                invalid_reasons[backend] = (priority, invalid_reasons_i)
266
267
            else:
                valid_backends_priorities.append((backend, priority))
268

269
        return valid_backends_priorities, invalid_reasons
270

271
    @classmethod
272
    def select_attention_backend(
273
        cls,
274
        selected_backend: "AttentionBackendEnum | None",
275
        attn_selector_config: "AttentionSelectorConfig",
276
277
        device_capability: "DeviceCapability",
        raise_on_invalid: bool = True,
278
        num_heads: int | None = None,
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
    ) -> "AttentionBackendEnum | None":
        """Select the best attention backend for the given configuration.

        Args:
            selected_backend: User-specified backend, or None for auto-selection
            attn_selector_config: Configuration for attention selection
            device_capability: Device capability info
            raise_on_invalid: If True, raise ValueError when no valid backend
            num_heads: Number of attention heads per GPU, used for backend
                priority ordering on Blackwell GPUs

        Returns:
            The selected backend enum, or None if no valid backend found
            and raise_on_invalid is False
        """
294
295
296
297
        # First try checking just the selected backend, if there is one.
        if selected_backend is not None:
            try:
                backend_class = selected_backend.get_class()
298
                validation_errors = backend_class.validate_configuration(
299
300
                    device_capability=device_capability,
                    **attn_selector_config._asdict(),
301
                )
302
            except ImportError:
303
304
305
306
307
308
309
310
311
                validation_errors = ["ImportError"]
            if validation_errors:
                if raise_on_invalid:
                    raise ValueError(
                        f"Selected backend {selected_backend} is not valid for "
                        f"this configuration. Reason: {validation_errors}"
                    )
                return None
            return selected_backend
312

313
        # No selected backend, so find the best valid one.
314
        valid_backends_priorities, invalid_reasons = cls.get_valid_backends(
315
316
            device_capability=device_capability,
            attn_selector_config=attn_selector_config,
317
            num_heads=num_heads,
318
        )
319

320
        if len(valid_backends_priorities) == 0:
321
322
323
324
325
            if raise_on_invalid:
                reasons_str = (
                    "{"
                    + ", ".join(
                        f"{backend.name}: [{', '.join(reasons)}]"
326
                        for backend, (_, reasons) in invalid_reasons.items()
327
328
329
330
331
332
333
334
335
                    )
                    + "}"
                )
                config_str = attn_selector_config.__repr__()
                raise ValueError(
                    f"No valid attention backend found for {cls.device_name} "
                    f"with {config_str}. Reasons: {reasons_str}."
                )
            return None
336

337
338
        # Select the one with the highest priority (lowest index).
        sorted_backends = sorted(valid_backends_priorities, key=lambda x: x[1])
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
        chosen_backend, chosen_priority = sorted_backends[0]

        # If the user specified --block-size (but not --attention-backend),
        # check whether that constraint precluded any higher-priority backends.
        if attn_selector_config.block_size is not None:
            excluded = [
                backend
                for backend, (priority, reasons) in invalid_reasons.items()
                if priority < chosen_priority
                and reasons == ["block_size not supported"]
            ]
            if excluded:
                names = ", ".join(b.name for b in excluded)
                logger.warning(
                    "--block-size %d excluded higher-priority backend(s) "
                    "%s. Using %s instead, which may result in reduced "
                    "performance. Consider removing --block-size to "
                    "auto-select the optimal block size.",
                    attn_selector_config.block_size,
                    names,
                    chosen_backend.name,
                )

        return chosen_backend
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379

    @classmethod
    def get_attn_backend_cls(
        cls,
        selected_backend: "AttentionBackendEnum | None",
        attn_selector_config: "AttentionSelectorConfig",
        num_heads: int | None = None,
    ) -> str:
        device_capability = cls.get_device_capability()
        assert device_capability is not None

        chosen_backend = cls.select_attention_backend(
            selected_backend=selected_backend,
            attn_selector_config=attn_selector_config,
            num_heads=num_heads,
            device_capability=device_capability,
            raise_on_invalid=True,
380
        )
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
        assert chosen_backend is not None  # raise_on_invalid=True guarantees this

        # Log the selection
        if selected_backend is not None:
            logger.info("Using %s backend.", chosen_backend)
        else:
            # Get all valid backends for logging
            valid_backends_priorities, invalid_reasons = cls.get_valid_backends(
                device_capability=device_capability,
                attn_selector_config=attn_selector_config,
                num_heads=num_heads,
            )
            reasons_str = (
                "{"
                + ", ".join(
                    f"{backend.name}: [{', '.join(reasons)}]"
397
                    for backend, (_, reasons) in invalid_reasons.items()
398
399
400
401
402
403
404
405
406
407
408
                )
                + "}"
            )
            config_str = attn_selector_config.__repr__()
            logger.debug_once(
                f"Some attention backends are not valid for {cls.device_name} with "
                f"{config_str}. Reasons: {reasons_str}."
            )
            logger.info_once(
                "Using %s attention backend out of potential backends: %s",
                chosen_backend.name,
409
                tuple(backend.name for backend, _ in valid_backends_priorities),
410
411
                scope="local",
            )
412

413
        return chosen_backend.get_path()
414

415
416
417
418
    @classmethod
    def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]:
        return [
            AttentionBackendEnum.FLASH_ATTN,
419
420
            AttentionBackendEnum.TRITON_ATTN,
            AttentionBackendEnum.TORCH_SDPA,
421
422
423
424
425
426
427
        ]

    @classmethod
    def get_vit_attn_backend(
        cls,
        head_size: int,
        dtype: torch.dtype,
428
        backend: "AttentionBackendEnum | None" = None,
429
430
431
432
433
434
435
436
437
    ) -> "AttentionBackendEnum":
        if backend is not None:
            assert backend in cls.get_supported_vit_attn_backends(), (
                f"Backend {backend} is not supported for vit attention. "
                f"Supported backends are: {cls.get_supported_vit_attn_backends()}"
            )
            logger.info_once(f"Using backend {backend} for vit attention")
            return backend

438
439
440
441
        cc = cls.get_device_capability()
        for vit_attn_backend in cls.get_supported_vit_attn_backends():
            if vit_attn_backend == AttentionBackendEnum.TORCH_SDPA:
                continue
442
            try:
443
444
                backend_class = vit_attn_backend.get_class()
                is_backend_supported = backend_class.supports_head_size(
445
                    head_size
446
447
448
449
450
451
452
453
454
455
456
                ) and backend_class.supports_dtype(dtype)
                if cc is not None:
                    is_backend_supported = (
                        is_backend_supported
                        and backend_class.supports_compute_capability(cc)
                    )
                if is_backend_supported:
                    logger.info_once(
                        f"Using backend {vit_attn_backend} for vit attention"
                    )
                    return vit_attn_backend
457
458
459
460
461
            except ImportError:
                pass

        return AttentionBackendEnum.TORCH_SDPA

462
463
464
465
    @classmethod
    def get_punica_wrapper(cls) -> str:
        return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU"

466
467
    @classmethod
    def get_device_communicator_cls(cls) -> str:
468
469
470
        return (
            "vllm.distributed.device_communicators.cuda_communicator.CudaCommunicator"  # noqa
        )
471

472
473
474
475
    @classmethod
    def supports_fp8(cls) -> bool:
        return cls.has_device_capability(89)

476
477
478
479
    @classmethod
    def use_custom_allreduce(cls) -> bool:
        return True

480
481
482
483
    @classmethod
    def opaque_attention_op(cls) -> bool:
        return True

484
    @classmethod
485
486
    def get_static_graph_wrapper_cls(cls) -> str:
        return "vllm.compilation.cuda_graph.CUDAGraphWrapper"
487

488
489
490
491
    @classmethod
    def device_count(cls) -> int:
        return cuda_device_count_stateless()

492
    @classmethod
493
494
    def check_if_supports_dtype(cls, dtype: torch.dtype):
        if dtype == torch.bfloat16:  # noqa: SIM102
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
            if not cls.has_device_capability(80):
                capability = cls.get_device_capability()
                gpu_name = cls.get_device_name()

                if capability is None:
                    compute_str = "does not have a compute capability"
                else:
                    version_str = capability.as_version_str()
                    compute_str = f"has compute capability {version_str}"

                raise ValueError(
                    "Bfloat16 is only supported on GPUs "
                    "with compute capability of at least 8.0. "
                    f"Your {gpu_name} GPU {compute_str}. "
                    "You can use float16 instead by explicitly setting the "
510
511
                    "`dtype` flag in CLI, for example: --dtype=half."
                )
512

513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
    @classmethod
    def insert_blocks_to_device(
        cls,
        src_cache: torch.Tensor,
        dst_cache: torch.Tensor,
        src_block_indices: torch.Tensor,
        dst_block_indices: torch.Tensor,
    ) -> None:
        """Copy blocks from src_cache to dst_cache on GPU."""
        _src_cache = src_cache[:, src_block_indices]
        dst_cache[:, dst_block_indices] = _src_cache.to(dst_cache.device)

    @classmethod
    def swap_out_blocks_to_host(
        cls,
        src_cache: torch.Tensor,
        dst_cache: torch.Tensor,
        src_block_indices: torch.Tensor,
        dst_block_indices: torch.Tensor,
    ) -> None:
        """Copy blocks from GPU to host (CPU)."""
        _src_cache = src_cache[:, src_block_indices]
        dst_cache[:, dst_block_indices] = _src_cache.cpu()

537
538
539
540
    @classmethod
    def support_hybrid_kv_cache(cls) -> bool:
        return True

541
542
543
544
    @classmethod
    def support_static_graph_mode(cls) -> bool:
        return True

545

546
547
548
549
550
# NVML utils
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
# all the related functions work on real physical device ids.
# the major benefit of using NVML is that it will not initialize CUDA
class NvmlCudaPlatform(CudaPlatformBase):
551
    @classmethod
552
    @cache
553
    @with_nvml_context
554
    def get_device_capability(cls, device_id: int = 0) -> DeviceCapability | None:
555
        try:
556
            physical_device_id = cls.device_id_to_physical_device_id(device_id)
557
558
559
560
561
562
563
564
565
566
            handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
            major, minor = pynvml.nvmlDeviceGetCudaComputeCapability(handle)
            return DeviceCapability(major=major, minor=minor)
        except RuntimeError:
            return None

    @classmethod
    @with_nvml_context
    def has_device_capability(
        cls,
567
        capability: tuple[int, int] | int,
568
569
570
571
572
573
        device_id: int = 0,
    ) -> bool:
        try:
            return super().has_device_capability(capability, device_id)
        except RuntimeError:
            return False
574

575
    @classmethod
576
    @with_nvml_context
577
    def get_device_name(cls, device_id: int = 0) -> str:
578
        physical_device_id = cls.device_id_to_physical_device_id(device_id)
579
        return cls._get_physical_device_name(physical_device_id)
580

581
582
583
    @classmethod
    @with_nvml_context
    def get_device_uuid(cls, device_id: int = 0) -> str:
584
        physical_device_id = cls.device_id_to_physical_device_id(device_id)
585
586
587
        handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
        return pynvml.nvmlDeviceGetUUID(handle)

588
    @classmethod
589
    @with_nvml_context
590
    def get_device_total_memory(cls, device_id: int = 0) -> int:
591
        physical_device_id = cls.device_id_to_physical_device_id(device_id)
592
593
        handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
        return int(pynvml.nvmlDeviceGetMemoryInfo(handle).total)
594

595
    @classmethod
596
    @with_nvml_context
597
    def is_fully_connected(cls, physical_device_ids: list[int]) -> bool:
598
599
600
        """
        query if the set of gpus are fully connected by nvlink (1 hop)
        """
601
        handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in physical_device_ids]
602
603
604
605
606
        for i, handle in enumerate(handles):
            for j, peer_handle in enumerate(handles):
                if i < j:
                    try:
                        p2p_status = pynvml.nvmlDeviceGetP2PStatus(
607
608
609
610
                            handle,
                            peer_handle,
                            pynvml.NVML_P2P_CAPS_INDEX_NVLINK,
                        )
611
612
                        if p2p_status != pynvml.NVML_P2P_STATUS_OK:
                            return False
613
614
                    except pynvml.NVMLError:
                        logger.exception(
615
                            "NVLink detection failed. This is normal if"
616
617
                            " your machine has no NVLink equipped."
                        )
618
619
                        return False
        return True
620
621

    @classmethod
622
623
624
625
626
627
628
629
630
    def _get_physical_device_name(cls, device_id: int = 0) -> str:
        handle = pynvml.nvmlDeviceGetHandleByIndex(device_id)
        return pynvml.nvmlDeviceGetName(handle)

    @classmethod
    @with_nvml_context
    def log_warnings(cls):
        device_ids: int = pynvml.nvmlDeviceGetCount()
        if device_ids > 1:
631
632
633
634
635
            device_names = [cls._get_physical_device_name(i) for i in range(device_ids)]
            if (
                len(set(device_names)) > 1
                and os.environ.get("CUDA_DEVICE_ORDER") != "PCI_BUS_ID"
            ):
636
                logger.warning(
637
                    "Detected different devices in the system: %s. Please"
638
639
                    " make sure to set `CUDA_DEVICE_ORDER=PCI_BUS_ID` to "
                    "avoid unexpected behavior.",
640
                    ", ".join(device_names),
641
642
643
644
645
                )


class NonNvmlCudaPlatform(CudaPlatformBase):
    @classmethod
646
    @cache
647
648
649
650
651
652
653
654
655
656
657
658
659
660
    def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:
        major, minor = torch.cuda.get_device_capability(device_id)
        return DeviceCapability(major=major, minor=minor)

    @classmethod
    def get_device_name(cls, device_id: int = 0) -> str:
        return torch.cuda.get_device_name(device_id)

    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        device_props = torch.cuda.get_device_properties(device_id)
        return device_props.total_memory

    @classmethod
661
    def is_fully_connected(cls, physical_device_ids: list[int]) -> bool:
662
663
        logger.exception(
            "NVLink detection not possible, as context support was"
664
665
            " not found. Assuming no NVLink available."
        )
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
        return False


# Autodetect either NVML-enabled or non-NVML platform
# based on whether NVML is available.
nvml_available = False
try:
    try:
        pynvml.nvmlInit()
        nvml_available = True
    except Exception:
        # On Jetson, NVML is not supported.
        nvml_available = False
finally:
    if nvml_available:
        pynvml.nvmlShutdown()

CudaPlatform = NvmlCudaPlatform if nvml_available else NonNvmlCudaPlatform

685
CudaPlatform.log_warnings()