"vllm/model_executor/models/llama.py" did not exist on "a1b3de86cd6f27aeb299d45296a7409b8d2b7c0c"
cuda.py 27.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
"""Code inside this file can safely assume cuda platform, e.g. importing
pynvml. However, it should not initialize cuda context.
"""

7
import os
8
from datetime import timedelta
9
from functools import cache, wraps
10
from typing import TYPE_CHECKING, Callable, Optional, TypeVar, Union
11

12
import torch
13
14
from torch.distributed import PrefixStore, ProcessGroup
from torch.distributed.distributed_c10d import is_nccl_available
15
from typing_extensions import ParamSpec
16

17
18
# import custom ops, trigger op registration
import vllm._C  # noqa
19
import vllm.envs as envs
20
from vllm.logger import init_logger
21
from vllm.utils import cuda_device_count_stateless, import_pynvml
22

23
from .interface import DeviceCapability, Platform, PlatformEnum
24

25
if TYPE_CHECKING:
26
    from vllm.attention.backends.registry import _Backend
27
    from vllm.config import ModelConfig, VllmConfig
28
29
else:
    _Backend = None
30

31
32
logger = init_logger(__name__)

33
34
35
_P = ParamSpec("_P")
_R = TypeVar("_R")

36
pynvml = import_pynvml()
37

38
39
40
41
# pytorch 2.5 uses cudnn sdpa by default, which will cause crash on some models
# see https://github.com/huggingface/diffusers/issues/9704 for details
torch.backends.cuda.enable_cudnn_sdp(False)

42

43
def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]:
44
45

    @wraps(fn)
46
    def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
47
48
49
50
51
52
53
54
55
        pynvml.nvmlInit()
        try:
            return fn(*args, **kwargs)
        finally:
            pynvml.nvmlShutdown()

    return wrapper


56
57
class CudaPlatformBase(Platform):
    _enum = PlatformEnum.CUDA
58
    device_name: str = "cuda"
59
60
    device_type: str = "cuda"
    dispatch_key: str = "CUDA"
61
    ray_device_key: str = "GPU"
62
    dist_backend: str = "nccl"
63
    device_control_env_var: str = "CUDA_VISIBLE_DEVICES"
64

65
    @property
66
    def supported_dtypes(self) -> list[torch.dtype]:
67
68
69
        if self.has_device_capability(80):
            # Ampere and Hopper or later NVIDIA GPUs.
            return [torch.bfloat16, torch.float16, torch.float32]
70
        if self.has_device_capability(60):
71
72
73
74
75
76
            # Pascal, Volta and Turing NVIDIA GPUs, BF16 is not supported
            return [torch.float16, torch.float32]
        # Kepler and Maxwell NVIDIA GPUs, only FP32 is supported,
        # though vLLM doesn't support these GPUs.
        return [torch.float32]

77
78
79
80
81
    @classmethod
    def set_device(cls, device: torch.device) -> None:
        """
        Set the device for the current platform.
        """
82
        torch.cuda.set_device(device)
83
84
85
86
87
        # With this trick we can force the device to be set eagerly
        # see https://github.com/pytorch/pytorch/issues/155668
        # for why and when it is needed
        _ = torch.zeros(1, device=device)

88
    @classmethod
89
90
91
    def get_device_capability(cls,
                              device_id: int = 0
                              ) -> Optional[DeviceCapability]:
92
        raise NotImplementedError
93

94
95
96
    @classmethod
    def get_device_name(cls, device_id: int = 0) -> str:
        raise NotImplementedError
97

98
99
100
    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        raise NotImplementedError
101

102
    @classmethod
103
    def is_fully_connected(cls, device_ids: list[int]) -> bool:
104
        raise NotImplementedError
105

106
107
108
    @classmethod
    def log_warnings(cls):
        pass
109

110
    @classmethod
111
    def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
112
        parallel_config = vllm_config.parallel_config
113
        model_config = vllm_config.model_config
114

115
        if parallel_config.worker_cls == "auto":
116
            parallel_config.worker_cls = "vllm.v1.worker.gpu_worker.Worker"
117

118
119
120
        cache_config = vllm_config.cache_config
        if cache_config and cache_config.block_size is None:
            cache_config.block_size = 16
121

122
        # TODO(lucas): handle this more gracefully
123
124
        # Note: model_config may be None during testing
        if model_config is not None and model_config.use_mla:
125
126
            use_sparse = hasattr(vllm_config.model_config.hf_config,
                                 "index_topk")
127
128
129
130
131
132
            # If `VLLM_ATTENTION_BACKEND` is not set and we are using MLA,
            # then we default to FlashMLA backend for non-blackwell GPUs,
            # else we default to CutlassMLA. For each case, we force the
            # required block_size.
            use_flashmla = False
            use_cutlass_mla = False
133
            use_flashinfer_mla = False
134
135
136
137
138
139

            if envs.VLLM_ATTENTION_BACKEND is None:
                # Default case
                if cls.is_device_capability(100):
                    # Blackwell => Force CutlassMLA.
                    use_cutlass_mla = True
140
141
142
                    # TODO: This does not work, because the
                    # global_force_attn_backend_context_manager is not set.
                    # See vllm/attention/selector.py:_cached_get_attn_backend
143
                    envs.VLLM_ATTENTION_BACKEND = "CUTLASS_MLA"
144
145
146
147
148
149
150
                else:
                    # Not Blackwell
                    use_flashmla = True
            else:
                # Forced case
                use_flashmla = (envs.VLLM_ATTENTION_BACKEND == "FLASHMLA")
                use_cutlass_mla = (
151
                    envs.VLLM_ATTENTION_BACKEND == "CUTLASS_MLA")
152
153
                use_flashinfer_mla = (
                    envs.VLLM_ATTENTION_BACKEND == "FLASHINFER_MLA")
154

155
            from vllm.attention.ops.flashmla import is_flashmla_supported
156
157
158
159
160
            if use_flashmla and is_flashmla_supported()[0] \
                and cache_config.block_size != 64:
                cache_config.block_size = 64
                logger.info(
                    "Forcing kv cache block size to 64 for FlashMLA backend.")
161

162
163
164
            if use_cutlass_mla and cache_config.block_size != 128:
                cache_config.block_size = 128
                logger.info("Forcing kv cache block size to 128 for "
165
                            "CUTLASS_MLA backend.")
166

167
168
169
170
171
            if use_flashinfer_mla and cache_config.block_size not in [32, 64]:
                cache_config.block_size = 64
                logger.info(
                    "Forcing kv cache block size to 64 for FlashInferMLA "
                    "backend.")
172

173
174
175
176
177
178
            # TODO(Chen): remove this hacky code
            if use_sparse and cache_config.block_size != 64:
                cache_config.block_size = 64
                logger.info(
                    "Forcing kv cache block size to 64 for FlashMLASparse "
                    "backend.")
179
180
181
        # lazy import to avoid circular import
        from vllm.config import CUDAGraphMode

182
        compilation_config = vllm_config.compilation_config
183
184
        if (envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput"
                and parallel_config.data_parallel_size > 1
185
186
187
188
                and compilation_config.cudagraph_mode != CUDAGraphMode.NONE):
            # TODO: Piecewise Cuda graph might be enabled
            # if torch compile cache key issue fixed
            # See https://github.com/vllm-project/vllm/pull/25093
189
            logger.info(
190
191
192
193
194
195
                "WideEP: Disabling CUDA Graphs since DeepEP high-throughput "
                "kernels are optimized for prefill and are incompatible with "
                "CUDA Graphs. "
                "In order to use CUDA Graphs for decode-optimized workloads, "
                "set VLLM_ALL2ALL_BACKEND to another option, such as "
                "deepep_low_latency, pplx, or allgather_reducescatter.")
196
            compilation_config.cudagraph_mode = CUDAGraphMode.NONE
197

198
199
200
201
    @classmethod
    def get_current_memory_usage(cls,
                                 device: Optional[torch.types.Device] = None
                                 ) -> float:
202
        torch.cuda.empty_cache()
203
204
205
        torch.cuda.reset_peak_memory_stats(device)
        return torch.cuda.max_memory_allocated(device)

206
    @classmethod
207
    def get_vit_attn_backend(cls, head_size: int,
208
209
                             dtype: torch.dtype) -> "_Backend":
        from vllm.attention.backends.registry import _Backend
210
211
212
213
214
215

        # For Blackwell GPUs, force TORCH_SDPA for now.
        # See https://github.com/facebookresearch/xformers/issues/1317#issuecomment-3199392579 # noqa: E501
        if cls.has_device_capability(100):
            return _Backend.TORCH_SDPA

216
217
218
219
220
221
222
223
224
        if dtype not in (torch.float16, torch.bfloat16):
            return _Backend.XFORMERS

        if cls.has_device_capability(80):
            FLASH_ATTN_V1 = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"  # noqa: E501
            from vllm.attention.selector import is_attn_backend_supported
            is_default_fa_supported = is_attn_backend_supported(
                FLASH_ATTN_V1, head_size, dtype, allow_import_error=False)
            if is_default_fa_supported:
225
                return _Backend.FLASH_ATTN
226
227
228
229
230
231
            else:
                # Fallback to XFORMERS
                return _Backend.XFORMERS
        else:
            # Fallback for Volta/Turing GPUs or FA not supported
            return _Backend.XFORMERS
232

233
234
    @classmethod
    def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
235
                             kv_cache_dtype, block_size, use_v1, use_mla,
236
                             has_sink, use_sparse) -> str:
237
        from vllm.attention.backends.registry import _Backend
238
        if use_mla:
239
240
241
242
            if not use_v1:
                raise RuntimeError(
                    "MLA attention backends require the V1 engine. "
                    "Set VLLM_USE_V1=1 to enable them.")
243
244
245
246

            from vllm.attention.ops.flashmla import is_flashmla_supported
            from vllm.attention.utils.fa_utils import flash_attn_supports_mla

247
248
249
250
251
            if use_sparse:
                logger.info_once("Using Sparse MLA backend on V1 engine.")
                return ("vllm.v1.attention.backends.mla.flashmla_sparse."
                        "FlashMLASparseBackend")

252
253
254
            use_cutlassmla = selected_backend == _Backend.CUTLASS_MLA or (
                selected_backend is None and cls.is_device_capability(100)
                and block_size == 128)
255
256
257
            use_flashinfermla = selected_backend == _Backend.FLASHINFER_MLA or (
                selected_backend is None and cls.is_device_capability(100)
                and block_size in [32, 64])
258
259
            use_flashmla = selected_backend == _Backend.FLASHMLA or (
                selected_backend is None and is_flashmla_supported()[0])
260
261
262
263
264
265
            use_flashattn = selected_backend == _Backend.FLASH_ATTN_MLA or (
                selected_backend is None and flash_attn_supports_mla())
            use_triton = selected_backend == _Backend.TRITON_MLA or (
                selected_backend is None)

            if use_cutlassmla:
266
267
268
                logger.info_once("Using Cutlass MLA backend on V1 engine.")
                return ("vllm.v1.attention.backends.mla."
                        "cutlass_mla.CutlassMLABackend")
269
            if use_flashinfermla:
270
271
272
273
274
275
                from vllm.v1.attention.backends.utils import (
                    set_kv_cache_layout)
                set_kv_cache_layout("HND")
                logger.info_once("Using FlashInfer MLA backend on V1 engine.")
                return ("vllm.v1.attention.backends.mla."
                        "flashinfer_mla.FlashInferMLABackend")
276
277
            if use_flashmla:
                if block_size != 64:
278
279
280
281
282
                    logger.warning(
                        "FlashMLA backend is not supported for block size %d"
                        " (currently only supports block size 64).",
                        block_size)
                else:
283
                    logger.info_once("Using FlashMLA backend on V1 engine.")
284
                    return ("vllm.v1.attention.backends.mla."
285
286
287
288
289
290
                            "flashmla.FlashMLABackend")
            if use_flashattn:
                logger.info_once(
                    "Using FlashAttention MLA backend on V1 engine.")
                return ("vllm.v1.attention.backends.mla."
                        "flashattn_mla.FlashAttnMLABackend")
291
            if use_triton:
292
293
294
                logger.info_once("Using Triton MLA backend on V1 engine.")
                return ("vllm.v1.attention.backends.mla."
                        "triton_mla.TritonMLABackend")
295
        if use_v1:
296
297
            FLASHINFER_V1 = "vllm.v1.attention.backends.flashinfer.FlashInferBackend"  # noqa: E501
            FLEX_ATTENTION_V1 = "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend"  # noqa: E501
298
            TRITON_ATTN = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend"  # noqa: E501
299
            FLASH_ATTN_V1 = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"  # noqa: E501
300
            TREE_ATTN_V1 = "vllm.v1.attention.backends.tree_attn.TreeAttentionBackend"  # noqa: E501
301
            XFORMERS_V1 = "vllm.v1.attention.backends.xformers.XFormersAttentionBackend"  # noqa: E501
302

303
304
305
            use_fp8_kv_cache = (kv_cache_dtype is not None
                                and kv_cache_dtype.startswith("fp8"))

306
307
            if selected_backend == _Backend.FLASHINFER:
                logger.info_once("Using FlashInfer backend on V1 engine.")
308
309
310
311
                if cls.has_device_capability(100):
                    from vllm.v1.attention.backends.utils import (
                        set_kv_cache_layout)
                    set_kv_cache_layout("HND")
312
                return FLASHINFER_V1
313
            elif selected_backend == _Backend.FLEX_ATTENTION:
314
315
                logger.info_once("Using FlexAttention backend on V1 engine.")
                return FLEX_ATTENTION_V1
316
            elif selected_backend == _Backend.TRITON_ATTN:
317
                logger.info_once("Using Triton backend on V1 engine.")
318
                return TRITON_ATTN
319
320
            elif selected_backend == _Backend.FLASH_ATTN:
                logger.info_once("Using Flash Attention backend on V1 engine.")
321
                return FLASH_ATTN_V1
322
323
324
            elif selected_backend == _Backend.TREE_ATTN:
                logger.info_once("Using Tree Attention backend on V1 engine.")
                return TREE_ATTN_V1
325
            elif selected_backend == _Backend.XFORMERS:
326
327
                logger.info_once("Using XFormers backend on V1 engine.")
                return XFORMERS_V1
328

329
            from vllm.attention.selector import is_attn_backend_supported
330
331

            # Default backends for V1 engine
332
            # Prefer FlashInfer for Blackwell GPUs if installed
333
334
335
            if cls.is_device_capability(100):
                if is_default_backend_supported := is_attn_backend_supported(
                        FLASHINFER_V1, head_size, dtype):
336
337
                    from vllm.v1.attention.backends.utils import (
                        set_kv_cache_layout)
338

339
                    logger.info_once(
340
341
342
                        "Using FlashInfer backend with HND KV cache layout on "
                        "V1 engine by default for Blackwell (SM 10.0) GPUs.")
                    set_kv_cache_layout("HND")
343

344
                    return FLASHINFER_V1
345
346
347

                if not is_default_backend_supported.can_import:
                    logger.warning_once(
348
349
350
                        "FlashInfer failed to import for V1 engine on "
                        "Blackwell (SM 10.0) GPUs; it is recommended to "
                        "install FlashInfer for better performance.")
351

352
            # FlashAttention is the default for SM 8.0+ GPUs
353
            if cls.has_device_capability(80):
354
355
                if (has_sink or
                        use_fp8_kv_cache) and not cls.is_device_capability(90):
356
                    logger.info_once("Using Triton backend on V1 engine.")
357
                    return TRITON_ATTN
358
                elif is_default_backend_supported := is_attn_backend_supported(
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
                        FLASH_ATTN_V1, head_size, dtype,
                        allow_import_error=False):
                    logger.info_once("Using Flash Attention backend on "
                                     "V1 engine.")
                    return FLASH_ATTN_V1

            # FlexAttention is the default for older GPUs
            else:
                logger.info_once("Using FlexAttention backend on V1 engine.")
                return FLEX_ATTENTION_V1

            assert not is_default_backend_supported

            use_flex_attention_reason = {}
            if not is_default_backend_supported.head_size:
                use_flex_attention_reason["head_size"] = head_size
            if not is_default_backend_supported.dtype:
                use_flex_attention_reason["dtype"] = dtype
377

378
379
380
381
382
            logger.info_once(
                "Using FlexAttention backend for %s on V1 engine.",
                ", ".join(f"{k}={v}"
                          for k, v in use_flex_attention_reason.items()),
            )
383
            return FLEX_ATTENTION_V1
384

385
386
387
        raise RuntimeError(
            "V0 attention backends have been removed. Set VLLM_USE_V1=1 "
            "to select a supported backend.")
388

389
390
391
392
    @classmethod
    def get_punica_wrapper(cls) -> str:
        return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU"

393
394
395
396
    @classmethod
    def get_device_communicator_cls(cls) -> str:
        return "vllm.distributed.device_communicators.cuda_communicator.CudaCommunicator"  # noqa

397
398
399
400
    @classmethod
    def supports_fp8(cls) -> bool:
        return cls.has_device_capability(89)

401
402
403
404
    @classmethod
    def use_custom_allreduce(cls) -> bool:
        return True

405
406
407
408
    @classmethod
    def opaque_attention_op(cls) -> bool:
        return True

409
    @classmethod
410
411
    def get_static_graph_wrapper_cls(cls) -> str:
        return "vllm.compilation.cuda_graph.CUDAGraphWrapper"
412

413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
    @classmethod
    def stateless_init_device_torch_dist_pg(
        cls,
        backend: str,
        prefix_store: PrefixStore,
        group_rank: int,
        group_size: int,
        timeout: timedelta,
    ) -> ProcessGroup:
        assert is_nccl_available()
        pg: ProcessGroup = ProcessGroup(
            prefix_store,
            group_rank,
            group_size,
        )
        from torch.distributed.distributed_c10d import ProcessGroupNCCL

        backend_options = ProcessGroupNCCL.Options()
        backend_options._timeout = timeout

        backend_class = ProcessGroupNCCL(prefix_store, group_rank, group_size,
                                         backend_options)
        backend_type = ProcessGroup.BackendType.NCCL
        device = torch.device("cuda")
        pg._set_default_backend(backend_type)
        backend_class._set_sequence_number_for_group()

        pg._register_backend(device, backend_type, backend_class)
        return pg

443
444
445
446
    @classmethod
    def device_count(cls) -> int:
        return cuda_device_count_stateless()

447
    @classmethod
448
449
    def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str,
                                    model_config: "ModelConfig") -> bool:
450
        fp8_attention = kv_cache_dtype.startswith("fp8")
451
452
        attention_backend = envs.VLLM_ATTENTION_BACKEND

453
        supported = False
454
455
456
457
458
459
460
461
462
        if model_config is not None and model_config.use_mla:
            # Default to CutlassMLA for blackwell,
            # FlashMLA otherwise
            if attention_backend is None:
                if cls.is_device_capability(100):
                    attention_backend = "CUTLASS_MLA"
                else:
                    attention_backend = "FLASHMLA"

463
            # Only FlashMLA and CUTLASS_MLA support fp8
464
465
466
            if attention_backend in [
                    "FLASHMLA", "CUTLASS_MLA", "FLASHINFER_MLA"
            ]:
467
468
469
470
471
472
                supported = True
            else:
                supported = (not fp8_attention)
        else:
            # Default to FlashAttention
            if attention_backend is None:
473
                attention_backend = "FLASH_ATTN"
474
475
476
477

            # All Blackwell backends support fp8
            if cls.is_device_capability(100):
                supported = True
478
            elif attention_backend == "FLASH_ATTN":
479
480
481
482
483
484
                if fp8_attention:
                    from vllm.attention.utils.fa_utils import (
                        flash_attn_supports_fp8)
                    supported = flash_attn_supports_fp8()
                else:
                    supported = True
485
486
            elif attention_backend == "FLASHINFER":
                supported = True
487
            elif attention_backend == "TRITON_ATTN":
488
                supported = cls.supports_fp8()
489
490
        return supported

491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
    @classmethod
    def check_if_supports_dtype(cls, torch_dtype: torch.dtype):
        if torch_dtype == torch.bfloat16:  # noqa: SIM102
            if not cls.has_device_capability(80):
                capability = cls.get_device_capability()
                gpu_name = cls.get_device_name()

                if capability is None:
                    compute_str = "does not have a compute capability"
                else:
                    version_str = capability.as_version_str()
                    compute_str = f"has compute capability {version_str}"

                raise ValueError(
                    "Bfloat16 is only supported on GPUs "
                    "with compute capability of at least 8.0. "
                    f"Your {gpu_name} GPU {compute_str}. "
                    "You can use float16 instead by explicitly setting the "
                    "`dtype` flag in CLI, for example: --dtype=half.")

511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
    @classmethod
    def insert_blocks_to_device(
        cls,
        src_cache: torch.Tensor,
        dst_cache: torch.Tensor,
        src_block_indices: torch.Tensor,
        dst_block_indices: torch.Tensor,
    ) -> None:
        """Copy blocks from src_cache to dst_cache on GPU."""
        _src_cache = src_cache[:, src_block_indices]
        dst_cache[:, dst_block_indices] = _src_cache.to(dst_cache.device)

    @classmethod
    def swap_out_blocks_to_host(
        cls,
        src_cache: torch.Tensor,
        dst_cache: torch.Tensor,
        src_block_indices: torch.Tensor,
        dst_block_indices: torch.Tensor,
    ) -> None:
        """Copy blocks from GPU to host (CPU)."""
        _src_cache = src_cache[:, src_block_indices]
        dst_cache[:, dst_block_indices] = _src_cache.cpu()

535
536
537
538
    @classmethod
    def support_hybrid_kv_cache(cls) -> bool:
        return True

539
540
541
542
    @classmethod
    def support_static_graph_mode(cls) -> bool:
        return True

543

544
545
546
547
548
# NVML utils
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
# all the related functions work on real physical device ids.
# the major benefit of using NVML is that it will not initialize CUDA
class NvmlCudaPlatform(CudaPlatformBase):
549

550
    @classmethod
551
    @cache
552
    @with_nvml_context
553
554
555
556
    def get_device_capability(cls,
                              device_id: int = 0
                              ) -> Optional[DeviceCapability]:
        try:
557
            physical_device_id = cls.device_id_to_physical_device_id(device_id)
558
559
560
561
562
563
564
565
566
567
            handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
            major, minor = pynvml.nvmlDeviceGetCudaComputeCapability(handle)
            return DeviceCapability(major=major, minor=minor)
        except RuntimeError:
            return None

    @classmethod
    @with_nvml_context
    def has_device_capability(
        cls,
568
        capability: Union[tuple[int, int], int],
569
570
571
572
573
574
        device_id: int = 0,
    ) -> bool:
        try:
            return super().has_device_capability(capability, device_id)
        except RuntimeError:
            return False
575

576
    @classmethod
577
    @with_nvml_context
578
    def get_device_name(cls, device_id: int = 0) -> str:
579
        physical_device_id = cls.device_id_to_physical_device_id(device_id)
580
        return cls._get_physical_device_name(physical_device_id)
581

582
583
584
    @classmethod
    @with_nvml_context
    def get_device_uuid(cls, device_id: int = 0) -> str:
585
        physical_device_id = cls.device_id_to_physical_device_id(device_id)
586
587
588
        handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
        return pynvml.nvmlDeviceGetUUID(handle)

589
    @classmethod
590
    @with_nvml_context
591
    def get_device_total_memory(cls, device_id: int = 0) -> int:
592
        physical_device_id = cls.device_id_to_physical_device_id(device_id)
593
594
        handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
        return int(pynvml.nvmlDeviceGetMemoryInfo(handle).total)
595

596
    @classmethod
597
    @with_nvml_context
598
    def is_fully_connected(cls, physical_device_ids: list[int]) -> bool:
599
600
601
602
603
604
605
606
607
608
609
        """
        query if the set of gpus are fully connected by nvlink (1 hop)
        """
        handles = [
            pynvml.nvmlDeviceGetHandleByIndex(i) for i in physical_device_ids
        ]
        for i, handle in enumerate(handles):
            for j, peer_handle in enumerate(handles):
                if i < j:
                    try:
                        p2p_status = pynvml.nvmlDeviceGetP2PStatus(
610
611
612
613
                            handle,
                            peer_handle,
                            pynvml.NVML_P2P_CAPS_INDEX_NVLINK,
                        )
614
615
                        if p2p_status != pynvml.NVML_P2P_STATUS_OK:
                            return False
616
617
                    except pynvml.NVMLError:
                        logger.exception(
618
619
                            "NVLink detection failed. This is normal if"
                            " your machine has no NVLink equipped.")
620
621
                        return False
        return True
622
623

    @classmethod
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
    def _get_physical_device_name(cls, device_id: int = 0) -> str:
        handle = pynvml.nvmlDeviceGetHandleByIndex(device_id)
        return pynvml.nvmlDeviceGetName(handle)

    @classmethod
    @with_nvml_context
    def log_warnings(cls):
        device_ids: int = pynvml.nvmlDeviceGetCount()
        if device_ids > 1:
            device_names = [
                cls._get_physical_device_name(i) for i in range(device_ids)
            ]
            if (len(set(device_names)) > 1
                    and os.environ.get("CUDA_DEVICE_ORDER") != "PCI_BUS_ID"):
                logger.warning(
639
                    "Detected different devices in the system: %s. Please"
640
641
                    " make sure to set `CUDA_DEVICE_ORDER=PCI_BUS_ID` to "
                    "avoid unexpected behavior.",
642
                    ", ".join(device_names),
643
644
645
646
647
648
                )


class NonNvmlCudaPlatform(CudaPlatformBase):

    @classmethod
649
    @cache
650
651
652
653
654
655
656
657
658
659
660
661
662
663
    def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:
        major, minor = torch.cuda.get_device_capability(device_id)
        return DeviceCapability(major=major, minor=minor)

    @classmethod
    def get_device_name(cls, device_id: int = 0) -> str:
        return torch.cuda.get_device_name(device_id)

    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        device_props = torch.cuda.get_device_properties(device_id)
        return device_props.total_memory

    @classmethod
664
    def is_fully_connected(cls, physical_device_ids: list[int]) -> bool:
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
        logger.exception(
            "NVLink detection not possible, as context support was"
            " not found. Assuming no NVLink available.")
        return False


# Autodetect either NVML-enabled or non-NVML platform
# based on whether NVML is available.
nvml_available = False
try:
    try:
        pynvml.nvmlInit()
        nvml_available = True
    except Exception:
        # On Jetson, NVML is not supported.
        nvml_available = False
finally:
    if nvml_available:
        pynvml.nvmlShutdown()

CudaPlatform = NvmlCudaPlatform if nvml_available else NonNvmlCudaPlatform

687
CudaPlatform.log_warnings()