"vllm/vscode:/vscode.git/clone" did not exist on "74333ae2f6c3c4aa4b55301e5ed7aba03a5b09f8"
cuda.py 24.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
"""Code inside this file can safely assume cuda platform, e.g. importing
pynvml. However, it should not initialize cuda context.
"""

7
8
from __future__ import annotations

9
import os
10
from collections.abc import Callable
11
from datetime import timedelta
12
from functools import cache, lru_cache, wraps
13
from typing import TYPE_CHECKING, TypeVar
14

15
import torch
16
17
from torch.distributed import PrefixStore, ProcessGroup
from torch.distributed.distributed_c10d import is_nccl_available
18
from typing_extensions import ParamSpec
19

20
21
# import custom ops, trigger op registration
import vllm._C  # noqa
22
import vllm._C_stable_libtorch  # noqa
23
import vllm.envs as envs
24
from vllm.logger import init_logger
25
from vllm.utils.import_utils import import_pynvml
26
from vllm.v1.attention.backends.registry import AttentionBackendEnum
27

28
from .interface import DeviceCapability, Platform, PlatformEnum
29

30
if TYPE_CHECKING:
31
    from vllm.config import VllmConfig
32
    from vllm.config.cache import CacheDType
33
    from vllm.v1.attention.selector import AttentionSelectorConfig
34
else:
35
36
    VllmConfig = None
    CacheDType = None
37

38
39
logger = init_logger(__name__)

40
41
42
_P = ParamSpec("_P")
_R = TypeVar("_R")

43
pynvml = import_pynvml()
44

45
46
47
48
# pytorch 2.5 uses cudnn sdpa by default, which will cause crash on some models
# see https://github.com/huggingface/diffusers/issues/9704 for details
torch.backends.cuda.enable_cudnn_sdp(False)

49

50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
@lru_cache(maxsize=8)
def _cuda_device_count_stateless(cuda_visible_devices: str | None = None) -> int:
    """Get number of CUDA devices, caching based on the value of CUDA_VISIBLE_DEVICES
    at the time of call.

    This should be used instead of torch.accelerator.device_count() unless
    CUDA_VISIBLE_DEVICES has already been set to the desired value.

    # This can be removed and simply replaced with torch.cuda.get_device_count
    # after https://github.com/pytorch/pytorch/pull/122815 is released."""
    # Note: cuda_visible_devices is not used, but we keep it as an argument for
    # LRU Cache purposes.

    # Code below is based on
    # https://github.com/pytorch/pytorch/blob/
    # c1cd946818442aca8c7f812b16d187ce1586c3bc/
    # torch/cuda/__init__.py#L831C1-L831C17
    import torch.cuda

    if not torch.cuda._is_compiled():
        return 0
    raw_count = torch.cuda._device_count_nvml()
    r = torch._C._cuda_getDeviceCount() if raw_count < 0 else raw_count
    return r


76
77
78
79
@cache
def _get_backend_priorities(
    use_mla: bool,
    device_capability: DeviceCapability,
80
    num_heads: int | None = None,
81
    kv_cache_dtype: CacheDType | None = None,
82
83
84
85
) -> list[AttentionBackendEnum]:
    """Get backend priorities with lazy import to avoid circular dependency."""
    if use_mla:
        if device_capability.major == 10:
86
87
88
89
90
            # Sparse MLA backend priorities
            # See https://github.com/vllm-project/vllm/issues/35807 for
            # benchmark results
            if kv_cache_dtype is not None and kv_cache_dtype.startswith("fp8"):
                # Prefer FlashInfer for fp8 kv cache
91
92
93
94
95
                sparse_backends = [
                    AttentionBackendEnum.FLASHINFER_MLA_SPARSE,
                    AttentionBackendEnum.FLASHMLA_SPARSE,
                ]
            else:
96
97
98
99
100
101
102
103
104
105
106
107
108
                # BF16 KV Cache
                # Prefer FlashInfer at low head counts (FlashMLA uses padding)
                if num_heads is not None and num_heads <= 16:
                    sparse_backends = [
                        AttentionBackendEnum.FLASHINFER_MLA_SPARSE,
                        AttentionBackendEnum.FLASHMLA_SPARSE,
                    ]
                else:
                    sparse_backends = [
                        AttentionBackendEnum.FLASHMLA_SPARSE,
                        AttentionBackendEnum.FLASHINFER_MLA_SPARSE,
                    ]

109
            return [
110
                AttentionBackendEnum.FLASHINFER_MLA,
111
                AttentionBackendEnum.CUTLASS_MLA,
112
                AttentionBackendEnum.FLASH_ATTN_MLA,
113
                AttentionBackendEnum.FLASHMLA,
114
                AttentionBackendEnum.TRITON_MLA,
115
                *sparse_backends,
116
117
118
119
            ]
        else:
            return [
                AttentionBackendEnum.FLASH_ATTN_MLA,
120
                AttentionBackendEnum.FLASHMLA,
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
                AttentionBackendEnum.FLASHINFER_MLA,
                AttentionBackendEnum.TRITON_MLA,
                AttentionBackendEnum.FLASHMLA_SPARSE,
            ]
    else:
        if device_capability.major == 10:
            return [
                AttentionBackendEnum.FLASHINFER,
                AttentionBackendEnum.FLASH_ATTN,
                AttentionBackendEnum.TRITON_ATTN,
                AttentionBackendEnum.FLEX_ATTENTION,
            ]
        else:
            return [
                AttentionBackendEnum.FLASH_ATTN,
                AttentionBackendEnum.FLASHINFER,
                AttentionBackendEnum.TRITON_ATTN,
                AttentionBackendEnum.FLEX_ATTENTION,
            ]


142
def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]:
143
    @wraps(fn)
144
    def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
145
146
147
148
149
150
151
152
153
        pynvml.nvmlInit()
        try:
            return fn(*args, **kwargs)
        finally:
            pynvml.nvmlShutdown()

    return wrapper


154
155
class CudaPlatformBase(Platform):
    _enum = PlatformEnum.CUDA
156
    device_name: str = "cuda"
157
158
    device_type: str = "cuda"
    dispatch_key: str = "CUDA"
159
    ray_device_key: str = "GPU"
160
    dist_backend: str = "nccl"
161
    device_control_env_var: str = "CUDA_VISIBLE_DEVICES"
162
163
164
    ray_noset_device_env_vars: list[str] = [
        "RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES",
    ]
165

166
    @property
167
    def supported_dtypes(self) -> list[torch.dtype]:
168
169
170
        if self.has_device_capability(80):
            # Ampere and Hopper or later NVIDIA GPUs.
            return [torch.bfloat16, torch.float16, torch.float32]
171
        if self.has_device_capability(60):
172
173
174
175
176
177
            # Pascal, Volta and Turing NVIDIA GPUs, BF16 is not supported
            return [torch.float16, torch.float32]
        # Kepler and Maxwell NVIDIA GPUs, only FP32 is supported,
        # though vLLM doesn't support these GPUs.
        return [torch.float32]

178
179
180
181
182
    @classmethod
    def set_device(cls, device: torch.device) -> None:
        """
        Set the device for the current platform.
        """
183
        torch.cuda.set_device(device)
184
185
186
187
188
        # With this trick we can force the device to be set eagerly
        # see https://github.com/pytorch/pytorch/issues/155668
        # for why and when it is needed
        _ = torch.zeros(1, device=device)

189
    @classmethod
190
    def get_device_capability(cls, device_id: int = 0) -> DeviceCapability | None:
191
        raise NotImplementedError
192

193
194
195
    @classmethod
    def get_device_name(cls, device_id: int = 0) -> str:
        raise NotImplementedError
196

197
198
199
    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        raise NotImplementedError
200

201
    @classmethod
202
    def is_fully_connected(cls, device_ids: list[int]) -> bool:
203
        raise NotImplementedError
204

205
206
207
    @classmethod
    def log_warnings(cls):
        pass
208

209
    @classmethod
210
    def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
211
        parallel_config = vllm_config.parallel_config
212
        model_config = vllm_config.model_config
213

214
        if parallel_config.worker_cls == "auto":
215
            parallel_config.worker_cls = "vllm.v1.worker.gpu_worker.Worker"
216

217
218
219
220
221
222
223
224
225
226
227
228
229
230
        scheduler_config = vllm_config.scheduler_config
        # Note: model_config may be None during testing
        if (
            model_config is not None
            and model_config.is_mm_prefix_lm
            and scheduler_config.is_multimodal_model
            and not scheduler_config.disable_chunked_mm_input
        ):
            logger.warning(
                "Forcing --disable_chunked_mm_input for models "
                "with multimodal-bidirectional attention."
            )
            scheduler_config.disable_chunked_mm_input = True

231
    @classmethod
232
    def get_current_memory_usage(
233
        cls, device: torch.types.Device | None = None
234
    ) -> float:
235
        torch.cuda.empty_cache()
236
237
238
        torch.cuda.reset_peak_memory_stats(device)
        return torch.cuda.max_memory_allocated(device)

239
    @classmethod
240
    def get_valid_backends(
241
        cls,
242
        device_capability: DeviceCapability,
243
        attn_selector_config: AttentionSelectorConfig,
244
        num_heads: int | None = None,
245
    ) -> tuple[
246
247
        list[tuple[AttentionBackendEnum, int]],
        dict[AttentionBackendEnum, tuple[int, list[str]]],
248
249
    ]:
        valid_backends_priorities = []
250
        invalid_reasons: dict[AttentionBackendEnum, tuple[int, list[str]]] = {}
251

252
        backend_priorities = _get_backend_priorities(
253
254
255
            attn_selector_config.use_mla,
            device_capability,
            num_heads,
256
            attn_selector_config.kv_cache_dtype,
257
        )
258
259
260
261
        for priority, backend in enumerate(backend_priorities):
            try:
                backend_class = backend.get_class()
                invalid_reasons_i = backend_class.validate_configuration(
262
263
                    device_capability=device_capability,
                    **attn_selector_config._asdict(),
264
                )
265
266
267
            except ImportError:
                invalid_reasons_i = ["ImportError"]
            if invalid_reasons_i:
268
                invalid_reasons[backend] = (priority, invalid_reasons_i)
269
270
            else:
                valid_backends_priorities.append((backend, priority))
271

272
        return valid_backends_priorities, invalid_reasons
273

274
    @classmethod
275
    def get_attn_backend_cls(
276
        cls,
277
278
        selected_backend: AttentionBackendEnum | None,
        attn_selector_config: AttentionSelectorConfig,
279
        num_heads: int | None = None,
280
281
282
283
    ) -> str:
        device_capability = cls.get_device_capability()
        assert device_capability is not None

284
285
286
287
        # First try checking just the selected backend, if there is one.
        if selected_backend is not None:
            try:
                backend_class = selected_backend.get_class()
288
                invalid_reasons = backend_class.validate_configuration(
289
290
                    device_capability=device_capability,
                    **attn_selector_config._asdict(),
291
                )
292
            except ImportError:
293
294
295
296
297
298
299
300
301
                invalid_reasons = ["ImportError"]
            if invalid_reasons:
                raise ValueError(
                    f"Selected backend {selected_backend} is not valid for "
                    f"this configuration. Reason: {invalid_reasons}"
                )
            else:
                logger.info("Using %s backend.", selected_backend)
                return selected_backend.get_path()
302

303
304
        # No selected backend or the selected backend is invalid,
        # so we try finding a valid backend.
305
        valid_backends_priorities, all_invalid_reasons = cls.get_valid_backends(
306
307
            device_capability=device_capability,
            attn_selector_config=attn_selector_config,
308
            num_heads=num_heads,
309
        )
310
311
312
313
        reasons_str = (
            "{"
            + ", ".join(
                f"{backend.name}: [{', '.join(reasons)}]"
314
                for backend, (_, reasons) in all_invalid_reasons.items()
315
316
317
318
319
320
321
322
            )
            + "}"
        )
        config_str = attn_selector_config.__repr__()
        logger.debug_once(
            f"Some attention backends are not valid for {cls.device_name} with "
            f"{config_str}. Reasons: {reasons_str}."
        )
323
        if len(valid_backends_priorities) == 0:
324
325
326
327
            raise ValueError(
                f"No valid attention backend found for {cls.device_name} "
                f"with {config_str}. Reasons: {reasons_str}."
            )
328

329
330
331
332
333
334
335
336
        # We have found some valid backends. Select the one with the
        # highest priority.
        sorted_indices = sorted(
            range(len(valid_backends_priorities)),
            key=lambda i: valid_backends_priorities[i][1],
        )
        selected_index = sorted_indices[0]
        selected_backend = valid_backends_priorities[selected_index][0]
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
        selected_priority = valid_backends_priorities[selected_index][1]

        # If the user specified --block-size (but not --attention-backend),
        # check whether that constraint precluded any higher-priority backends.
        if attn_selector_config.block_size is not None:
            excluded = [
                backend
                for backend, (priority, reasons) in all_invalid_reasons.items()
                if priority < selected_priority
                and reasons == ["block_size not supported"]
            ]
            if excluded:
                names = ", ".join(b.name for b in excluded)
                logger.warning(
                    "--block-size %d precluded higher-priority backend(s) "
                    "%s. Using %s instead, which may result in reduced "
                    "performance. Consider removing --block-size to "
                    "auto-select the optimal block size.",
                    attn_selector_config.block_size,
                    names,
                    selected_backend.name,
                )

360
361
362
363
364
        logger.info_once(
            "Using %s attention backend out of potential backends: %s.",
            selected_backend.name,
            "[" + ", ".join(f"'{b[0].name}'" for b in valid_backends_priorities) + "]",
            scope="local",
365
366
        )

367
        return selected_backend.get_path()
368

369
    @classmethod
370
    def get_supported_vit_attn_backends(cls) -> list[AttentionBackendEnum]:
371
372
373
374
375
376
377
378
379
380
381
382
383
384
        if cls.has_device_capability(80):
            return [
                AttentionBackendEnum.FLASH_ATTN,
                AttentionBackendEnum.TRITON_ATTN,
                AttentionBackendEnum.TORCH_SDPA,
                AttentionBackendEnum.FLASHINFER,
            ]
        else:
            return [
                AttentionBackendEnum.FLASH_ATTN,
                AttentionBackendEnum.TORCH_SDPA,
                AttentionBackendEnum.TRITON_ATTN,
                AttentionBackendEnum.FLASHINFER,
            ]
385
386
387
388
389
390

    @classmethod
    def get_vit_attn_backend(
        cls,
        head_size: int,
        dtype: torch.dtype,
391
392
        backend: AttentionBackendEnum | None = None,
    ) -> AttentionBackendEnum:
393
394
395
396
397
398
399
400
        if backend is not None:
            assert backend in cls.get_supported_vit_attn_backends(), (
                f"Backend {backend} is not supported for vit attention. "
                f"Supported backends are: {cls.get_supported_vit_attn_backends()}"
            )
            logger.info_once(f"Using backend {backend} for vit attention")
            return backend

401
402
403
        cc = cls.get_device_capability()
        for vit_attn_backend in cls.get_supported_vit_attn_backends():
            if vit_attn_backend == AttentionBackendEnum.TORCH_SDPA:
404
                return vit_attn_backend
405
            try:
406
407
                backend_class = vit_attn_backend.get_class()
                is_backend_supported = backend_class.supports_head_size(
408
                    head_size
409
410
411
412
413
414
415
416
                ) and backend_class.supports_dtype(dtype)
                if cc is not None:
                    is_backend_supported = (
                        is_backend_supported
                        and backend_class.supports_compute_capability(cc)
                    )
                if is_backend_supported:
                    logger.info_once(
417
418
                        f"Using backend {vit_attn_backend} for vit attention",
                        scope="local",
419
420
                    )
                    return vit_attn_backend
421
422
423
424
425
            except ImportError:
                pass

        return AttentionBackendEnum.TORCH_SDPA

426
427
428
429
    @classmethod
    def get_punica_wrapper(cls) -> str:
        return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU"

430
431
    @classmethod
    def get_device_communicator_cls(cls) -> str:
432
433
434
        return (
            "vllm.distributed.device_communicators.cuda_communicator.CudaCommunicator"  # noqa
        )
435

436
437
438
439
    @classmethod
    def supports_fp8(cls) -> bool:
        return cls.has_device_capability(89)

440
441
442
443
    @classmethod
    def use_custom_allreduce(cls) -> bool:
        return True

444
445
446
447
    @classmethod
    def opaque_attention_op(cls) -> bool:
        return True

448
    @classmethod
449
450
    def get_static_graph_wrapper_cls(cls) -> str:
        return "vllm.compilation.cuda_graph.CUDAGraphWrapper"
451

452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
    @classmethod
    def stateless_init_device_torch_dist_pg(
        cls,
        backend: str,
        prefix_store: PrefixStore,
        group_rank: int,
        group_size: int,
        timeout: timedelta,
    ) -> ProcessGroup:
        assert is_nccl_available()
        pg: ProcessGroup = ProcessGroup(
            prefix_store,
            group_rank,
            group_size,
        )
        from torch.distributed.distributed_c10d import ProcessGroupNCCL

        backend_options = ProcessGroupNCCL.Options()
        backend_options._timeout = timeout

        backend_class = ProcessGroupNCCL(
            prefix_store, group_rank, group_size, backend_options
        )
        backend_type = ProcessGroup.BackendType.NCCL
        device = torch.device("cuda")
        pg._set_default_backend(backend_type)
        backend_class._set_sequence_number_for_group()

        pg._register_backend(device, backend_type, backend_class)
        return pg

483
484
    @classmethod
    def device_count(cls) -> int:
485
        return _cuda_device_count_stateless(envs.CUDA_VISIBLE_DEVICES)
486

487
    @classmethod
488
489
    def check_if_supports_dtype(cls, dtype: torch.dtype):
        if dtype == torch.bfloat16:  # noqa: SIM102
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
            if not cls.has_device_capability(80):
                capability = cls.get_device_capability()
                gpu_name = cls.get_device_name()

                if capability is None:
                    compute_str = "does not have a compute capability"
                else:
                    version_str = capability.as_version_str()
                    compute_str = f"has compute capability {version_str}"

                raise ValueError(
                    "Bfloat16 is only supported on GPUs "
                    "with compute capability of at least 8.0. "
                    f"Your {gpu_name} GPU {compute_str}. "
                    "You can use float16 instead by explicitly setting the "
505
506
                    "`dtype` flag in CLI, for example: --dtype=half."
                )
507

508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
    @classmethod
    def insert_blocks_to_device(
        cls,
        src_cache: torch.Tensor,
        dst_cache: torch.Tensor,
        src_block_indices: torch.Tensor,
        dst_block_indices: torch.Tensor,
    ) -> None:
        """Copy blocks from src_cache to dst_cache on GPU."""
        _src_cache = src_cache[:, src_block_indices]
        dst_cache[:, dst_block_indices] = _src_cache.to(dst_cache.device)

    @classmethod
    def swap_out_blocks_to_host(
        cls,
        src_cache: torch.Tensor,
        dst_cache: torch.Tensor,
        src_block_indices: torch.Tensor,
        dst_block_indices: torch.Tensor,
    ) -> None:
        """Copy blocks from GPU to host (CPU)."""
        _src_cache = src_cache[:, src_block_indices]
        dst_cache[:, dst_block_indices] = _src_cache.cpu()

532
533
534
535
    @classmethod
    def support_hybrid_kv_cache(cls) -> bool:
        return True

536
537
538
539
    @classmethod
    def support_static_graph_mode(cls) -> bool:
        return True

540
541
542
543
544
    @classmethod
    def support_deep_gemm(cls) -> bool:
        """Currently, only Hopper and Blackwell GPUs are supported."""
        return cls.is_device_capability(90) or cls.is_device_capability_family(100)

545
    @classmethod
546
    def num_compute_units(cls, device_id: int = 0) -> int:
547
548
        return torch.cuda.get_device_properties(device_id).multi_processor_count

549
550
551
552
    @classmethod
    def use_custom_op_collectives(cls) -> bool:
        return True

553

554
555
556
557
558
# NVML utils
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
# all the related functions work on real physical device ids.
# the major benefit of using NVML is that it will not initialize CUDA
class NvmlCudaPlatform(CudaPlatformBase):
559
    @classmethod
560
    @cache
561
    @with_nvml_context
562
    def get_device_capability(cls, device_id: int = 0) -> DeviceCapability | None:
563
        try:
564
            physical_device_id = cls.device_id_to_physical_device_id(device_id)
565
566
567
568
569
570
571
572
573
574
            handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
            major, minor = pynvml.nvmlDeviceGetCudaComputeCapability(handle)
            return DeviceCapability(major=major, minor=minor)
        except RuntimeError:
            return None

    @classmethod
    @with_nvml_context
    def has_device_capability(
        cls,
575
        capability: tuple[int, int] | int,
576
577
578
579
580
581
        device_id: int = 0,
    ) -> bool:
        try:
            return super().has_device_capability(capability, device_id)
        except RuntimeError:
            return False
582

583
    @classmethod
584
    @with_nvml_context
585
    def get_device_name(cls, device_id: int = 0) -> str:
586
        physical_device_id = cls.device_id_to_physical_device_id(device_id)
587
        return cls._get_physical_device_name(physical_device_id)
588

589
590
591
    @classmethod
    @with_nvml_context
    def get_device_uuid(cls, device_id: int = 0) -> str:
592
        physical_device_id = cls.device_id_to_physical_device_id(device_id)
593
594
595
        handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
        return pynvml.nvmlDeviceGetUUID(handle)

596
    @classmethod
597
    @with_nvml_context
598
    def get_device_total_memory(cls, device_id: int = 0) -> int:
599
        physical_device_id = cls.device_id_to_physical_device_id(device_id)
600
601
        handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
        return int(pynvml.nvmlDeviceGetMemoryInfo(handle).total)
602

603
    @classmethod
604
    @with_nvml_context
605
    def is_fully_connected(cls, physical_device_ids: list[int]) -> bool:
606
607
608
        """
        query if the set of gpus are fully connected by nvlink (1 hop)
        """
609
        handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in physical_device_ids]
610
611
612
613
614
        for i, handle in enumerate(handles):
            for j, peer_handle in enumerate(handles):
                if i < j:
                    try:
                        p2p_status = pynvml.nvmlDeviceGetP2PStatus(
615
616
617
618
                            handle,
                            peer_handle,
                            pynvml.NVML_P2P_CAPS_INDEX_NVLINK,
                        )
619
620
                        if p2p_status != pynvml.NVML_P2P_STATUS_OK:
                            return False
621
622
                    except pynvml.NVMLError:
                        logger.exception(
623
                            "NVLink detection failed. This is normal if"
624
625
                            " your machine has no NVLink equipped."
                        )
626
627
                        return False
        return True
628
629

    @classmethod
630
631
632
633
634
635
636
637
638
    def _get_physical_device_name(cls, device_id: int = 0) -> str:
        handle = pynvml.nvmlDeviceGetHandleByIndex(device_id)
        return pynvml.nvmlDeviceGetName(handle)

    @classmethod
    @with_nvml_context
    def log_warnings(cls):
        device_ids: int = pynvml.nvmlDeviceGetCount()
        if device_ids > 1:
639
640
641
642
643
            device_names = [cls._get_physical_device_name(i) for i in range(device_ids)]
            if (
                len(set(device_names)) > 1
                and os.environ.get("CUDA_DEVICE_ORDER") != "PCI_BUS_ID"
            ):
644
                logger.warning(
645
                    "Detected different devices in the system: %s. Please"
646
647
                    " make sure to set `CUDA_DEVICE_ORDER=PCI_BUS_ID` to "
                    "avoid unexpected behavior.",
648
                    ", ".join(device_names),
649
650
651
652
653
                )


class NonNvmlCudaPlatform(CudaPlatformBase):
    @classmethod
654
    @cache
655
656
657
658
659
660
661
662
663
664
665
666
667
668
    def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:
        major, minor = torch.cuda.get_device_capability(device_id)
        return DeviceCapability(major=major, minor=minor)

    @classmethod
    def get_device_name(cls, device_id: int = 0) -> str:
        return torch.cuda.get_device_name(device_id)

    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        device_props = torch.cuda.get_device_properties(device_id)
        return device_props.total_memory

    @classmethod
669
    def is_fully_connected(cls, physical_device_ids: list[int]) -> bool:
670
671
        logger.exception(
            "NVLink detection not possible, as context support was"
672
673
            " not found. Assuming no NVLink available."
        )
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
        return False


# Autodetect either NVML-enabled or non-NVML platform
# based on whether NVML is available.
nvml_available = False
try:
    try:
        pynvml.nvmlInit()
        nvml_available = True
    except Exception:
        # On Jetson, NVML is not supported.
        nvml_available = False
finally:
    if nvml_available:
        pynvml.nvmlShutdown()

CudaPlatform = NvmlCudaPlatform if nvml_available else NonNvmlCudaPlatform

693
CudaPlatform.log_warnings()