"vllm/vscode:/vscode.git/clone" did not exist on "ff38f0a32c5f7a2dd1aa3bee3806d957adacb9bd"
cuda.py 25 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
"""Code inside this file can safely assume cuda platform, e.g. importing
pynvml. However, it should not initialize cuda context.
"""

7
import os
8
from collections.abc import Callable
9
from functools import cache, wraps
10
from typing import TYPE_CHECKING, TypeVar
11

12
import torch
13
from typing_extensions import ParamSpec
14

15
16
# import custom ops, trigger op registration
import vllm._C  # noqa
17
from vllm.logger import init_logger
18
from vllm.utils.import_utils import import_pynvml
19
from vllm.utils.torch_utils import cuda_device_count_stateless
20
from vllm.v1.attention.backends.registry import AttentionBackendEnum
21

22
from .interface import DeviceCapability, Platform, PlatformEnum
23

24
if TYPE_CHECKING:
25
    from vllm.config import VllmConfig
26
    from vllm.config.cache import CacheDType
27
    from vllm.v1.attention.selector import AttentionSelectorConfig
28
else:
29
30
    VllmConfig = None
    CacheDType = None
31

32
33
logger = init_logger(__name__)

34
35
36
_P = ParamSpec("_P")
_R = TypeVar("_R")

37
pynvml = import_pynvml()
38

39
40
41
42
# pytorch 2.5 uses cudnn sdpa by default, which will cause crash on some models
# see https://github.com/huggingface/diffusers/issues/9704 for details
torch.backends.cuda.enable_cudnn_sdp(False)

43

44
45
46
47
@cache
def _get_backend_priorities(
    use_mla: bool,
    device_capability: DeviceCapability,
48
    num_heads: int | None = None,
49
50
51
52
) -> list[AttentionBackendEnum]:
    """Get backend priorities with lazy import to avoid circular dependency."""
    if use_mla:
        if device_capability.major == 10:
53
54
55
56
57
58
59
60
61
62
63
            # Prefer FlashInfer at low head counts (FlashMLA uses padding)
            if num_heads is not None and num_heads <= 16:
                sparse_backends = [
                    AttentionBackendEnum.FLASHINFER_MLA_SPARSE,
                    AttentionBackendEnum.FLASHMLA_SPARSE,
                ]
            else:
                sparse_backends = [
                    AttentionBackendEnum.FLASHMLA_SPARSE,
                    AttentionBackendEnum.FLASHINFER_MLA_SPARSE,
                ]
64
            return [
65
                AttentionBackendEnum.FLASHINFER_MLA,
66
                AttentionBackendEnum.CUTLASS_MLA,
67
                AttentionBackendEnum.FLASH_ATTN_MLA,
68
                AttentionBackendEnum.FLASHMLA,
69
                AttentionBackendEnum.TRITON_MLA,
70
                *sparse_backends,
71
72
73
74
            ]
        else:
            return [
                AttentionBackendEnum.FLASH_ATTN_MLA,
75
                AttentionBackendEnum.FLASHMLA,
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
                AttentionBackendEnum.FLASHINFER_MLA,
                AttentionBackendEnum.TRITON_MLA,
                AttentionBackendEnum.FLASHMLA_SPARSE,
            ]
    else:
        if device_capability.major == 10:
            return [
                AttentionBackendEnum.FLASHINFER,
                AttentionBackendEnum.FLASH_ATTN,
                AttentionBackendEnum.TRITON_ATTN,
                AttentionBackendEnum.FLEX_ATTENTION,
            ]
        else:
            return [
                AttentionBackendEnum.FLASH_ATTN,
                AttentionBackendEnum.FLASHINFER,
                AttentionBackendEnum.TRITON_ATTN,
                AttentionBackendEnum.FLEX_ATTENTION,
            ]


97
def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]:
98
    @wraps(fn)
99
    def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
100
101
102
103
104
105
106
107
108
        pynvml.nvmlInit()
        try:
            return fn(*args, **kwargs)
        finally:
            pynvml.nvmlShutdown()

    return wrapper


109
110
class CudaPlatformBase(Platform):
    _enum = PlatformEnum.CUDA
111
    device_name: str = "cuda"
112
113
    device_type: str = "cuda"
    dispatch_key: str = "CUDA"
114
    ray_device_key: str = "GPU"
115
    dist_backend: str = "nccl"
116
    device_control_env_var: str = "CUDA_VISIBLE_DEVICES"
117
118
119
    ray_noset_device_env_vars: list[str] = [
        "RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES",
    ]
120

121
    @property
122
    def supported_dtypes(self) -> list[torch.dtype]:
123
124
125
        if self.has_device_capability(80):
            # Ampere and Hopper or later NVIDIA GPUs.
            return [torch.bfloat16, torch.float16, torch.float32]
126
        if self.has_device_capability(60):
127
128
129
130
131
132
            # Pascal, Volta and Turing NVIDIA GPUs, BF16 is not supported
            return [torch.float16, torch.float32]
        # Kepler and Maxwell NVIDIA GPUs, only FP32 is supported,
        # though vLLM doesn't support these GPUs.
        return [torch.float32]

133
134
135
136
137
    @classmethod
    def set_device(cls, device: torch.device) -> None:
        """
        Set the device for the current platform.
        """
138
        torch.cuda.set_device(device)
139
140
141
142
143
        # With this trick we can force the device to be set eagerly
        # see https://github.com/pytorch/pytorch/issues/155668
        # for why and when it is needed
        _ = torch.zeros(1, device=device)

144
    @classmethod
145
    def get_device_capability(cls, device_id: int = 0) -> DeviceCapability | None:
146
        raise NotImplementedError
147

148
149
150
    @classmethod
    def get_device_name(cls, device_id: int = 0) -> str:
        raise NotImplementedError
151

152
153
154
    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        raise NotImplementedError
155

156
    @classmethod
157
    def is_fully_connected(cls, device_ids: list[int]) -> bool:
158
        raise NotImplementedError
159

160
161
162
    @classmethod
    def log_warnings(cls):
        pass
163

164
    @classmethod
165
    def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
166
167
        from vllm.v1.attention.backends.registry import AttentionBackendEnum

168
        parallel_config = vllm_config.parallel_config
169
        model_config = vllm_config.model_config
170

171
        if parallel_config.worker_cls == "auto":
172
            parallel_config.worker_cls = "vllm.v1.worker.gpu_worker.Worker"
173

174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
        cache_config = vllm_config.cache_config
        if cache_config and cache_config.block_size is None:
            cache_config.block_size = 16

        # TODO(lucas): handle this more gracefully
        # Note: model_config may be None during testing
        # Note: block_size is initialized in
        # HybridAttentionMambaModelConfig.verify_and_update_config
        # for models with both attention and mamba,
        # and doesn't need to be reinitialized here
        if (
            model_config is not None
            and model_config.use_mla
            and cache_config.block_size is not None
        ):
            use_sparse = hasattr(vllm_config.model_config.hf_config, "index_topk")
            # If `--attention-config.backend` is not set and we are using MLA,
            # then we default to FlashMLA backend for non-blackwell GPUs,
            # else we default to CutlassMLA. For each case, we force the
            # required block_size.
            use_flashmla = False
            use_cutlass_mla = False
            use_flashinfer_mla = False
            use_flashmla_sparse = False
            use_flashinfer_mla_sparse = False

            from vllm.v1.attention.ops.flashmla import is_flashmla_dense_supported

            if vllm_config.attention_config.backend is None:
                # Default case
                hf_text_config = model_config.hf_text_config
                qk_nope_head_dim = getattr(hf_text_config, "qk_nope_head_dim", 1)
                if (
                    cls.is_device_capability_family(100)
                    and not use_sparse
                    and qk_nope_head_dim == 128
                ):
                    # Blackwell => Force FlashInfer MLA (unless sparse, i.e. DSv3.2)
                    # and only if qk_nope_head_dim == 128 (kernel constraint)
                    use_flashinfer_mla = True
                    # Set the backend in AttentionConfig so it's used during
                    # backend selection
                    vllm_config.attention_config.backend = (
                        AttentionBackendEnum.FLASHINFER_MLA
                    )
                elif cls.is_device_capability_family(100) and not use_sparse:
                    # Fall back to CUTLASS_MLA as 2nd priority on Blackwell
                    use_cutlass_mla = True
                elif is_flashmla_dense_supported()[0]:
                    # Non-Blackwell with FlashMLA support
                    use_flashmla = True
                else:
                    # Fallback: will use Triton MLA or other compatible backend
                    pass
            else:
                # Forced case
                backend = vllm_config.attention_config.backend
                use_flashmla = backend == AttentionBackendEnum.FLASHMLA
                use_cutlass_mla = backend == AttentionBackendEnum.CUTLASS_MLA
                use_flashinfer_mla = backend == AttentionBackendEnum.FLASHINFER_MLA
                use_flashmla_sparse = backend == AttentionBackendEnum.FLASHMLA_SPARSE
                use_flashinfer_mla_sparse = (
                    backend == AttentionBackendEnum.FLASHINFER_MLA_SPARSE
                )

            if (
                use_flashmla
                and is_flashmla_dense_supported()[0]
                and cache_config.block_size % 64 != 0
            ):
                cache_config.block_size = 64
                logger.info("Forcing kv cache block size to 64 for FlashMLA backend.")

            if use_cutlass_mla and cache_config.block_size % 128 != 0:
                cache_config.block_size = 128
                logger.info(
                    "Forcing kv cache block size to 128 for CUTLASS_MLA backend."
                )

            if (
                use_flashinfer_mla
                and cache_config.block_size != 32
                and cache_config.block_size % 64 != 0
            ):
                cache_config.block_size = 64
                logger.info(
                    "Forcing kv cache block size to 64 for FlashInferMLA backend."
                )

            if use_sparse:
                if not (use_flashmla_sparse or use_flashinfer_mla_sparse):
                    use_flashmla_sparse = True

                if use_flashmla_sparse and cache_config.block_size != 64:
                    cache_config.block_size = 64
                    logger.info(
                        "Forcing kv cache block size to 64 for FlashMLASparse backend."
                    )
                elif use_flashinfer_mla_sparse and cache_config.block_size not in (
                    32,
                    64,
                ):
                    cache_config.block_size = 64
                    logger.info(
                        "Forcing kv cache block size to 64 for FlashInferMLASparse "
                        "backend."
                    )

282
283
284
285
286
287
288
289
290
291
292
293
294
295
        scheduler_config = vllm_config.scheduler_config
        # Note: model_config may be None during testing
        if (
            model_config is not None
            and model_config.is_mm_prefix_lm
            and scheduler_config.is_multimodal_model
            and not scheduler_config.disable_chunked_mm_input
        ):
            logger.warning(
                "Forcing --disable_chunked_mm_input for models "
                "with multimodal-bidirectional attention."
            )
            scheduler_config.disable_chunked_mm_input = True

296
    @classmethod
297
    def get_current_memory_usage(
298
        cls, device: torch.types.Device | None = None
299
    ) -> float:
300
        torch.cuda.empty_cache()
301
302
303
        torch.cuda.reset_peak_memory_stats(device)
        return torch.cuda.max_memory_allocated(device)

304
    @classmethod
305
    def get_valid_backends(
306
        cls,
307
308
        device_capability: DeviceCapability,
        attn_selector_config: "AttentionSelectorConfig",
309
        num_heads: int | None = None,
310
311
    ) -> tuple[
        list[tuple["AttentionBackendEnum", int]],
312
        dict["AttentionBackendEnum", list[str]],
313
314
    ]:
        valid_backends_priorities = []
315
        invalid_reasons = {}
316

317
        backend_priorities = _get_backend_priorities(
318
319
320
            attn_selector_config.use_mla,
            device_capability,
            num_heads,
321
        )
322
323
324
325
        for priority, backend in enumerate(backend_priorities):
            try:
                backend_class = backend.get_class()
                invalid_reasons_i = backend_class.validate_configuration(
326
327
                    device_capability=device_capability,
                    **attn_selector_config._asdict(),
328
                )
329
330
331
            except ImportError:
                invalid_reasons_i = ["ImportError"]
            if invalid_reasons_i:
332
                invalid_reasons[backend] = invalid_reasons_i
333
334
            else:
                valid_backends_priorities.append((backend, priority))
335

336
        return valid_backends_priorities, invalid_reasons
337

338
    @classmethod
339
    def get_attn_backend_cls(
340
        cls,
341
        selected_backend: "AttentionBackendEnum",
342
        attn_selector_config: "AttentionSelectorConfig",
343
        num_heads: int | None = None,
344
345
346
347
348
    ) -> str:
        device_capability = cls.get_device_capability()
        assert device_capability is not None

        attn_selector_config = attn_selector_config._replace(block_size=None)
349
350
351
352
        # First try checking just the selected backend, if there is one.
        if selected_backend is not None:
            try:
                backend_class = selected_backend.get_class()
353
                invalid_reasons = backend_class.validate_configuration(
354
355
                    device_capability=device_capability,
                    **attn_selector_config._asdict(),
356
                )
357
            except ImportError:
358
359
360
361
362
363
364
365
366
                invalid_reasons = ["ImportError"]
            if invalid_reasons:
                raise ValueError(
                    f"Selected backend {selected_backend} is not valid for "
                    f"this configuration. Reason: {invalid_reasons}"
                )
            else:
                logger.info("Using %s backend.", selected_backend)
                return selected_backend.get_path()
367

368
369
        # No selected backend or the selected backend is invalid,
        # so we try finding a valid backend.
370
        valid_backends_priorities, invalid_reasons = cls.get_valid_backends(
371
372
            device_capability=device_capability,
            attn_selector_config=attn_selector_config,
373
            num_heads=num_heads,
374
        )
375
376
377
378
379
380
381
382
383
384
385
386
387
        reasons_str = (
            "{"
            + ", ".join(
                f"{backend.name}: [{', '.join(reasons)}]"
                for backend, reasons in invalid_reasons.items()
            )
            + "}"
        )
        config_str = attn_selector_config.__repr__()
        logger.debug_once(
            f"Some attention backends are not valid for {cls.device_name} with "
            f"{config_str}. Reasons: {reasons_str}."
        )
388
        if len(valid_backends_priorities) == 0:
389
390
391
392
            raise ValueError(
                f"No valid attention backend found for {cls.device_name} "
                f"with {config_str}. Reasons: {reasons_str}."
            )
393

394
395
396
397
398
399
400
401
402
403
404
405
406
        # We have found some valid backends. Select the one with the
        # highest priority.
        sorted_indices = sorted(
            range(len(valid_backends_priorities)),
            key=lambda i: valid_backends_priorities[i][1],
        )
        selected_index = sorted_indices[0]
        selected_backend = valid_backends_priorities[selected_index][0]
        logger.info_once(
            "Using %s attention backend out of potential backends: %s.",
            selected_backend.name,
            "[" + ", ".join(f"'{b[0].name}'" for b in valid_backends_priorities) + "]",
            scope="local",
407
408
        )

409
        return selected_backend.get_path()
410

411
412
413
414
    @classmethod
    def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]:
        return [
            AttentionBackendEnum.FLASH_ATTN,
415
416
            AttentionBackendEnum.TRITON_ATTN,
            AttentionBackendEnum.TORCH_SDPA,
417
            AttentionBackendEnum.FLASHINFER,
418
419
420
421
422
423
424
        ]

    @classmethod
    def get_vit_attn_backend(
        cls,
        head_size: int,
        dtype: torch.dtype,
425
        backend: "AttentionBackendEnum | None" = None,
426
427
428
429
430
431
432
433
434
    ) -> "AttentionBackendEnum":
        if backend is not None:
            assert backend in cls.get_supported_vit_attn_backends(), (
                f"Backend {backend} is not supported for vit attention. "
                f"Supported backends are: {cls.get_supported_vit_attn_backends()}"
            )
            logger.info_once(f"Using backend {backend} for vit attention")
            return backend

435
436
437
438
        cc = cls.get_device_capability()
        for vit_attn_backend in cls.get_supported_vit_attn_backends():
            if vit_attn_backend == AttentionBackendEnum.TORCH_SDPA:
                continue
439
            try:
440
441
                backend_class = vit_attn_backend.get_class()
                is_backend_supported = backend_class.supports_head_size(
442
                    head_size
443
444
445
446
447
448
449
450
451
452
453
                ) and backend_class.supports_dtype(dtype)
                if cc is not None:
                    is_backend_supported = (
                        is_backend_supported
                        and backend_class.supports_compute_capability(cc)
                    )
                if is_backend_supported:
                    logger.info_once(
                        f"Using backend {vit_attn_backend} for vit attention"
                    )
                    return vit_attn_backend
454
455
456
457
458
            except ImportError:
                pass

        return AttentionBackendEnum.TORCH_SDPA

459
460
461
462
    @classmethod
    def get_punica_wrapper(cls) -> str:
        return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU"

463
464
    @classmethod
    def get_device_communicator_cls(cls) -> str:
465
466
467
        return (
            "vllm.distributed.device_communicators.cuda_communicator.CudaCommunicator"  # noqa
        )
468

469
470
471
472
    @classmethod
    def supports_fp8(cls) -> bool:
        return cls.has_device_capability(89)

473
474
475
476
    @classmethod
    def use_custom_allreduce(cls) -> bool:
        return True

477
478
479
480
    @classmethod
    def opaque_attention_op(cls) -> bool:
        return True

481
    @classmethod
482
483
    def get_static_graph_wrapper_cls(cls) -> str:
        return "vllm.compilation.cuda_graph.CUDAGraphWrapper"
484

485
486
487
488
    @classmethod
    def device_count(cls) -> int:
        return cuda_device_count_stateless()

489
    @classmethod
490
491
    def check_if_supports_dtype(cls, dtype: torch.dtype):
        if dtype == torch.bfloat16:  # noqa: SIM102
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
            if not cls.has_device_capability(80):
                capability = cls.get_device_capability()
                gpu_name = cls.get_device_name()

                if capability is None:
                    compute_str = "does not have a compute capability"
                else:
                    version_str = capability.as_version_str()
                    compute_str = f"has compute capability {version_str}"

                raise ValueError(
                    "Bfloat16 is only supported on GPUs "
                    "with compute capability of at least 8.0. "
                    f"Your {gpu_name} GPU {compute_str}. "
                    "You can use float16 instead by explicitly setting the "
507
508
                    "`dtype` flag in CLI, for example: --dtype=half."
                )
509

510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
    @classmethod
    def insert_blocks_to_device(
        cls,
        src_cache: torch.Tensor,
        dst_cache: torch.Tensor,
        src_block_indices: torch.Tensor,
        dst_block_indices: torch.Tensor,
    ) -> None:
        """Copy blocks from src_cache to dst_cache on GPU."""
        _src_cache = src_cache[:, src_block_indices]
        dst_cache[:, dst_block_indices] = _src_cache.to(dst_cache.device)

    @classmethod
    def swap_out_blocks_to_host(
        cls,
        src_cache: torch.Tensor,
        dst_cache: torch.Tensor,
        src_block_indices: torch.Tensor,
        dst_block_indices: torch.Tensor,
    ) -> None:
        """Copy blocks from GPU to host (CPU)."""
        _src_cache = src_cache[:, src_block_indices]
        dst_cache[:, dst_block_indices] = _src_cache.cpu()

534
535
536
537
    @classmethod
    def support_hybrid_kv_cache(cls) -> bool:
        return True

538
539
540
541
    @classmethod
    def support_static_graph_mode(cls) -> bool:
        return True

542
543
544
545
    @classmethod
    def num_compute_units(cls, device_id=0):
        return torch.cuda.get_device_properties(device_id).multi_processor_count

546

547
548
549
550
551
# NVML utils
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
# all the related functions work on real physical device ids.
# the major benefit of using NVML is that it will not initialize CUDA
class NvmlCudaPlatform(CudaPlatformBase):
552
    @classmethod
553
    @cache
554
    @with_nvml_context
555
    def get_device_capability(cls, device_id: int = 0) -> DeviceCapability | None:
556
        try:
557
            physical_device_id = cls.device_id_to_physical_device_id(device_id)
558
559
560
561
562
563
564
565
566
567
            handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
            major, minor = pynvml.nvmlDeviceGetCudaComputeCapability(handle)
            return DeviceCapability(major=major, minor=minor)
        except RuntimeError:
            return None

    @classmethod
    @with_nvml_context
    def has_device_capability(
        cls,
568
        capability: tuple[int, int] | int,
569
570
571
572
573
574
        device_id: int = 0,
    ) -> bool:
        try:
            return super().has_device_capability(capability, device_id)
        except RuntimeError:
            return False
575

576
    @classmethod
577
    @with_nvml_context
578
    def get_device_name(cls, device_id: int = 0) -> str:
579
        physical_device_id = cls.device_id_to_physical_device_id(device_id)
580
        return cls._get_physical_device_name(physical_device_id)
581

582
583
584
    @classmethod
    @with_nvml_context
    def get_device_uuid(cls, device_id: int = 0) -> str:
585
        physical_device_id = cls.device_id_to_physical_device_id(device_id)
586
587
588
        handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
        return pynvml.nvmlDeviceGetUUID(handle)

589
    @classmethod
590
    @with_nvml_context
591
    def get_device_total_memory(cls, device_id: int = 0) -> int:
592
        physical_device_id = cls.device_id_to_physical_device_id(device_id)
593
594
        handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
        return int(pynvml.nvmlDeviceGetMemoryInfo(handle).total)
595

596
    @classmethod
597
    @with_nvml_context
598
    def is_fully_connected(cls, physical_device_ids: list[int]) -> bool:
599
600
601
        """
        query if the set of gpus are fully connected by nvlink (1 hop)
        """
602
        handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in physical_device_ids]
603
604
605
606
607
        for i, handle in enumerate(handles):
            for j, peer_handle in enumerate(handles):
                if i < j:
                    try:
                        p2p_status = pynvml.nvmlDeviceGetP2PStatus(
608
609
610
611
                            handle,
                            peer_handle,
                            pynvml.NVML_P2P_CAPS_INDEX_NVLINK,
                        )
612
613
                        if p2p_status != pynvml.NVML_P2P_STATUS_OK:
                            return False
614
615
                    except pynvml.NVMLError:
                        logger.exception(
616
                            "NVLink detection failed. This is normal if"
617
618
                            " your machine has no NVLink equipped."
                        )
619
620
                        return False
        return True
621
622

    @classmethod
623
624
625
626
627
628
629
630
631
    def _get_physical_device_name(cls, device_id: int = 0) -> str:
        handle = pynvml.nvmlDeviceGetHandleByIndex(device_id)
        return pynvml.nvmlDeviceGetName(handle)

    @classmethod
    @with_nvml_context
    def log_warnings(cls):
        device_ids: int = pynvml.nvmlDeviceGetCount()
        if device_ids > 1:
632
633
634
635
636
            device_names = [cls._get_physical_device_name(i) for i in range(device_ids)]
            if (
                len(set(device_names)) > 1
                and os.environ.get("CUDA_DEVICE_ORDER") != "PCI_BUS_ID"
            ):
637
                logger.warning(
638
                    "Detected different devices in the system: %s. Please"
639
640
                    " make sure to set `CUDA_DEVICE_ORDER=PCI_BUS_ID` to "
                    "avoid unexpected behavior.",
641
                    ", ".join(device_names),
642
643
644
645
646
                )


class NonNvmlCudaPlatform(CudaPlatformBase):
    @classmethod
647
    @cache
648
649
650
651
652
653
654
655
656
657
658
659
660
661
    def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:
        major, minor = torch.cuda.get_device_capability(device_id)
        return DeviceCapability(major=major, minor=minor)

    @classmethod
    def get_device_name(cls, device_id: int = 0) -> str:
        return torch.cuda.get_device_name(device_id)

    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        device_props = torch.cuda.get_device_properties(device_id)
        return device_props.total_memory

    @classmethod
662
    def is_fully_connected(cls, physical_device_ids: list[int]) -> bool:
663
664
        logger.exception(
            "NVLink detection not possible, as context support was"
665
666
            " not found. Assuming no NVLink available."
        )
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
        return False


# Autodetect either NVML-enabled or non-NVML platform
# based on whether NVML is available.
nvml_available = False
try:
    try:
        pynvml.nvmlInit()
        nvml_available = True
    except Exception:
        # On Jetson, NVML is not supported.
        nvml_available = False
finally:
    if nvml_available:
        pynvml.nvmlShutdown()

CudaPlatform = NvmlCudaPlatform if nvml_available else NonNvmlCudaPlatform

686
CudaPlatform.log_warnings()