cuda.py 29.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
"""Code inside this file can safely assume cuda platform, e.g. importing
pynvml. However, it should not initialize cuda context.
"""

7
import os
8
from datetime import timedelta
9
from functools import cache, wraps
10
from typing import TYPE_CHECKING, Callable, Optional, TypeVar, Union
11

12
import torch
13
14
from torch.distributed import PrefixStore, ProcessGroup
from torch.distributed.distributed_c10d import is_nccl_available
15
from typing_extensions import ParamSpec
16

17
18
# import custom ops, trigger op registration
import vllm._C  # noqa
19
import vllm.envs as envs
20
from vllm.logger import init_logger
21
from vllm.utils import cuda_device_count_stateless, import_pynvml
22

23
from .interface import DeviceCapability, Platform, PlatformEnum, _Backend
24

25
if TYPE_CHECKING:
26
    from vllm.config import ModelConfig, VllmConfig
27

28
29
logger = init_logger(__name__)

30
31
32
_P = ParamSpec("_P")
_R = TypeVar("_R")

33
pynvml = import_pynvml()
34

35
36
37
38
# pytorch 2.5 uses cudnn sdpa by default, which will cause crash on some models
# see https://github.com/huggingface/diffusers/issues/9704 for details
torch.backends.cuda.enable_cudnn_sdp(False)

39

40
def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]:
41
42

    @wraps(fn)
43
    def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
44
45
46
47
48
49
50
51
52
        pynvml.nvmlInit()
        try:
            return fn(*args, **kwargs)
        finally:
            pynvml.nvmlShutdown()

    return wrapper


53
54
class CudaPlatformBase(Platform):
    _enum = PlatformEnum.CUDA
55
    device_name: str = "cuda"
56
57
    device_type: str = "cuda"
    dispatch_key: str = "CUDA"
58
    ray_device_key: str = "GPU"
59
    dist_backend: str = "nccl"
60
    device_control_env_var: str = "CUDA_VISIBLE_DEVICES"
61

62
    @property
63
    def supported_dtypes(self) -> list[torch.dtype]:
64
65
66
67
68
69
70
71
72
73
74
        if self.has_device_capability(80):
            # Ampere and Hopper or later NVIDIA GPUs.
            return [torch.bfloat16, torch.float16, torch.float32]
        elif (not self.has_device_capability(80)
              ) and self.has_device_capability(60):
            # Pascal, Volta and Turing NVIDIA GPUs, BF16 is not supported
            return [torch.float16, torch.float32]
        # Kepler and Maxwell NVIDIA GPUs, only FP32 is supported,
        # though vLLM doesn't support these GPUs.
        return [torch.float32]

75
76
77
78
79
    @classmethod
    def set_device(cls, device: torch.device) -> None:
        """
        Set the device for the current platform.
        """
80
        torch.cuda.set_device(device)
81
82
83
84
85
        # With this trick we can force the device to be set eagerly
        # see https://github.com/pytorch/pytorch/issues/155668
        # for why and when it is needed
        _ = torch.zeros(1, device=device)

86
    @classmethod
87
88
89
    def get_device_capability(cls,
                              device_id: int = 0
                              ) -> Optional[DeviceCapability]:
90
        raise NotImplementedError
91

92
93
94
    @classmethod
    def get_device_name(cls, device_id: int = 0) -> str:
        raise NotImplementedError
95

96
97
98
    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        raise NotImplementedError
99

100
101
    @classmethod
    def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
102
        if enforce_eager and not envs.VLLM_USE_V1:
103
104
105
106
107
108
109
            logger.warning(
                "To see benefits of async output processing, enable CUDA "
                "graph. Since, enforce-eager is enabled, async output "
                "processor cannot be used")
            return False
        return True

110
    @classmethod
111
    def is_fully_connected(cls, device_ids: list[int]) -> bool:
112
        raise NotImplementedError
113

114
115
116
    @classmethod
    def log_warnings(cls):
        pass
117

118
    @classmethod
119
    def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
120
        parallel_config = vllm_config.parallel_config
121
        model_config = vllm_config.model_config
122

123
        if parallel_config.worker_cls == "auto":
124
            if vllm_config.speculative_config:
125
126
127
128
                if not envs.VLLM_USE_V1:
                    raise NotImplementedError(
                        "Speculative decoding is not supported on vLLM V0.")
                parallel_config.worker_cls = "vllm.v1.worker.gpu_worker.Worker"
129
            else:
130
131
                if envs.VLLM_USE_V1:
                    parallel_config.worker_cls = \
132
                        "vllm.v1.worker.gpu_worker.Worker"
133
134
                else:
                    parallel_config.worker_cls = "vllm.worker.worker.Worker"
135

136
137
138
        cache_config = vllm_config.cache_config
        if cache_config and cache_config.block_size is None:
            cache_config.block_size = 16
139

140
        # TODO(lucas): handle this more gracefully
141
142
        # Note: model_config may be None during testing
        if model_config is not None and model_config.use_mla:
143
144
145
146
147
148
149
150
151
152
153
154
            # If `VLLM_ATTENTION_BACKEND` is not set and we are using MLA,
            # then we default to FlashMLA backend for non-blackwell GPUs,
            # else we default to CutlassMLA. For each case, we force the
            # required block_size.
            use_flashmla = False
            use_cutlass_mla = False

            if envs.VLLM_ATTENTION_BACKEND is None:
                # Default case
                if cls.is_device_capability(100):
                    # Blackwell => Force CutlassMLA.
                    use_cutlass_mla = True
155
156
157
                    # TODO: This does not work, because the
                    # global_force_attn_backend_context_manager is not set.
                    # See vllm/attention/selector.py:_cached_get_attn_backend
158
                    envs.VLLM_ATTENTION_BACKEND = "CUTLASS_MLA"
159
160
161
162
163
164
165
                else:
                    # Not Blackwell
                    use_flashmla = True
            else:
                # Forced case
                use_flashmla = (envs.VLLM_ATTENTION_BACKEND == "FLASHMLA")
                use_cutlass_mla = (
166
                    envs.VLLM_ATTENTION_BACKEND == "CUTLASS_MLA")
167

168
            from vllm.attention.ops.flashmla import is_flashmla_supported
169
170
171
172
173
            if use_flashmla and is_flashmla_supported()[0] \
                and cache_config.block_size != 64:
                cache_config.block_size = 64
                logger.info(
                    "Forcing kv cache block size to 64 for FlashMLA backend.")
174

175
176
177
            if use_cutlass_mla and cache_config.block_size != 128:
                cache_config.block_size = 128
                logger.info("Forcing kv cache block size to 128 for "
178
                            "CUTLASS_MLA backend.")
179

180
181
182
        # lazy import to avoid circular import
        from vllm.config import CUDAGraphMode

183
        compilation_config = vllm_config.compilation_config
184
185
        if (envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput"
                and parallel_config.data_parallel_size > 1
186
187
                and compilation_config.cudagraph_mode
                not in [CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE]):
188
            logger.info(
189
190
191
192
193
                "Data Parallel with DeepEP high-throughput: using PIECEWISE "
                "CUDA graphs and excluding MoE ops from capture. Set "
                "VLLM_ALL2ALL_BACKEND=deepep_low_latency if you need MoE "
                "graphs captured as well.")
            compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
194

195
196
197
198
    @classmethod
    def get_current_memory_usage(cls,
                                 device: Optional[torch.types.Device] = None
                                 ) -> float:
199
        torch.cuda.empty_cache()
200
201
202
        torch.cuda.reset_peak_memory_stats(device)
        return torch.cuda.max_memory_allocated(device)

203
204
205
206
207
208
209
210
211
212
213
214
215
216
    @classmethod
    def get_vit_attn_backend(cls, support_fa: bool = False) -> _Backend:
        if cls.has_device_capability(80) and support_fa:
            from transformers.utils import is_flash_attn_2_available
            if is_flash_attn_2_available():
                return _Backend.FLASH_ATTN
            logger.warning_once(
                "Current `vllm-flash-attn` has a bug inside vision "
                "module, so we use xformers backend instead. You can "
                "run `pip install flash-attn` to use flash-attention "
                "backend.")
        # Fallback for Volta/Turing GPUs or FA not supported
        return _Backend.XFORMERS

217
218
    @classmethod
    def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
219
220
                             kv_cache_dtype, block_size, use_v1, use_mla,
                             has_sink) -> str:
221
        if use_mla:
222
            # TODO(lucas): refactor to be more concise
223
            #  we should probably consider factoring out V1 here
224
225
226
227
228
229
230

            from vllm.attention.ops.flashmla import is_flashmla_supported
            from vllm.attention.utils.fa_utils import flash_attn_supports_mla

            use_cutlassmla = selected_backend == _Backend.CUTLASS_MLA or (
                selected_backend is None and cls.is_device_capability(100)
                and block_size == 128)
231
232
            use_flashinfermla = (selected_backend == _Backend.FLASHINFER_MLA
                                 and cls.has_device_capability(100))
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
            use_flashmla = selected_backend in [
                _Backend.FLASHMLA, _Backend.FLASHMLA_VLLM_V1
            ] or (selected_backend is None and is_flashmla_supported()[0])
            use_flashattn = selected_backend == _Backend.FLASH_ATTN_MLA or (
                selected_backend is None and flash_attn_supports_mla())
            use_triton = selected_backend == _Backend.TRITON_MLA or (
                selected_backend is None)

            def _get_version(name, import_suffix) -> str:
                if use_v1:
                    logger.info_once(f"Using {name} backend on V1 engine.")
                    return f"vllm.v1.attention.backends.mla.{import_suffix}"
                else:
                    logger.info_once(f"Using {name} backend.")
                    return f"vllm.attention.backends.{import_suffix}"

            if use_cutlassmla:
250
251
252
253
254
255
256
                if use_v1:
                    logger.info_once("Using Cutlass MLA backend on V1 engine.")
                    return ("vllm.v1.attention.backends.mla."
                            "cutlass_mla.CutlassMLABackend")
                else:
                    logger.warning(
                        "Cutlass MLA backend is only supported on V1 engine")
257
258
259
260
261
262
263
264
265
266
267
268
269
            if use_flashinfermla:
                if use_v1:
                    from vllm.v1.attention.backends.utils import (
                        set_kv_cache_layout)
                    set_kv_cache_layout("HND")
                    logger.info_once(
                        "Using FlashInfer MLA backend on V1 engine.")
                    return ("vllm.v1.attention.backends.mla."
                            "flashinfer_mla.FlashInferMLABackend")
                else:
                    logger.warning(
                        "FlashInfer MLA backend is only supported on V1 engine"
                    )
270
271
            if use_flashmla:
                if block_size != 64:
272
273
274
275
276
                    logger.warning(
                        "FlashMLA backend is not supported for block size %d"
                        " (currently only supports block size 64).",
                        block_size)
                else:
277
278
279
280
281
282
283
284
285
286
287
288
289
290
                    return _get_version("FlashMLA", "flashmla.FlashMLABackend")
            if use_flashattn:
                if use_v1:
                    logger.info_once(
                        "Using FlashAttention MLA backend on V1 engine.")
                    return ("vllm.v1.attention.backends.mla."
                            "flashattn_mla.FlashAttnMLABackend")
                else:
                    logger.warning(
                        "FlashAttention MLA backend is only supported on V1 "
                        "engine.")
            if use_triton:
                return _get_version("Triton MLA",
                                    "triton_mla.TritonMLABackend")
291
        if use_v1:
292
293
294
295
            FLASHINFER_V1 = "vllm.v1.attention.backends.flashinfer.FlashInferBackend"  # noqa: E501
            FLEX_ATTENTION_V1 = "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend"  # noqa: E501
            TRITON_ATTN_VLLM_V1 = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend"  # noqa: E501
            FLASH_ATTN_V1 = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"  # noqa: E501
296
            TREE_ATTN_V1 = "vllm.v1.attention.backends.tree_attn.TreeAttentionBackend"  # noqa: E501
297
            XFORMERS_V1 = "vllm.v1.attention.backends.xformers.XFormersAttentionBackend"  # noqa: E501
298

299
300
            if selected_backend == _Backend.FLASHINFER:
                logger.info_once("Using FlashInfer backend on V1 engine.")
301
302
303
304
                if cls.has_device_capability(100):
                    from vllm.v1.attention.backends.utils import (
                        set_kv_cache_layout)
                    set_kv_cache_layout("HND")
305
                return FLASHINFER_V1
306
            elif selected_backend == _Backend.FLEX_ATTENTION:
307
308
                logger.info_once("Using FlexAttention backend on V1 engine.")
                return FLEX_ATTENTION_V1
309
            elif selected_backend == _Backend.TRITON_ATTN_VLLM_V1:
310
                logger.info_once("Using Triton backend on V1 engine.")
311
                return TRITON_ATTN_VLLM_V1
312
313
            elif selected_backend == _Backend.FLASH_ATTN:
                logger.info_once("Using Flash Attention backend on V1 engine.")
314
                return FLASH_ATTN_V1
315
316
317
            elif selected_backend == _Backend.TREE_ATTN:
                logger.info_once("Using Tree Attention backend on V1 engine.")
                return TREE_ATTN_V1
318
319
320
            elif selected_backend == _Backend.XFORMERS_VLLM_V1:
                logger.info_once("Using XFormers backend on V1 engine.")
                return XFORMERS_V1
321

322
            from vllm.attention.selector import is_attn_backend_supported
323
324

            # Default backends for V1 engine
325
            # Prefer FlashInfer for Blackwell GPUs if installed
326
327
328
            if cls.is_device_capability(100):
                if is_default_backend_supported := is_attn_backend_supported(
                        FLASHINFER_V1, head_size, dtype):
329
330
                    from vllm.v1.attention.backends.utils import (
                        set_kv_cache_layout)
331

332
                    logger.info_once(
333
334
335
                        "Using FlashInfer backend with HND KV cache layout on "
                        "V1 engine by default for Blackwell (SM 10.0) GPUs.")
                    set_kv_cache_layout("HND")
336

337
                    return FLASHINFER_V1
338
339
340

                if not is_default_backend_supported.can_import:
                    logger.warning_once(
341
342
343
                        "FlashInfer failed to import for V1 engine on "
                        "Blackwell (SM 10.0) GPUs; it is recommended to "
                        "install FlashInfer for better performance.")
344

345
            # FlashAttention is the default for SM 8.0+ GPUs
346
            if cls.has_device_capability(80):
347
                if has_sink and not cls.is_device_capability(90):
348
349
                    logger.info_once("Using Triton backend on V1 engine.")
                    return TRITON_ATTN_VLLM_V1
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
                if is_default_backend_supported := is_attn_backend_supported(
                        FLASH_ATTN_V1, head_size, dtype,
                        allow_import_error=False):
                    logger.info_once("Using Flash Attention backend on "
                                     "V1 engine.")
                    return FLASH_ATTN_V1

            # FlexAttention is the default for older GPUs
            else:
                logger.info_once("Using FlexAttention backend on V1 engine.")
                return FLEX_ATTENTION_V1

            assert not is_default_backend_supported

            use_flex_attention_reason = {}
            if not is_default_backend_supported.head_size:
                use_flex_attention_reason["head_size"] = head_size
            if not is_default_backend_supported.dtype:
                use_flex_attention_reason["dtype"] = dtype
369

370
371
372
373
374
            logger.info_once(
                "Using FlexAttention backend for %s on V1 engine.",
                ", ".join(f"{k}={v}"
                          for k, v in use_flex_attention_reason.items()),
            )
375
            return FLEX_ATTENTION_V1
376
377

        # Backends for V0 engine
378
        if selected_backend == _Backend.XFORMERS:
379
380
            logger.info("Using XFormers backend.")
            return "vllm.attention.backends.xformers.XFormersBackend"
381
382
383
384
        elif selected_backend == _Backend.DUAL_CHUNK_FLASH_ATTN:
            logger.info("Using DualChunkFlashAttention backend.")
            return ("vllm.attention.backends.dual_chunk_flash_attn."
                    "DualChunkFlashAttentionBackend")
385
386
387
388
        elif selected_backend == _Backend.DIFFERENTIAL_FLASH_ATTN:
            logger.info("Using DifferentialFlashAttention backend.")
            return ("vllm.attention.backends.differential_flash_attn."
                    "DifferentialFlashAttentionBackend")
389
390
391
392
        elif selected_backend == _Backend.FLASH_ATTN:
            pass
        elif selected_backend:
            raise ValueError(
393
394
                f"Invalid attention backend for {cls.device_name}, "
                f"with use_v1: {use_v1} use_mla: {use_mla}")
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419

        target_backend = _Backend.FLASH_ATTN
        if not cls.has_device_capability(80):
            # Volta and Turing NVIDIA GPUs.
            logger.info(
                "Cannot use FlashAttention-2 backend for Volta and Turing "
                "GPUs.")
            target_backend = _Backend.XFORMERS
        elif dtype not in (torch.float16, torch.bfloat16):
            logger.info(
                "Cannot use FlashAttention-2 backend for dtype other than "
                "torch.float16 or torch.bfloat16.")
            target_backend = _Backend.XFORMERS
        elif block_size % 16 != 0:
            logger.info(
                "Cannot use FlashAttention-2 backend for block size not "
                "divisible by 16.")
            target_backend = _Backend.XFORMERS

        # FlashAttn is valid for the model, checking if the package is
        # installed.
        if target_backend == _Backend.FLASH_ATTN:
            try:
                import vllm.vllm_flash_attn  # noqa: F401
                from vllm.attention.backends.flash_attn import (  # noqa: F401
420
                    FlashAttentionBackend, flash_attn_supports_fp8)
421
422
423
424
425
426
427
428

                supported_sizes = \
                    FlashAttentionBackend.get_supported_head_sizes()
                if head_size not in supported_sizes:
                    logger.info(
                        "Cannot use FlashAttention-2 backend for head size %d.",
                        head_size)
                    target_backend = _Backend.XFORMERS
429
430
                fp8_kv_cache = (kv_cache_dtype is not None
                                and kv_cache_dtype.startswith("fp8"))
431
                if (fp8_kv_cache and not flash_attn_supports_fp8()):
432
                    logger.info(
433
                        "Cannot use FlashAttention backend for FP8 KV cache.")
434
                    target_backend = _Backend.XFORMERS
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
            except ImportError:
                logger.info(
                    "Cannot use FlashAttention-2 backend because the "
                    "vllm.vllm_flash_attn package is not found. "
                    "Make sure that vllm_flash_attn was built and installed "
                    "(on by default).")
                target_backend = _Backend.XFORMERS

        if target_backend == _Backend.XFORMERS:
            logger.info("Using XFormers backend.")
            return "vllm.attention.backends.xformers.XFormersBackend"

        logger.info("Using Flash Attention backend.")
        return "vllm.attention.backends.flash_attn.FlashAttentionBackend"

450
451
452
453
    @classmethod
    def get_punica_wrapper(cls) -> str:
        return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU"

454
455
456
457
    @classmethod
    def get_device_communicator_cls(cls) -> str:
        return "vllm.distributed.device_communicators.cuda_communicator.CudaCommunicator"  # noqa

458
459
460
461
    @classmethod
    def supports_fp8(cls) -> bool:
        return cls.has_device_capability(89)

462
    @classmethod
463
    def supports_v1(cls, model_config: "ModelConfig") -> bool:
464
465
        return True

466
467
468
469
    @classmethod
    def use_custom_allreduce(cls) -> bool:
        return True

470
471
472
473
    @classmethod
    def opaque_attention_op(cls) -> bool:
        return True

474
    @classmethod
475
476
    def get_static_graph_wrapper_cls(cls) -> str:
        return "vllm.compilation.cuda_graph.CUDAGraphWrapper"
477

478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
    @classmethod
    def stateless_init_device_torch_dist_pg(
        cls,
        backend: str,
        prefix_store: PrefixStore,
        group_rank: int,
        group_size: int,
        timeout: timedelta,
    ) -> ProcessGroup:
        assert is_nccl_available()
        pg: ProcessGroup = ProcessGroup(
            prefix_store,
            group_rank,
            group_size,
        )
        from torch.distributed.distributed_c10d import ProcessGroupNCCL

        backend_options = ProcessGroupNCCL.Options()
        backend_options._timeout = timeout

        backend_class = ProcessGroupNCCL(prefix_store, group_rank, group_size,
                                         backend_options)
        backend_type = ProcessGroup.BackendType.NCCL
        device = torch.device("cuda")
        pg._set_default_backend(backend_type)
        backend_class._set_sequence_number_for_group()

        pg._register_backend(device, backend_type, backend_class)
        return pg

508
509
510
511
    @classmethod
    def device_count(cls) -> int:
        return cuda_device_count_stateless()

512
    @classmethod
513
514
    def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str,
                                    model_config: "ModelConfig") -> bool:
515
        fp8_attention = kv_cache_dtype.startswith("fp8")
516
517
        attention_backend = envs.VLLM_ATTENTION_BACKEND

518
        supported = False
519
520
521
522
523
524
525
526
527
        if model_config is not None and model_config.use_mla:
            # Default to CutlassMLA for blackwell,
            # FlashMLA otherwise
            if attention_backend is None:
                if cls.is_device_capability(100):
                    attention_backend = "CUTLASS_MLA"
                else:
                    attention_backend = "FLASHMLA"

528
529
            # Only FlashMLA and CUTLASS_MLA support fp8
            if attention_backend in ["FLASHMLA", "CUTLASS_MLA"]:
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
                supported = True
            else:
                supported = (not fp8_attention)
        else:
            # Default to FlashAttention
            if attention_backend is None:
                attention_backend = "FLASH_ATTN_VLLM_V1"

            # All Blackwell backends support fp8
            if cls.is_device_capability(100):
                supported = True
            elif attention_backend == "FLASH_ATTN_VLLM_V1":
                if fp8_attention:
                    from vllm.attention.utils.fa_utils import (
                        flash_attn_supports_fp8)
                    supported = flash_attn_supports_fp8()
                else:
                    supported = True
548
549
550
551
            elif attention_backend == "FLASHINFER":
                supported = True
            elif attention_backend == "TRITON_ATTN_VLLM_V1":
                supported = cls.supports_fp8()
552
553
        return supported

554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
    @classmethod
    def check_if_supports_dtype(cls, torch_dtype: torch.dtype):
        if torch_dtype == torch.bfloat16:  # noqa: SIM102
            if not cls.has_device_capability(80):
                capability = cls.get_device_capability()
                gpu_name = cls.get_device_name()

                if capability is None:
                    compute_str = "does not have a compute capability"
                else:
                    version_str = capability.as_version_str()
                    compute_str = f"has compute capability {version_str}"

                raise ValueError(
                    "Bfloat16 is only supported on GPUs "
                    "with compute capability of at least 8.0. "
                    f"Your {gpu_name} GPU {compute_str}. "
                    "You can use float16 instead by explicitly setting the "
                    "`dtype` flag in CLI, for example: --dtype=half.")

574
575
576
577
    @classmethod
    def support_hybrid_kv_cache(cls) -> bool:
        return True

578

579
580
581
582
583
# NVML utils
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
# all the related functions work on real physical device ids.
# the major benefit of using NVML is that it will not initialize CUDA
class NvmlCudaPlatform(CudaPlatformBase):
584

585
    @classmethod
586
    @cache
587
    @with_nvml_context
588
589
590
591
    def get_device_capability(cls,
                              device_id: int = 0
                              ) -> Optional[DeviceCapability]:
        try:
592
            physical_device_id = cls.device_id_to_physical_device_id(device_id)
593
594
595
596
597
598
599
600
601
602
            handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
            major, minor = pynvml.nvmlDeviceGetCudaComputeCapability(handle)
            return DeviceCapability(major=major, minor=minor)
        except RuntimeError:
            return None

    @classmethod
    @with_nvml_context
    def has_device_capability(
        cls,
603
        capability: Union[tuple[int, int], int],
604
605
606
607
608
609
        device_id: int = 0,
    ) -> bool:
        try:
            return super().has_device_capability(capability, device_id)
        except RuntimeError:
            return False
610

611
    @classmethod
612
    @with_nvml_context
613
    def get_device_name(cls, device_id: int = 0) -> str:
614
        physical_device_id = cls.device_id_to_physical_device_id(device_id)
615
        return cls._get_physical_device_name(physical_device_id)
616

617
618
619
    @classmethod
    @with_nvml_context
    def get_device_uuid(cls, device_id: int = 0) -> str:
620
        physical_device_id = cls.device_id_to_physical_device_id(device_id)
621
622
623
        handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
        return pynvml.nvmlDeviceGetUUID(handle)

624
    @classmethod
625
    @with_nvml_context
626
    def get_device_total_memory(cls, device_id: int = 0) -> int:
627
        physical_device_id = cls.device_id_to_physical_device_id(device_id)
628
629
        handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
        return int(pynvml.nvmlDeviceGetMemoryInfo(handle).total)
630

631
    @classmethod
632
    @with_nvml_context
633
    def is_fully_connected(cls, physical_device_ids: list[int]) -> bool:
634
635
636
637
638
639
640
641
642
643
644
        """
        query if the set of gpus are fully connected by nvlink (1 hop)
        """
        handles = [
            pynvml.nvmlDeviceGetHandleByIndex(i) for i in physical_device_ids
        ]
        for i, handle in enumerate(handles):
            for j, peer_handle in enumerate(handles):
                if i < j:
                    try:
                        p2p_status = pynvml.nvmlDeviceGetP2PStatus(
645
646
647
648
                            handle,
                            peer_handle,
                            pynvml.NVML_P2P_CAPS_INDEX_NVLINK,
                        )
649
650
                        if p2p_status != pynvml.NVML_P2P_STATUS_OK:
                            return False
651
652
                    except pynvml.NVMLError:
                        logger.exception(
653
654
                            "NVLink detection failed. This is normal if"
                            " your machine has no NVLink equipped.")
655
656
                        return False
        return True
657
658

    @classmethod
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
    def _get_physical_device_name(cls, device_id: int = 0) -> str:
        handle = pynvml.nvmlDeviceGetHandleByIndex(device_id)
        return pynvml.nvmlDeviceGetName(handle)

    @classmethod
    @with_nvml_context
    def log_warnings(cls):
        device_ids: int = pynvml.nvmlDeviceGetCount()
        if device_ids > 1:
            device_names = [
                cls._get_physical_device_name(i) for i in range(device_ids)
            ]
            if (len(set(device_names)) > 1
                    and os.environ.get("CUDA_DEVICE_ORDER") != "PCI_BUS_ID"):
                logger.warning(
674
                    "Detected different devices in the system: %s. Please"
675
676
                    " make sure to set `CUDA_DEVICE_ORDER=PCI_BUS_ID` to "
                    "avoid unexpected behavior.",
677
                    ", ".join(device_names),
678
679
680
681
682
683
                )


class NonNvmlCudaPlatform(CudaPlatformBase):

    @classmethod
684
    @cache
685
686
687
688
689
690
691
692
693
694
695
696
697
698
    def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:
        major, minor = torch.cuda.get_device_capability(device_id)
        return DeviceCapability(major=major, minor=minor)

    @classmethod
    def get_device_name(cls, device_id: int = 0) -> str:
        return torch.cuda.get_device_name(device_id)

    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        device_props = torch.cuda.get_device_properties(device_id)
        return device_props.total_memory

    @classmethod
699
    def is_fully_connected(cls, physical_device_ids: list[int]) -> bool:
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
        logger.exception(
            "NVLink detection not possible, as context support was"
            " not found. Assuming no NVLink available.")
        return False


# Autodetect either NVML-enabled or non-NVML platform
# based on whether NVML is available.
nvml_available = False
try:
    try:
        pynvml.nvmlInit()
        nvml_available = True
    except Exception:
        # On Jetson, NVML is not supported.
        nvml_available = False
finally:
    if nvml_available:
        pynvml.nvmlShutdown()

CudaPlatform = NvmlCudaPlatform if nvml_available else NonNvmlCudaPlatform

722
CudaPlatform.log_warnings()