cuda.py 30.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
"""Code inside this file can safely assume cuda platform, e.g. importing
pynvml. However, it should not initialize cuda context.
"""

7
8
from __future__ import annotations

9
import os
10
from collections.abc import Callable
11
from datetime import timedelta
12
from functools import cache, lru_cache, wraps
13
from typing import TYPE_CHECKING, TypeVar
14

15
import torch
16
17
from torch.distributed import PrefixStore, ProcessGroup
from torch.distributed.distributed_c10d import is_nccl_available
18
from typing_extensions import ParamSpec
19

20
21
# import custom ops, trigger op registration
import vllm._C  # noqa
22
import vllm._C_stable_libtorch  # noqa
23
import vllm.envs as envs
24
from vllm.logger import init_logger
25
from vllm.utils.import_utils import import_pynvml
26
from vllm.utils.torch_utils import is_quantized_kv_cache
27
from vllm.v1.attention.backends.registry import AttentionBackendEnum
28

29
from .interface import DeviceCapability, Platform, PlatformEnum
30

31
if TYPE_CHECKING:
32
    from vllm.config import VllmConfig
33
    from vllm.config.cache import CacheDType
34
    from vllm.config.kernel import IrOpPriorityConfig
35
    from vllm.v1.attention.selector import AttentionSelectorConfig
36
else:
37
38
    VllmConfig = None
    CacheDType = None
39

40
41
logger = init_logger(__name__)

42
43
44
_P = ParamSpec("_P")
_R = TypeVar("_R")

45
pynvml = import_pynvml()
46

47
48
49
50
# pytorch 2.5 uses cudnn sdpa by default, which will cause crash on some models
# see https://github.com/huggingface/diffusers/issues/9704 for details
torch.backends.cuda.enable_cudnn_sdp(False)

51

52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
@lru_cache(maxsize=8)
def _cuda_device_count_stateless(cuda_visible_devices: str | None = None) -> int:
    """Get number of CUDA devices, caching based on the value of CUDA_VISIBLE_DEVICES
    at the time of call.

    This should be used instead of torch.accelerator.device_count() unless
    CUDA_VISIBLE_DEVICES has already been set to the desired value.

    # This can be removed and simply replaced with torch.cuda.get_device_count
    # after https://github.com/pytorch/pytorch/pull/122815 is released."""
    # Note: cuda_visible_devices is not used, but we keep it as an argument for
    # LRU Cache purposes.

    # Code below is based on
    # https://github.com/pytorch/pytorch/blob/
    # c1cd946818442aca8c7f812b16d187ce1586c3bc/
    # torch/cuda/__init__.py#L831C1-L831C17
    import torch.cuda

    if not torch.cuda._is_compiled():
        return 0
    raw_count = torch.cuda._device_count_nvml()
    r = torch._C._cuda_getDeviceCount() if raw_count < 0 else raw_count
    return r


78
79
80
81
@cache
def _get_backend_priorities(
    use_mla: bool,
    device_capability: DeviceCapability,
82
    num_heads: int | None = None,
83
    kv_cache_dtype: CacheDType | None = None,
84
85
86
87
) -> list[AttentionBackendEnum]:
    """Get backend priorities with lazy import to avoid circular dependency."""
    if use_mla:
        if device_capability.major == 10:
88
89
90
            # Sparse MLA backend priorities
            # See https://github.com/vllm-project/vllm/issues/35807 for
            # benchmark results
91
            if kv_cache_dtype is not None and is_quantized_kv_cache(kv_cache_dtype):
92
                # Prefer FlashInfer for fp8 kv cache
93
94
95
96
97
                sparse_backends = [
                    AttentionBackendEnum.FLASHINFER_MLA_SPARSE,
                    AttentionBackendEnum.FLASHMLA_SPARSE,
                ]
            else:
98
99
100
101
102
103
104
105
106
107
108
109
110
                # BF16 KV Cache
                # Prefer FlashInfer at low head counts (FlashMLA uses padding)
                if num_heads is not None and num_heads <= 16:
                    sparse_backends = [
                        AttentionBackendEnum.FLASHINFER_MLA_SPARSE,
                        AttentionBackendEnum.FLASHMLA_SPARSE,
                    ]
                else:
                    sparse_backends = [
                        AttentionBackendEnum.FLASHMLA_SPARSE,
                        AttentionBackendEnum.FLASHINFER_MLA_SPARSE,
                    ]

111
            return [
112
                AttentionBackendEnum.FLASHINFER_MLA,
113
                AttentionBackendEnum.CUTLASS_MLA,
114
                AttentionBackendEnum.FLASH_ATTN_MLA,
115
                AttentionBackendEnum.FLASHMLA,
116
                AttentionBackendEnum.TRITON_MLA,
117
                *sparse_backends,
118
119
120
121
            ]
        else:
            return [
                AttentionBackendEnum.FLASH_ATTN_MLA,
122
                AttentionBackendEnum.FLASHMLA,
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
                AttentionBackendEnum.FLASHINFER_MLA,
                AttentionBackendEnum.TRITON_MLA,
                AttentionBackendEnum.FLASHMLA_SPARSE,
            ]
    else:
        if device_capability.major == 10:
            return [
                AttentionBackendEnum.FLASHINFER,
                AttentionBackendEnum.FLASH_ATTN,
                AttentionBackendEnum.TRITON_ATTN,
                AttentionBackendEnum.FLEX_ATTENTION,
            ]
        else:
            return [
                AttentionBackendEnum.FLASH_ATTN,
                AttentionBackendEnum.FLASHINFER,
                AttentionBackendEnum.TRITON_ATTN,
                AttentionBackendEnum.FLEX_ATTENTION,
            ]


144
def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]:
145
    @wraps(fn)
146
    def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
147
148
149
150
151
152
153
154
155
        pynvml.nvmlInit()
        try:
            return fn(*args, **kwargs)
        finally:
            pynvml.nvmlShutdown()

    return wrapper


156
157
class CudaPlatformBase(Platform):
    _enum = PlatformEnum.CUDA
158
    device_name: str = "cuda"
159
160
    device_type: str = "cuda"
    dispatch_key: str = "CUDA"
161
    ray_device_key: str = "GPU"
162
    dist_backend: str = "nccl"
163
    device_control_env_var: str = "CUDA_VISIBLE_DEVICES"
164
165
166
    ray_noset_device_env_vars: list[str] = [
        "RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES",
    ]
167

168
    @property
169
    def supported_dtypes(self) -> list[torch.dtype]:
170
171
172
        if self.has_device_capability(80):
            # Ampere and Hopper or later NVIDIA GPUs.
            return [torch.bfloat16, torch.float16, torch.float32]
173
        if self.has_device_capability(60):
174
175
176
177
178
179
            # Pascal, Volta and Turing NVIDIA GPUs, BF16 is not supported
            return [torch.float16, torch.float32]
        # Kepler and Maxwell NVIDIA GPUs, only FP32 is supported,
        # though vLLM doesn't support these GPUs.
        return [torch.float32]

180
181
182
183
184
    @classmethod
    def set_device(cls, device: torch.device) -> None:
        """
        Set the device for the current platform.
        """
185
        torch.cuda.set_device(device)
186
187
188
189
190
        # With this trick we can force the device to be set eagerly
        # see https://github.com/pytorch/pytorch/issues/155668
        # for why and when it is needed
        _ = torch.zeros(1, device=device)

191
192
193
194
    @classmethod
    def manual_seed_all(cls, seed: int) -> None:
        torch.cuda.manual_seed_all(seed)

195
    @classmethod
196
    def get_device_capability(cls, device_id: int = 0) -> DeviceCapability | None:
197
        raise NotImplementedError
198

199
200
201
    @classmethod
    def get_device_name(cls, device_id: int = 0) -> str:
        raise NotImplementedError
202

203
204
205
    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        raise NotImplementedError
206

207
    @classmethod
208
    def is_fully_connected(cls, device_ids: list[int]) -> bool:
209
        raise NotImplementedError
210

211
212
213
    @classmethod
    def log_warnings(cls):
        pass
214

215
    @classmethod
216
    def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
217
        parallel_config = vllm_config.parallel_config
218
        model_config = vllm_config.model_config
219

220
        if parallel_config.worker_cls == "auto":
221
            parallel_config.worker_cls = "vllm.v1.worker.gpu_worker.Worker"
222

223
224
225
226
227
228
229
230
231
232
233
234
235
236
        scheduler_config = vllm_config.scheduler_config
        # Note: model_config may be None during testing
        if (
            model_config is not None
            and model_config.is_mm_prefix_lm
            and scheduler_config.is_multimodal_model
            and not scheduler_config.disable_chunked_mm_input
        ):
            logger.warning(
                "Forcing --disable_chunked_mm_input for models "
                "with multimodal-bidirectional attention."
            )
            scheduler_config.disable_chunked_mm_input = True

237
    @classmethod
238
    def get_current_memory_usage(
239
        cls, device: torch.types.Device | None = None
240
    ) -> float:
241
        torch.cuda.empty_cache()
242
243
244
        torch.cuda.reset_peak_memory_stats(device)
        return torch.cuda.max_memory_allocated(device)

245
    @classmethod
246
    def get_valid_backends(
247
        cls,
248
        device_capability: DeviceCapability,
249
        attn_selector_config: AttentionSelectorConfig,
250
        num_heads: int | None = None,
251
    ) -> tuple[
252
253
        list[tuple[AttentionBackendEnum, int]],
        dict[AttentionBackendEnum, tuple[int, list[str]]],
254
255
    ]:
        valid_backends_priorities = []
256
        invalid_reasons: dict[AttentionBackendEnum, tuple[int, list[str]]] = {}
257

258
        backend_priorities = _get_backend_priorities(
259
260
261
            attn_selector_config.use_mla,
            device_capability,
            num_heads,
262
            attn_selector_config.kv_cache_dtype,
263
        )
264
265
266
267
        for priority, backend in enumerate(backend_priorities):
            try:
                backend_class = backend.get_class()
                invalid_reasons_i = backend_class.validate_configuration(
268
269
                    device_capability=device_capability,
                    **attn_selector_config._asdict(),
270
                )
271
272
273
            except ImportError:
                invalid_reasons_i = ["ImportError"]
            if invalid_reasons_i:
274
                invalid_reasons[backend] = (priority, invalid_reasons_i)
275
276
            else:
                valid_backends_priorities.append((backend, priority))
277

278
        return valid_backends_priorities, invalid_reasons
279

280
    @classmethod
281
    def get_attn_backend_cls(
282
        cls,
283
284
        selected_backend: AttentionBackendEnum | None,
        attn_selector_config: AttentionSelectorConfig,
285
        num_heads: int | None = None,
286
287
288
289
    ) -> str:
        device_capability = cls.get_device_capability()
        assert device_capability is not None

290
291
292
293
        # First try checking just the selected backend, if there is one.
        if selected_backend is not None:
            try:
                backend_class = selected_backend.get_class()
294
                invalid_reasons = backend_class.validate_configuration(
295
296
                    device_capability=device_capability,
                    **attn_selector_config._asdict(),
297
                )
298
            except ImportError:
299
300
301
302
303
304
305
306
307
                invalid_reasons = ["ImportError"]
            if invalid_reasons:
                raise ValueError(
                    f"Selected backend {selected_backend} is not valid for "
                    f"this configuration. Reason: {invalid_reasons}"
                )
            else:
                logger.info("Using %s backend.", selected_backend)
                return selected_backend.get_path()
308

309
310
        # No selected backend or the selected backend is invalid,
        # so we try finding a valid backend.
311
        valid_backends_priorities, all_invalid_reasons = cls.get_valid_backends(
312
313
            device_capability=device_capability,
            attn_selector_config=attn_selector_config,
314
            num_heads=num_heads,
315
        )
316
317
318
319
        reasons_str = (
            "{"
            + ", ".join(
                f"{backend.name}: [{', '.join(reasons)}]"
320
                for backend, (_, reasons) in all_invalid_reasons.items()
321
322
323
324
325
326
327
328
            )
            + "}"
        )
        config_str = attn_selector_config.__repr__()
        logger.debug_once(
            f"Some attention backends are not valid for {cls.device_name} with "
            f"{config_str}. Reasons: {reasons_str}."
        )
329
        if len(valid_backends_priorities) == 0:
330
331
332
333
            raise ValueError(
                f"No valid attention backend found for {cls.device_name} "
                f"with {config_str}. Reasons: {reasons_str}."
            )
334

335
336
337
338
339
340
341
342
        # We have found some valid backends. Select the one with the
        # highest priority.
        sorted_indices = sorted(
            range(len(valid_backends_priorities)),
            key=lambda i: valid_backends_priorities[i][1],
        )
        selected_index = sorted_indices[0]
        selected_backend = valid_backends_priorities[selected_index][0]
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
        selected_priority = valid_backends_priorities[selected_index][1]

        # If the user specified --block-size (but not --attention-backend),
        # check whether that constraint precluded any higher-priority backends.
        if attn_selector_config.block_size is not None:
            excluded = [
                backend
                for backend, (priority, reasons) in all_invalid_reasons.items()
                if priority < selected_priority
                and reasons == ["block_size not supported"]
            ]
            if excluded:
                names = ", ".join(b.name for b in excluded)
                logger.warning(
                    "--block-size %d precluded higher-priority backend(s) "
                    "%s. Using %s instead, which may result in reduced "
                    "performance. Consider removing --block-size to "
                    "auto-select the optimal block size.",
                    attn_selector_config.block_size,
                    names,
                    selected_backend.name,
                )

366
367
368
369
370
        logger.info_once(
            "Using %s attention backend out of potential backends: %s.",
            selected_backend.name,
            "[" + ", ".join(f"'{b[0].name}'" for b in valid_backends_priorities) + "]",
            scope="local",
371
372
        )

373
        return selected_backend.get_path()
374

375
    @classmethod
376
    def get_supported_vit_attn_backends(cls) -> list[AttentionBackendEnum]:
377
378
379
380
381
382
383
384
385
386
387
388
389
390
        if cls.has_device_capability(80):
            return [
                AttentionBackendEnum.FLASH_ATTN,
                AttentionBackendEnum.TRITON_ATTN,
                AttentionBackendEnum.TORCH_SDPA,
                AttentionBackendEnum.FLASHINFER,
            ]
        else:
            return [
                AttentionBackendEnum.FLASH_ATTN,
                AttentionBackendEnum.TORCH_SDPA,
                AttentionBackendEnum.TRITON_ATTN,
                AttentionBackendEnum.FLASHINFER,
            ]
391
392
393
394
395
396

    @classmethod
    def get_vit_attn_backend(
        cls,
        head_size: int,
        dtype: torch.dtype,
397
398
        backend: AttentionBackendEnum | None = None,
    ) -> AttentionBackendEnum:
399
400
401
402
403
404
405
406
        if backend is not None:
            assert backend in cls.get_supported_vit_attn_backends(), (
                f"Backend {backend} is not supported for vit attention. "
                f"Supported backends are: {cls.get_supported_vit_attn_backends()}"
            )
            logger.info_once(f"Using backend {backend} for vit attention")
            return backend

407
408
409
        cc = cls.get_device_capability()
        for vit_attn_backend in cls.get_supported_vit_attn_backends():
            if vit_attn_backend == AttentionBackendEnum.TORCH_SDPA:
410
                return vit_attn_backend
411
            try:
412
413
                backend_class = vit_attn_backend.get_class()
                is_backend_supported = backend_class.supports_head_size(
414
                    head_size
415
416
417
418
419
420
421
422
                ) and backend_class.supports_dtype(dtype)
                if cc is not None:
                    is_backend_supported = (
                        is_backend_supported
                        and backend_class.supports_compute_capability(cc)
                    )
                if is_backend_supported:
                    logger.info_once(
423
424
                        f"Using backend {vit_attn_backend} for vit attention",
                        scope="local",
425
426
                    )
                    return vit_attn_backend
427
428
429
430
431
            except ImportError:
                pass

        return AttentionBackendEnum.TORCH_SDPA

432
433
434
435
    @classmethod
    def get_punica_wrapper(cls) -> str:
        return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU"

436
437
    @classmethod
    def get_device_communicator_cls(cls) -> str:
438
439
440
        return (
            "vllm.distributed.device_communicators.cuda_communicator.CudaCommunicator"  # noqa
        )
441

442
443
444
445
    @classmethod
    def supports_fp8(cls) -> bool:
        return cls.has_device_capability(89)

446
447
448
449
    @classmethod
    def use_custom_allreduce(cls) -> bool:
        return True

450
451
452
453
    @classmethod
    def opaque_attention_op(cls) -> bool:
        return True

454
    @classmethod
455
456
    def get_static_graph_wrapper_cls(cls) -> str:
        return "vllm.compilation.cuda_graph.CUDAGraphWrapper"
457

458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
    @classmethod
    def stateless_init_device_torch_dist_pg(
        cls,
        backend: str,
        prefix_store: PrefixStore,
        group_rank: int,
        group_size: int,
        timeout: timedelta,
    ) -> ProcessGroup:
        assert is_nccl_available()
        pg: ProcessGroup = ProcessGroup(
            prefix_store,
            group_rank,
            group_size,
        )
        from torch.distributed.distributed_c10d import ProcessGroupNCCL

        backend_options = ProcessGroupNCCL.Options()
        backend_options._timeout = timeout

        backend_class = ProcessGroupNCCL(
            prefix_store, group_rank, group_size, backend_options
        )
        backend_type = ProcessGroup.BackendType.NCCL
        device = torch.device("cuda")
        pg._set_default_backend(backend_type)
        backend_class._set_sequence_number_for_group()

        pg._register_backend(device, backend_type, backend_class)
        return pg

489
490
    @classmethod
    def device_count(cls) -> int:
491
        return _cuda_device_count_stateless(envs.CUDA_VISIBLE_DEVICES)
492

493
    @classmethod
494
495
    def check_if_supports_dtype(cls, dtype: torch.dtype):
        if dtype == torch.bfloat16:  # noqa: SIM102
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
            if not cls.has_device_capability(80):
                capability = cls.get_device_capability()
                gpu_name = cls.get_device_name()

                if capability is None:
                    compute_str = "does not have a compute capability"
                else:
                    version_str = capability.as_version_str()
                    compute_str = f"has compute capability {version_str}"

                raise ValueError(
                    "Bfloat16 is only supported on GPUs "
                    "with compute capability of at least 8.0. "
                    f"Your {gpu_name} GPU {compute_str}. "
                    "You can use float16 instead by explicitly setting the "
511
512
                    "`dtype` flag in CLI, for example: --dtype=half."
                )
513

514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
    @classmethod
    def insert_blocks_to_device(
        cls,
        src_cache: torch.Tensor,
        dst_cache: torch.Tensor,
        src_block_indices: torch.Tensor,
        dst_block_indices: torch.Tensor,
    ) -> None:
        """Copy blocks from src_cache to dst_cache on GPU."""
        _src_cache = src_cache[:, src_block_indices]
        dst_cache[:, dst_block_indices] = _src_cache.to(dst_cache.device)

    @classmethod
    def swap_out_blocks_to_host(
        cls,
        src_cache: torch.Tensor,
        dst_cache: torch.Tensor,
        src_block_indices: torch.Tensor,
        dst_block_indices: torch.Tensor,
    ) -> None:
        """Copy blocks from GPU to host (CPU)."""
        _src_cache = src_cache[:, src_block_indices]
        dst_cache[:, dst_block_indices] = _src_cache.cpu()

538
539
540
541
    @classmethod
    def support_hybrid_kv_cache(cls) -> bool:
        return True

542
543
544
545
    @classmethod
    def support_static_graph_mode(cls) -> bool:
        return True

546
547
548
549
550
    @classmethod
    def support_deep_gemm(cls) -> bool:
        """Currently, only Hopper and Blackwell GPUs are supported."""
        return cls.is_device_capability(90) or cls.is_device_capability_family(100)

551
    @classmethod
552
    def num_compute_units(cls, device_id: int = 0) -> int:
553
554
        return torch.cuda.get_device_properties(device_id).multi_processor_count

555
556
557
558
    @classmethod
    def use_custom_op_collectives(cls) -> bool:
        return True

559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
    @classmethod
    def get_default_ir_op_priority(cls, vllm_config: VllmConfig) -> IrOpPriorityConfig:
        from vllm.config.compilation import CompilationMode
        from vllm.config.kernel import IrOpPriorityConfig

        # Native used by default when compiling,
        # use vllm_c kernels where available when no codegen
        cc = vllm_config.compilation_config
        using_inductor = cc.backend == "inductor" and cc.mode != CompilationMode.NONE
        default = ["native"] if using_inductor else ["vllm_c", "native"]

        # Use oink if enabled for rms_norm
        # TODO(Laurawly/luka): remove this env var,
        #  users can just use IR op priority directly
        rms_norm = default
        if envs.VLLM_USE_OINK_OPS:
            rms_norm = ["oink"] + default

        return IrOpPriorityConfig.with_default(default, rms_norm=rms_norm)

579

580
581
582
583
584
# NVML utils
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
# all the related functions work on real physical device ids.
# the major benefit of using NVML is that it will not initialize CUDA
class NvmlCudaPlatform(CudaPlatformBase):
585
    @classmethod
586
    @cache
587
    @with_nvml_context
588
    def get_device_capability(cls, device_id: int = 0) -> DeviceCapability | None:
589
        try:
590
            physical_device_id = cls.device_id_to_physical_device_id(device_id)
591
592
593
594
595
596
597
598
599
600
            handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
            major, minor = pynvml.nvmlDeviceGetCudaComputeCapability(handle)
            return DeviceCapability(major=major, minor=minor)
        except RuntimeError:
            return None

    @classmethod
    @with_nvml_context
    def has_device_capability(
        cls,
601
        capability: tuple[int, int] | int,
602
603
604
605
606
607
        device_id: int = 0,
    ) -> bool:
        try:
            return super().has_device_capability(capability, device_id)
        except RuntimeError:
            return False
608

609
    @classmethod
610
    @with_nvml_context
611
    def get_device_name(cls, device_id: int = 0) -> str:
612
        physical_device_id = cls.device_id_to_physical_device_id(device_id)
613
        return cls._get_physical_device_name(physical_device_id)
614

615
616
617
    @classmethod
    @with_nvml_context
    def get_device_uuid(cls, device_id: int = 0) -> str:
618
        physical_device_id = cls.device_id_to_physical_device_id(device_id)
619
620
621
        handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
        return pynvml.nvmlDeviceGetUUID(handle)

622
    @classmethod
623
    @with_nvml_context
624
    def get_device_total_memory(cls, device_id: int = 0) -> int:
625
        physical_device_id = cls.device_id_to_physical_device_id(device_id)
626
627
        handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
        return int(pynvml.nvmlDeviceGetMemoryInfo(handle).total)
628

629
    @classmethod
630
    @with_nvml_context
631
    def is_fully_connected(cls, physical_device_ids: list[int]) -> bool:
632
633
634
        """
        query if the set of gpus are fully connected by nvlink (1 hop)
        """
635
        handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in physical_device_ids]
636
637
638
639
640
        for i, handle in enumerate(handles):
            for j, peer_handle in enumerate(handles):
                if i < j:
                    try:
                        p2p_status = pynvml.nvmlDeviceGetP2PStatus(
641
642
643
644
                            handle,
                            peer_handle,
                            pynvml.NVML_P2P_CAPS_INDEX_NVLINK,
                        )
645
646
                        if p2p_status != pynvml.NVML_P2P_STATUS_OK:
                            return False
647
648
                    except pynvml.NVMLError:
                        logger.exception(
649
                            "NVLink detection failed. This is normal if"
650
651
                            " your machine has no NVLink equipped."
                        )
652
653
                        return False
        return True
654
655

    @classmethod
656
657
658
659
    def _get_physical_device_name(cls, device_id: int = 0) -> str:
        handle = pynvml.nvmlDeviceGetHandleByIndex(device_id)
        return pynvml.nvmlDeviceGetName(handle)

660
661
662
663
664
665
666
667
    @classmethod
    @with_nvml_context
    def get_device_numa_node(cls, device_id: int = 0) -> int | None:
        """Get the NUMA node ID for a GPU device."""
        physical_device_id = cls.device_id_to_physical_device_id(device_id)
        handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)

        try:
668
669
670
671
672
673
674
675
676
677
678
679
            numa_node = pynvml.nvmlDeviceGetNumaNodeId(handle)
            if cls._numa_node_has_cpus(numa_node):
                return numa_node
            # On non-CDMM Grace-Blackwell systems (e.g. GB200), each GPU's HBM
            # is a separate NUMA node with no CPUs.  Fall through to
            # CPU-affinity-based detection to find the nearest CPU node.
            logger.debug(
                "NUMA node %d for GPU %d has no CPUs (non-CDMM topology), "
                "falling back to CPU-affinity-based detection",
                numa_node,
                device_id,
            )
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
        except Exception:
            pass

        try:
            cpu_ids = cls._get_device_cpu_affinity(handle)
            if cpu_ids:
                numa_node = cls._get_numa_node_for_cpu(cpu_ids[0])
                if numa_node is not None:
                    logger.debug(
                        "Determined NUMA node %d for GPU %d via CPU affinity",
                        numa_node,
                        device_id,
                    )
                    return numa_node
        except Exception as e:
            logger.warning("Failed to get NUMA node for GPU %d: %s", device_id, e)

        return None

699
700
701
702
703
704
705
706
707
708
709
    @classmethod
    def _numa_node_has_cpus(cls, node_id: int) -> bool:
        """Check whether a NUMA node has any CPUs assigned to it."""
        from pathlib import Path

        cpulist_file = Path(f"/sys/devices/system/node/node{node_id}/cpulist")
        try:
            return cpulist_file.read_text().strip() != ""
        except (OSError, ValueError):
            return False

710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
    @classmethod
    def _get_device_cpu_affinity(cls, handle) -> list[int]:
        """Get the list of CPU IDs associated with a GPU via NVML."""
        cpu_count = os.cpu_count()
        if cpu_count is None:
            return []

        cpu_set_size = (cpu_count + 63) // 64
        cpu_affinity_mask = pynvml.nvmlDeviceGetCpuAffinity(handle, cpu_set_size)

        cpu_ids = []
        for i, mask in enumerate(cpu_affinity_mask):
            for bit in range(64):
                cpu_id = i * 64 + bit
                if cpu_id >= cpu_count:
                    break
                if mask & (1 << bit):
                    cpu_ids.append(cpu_id)
        return cpu_ids

    @classmethod
    def _get_numa_node_for_cpu(cls, cpu_id: int) -> int | None:
        """Determine which NUMA node a CPU belongs to."""
        from pathlib import Path

        node_path = Path("/sys/devices/system/node")
        if not node_path.exists():
            return None

        for node_dir in node_path.iterdir():
            if not node_dir.name.startswith("node"):
                continue
            try:
                node_id = int(node_dir.name[4:])
                cpulist_file = node_dir / "cpulist"
                if cpulist_file.exists():
                    cpulist = cpulist_file.read_text().strip()
                    if cls._cpu_in_cpulist(cpu_id, cpulist):
                        return node_id
            except (ValueError, OSError):
                continue
        return None

    @classmethod
    def _cpu_in_cpulist(cls, cpu_id: int, cpulist: str) -> bool:
        """Check if a CPU ID is in a cpulist string such as '0-3,8-11'."""
        for part in cpulist.split(","):
            part = part.strip()
            if "-" in part:
                start, end = part.split("-", 1)
                if int(start) <= cpu_id <= int(end):
                    return True
            elif part.isdigit() and int(part) == cpu_id:
                return True
        return False

    @classmethod
    @with_nvml_context
    def get_all_device_numa_nodes(cls) -> list[int] | None:
        """Get NUMA nodes for all visible GPU devices."""
        try:
            numa_nodes = []
            for device_id in range(cls.device_count()):
                numa_node = cls.get_device_numa_node(device_id)
                if numa_node is None:
                    logger.warning(
                        "Could not detect NUMA node for GPU %d, "
                        "disabling automatic NUMA binding",
                        device_id,
                    )
                    return None
                numa_nodes.append(numa_node)
            return numa_nodes
        except Exception as e:
            logger.warning("Failed to get NUMA nodes for GPUs: %s", e)
            return None

787
788
789
790
791
    @classmethod
    @with_nvml_context
    def log_warnings(cls):
        device_ids: int = pynvml.nvmlDeviceGetCount()
        if device_ids > 1:
792
793
794
795
796
            device_names = [cls._get_physical_device_name(i) for i in range(device_ids)]
            if (
                len(set(device_names)) > 1
                and os.environ.get("CUDA_DEVICE_ORDER") != "PCI_BUS_ID"
            ):
797
                logger.warning(
798
                    "Detected different devices in the system: %s. Please"
799
800
                    " make sure to set `CUDA_DEVICE_ORDER=PCI_BUS_ID` to "
                    "avoid unexpected behavior.",
801
                    ", ".join(device_names),
802
803
804
805
806
                )


class NonNvmlCudaPlatform(CudaPlatformBase):
    @classmethod
807
    @cache
808
809
810
811
812
813
814
815
816
817
818
819
820
821
    def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:
        major, minor = torch.cuda.get_device_capability(device_id)
        return DeviceCapability(major=major, minor=minor)

    @classmethod
    def get_device_name(cls, device_id: int = 0) -> str:
        return torch.cuda.get_device_name(device_id)

    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        device_props = torch.cuda.get_device_properties(device_id)
        return device_props.total_memory

    @classmethod
822
    def is_fully_connected(cls, physical_device_ids: list[int]) -> bool:
823
824
        logger.exception(
            "NVLink detection not possible, as context support was"
825
826
            " not found. Assuming no NVLink available."
        )
827
828
        return False

829
830
831
832
833
834
835
836
    @classmethod
    def get_device_numa_node(cls, device_id: int = 0) -> int | None:
        return None

    @classmethod
    def get_all_device_numa_nodes(cls) -> list[int] | None:
        return None

837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853

# Autodetect either NVML-enabled or non-NVML platform
# based on whether NVML is available.
nvml_available = False
try:
    try:
        pynvml.nvmlInit()
        nvml_available = True
    except Exception:
        # On Jetson, NVML is not supported.
        nvml_available = False
finally:
    if nvml_available:
        pynvml.nvmlShutdown()

CudaPlatform = NvmlCudaPlatform if nvml_available else NonNvmlCudaPlatform

854
CudaPlatform.log_warnings()