"vllm/vscode:/vscode.git/clone" did not exist on "ed3aeb25a4cf833027ce937c5fdfe50371b7fabd"
cuda.py 25.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
"""Code inside this file can safely assume cuda platform, e.g. importing
pynvml. However, it should not initialize cuda context.
"""

7
import os
8
from datetime import timedelta
9
from functools import cache, wraps
10
from typing import TYPE_CHECKING, Callable, Optional, TypeVar, Union
11

12
import torch
13
14
from torch.distributed import PrefixStore, ProcessGroup
from torch.distributed.distributed_c10d import is_nccl_available
15
from typing_extensions import ParamSpec
16

17
18
# import custom ops, trigger op registration
import vllm._C  # noqa
19
import vllm.envs as envs
20
from vllm.logger import init_logger
21
from vllm.utils import cuda_device_count_stateless, import_pynvml
22

23
from .interface import DeviceCapability, Platform, PlatformEnum, _Backend
24

25
if TYPE_CHECKING:
26
    from vllm.config import ModelConfig, VllmConfig
27

28
29
logger = init_logger(__name__)

30
31
32
_P = ParamSpec("_P")
_R = TypeVar("_R")

33
pynvml = import_pynvml()
34

35
36
37
38
# pytorch 2.5 uses cudnn sdpa by default, which will cause crash on some models
# see https://github.com/huggingface/diffusers/issues/9704 for details
torch.backends.cuda.enable_cudnn_sdp(False)

39

40
def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]:
41
42

    @wraps(fn)
43
    def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
44
45
46
47
48
49
50
51
52
        pynvml.nvmlInit()
        try:
            return fn(*args, **kwargs)
        finally:
            pynvml.nvmlShutdown()

    return wrapper


53
54
class CudaPlatformBase(Platform):
    _enum = PlatformEnum.CUDA
55
    device_name: str = "cuda"
56
57
    device_type: str = "cuda"
    dispatch_key: str = "CUDA"
58
    ray_device_key: str = "GPU"
59
    dist_backend: str = "nccl"
60
    device_control_env_var: str = "CUDA_VISIBLE_DEVICES"
61

62
    @property
63
    def supported_dtypes(self) -> list[torch.dtype]:
64
65
66
67
68
69
70
71
72
73
74
        if self.has_device_capability(80):
            # Ampere and Hopper or later NVIDIA GPUs.
            return [torch.bfloat16, torch.float16, torch.float32]
        elif (not self.has_device_capability(80)
              ) and self.has_device_capability(60):
            # Pascal, Volta and Turing NVIDIA GPUs, BF16 is not supported
            return [torch.float16, torch.float32]
        # Kepler and Maxwell NVIDIA GPUs, only FP32 is supported,
        # though vLLM doesn't support these GPUs.
        return [torch.float32]

75
76
77
78
79
    @classmethod
    def set_device(cls, device: torch.device) -> None:
        """
        Set the device for the current platform.
        """
80
        torch.cuda.set_device(device)
81
82
83
84
85
        # With this trick we can force the device to be set eagerly
        # see https://github.com/pytorch/pytorch/issues/155668
        # for why and when it is needed
        _ = torch.zeros(1, device=device)

86
    @classmethod
87
88
89
    def get_device_capability(cls,
                              device_id: int = 0
                              ) -> Optional[DeviceCapability]:
90
        raise NotImplementedError
91

92
93
94
    @classmethod
    def get_device_name(cls, device_id: int = 0) -> str:
        raise NotImplementedError
95

96
97
98
    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        raise NotImplementedError
99

100
101
    @classmethod
    def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
102
        if enforce_eager and not envs.VLLM_USE_V1:
103
104
105
106
107
108
109
            logger.warning(
                "To see benefits of async output processing, enable CUDA "
                "graph. Since, enforce-eager is enabled, async output "
                "processor cannot be used")
            return False
        return True

110
    @classmethod
111
    def is_fully_connected(cls, device_ids: list[int]) -> bool:
112
        raise NotImplementedError
113

114
115
116
    @classmethod
    def log_warnings(cls):
        pass
117

118
    @classmethod
119
    def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
120
121
        parallel_config = vllm_config.parallel_config
        scheduler_config = vllm_config.scheduler_config
122
        model_config = vllm_config.model_config
123

124
125
        if parallel_config.worker_cls == "auto":
            if scheduler_config.is_multi_step:
126
                if envs.VLLM_USE_V1:
127
128
                    raise NotImplementedError(
                        "Multi-step scheduling is not supported (and not "
129
                        "needed) on vLLM V1. Please launch without "
130
                        "--num-scheduler-steps.")
131
132
133
                else:
                    parallel_config.worker_cls = \
                        "vllm.worker.multi_step_worker.MultiStepWorker"
134
            elif vllm_config.speculative_config:
135
136
137
138
                if not envs.VLLM_USE_V1:
                    raise NotImplementedError(
                        "Speculative decoding is not supported on vLLM V0.")
                parallel_config.worker_cls = "vllm.v1.worker.gpu_worker.Worker"
139
            else:
140
141
142
143
144
                if envs.VLLM_USE_V1:
                    parallel_config.worker_cls = \
                            "vllm.v1.worker.gpu_worker.Worker"
                else:
                    parallel_config.worker_cls = "vllm.worker.worker.Worker"
145

146
147
148
        cache_config = vllm_config.cache_config
        if cache_config and cache_config.block_size is None:
            cache_config.block_size = 16
149

150
        # TODO(lucas): handle this more gracefully
151
152
        # Note: model_config may be None during testing
        if model_config is not None and model_config.use_mla:
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
            # If `VLLM_ATTENTION_BACKEND` is not set and we are using MLA,
            # then we default to FlashMLA backend for non-blackwell GPUs,
            # else we default to CutlassMLA. For each case, we force the
            # required block_size.
            use_flashmla = False
            use_cutlass_mla = False

            if envs.VLLM_ATTENTION_BACKEND is None:
                # Default case
                if cls.is_device_capability(100):
                    # Blackwell => Force CutlassMLA.
                    use_cutlass_mla = True
                    envs.VLLM_ATTENTION_BACKEND = "CUTLASS_MLA_VLLM_V1"
                else:
                    # Not Blackwell
                    use_flashmla = True
            else:
                # Forced case
                use_flashmla = (envs.VLLM_ATTENTION_BACKEND == "FLASHMLA")
                use_cutlass_mla = (
                    envs.VLLM_ATTENTION_BACKEND == "CUTLASS_MLA_VLLM_V1")

175
            from vllm.attention.ops.flashmla import is_flashmla_supported
176
177
178
179
180
            if use_flashmla and is_flashmla_supported()[0] \
                and cache_config.block_size != 64:
                cache_config.block_size = 64
                logger.info(
                    "Forcing kv cache block size to 64 for FlashMLA backend.")
181

182
183
184
185
186
            if use_cutlass_mla and cache_config.block_size != 128:
                cache_config.block_size = 128
                logger.info("Forcing kv cache block size to 128 for "
                            "CUTLASS_MLA_VLLM_V1 backend.")

187
        compilation_config = vllm_config.compilation_config
188
189
        if (envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput"
                and parallel_config.data_parallel_size > 1
190
                and compilation_config.use_cudagraph):
191
192
193
194
195
196
            logger.info(
                "Data Parallel: Forcing enforce eager to be True since DP "
                "with DeepEP high-throughput kernels are not CUDA Graph "
                "compatible. The DeepEP low-latency kernels are CUDA Graph "
                "compatible. Set the all_to_all backend to deepep_low_latency "
                "to use those kernels instead.")
197
198
199
            compilation_config.use_cudagraph = False
            if model_config is not None:
                model_config.enforce_eager = True
200

201
202
203
204
    @classmethod
    def get_current_memory_usage(cls,
                                 device: Optional[torch.types.Device] = None
                                 ) -> float:
205
        torch.cuda.empty_cache()
206
207
208
        torch.cuda.reset_peak_memory_stats(device)
        return torch.cuda.max_memory_allocated(device)

209
210
    @classmethod
    def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
211
212
213
                             kv_cache_dtype, block_size, use_v1,
                             use_mla) -> str:
        if use_mla:
214
215
            # TODO(lucas): refactor to  be more concise
            #  we should probably consider factoring out V1 here
216
217
218
219
220
221
222
223
            if selected_backend == _Backend.CUTLASS_MLA_VLLM_V1:
                if use_v1:
                    logger.info_once("Using Cutlass MLA backend on V1 engine.")
                    return ("vllm.v1.attention.backends.mla."
                            "cutlass_mla.CutlassMLABackend")
                else:
                    logger.warning(
                        "Cutlass MLA backend is only supported on V1 engine")
224
225
226
227
228
229
230
231
232
            if selected_backend == _Backend.TRITON_MLA or block_size != 64:
                if use_v1:
                    logger.info_once("Using Triton MLA backend on V1 engine.")
                    return ("vllm.v1.attention.backends.mla."
                            "triton_mla.TritonMLABackend")
                else:
                    logger.info("Using Triton MLA backend.")
                    return "vllm.attention.backends.triton_mla.TritonMLABackend"
            else:
233
234
235
236
237
238
239
240
241
242
243
244
                from vllm.attention.backends.flashmla import (
                    is_flashmla_supported)
                if not is_flashmla_supported()[0]:
                    logger.warning(
                        "FlashMLA backend is not supported due to %s",
                        is_flashmla_supported()[1])
                elif block_size != 64:
                    logger.warning(
                        "FlashMLA backend is not supported for block size %d"
                        " (currently only supports block size 64).",
                        block_size)
                else:
245
                    if use_v1:
246
247
                        logger.info_once(
                            "Using FlashMLA backend on V1 engine.")
248
249
250
251
252
253
254
                        return ("vllm.v1.attention.backends.mla."
                                "flashmla.FlashMLABackend")
                    else:
                        logger.info("Using FlashMLA backend.")
                        return ("vllm.attention.backends."
                                "flashmla.FlashMLABackend")
        if use_v1:
255
256
257
258
259
            FLASHINFER_V1 = "vllm.v1.attention.backends.flashinfer.FlashInferBackend"  # noqa: E501
            FLEX_ATTENTION_V1 = "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend"  # noqa: E501
            TRITON_ATTN_VLLM_V1 = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend"  # noqa: E501
            FLASH_ATTN_V1 = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"  # noqa: E501

260
261
            if selected_backend == _Backend.FLASHINFER:
                logger.info_once("Using FlashInfer backend on V1 engine.")
262
263
264
265
                if cls.has_device_capability(100):
                    from vllm.v1.attention.backends.utils import (
                        set_kv_cache_layout)
                    set_kv_cache_layout("HND")
266
                return FLASHINFER_V1
267
            elif selected_backend == _Backend.FLEX_ATTENTION:
268
269
                logger.info_once("Using FlexAttention backend on V1 engine.")
                return FLEX_ATTENTION_V1
270
            elif selected_backend == _Backend.TRITON_ATTN_VLLM_V1:
271
                logger.info_once("Using Triton backend on V1 engine.")
272
                return TRITON_ATTN_VLLM_V1
273
274
            elif selected_backend == _Backend.FLASH_ATTN:
                logger.info_once("Using Flash Attention backend on V1 engine.")
275
276
                return FLASH_ATTN_V1

277
            from vllm.attention.selector import is_attn_backend_supported
278
279

            # Default backends for V1 engine
280
            # Prefer FlashInfer for Blackwell GPUs if installed
281
282
283
            if cls.is_device_capability(100):
                if is_default_backend_supported := is_attn_backend_supported(
                        FLASHINFER_V1, head_size, dtype):
284
285
                    from vllm.v1.attention.backends.utils import (
                        set_kv_cache_layout)
286

287
                    logger.info_once(
288
289
290
                        "Using FlashInfer backend with HND KV cache layout on "
                        "V1 engine by default for Blackwell (SM 10.0) GPUs.")
                    set_kv_cache_layout("HND")
291

292
                    return FLASHINFER_V1
293
294
295

                if not is_default_backend_supported.can_import:
                    logger.warning_once(
296
297
298
                        "FlashInfer failed to import for V1 engine on "
                        "Blackwell (SM 10.0) GPUs; it is recommended to "
                        "install FlashInfer for better performance.")
299

300
            # FlashAttention is the default for SM 8.0+ GPUs
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
            if cls.has_device_capability(80):
                if is_default_backend_supported := is_attn_backend_supported(
                        FLASH_ATTN_V1, head_size, dtype,
                        allow_import_error=False):
                    logger.info_once("Using Flash Attention backend on "
                                     "V1 engine.")
                    return FLASH_ATTN_V1

            # FlexAttention is the default for older GPUs
            else:
                logger.info_once("Using FlexAttention backend on V1 engine.")
                return FLEX_ATTENTION_V1

            assert not is_default_backend_supported

            use_flex_attention_reason = {}
            if not is_default_backend_supported.head_size:
                use_flex_attention_reason["head_size"] = head_size
            if not is_default_backend_supported.dtype:
                use_flex_attention_reason["dtype"] = dtype
321

322
323
324
325
326
            logger.info_once(
                "Using FlexAttention backend for %s on V1 engine.",
                ", ".join(f"{k}={v}"
                          for k, v in use_flex_attention_reason.items()),
            )
327
            return FLEX_ATTENTION_V1
328
329

        # Backends for V0 engine
330
331
        if selected_backend == _Backend.FLASHINFER:
            logger.info("Using FlashInfer backend.")
332
333
334
335
336
337
338
            if cls.has_device_capability(100):
                from vllm.v1.attention.backends.utils import (
                    set_kv_cache_layout)
                logger.info_once(
                    "Using HND KV cache layout on V1 engine by default for "
                    "Blackwell (SM 10.0) GPUs.")
                set_kv_cache_layout("HND")
339
340
341
342
            return "vllm.attention.backends.flashinfer.FlashInferBackend"
        elif selected_backend == _Backend.XFORMERS:
            logger.info("Using XFormers backend.")
            return "vllm.attention.backends.xformers.XFormersBackend"
343
344
345
346
        elif selected_backend == _Backend.DUAL_CHUNK_FLASH_ATTN:
            logger.info("Using DualChunkFlashAttention backend.")
            return ("vllm.attention.backends.dual_chunk_flash_attn."
                    "DualChunkFlashAttentionBackend")
347
348
349
350
        elif selected_backend == _Backend.DIFFERENTIAL_FLASH_ATTN:
            logger.info("Using DifferentialFlashAttention backend.")
            return ("vllm.attention.backends.differential_flash_attn."
                    "DifferentialFlashAttentionBackend")
351
352
353
354
        elif selected_backend == _Backend.FLASH_ATTN:
            pass
        elif selected_backend:
            raise ValueError(
355
356
                f"Invalid attention backend for {cls.device_name}, "
                f"with use_v1: {use_v1} use_mla: {use_mla}")
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381

        target_backend = _Backend.FLASH_ATTN
        if not cls.has_device_capability(80):
            # Volta and Turing NVIDIA GPUs.
            logger.info(
                "Cannot use FlashAttention-2 backend for Volta and Turing "
                "GPUs.")
            target_backend = _Backend.XFORMERS
        elif dtype not in (torch.float16, torch.bfloat16):
            logger.info(
                "Cannot use FlashAttention-2 backend for dtype other than "
                "torch.float16 or torch.bfloat16.")
            target_backend = _Backend.XFORMERS
        elif block_size % 16 != 0:
            logger.info(
                "Cannot use FlashAttention-2 backend for block size not "
                "divisible by 16.")
            target_backend = _Backend.XFORMERS

        # FlashAttn is valid for the model, checking if the package is
        # installed.
        if target_backend == _Backend.FLASH_ATTN:
            try:
                import vllm.vllm_flash_attn  # noqa: F401
                from vllm.attention.backends.flash_attn import (  # noqa: F401
382
                    FlashAttentionBackend, flash_attn_supports_fp8)
383
384
385
386
387
388
389
390

                supported_sizes = \
                    FlashAttentionBackend.get_supported_head_sizes()
                if head_size not in supported_sizes:
                    logger.info(
                        "Cannot use FlashAttention-2 backend for head size %d.",
                        head_size)
                    target_backend = _Backend.XFORMERS
391
392
                fp8_kv_cache = (kv_cache_dtype is not None
                                and kv_cache_dtype.startswith("fp8"))
393
                if (fp8_kv_cache and not flash_attn_supports_fp8()):
394
                    logger.info(
395
                        "Cannot use FlashAttention backend for FP8 KV cache.")
396
397
398
399
400
                    logger.warning(
                        "Please use FlashInfer backend with FP8 KV Cache for "
                        "better performance by setting environment variable "
                        "VLLM_ATTENTION_BACKEND=FLASHINFER")
                    target_backend = _Backend.XFORMERS
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
            except ImportError:
                logger.info(
                    "Cannot use FlashAttention-2 backend because the "
                    "vllm.vllm_flash_attn package is not found. "
                    "Make sure that vllm_flash_attn was built and installed "
                    "(on by default).")
                target_backend = _Backend.XFORMERS

        if target_backend == _Backend.XFORMERS:
            logger.info("Using XFormers backend.")
            return "vllm.attention.backends.xformers.XFormersBackend"

        logger.info("Using Flash Attention backend.")
        return "vllm.attention.backends.flash_attn.FlashAttentionBackend"

416
417
418
419
    @classmethod
    def get_punica_wrapper(cls) -> str:
        return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU"

420
421
422
423
    @classmethod
    def get_device_communicator_cls(cls) -> str:
        return "vllm.distributed.device_communicators.cuda_communicator.CudaCommunicator"  # noqa

424
425
426
427
    @classmethod
    def supports_fp8(cls) -> bool:
        return cls.has_device_capability(89)

428
    @classmethod
429
    def supports_v1(cls, model_config: "ModelConfig") -> bool:
430
431
        return True

432
433
434
435
    @classmethod
    def use_custom_allreduce(cls) -> bool:
        return True

436
437
438
439
    @classmethod
    def get_piecewise_backend_cls(cls) -> str:
        return "vllm.compilation.cuda_piecewise_backend.CUDAPiecewiseBackend"  # noqa

440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
    @classmethod
    def stateless_init_device_torch_dist_pg(
        cls,
        backend: str,
        prefix_store: PrefixStore,
        group_rank: int,
        group_size: int,
        timeout: timedelta,
    ) -> ProcessGroup:
        assert is_nccl_available()
        pg: ProcessGroup = ProcessGroup(
            prefix_store,
            group_rank,
            group_size,
        )
        from torch.distributed.distributed_c10d import ProcessGroupNCCL

        backend_options = ProcessGroupNCCL.Options()
        backend_options._timeout = timeout

        backend_class = ProcessGroupNCCL(prefix_store, group_rank, group_size,
                                         backend_options)
        backend_type = ProcessGroup.BackendType.NCCL
        device = torch.device("cuda")
        pg._set_default_backend(backend_type)
        backend_class._set_sequence_number_for_group()

        pg._register_backend(device, backend_type, backend_class)
        return pg

470
471
472
473
    @classmethod
    def device_count(cls) -> int:
        return cuda_device_count_stateless()

474
475
476
477
478
479
480
481
482
483
484
485
486
    @classmethod
    def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str) -> bool:
        fp8_attention = kv_cache_dtype.startswith("fp8")
        will_use_fa = (not envs.is_set("VLLM_ATTENTION_BACKEND")
                       ) or envs.VLLM_ATTENTION_BACKEND == "FLASH_ATTN_VLLM_V1"
        supported = False
        if cls.is_device_capability(100):
            supported = True
        elif fp8_attention and will_use_fa:
            from vllm.attention.utils.fa_utils import flash_attn_supports_fp8
            supported = flash_attn_supports_fp8()
        return supported

487

488
489
490
491
492
# NVML utils
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
# all the related functions work on real physical device ids.
# the major benefit of using NVML is that it will not initialize CUDA
class NvmlCudaPlatform(CudaPlatformBase):
493

494
    @classmethod
495
    @cache
496
    @with_nvml_context
497
498
499
500
    def get_device_capability(cls,
                              device_id: int = 0
                              ) -> Optional[DeviceCapability]:
        try:
501
            physical_device_id = cls.device_id_to_physical_device_id(device_id)
502
503
504
505
506
507
508
509
510
511
            handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
            major, minor = pynvml.nvmlDeviceGetCudaComputeCapability(handle)
            return DeviceCapability(major=major, minor=minor)
        except RuntimeError:
            return None

    @classmethod
    @with_nvml_context
    def has_device_capability(
        cls,
512
        capability: Union[tuple[int, int], int],
513
514
515
516
517
518
        device_id: int = 0,
    ) -> bool:
        try:
            return super().has_device_capability(capability, device_id)
        except RuntimeError:
            return False
519

520
    @classmethod
521
    @with_nvml_context
522
    def get_device_name(cls, device_id: int = 0) -> str:
523
        physical_device_id = cls.device_id_to_physical_device_id(device_id)
524
        return cls._get_physical_device_name(physical_device_id)
525

526
527
528
    @classmethod
    @with_nvml_context
    def get_device_uuid(cls, device_id: int = 0) -> str:
529
        physical_device_id = cls.device_id_to_physical_device_id(device_id)
530
531
532
        handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
        return pynvml.nvmlDeviceGetUUID(handle)

533
    @classmethod
534
    @with_nvml_context
535
    def get_device_total_memory(cls, device_id: int = 0) -> int:
536
        physical_device_id = cls.device_id_to_physical_device_id(device_id)
537
538
        handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
        return int(pynvml.nvmlDeviceGetMemoryInfo(handle).total)
539

540
    @classmethod
541
    @with_nvml_context
542
    def is_fully_connected(cls, physical_device_ids: list[int]) -> bool:
543
544
545
546
547
548
549
550
551
552
553
        """
        query if the set of gpus are fully connected by nvlink (1 hop)
        """
        handles = [
            pynvml.nvmlDeviceGetHandleByIndex(i) for i in physical_device_ids
        ]
        for i, handle in enumerate(handles):
            for j, peer_handle in enumerate(handles):
                if i < j:
                    try:
                        p2p_status = pynvml.nvmlDeviceGetP2PStatus(
554
555
556
557
                            handle,
                            peer_handle,
                            pynvml.NVML_P2P_CAPS_INDEX_NVLINK,
                        )
558
559
                        if p2p_status != pynvml.NVML_P2P_STATUS_OK:
                            return False
560
561
                    except pynvml.NVMLError:
                        logger.exception(
562
563
                            "NVLink detection failed. This is normal if"
                            " your machine has no NVLink equipped.")
564
565
                        return False
        return True
566
567

    @classmethod
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
    def _get_physical_device_name(cls, device_id: int = 0) -> str:
        handle = pynvml.nvmlDeviceGetHandleByIndex(device_id)
        return pynvml.nvmlDeviceGetName(handle)

    @classmethod
    @with_nvml_context
    def log_warnings(cls):
        device_ids: int = pynvml.nvmlDeviceGetCount()
        if device_ids > 1:
            device_names = [
                cls._get_physical_device_name(i) for i in range(device_ids)
            ]
            if (len(set(device_names)) > 1
                    and os.environ.get("CUDA_DEVICE_ORDER") != "PCI_BUS_ID"):
                logger.warning(
583
                    "Detected different devices in the system: %s. Please"
584
585
                    " make sure to set `CUDA_DEVICE_ORDER=PCI_BUS_ID` to "
                    "avoid unexpected behavior.",
586
                    ", ".join(device_names),
587
588
589
590
591
592
                )


class NonNvmlCudaPlatform(CudaPlatformBase):

    @classmethod
593
    @cache
594
595
596
597
598
599
600
601
602
603
604
605
606
607
    def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:
        major, minor = torch.cuda.get_device_capability(device_id)
        return DeviceCapability(major=major, minor=minor)

    @classmethod
    def get_device_name(cls, device_id: int = 0) -> str:
        return torch.cuda.get_device_name(device_id)

    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        device_props = torch.cuda.get_device_properties(device_id)
        return device_props.total_memory

    @classmethod
608
    def is_fully_connected(cls, physical_device_ids: list[int]) -> bool:
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
        logger.exception(
            "NVLink detection not possible, as context support was"
            " not found. Assuming no NVLink available.")
        return False


# Autodetect either NVML-enabled or non-NVML platform
# based on whether NVML is available.
nvml_available = False
try:
    try:
        pynvml.nvmlInit()
        nvml_available = True
    except Exception:
        # On Jetson, NVML is not supported.
        nvml_available = False
finally:
    if nvml_available:
        pynvml.nvmlShutdown()

CudaPlatform = NvmlCudaPlatform if nvml_available else NonNvmlCudaPlatform

631
CudaPlatform.log_warnings()