cuda.py 23.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
"""Code inside this file can safely assume cuda platform, e.g. importing
pynvml. However, it should not initialize cuda context.
"""

7
import os
8
from collections.abc import Callable
9
from functools import cache, wraps
10
from typing import TYPE_CHECKING, TypeVar
11

12
import torch
13
from typing_extensions import ParamSpec
14

15
16
# import custom ops, trigger op registration
import vllm._C  # noqa
17
18
from vllm.attention.backends.abstract import AttentionType
from vllm.attention.backends.registry import AttentionBackendEnum
19
from vllm.logger import init_logger
20
from vllm.utils.import_utils import import_pynvml
21
from vllm.utils.torch_utils import cuda_device_count_stateless
22

23
from .interface import DeviceCapability, Platform, PlatformEnum
24

25
if TYPE_CHECKING:
26
    from vllm.config import VllmConfig
27
    from vllm.config.cache import CacheDType
28
else:
29
30
    VllmConfig = None
    CacheDType = None
31

32
33
logger = init_logger(__name__)

34
35
36
_P = ParamSpec("_P")
_R = TypeVar("_R")

37
pynvml = import_pynvml()
38

39
40
41
42
# pytorch 2.5 uses cudnn sdpa by default, which will cause crash on some models
# see https://github.com/huggingface/diffusers/issues/9704 for details
torch.backends.cuda.enable_cudnn_sdp(False)

43

44
45
46
47
48
49
50
51
52
53
54
55
@cache
def _get_backend_priorities(
    use_mla: bool,
    device_capability: DeviceCapability,
) -> list[AttentionBackendEnum]:
    """Get backend priorities with lazy import to avoid circular dependency."""
    if use_mla:
        if device_capability.major == 10:
            return [
                AttentionBackendEnum.CUTLASS_MLA,
                AttentionBackendEnum.FLASHINFER_MLA,
                AttentionBackendEnum.FLASH_ATTN_MLA,
56
                AttentionBackendEnum.FLASHMLA,
57
58
59
60
61
62
                AttentionBackendEnum.TRITON_MLA,
                AttentionBackendEnum.FLASHMLA_SPARSE,
            ]
        else:
            return [
                AttentionBackendEnum.FLASH_ATTN_MLA,
63
                AttentionBackendEnum.FLASHMLA,
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
                AttentionBackendEnum.FLASHINFER_MLA,
                AttentionBackendEnum.TRITON_MLA,
                AttentionBackendEnum.FLASHMLA_SPARSE,
            ]
    else:
        if device_capability.major == 10:
            return [
                AttentionBackendEnum.FLASHINFER,
                AttentionBackendEnum.FLASH_ATTN,
                AttentionBackendEnum.TRITON_ATTN,
                AttentionBackendEnum.FLEX_ATTENTION,
            ]
        else:
            return [
                AttentionBackendEnum.FLASH_ATTN,
                AttentionBackendEnum.FLASHINFER,
                AttentionBackendEnum.TRITON_ATTN,
                AttentionBackendEnum.FLEX_ATTENTION,
            ]


85
def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]:
86
    @wraps(fn)
87
    def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
88
89
90
91
92
93
94
95
96
        pynvml.nvmlInit()
        try:
            return fn(*args, **kwargs)
        finally:
            pynvml.nvmlShutdown()

    return wrapper


97
98
class CudaPlatformBase(Platform):
    _enum = PlatformEnum.CUDA
99
    device_name: str = "cuda"
100
101
    device_type: str = "cuda"
    dispatch_key: str = "CUDA"
102
    ray_device_key: str = "GPU"
103
    dist_backend: str = "nccl"
104
    device_control_env_var: str = "CUDA_VISIBLE_DEVICES"
105

106
    @property
107
    def supported_dtypes(self) -> list[torch.dtype]:
108
109
110
        if self.has_device_capability(80):
            # Ampere and Hopper or later NVIDIA GPUs.
            return [torch.bfloat16, torch.float16, torch.float32]
111
        if self.has_device_capability(60):
112
113
114
115
116
117
            # Pascal, Volta and Turing NVIDIA GPUs, BF16 is not supported
            return [torch.float16, torch.float32]
        # Kepler and Maxwell NVIDIA GPUs, only FP32 is supported,
        # though vLLM doesn't support these GPUs.
        return [torch.float32]

118
119
120
121
122
    @classmethod
    def set_device(cls, device: torch.device) -> None:
        """
        Set the device for the current platform.
        """
123
        torch.cuda.set_device(device)
124
125
126
127
128
        # With this trick we can force the device to be set eagerly
        # see https://github.com/pytorch/pytorch/issues/155668
        # for why and when it is needed
        _ = torch.zeros(1, device=device)

129
    @classmethod
130
    def get_device_capability(cls, device_id: int = 0) -> DeviceCapability | None:
131
        raise NotImplementedError
132

133
134
135
    @classmethod
    def get_device_name(cls, device_id: int = 0) -> str:
        raise NotImplementedError
136

137
138
139
    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        raise NotImplementedError
140

141
    @classmethod
142
    def is_fully_connected(cls, device_ids: list[int]) -> bool:
143
        raise NotImplementedError
144

145
146
147
    @classmethod
    def log_warnings(cls):
        pass
148

149
    @classmethod
150
    def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
151
152
        from vllm.attention.backends.registry import AttentionBackendEnum

153
        parallel_config = vllm_config.parallel_config
154
        model_config = vllm_config.model_config
155

156
        if parallel_config.worker_cls == "auto":
157
            parallel_config.worker_cls = "vllm.v1.worker.gpu_worker.Worker"
158

159
160
161
        cache_config = vllm_config.cache_config
        if cache_config and cache_config.block_size is None:
            cache_config.block_size = 16
162

163
        # TODO(lucas): handle this more gracefully
164
        # Note: model_config may be None during testing
165
166
167
168
169
170
171
172
173
        # Note: block_size is initialized in
        # HybridAttentionMambaModelConfig.verify_and_update_config
        # for models with both attention and mamba,
        # and doesn't need to be reinitialized here
        if (
            model_config is not None
            and model_config.use_mla
            and cache_config.block_size is not None
        ):
174
            use_sparse = hasattr(vllm_config.model_config.hf_config, "index_topk")
175
            # If `--attention-config.backend` is not set and we are using MLA,
176
177
178
179
180
            # then we default to FlashMLA backend for non-blackwell GPUs,
            # else we default to CutlassMLA. For each case, we force the
            # required block_size.
            use_flashmla = False
            use_cutlass_mla = False
181
            use_flashinfer_mla = False
182

183
            if vllm_config.attention_config.backend is None:
184
185
186
187
                # Default case
                if cls.is_device_capability(100):
                    # Blackwell => Force CutlassMLA.
                    use_cutlass_mla = True
188
189
190
191
192
                    # Set the backend in AttentionConfig so it's used during
                    # backend selection
                    vllm_config.attention_config.backend = (
                        AttentionBackendEnum.CUTLASS_MLA
                    )
193
194
195
196
197
                else:
                    # Not Blackwell
                    use_flashmla = True
            else:
                # Forced case
198
199
200
201
                backend = vllm_config.attention_config.backend
                use_flashmla = backend == AttentionBackendEnum.FLASHMLA
                use_cutlass_mla = backend == AttentionBackendEnum.CUTLASS_MLA
                use_flashinfer_mla = backend == AttentionBackendEnum.FLASHINFER_MLA
202

203
            from vllm.attention.ops.flashmla import is_flashmla_dense_supported
204
205
206

            if (
                use_flashmla
207
                and is_flashmla_dense_supported()[0]
208
                and cache_config.block_size % 64 != 0
209
            ):
210
                cache_config.block_size = 64
211
                logger.info("Forcing kv cache block size to 64 for FlashMLA backend.")
212

213
            if use_cutlass_mla and cache_config.block_size % 128 != 0:
214
                cache_config.block_size = 128
215
216
217
                logger.info(
                    "Forcing kv cache block size to 128 for CUTLASS_MLA backend."
                )
218

219
220
221
222
223
            if (
                use_flashinfer_mla
                and cache_config.block_size != 32
                and cache_config.block_size % 64 != 0
            ):
224
225
                cache_config.block_size = 64
                logger.info(
226
227
                    "Forcing kv cache block size to 64 for FlashInferMLA backend."
                )
228

229
230
231
232
            # TODO(Chen): remove this hacky code
            if use_sparse and cache_config.block_size != 64:
                cache_config.block_size = 64
                logger.info(
233
234
                    "Forcing kv cache block size to 64 for FlashMLASparse backend."
                )
235
236
237
        # lazy import to avoid circular import
        from vllm.config import CUDAGraphMode

238
        compilation_config = vllm_config.compilation_config
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
        if compilation_config.cudagraph_mode.has_full_cudagraphs():
            # decode context parallel does not support full cudagraphs
            if parallel_config.decode_context_parallel_size > 1:
                logger.warning_once(
                    "Decode context parallel (DCP) is enabled, which is "
                    "incompatible with full CUDA graphs. "
                    "Overriding cudagraph_mode to PIECEWISE."
                )
                compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
            # prefill context parallel do not support full cudagraphs
            elif parallel_config.prefill_context_parallel_size > 1:
                logger.warning_once(
                    "Prefill context parallel (PCP) is enabled, which is "
                    "incompatible with full CUDA graphs. "
                    "Overriding cudagraph_mode to PIECEWISE."
                )
                compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
256
        if (
257
            parallel_config.all2all_backend == "deepep_high_throughput"
258
259
260
            and parallel_config.data_parallel_size > 1
            and compilation_config.cudagraph_mode != CUDAGraphMode.NONE
        ):
261
262
263
            # TODO: Piecewise Cuda graph might be enabled
            # if torch compile cache key issue fixed
            # See https://github.com/vllm-project/vllm/pull/25093
264
            logger.info(
265
266
267
268
                "WideEP: Disabling CUDA Graphs since DeepEP high-throughput "
                "kernels are optimized for prefill and are incompatible with "
                "CUDA Graphs. "
                "In order to use CUDA Graphs for decode-optimized workloads, "
269
                "use --all2all-backend with another option, such as "
270
271
                "deepep_low_latency, pplx, or allgather_reducescatter."
            )
272
            compilation_config.cudagraph_mode = CUDAGraphMode.NONE
273

274
    @classmethod
275
    def get_current_memory_usage(
276
        cls, device: torch.types.Device | None = None
277
    ) -> float:
278
        torch.cuda.empty_cache()
279
280
281
        torch.cuda.reset_peak_memory_stats(device)
        return torch.cuda.max_memory_allocated(device)

282
    @classmethod
283
284
285
    def get_vit_attn_backend(
        cls, head_size: int, dtype: torch.dtype
    ) -> "AttentionBackendEnum":
286
        # Try FlashAttention first
287
288
289
290
291
292
293
294
295
        if (cc := cls.get_device_capability()) and cc.major >= 8:
            try:
                backend_class = AttentionBackendEnum.FLASH_ATTN.get_class()
                if backend_class.supports_head_size(
                    head_size
                ) and backend_class.supports_dtype(dtype):
                    return AttentionBackendEnum.FLASH_ATTN
            except ImportError:
                pass
296

297
        return AttentionBackendEnum.TORCH_SDPA
298

299
    @classmethod
300
    def get_valid_backends(
301
302
303
304
305
306
307
308
        cls,
        head_size,
        dtype,
        kv_cache_dtype,
        block_size,
        use_mla,
        has_sink,
        use_sparse,
309
        device_capability,
310
        attn_type,
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
    ) -> tuple[
        list[tuple["AttentionBackendEnum", int]],
        dict["AttentionBackendEnum", list[str]],
    ]:
        valid_backends_priorities = []
        invalid_reasons = {}

        backend_priorities = _get_backend_priorities(use_mla, device_capability)
        for priority, backend in enumerate(backend_priorities):
            try:
                backend_class = backend.get_class()
                invalid_reasons_i = backend_class.validate_configuration(
                    head_size,
                    dtype,
                    kv_cache_dtype,
                    block_size,
                    use_mla,
                    has_sink,
                    use_sparse,
                    device_capability,
331
                    attn_type,
332
                )
333
334
335
336
337
338
            except ImportError:
                invalid_reasons_i = ["ImportError"]
            if invalid_reasons_i:
                invalid_reasons[backend] = invalid_reasons_i
            else:
                valid_backends_priorities.append((backend, priority))
339

340
        return valid_backends_priorities, invalid_reasons
341

342
343
344
345
346
347
348
349
350
351
352
    @classmethod
    def get_attn_backend_cls(
        cls,
        selected_backend: "AttentionBackendEnum",
        head_size: int,
        dtype: torch.dtype,
        kv_cache_dtype: "CacheDType | None",
        block_size: int | None,
        use_mla: bool,
        has_sink: bool,
        use_sparse: bool,
353
        attn_type: str | None = None,
354
    ) -> str:
355
356
357
        if attn_type is None:
            attn_type = AttentionType.DECODER

358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
        device_capability = cls.get_device_capability()
        assert device_capability is not None

        # First try checking just the selected backend, if there is one.
        if selected_backend is not None:
            try:
                backend_class = selected_backend.get_class()
                invalid_reasons = backend_class.validate_configuration(
                    head_size,
                    dtype,
                    kv_cache_dtype,
                    None,
                    use_mla,
                    has_sink,
                    use_sparse,
                    device_capability,
374
                    attn_type,
375
                )
376
377
378
379
380
381
            except ImportError:
                invalid_reasons = ["ImportError"]
            if invalid_reasons:
                raise ValueError(
                    f"Selected backend {selected_backend} is not valid for "
                    f"this configuration. Reason: {invalid_reasons}"
382
                )
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
            else:
                logger.info("Using %s backend.", selected_backend)
                return selected_backend.get_path()

        # No selected backend or the selected backend is invalid,
        # so we try finding a valid backend.
        valid_backends_priorities, invalid_reasons = cls.get_valid_backends(
            head_size,
            dtype,
            kv_cache_dtype,
            None,
            use_mla,
            has_sink,
            use_sparse,
            device_capability,
398
            attn_type,
399
        )
400
401
402
403
404
405
406
        reasons_str = (
            "{"
            + ", ".join(
                f"{backend.name}: [{', '.join(reasons)}]"
                for backend, reasons in invalid_reasons.items()
            )
            + "}"
407
        )
408
409
410
411
412
413
414
415
416
417
418
419
420
421
        config_str = (
            f"head_size: {head_size}, dtype: {dtype}, "
            f"kv_cache_dtype: {kv_cache_dtype}, block_size: {block_size}, "
            f"use_mla: {use_mla}, has_sink: {has_sink}, use_sparse: {use_sparse}"
        )
        logger.debug_once(
            f"Some attention backends are not valid for {cls.device_name} with "
            f"{config_str}. Reasons: {reasons_str}."
        )
        if len(valid_backends_priorities) == 0:
            raise ValueError(
                f"No valid attention backend found for {cls.device_name} "
                f"with {config_str}. Reasons: {reasons_str}."
            )
422

423
424
425
426
427
428
429
430
431
        # We have found some valid backends. Select the one with the
        # highest priority.
        sorted_indices = sorted(
            range(len(valid_backends_priorities)),
            key=lambda i: valid_backends_priorities[i][1],
        )
        selected_index = sorted_indices[0]
        selected_backend = valid_backends_priorities[selected_index][0]
        logger.info(
432
            "Using %s attention backend out of potential backends: %s",
433
            selected_backend.name,
434
            [b[0].name for b in valid_backends_priorities],
435
436
437
        )

        return selected_backend.get_path()
438

439
440
441
442
    @classmethod
    def get_punica_wrapper(cls) -> str:
        return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU"

443
444
    @classmethod
    def get_device_communicator_cls(cls) -> str:
445
446
447
        return (
            "vllm.distributed.device_communicators.cuda_communicator.CudaCommunicator"  # noqa
        )
448

449
450
451
452
    @classmethod
    def supports_fp8(cls) -> bool:
        return cls.has_device_capability(89)

453
454
455
456
    @classmethod
    def use_custom_allreduce(cls) -> bool:
        return True

457
458
459
460
    @classmethod
    def opaque_attention_op(cls) -> bool:
        return True

461
    @classmethod
462
463
    def get_static_graph_wrapper_cls(cls) -> str:
        return "vllm.compilation.cuda_graph.CUDAGraphWrapper"
464

465
466
467
468
    @classmethod
    def device_count(cls) -> int:
        return cuda_device_count_stateless()

469
    @classmethod
470
471
    def check_if_supports_dtype(cls, dtype: torch.dtype):
        if dtype == torch.bfloat16:  # noqa: SIM102
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
            if not cls.has_device_capability(80):
                capability = cls.get_device_capability()
                gpu_name = cls.get_device_name()

                if capability is None:
                    compute_str = "does not have a compute capability"
                else:
                    version_str = capability.as_version_str()
                    compute_str = f"has compute capability {version_str}"

                raise ValueError(
                    "Bfloat16 is only supported on GPUs "
                    "with compute capability of at least 8.0. "
                    f"Your {gpu_name} GPU {compute_str}. "
                    "You can use float16 instead by explicitly setting the "
487
488
                    "`dtype` flag in CLI, for example: --dtype=half."
                )
489

490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
    @classmethod
    def insert_blocks_to_device(
        cls,
        src_cache: torch.Tensor,
        dst_cache: torch.Tensor,
        src_block_indices: torch.Tensor,
        dst_block_indices: torch.Tensor,
    ) -> None:
        """Copy blocks from src_cache to dst_cache on GPU."""
        _src_cache = src_cache[:, src_block_indices]
        dst_cache[:, dst_block_indices] = _src_cache.to(dst_cache.device)

    @classmethod
    def swap_out_blocks_to_host(
        cls,
        src_cache: torch.Tensor,
        dst_cache: torch.Tensor,
        src_block_indices: torch.Tensor,
        dst_block_indices: torch.Tensor,
    ) -> None:
        """Copy blocks from GPU to host (CPU)."""
        _src_cache = src_cache[:, src_block_indices]
        dst_cache[:, dst_block_indices] = _src_cache.cpu()

514
515
516
517
    @classmethod
    def support_hybrid_kv_cache(cls) -> bool:
        return True

518
519
520
521
    @classmethod
    def support_static_graph_mode(cls) -> bool:
        return True

522

523
524
525
526
527
# NVML utils
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
# all the related functions work on real physical device ids.
# the major benefit of using NVML is that it will not initialize CUDA
class NvmlCudaPlatform(CudaPlatformBase):
528
    @classmethod
529
    @cache
530
    @with_nvml_context
531
    def get_device_capability(cls, device_id: int = 0) -> DeviceCapability | None:
532
        try:
533
            physical_device_id = cls.device_id_to_physical_device_id(device_id)
534
535
536
537
538
539
540
541
542
543
            handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
            major, minor = pynvml.nvmlDeviceGetCudaComputeCapability(handle)
            return DeviceCapability(major=major, minor=minor)
        except RuntimeError:
            return None

    @classmethod
    @with_nvml_context
    def has_device_capability(
        cls,
544
        capability: tuple[int, int] | int,
545
546
547
548
549
550
        device_id: int = 0,
    ) -> bool:
        try:
            return super().has_device_capability(capability, device_id)
        except RuntimeError:
            return False
551

552
    @classmethod
553
    @with_nvml_context
554
    def get_device_name(cls, device_id: int = 0) -> str:
555
        physical_device_id = cls.device_id_to_physical_device_id(device_id)
556
        return cls._get_physical_device_name(physical_device_id)
557

558
559
560
    @classmethod
    @with_nvml_context
    def get_device_uuid(cls, device_id: int = 0) -> str:
561
        physical_device_id = cls.device_id_to_physical_device_id(device_id)
562
563
564
        handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
        return pynvml.nvmlDeviceGetUUID(handle)

565
    @classmethod
566
    @with_nvml_context
567
    def get_device_total_memory(cls, device_id: int = 0) -> int:
568
        physical_device_id = cls.device_id_to_physical_device_id(device_id)
569
570
        handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
        return int(pynvml.nvmlDeviceGetMemoryInfo(handle).total)
571

572
    @classmethod
573
    @with_nvml_context
574
    def is_fully_connected(cls, physical_device_ids: list[int]) -> bool:
575
576
577
        """
        query if the set of gpus are fully connected by nvlink (1 hop)
        """
578
        handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in physical_device_ids]
579
580
581
582
583
        for i, handle in enumerate(handles):
            for j, peer_handle in enumerate(handles):
                if i < j:
                    try:
                        p2p_status = pynvml.nvmlDeviceGetP2PStatus(
584
585
586
587
                            handle,
                            peer_handle,
                            pynvml.NVML_P2P_CAPS_INDEX_NVLINK,
                        )
588
589
                        if p2p_status != pynvml.NVML_P2P_STATUS_OK:
                            return False
590
591
                    except pynvml.NVMLError:
                        logger.exception(
592
                            "NVLink detection failed. This is normal if"
593
594
                            " your machine has no NVLink equipped."
                        )
595
596
                        return False
        return True
597
598

    @classmethod
599
600
601
602
603
604
605
606
607
    def _get_physical_device_name(cls, device_id: int = 0) -> str:
        handle = pynvml.nvmlDeviceGetHandleByIndex(device_id)
        return pynvml.nvmlDeviceGetName(handle)

    @classmethod
    @with_nvml_context
    def log_warnings(cls):
        device_ids: int = pynvml.nvmlDeviceGetCount()
        if device_ids > 1:
608
609
610
611
612
            device_names = [cls._get_physical_device_name(i) for i in range(device_ids)]
            if (
                len(set(device_names)) > 1
                and os.environ.get("CUDA_DEVICE_ORDER") != "PCI_BUS_ID"
            ):
613
                logger.warning(
614
                    "Detected different devices in the system: %s. Please"
615
616
                    " make sure to set `CUDA_DEVICE_ORDER=PCI_BUS_ID` to "
                    "avoid unexpected behavior.",
617
                    ", ".join(device_names),
618
619
620
621
622
                )


class NonNvmlCudaPlatform(CudaPlatformBase):
    @classmethod
623
    @cache
624
625
626
627
628
629
630
631
632
633
634
635
636
637
    def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:
        major, minor = torch.cuda.get_device_capability(device_id)
        return DeviceCapability(major=major, minor=minor)

    @classmethod
    def get_device_name(cls, device_id: int = 0) -> str:
        return torch.cuda.get_device_name(device_id)

    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        device_props = torch.cuda.get_device_properties(device_id)
        return device_props.total_memory

    @classmethod
638
    def is_fully_connected(cls, physical_device_ids: list[int]) -> bool:
639
640
        logger.exception(
            "NVLink detection not possible, as context support was"
641
642
            " not found. Assuming no NVLink available."
        )
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
        return False


# Autodetect either NVML-enabled or non-NVML platform
# based on whether NVML is available.
nvml_available = False
try:
    try:
        pynvml.nvmlInit()
        nvml_available = True
    except Exception:
        # On Jetson, NVML is not supported.
        nvml_available = False
finally:
    if nvml_available:
        pynvml.nvmlShutdown()

CudaPlatform = NvmlCudaPlatform if nvml_available else NonNvmlCudaPlatform

662
CudaPlatform.log_warnings()