cuda.py 23.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
"""Code inside this file can safely assume cuda platform, e.g. importing
pynvml. However, it should not initialize cuda context.
"""

7
8
from __future__ import annotations

9
import os
10
from collections.abc import Callable
11
from datetime import timedelta
12
from functools import cache, wraps
13
from typing import TYPE_CHECKING, TypeVar
14

15
import torch
16
17
from torch.distributed import PrefixStore, ProcessGroup
from torch.distributed.distributed_c10d import is_nccl_available
18
from typing_extensions import ParamSpec
19

20
21
# import custom ops, trigger op registration
import vllm._C  # noqa
22
from vllm.logger import init_logger
23
from vllm.utils.import_utils import import_pynvml
24
from vllm.utils.torch_utils import cuda_device_count_stateless
25
from vllm.v1.attention.backends.registry import AttentionBackendEnum
26

27
from .interface import DeviceCapability, Platform, PlatformEnum
28

29
if TYPE_CHECKING:
30
    from vllm.config import VllmConfig
31
    from vllm.config.cache import CacheDType
32
    from vllm.v1.attention.selector import AttentionSelectorConfig
33
else:
34
35
    VllmConfig = None
    CacheDType = None
36

37
38
logger = init_logger(__name__)

39
40
41
_P = ParamSpec("_P")
_R = TypeVar("_R")

42
pynvml = import_pynvml()
43

44
45
46
47
# pytorch 2.5 uses cudnn sdpa by default, which will cause crash on some models
# see https://github.com/huggingface/diffusers/issues/9704 for details
torch.backends.cuda.enable_cudnn_sdp(False)

48

49
50
51
52
@cache
def _get_backend_priorities(
    use_mla: bool,
    device_capability: DeviceCapability,
53
    num_heads: int | None = None,
54
    kv_cache_dtype: CacheDType | None = None,
55
56
57
58
) -> list[AttentionBackendEnum]:
    """Get backend priorities with lazy import to avoid circular dependency."""
    if use_mla:
        if device_capability.major == 10:
59
60
61
62
63
            # Sparse MLA backend priorities
            # See https://github.com/vllm-project/vllm/issues/35807 for
            # benchmark results
            if kv_cache_dtype is not None and kv_cache_dtype.startswith("fp8"):
                # Prefer FlashInfer for fp8 kv cache
64
65
66
67
68
                sparse_backends = [
                    AttentionBackendEnum.FLASHINFER_MLA_SPARSE,
                    AttentionBackendEnum.FLASHMLA_SPARSE,
                ]
            else:
69
70
71
72
73
74
75
76
77
78
79
80
81
                # BF16 KV Cache
                # Prefer FlashInfer at low head counts (FlashMLA uses padding)
                if num_heads is not None and num_heads <= 16:
                    sparse_backends = [
                        AttentionBackendEnum.FLASHINFER_MLA_SPARSE,
                        AttentionBackendEnum.FLASHMLA_SPARSE,
                    ]
                else:
                    sparse_backends = [
                        AttentionBackendEnum.FLASHMLA_SPARSE,
                        AttentionBackendEnum.FLASHINFER_MLA_SPARSE,
                    ]

82
            return [
83
                AttentionBackendEnum.FLASHINFER_MLA,
84
                AttentionBackendEnum.CUTLASS_MLA,
85
                AttentionBackendEnum.FLASH_ATTN_MLA,
86
                AttentionBackendEnum.FLASHMLA,
87
                AttentionBackendEnum.TRITON_MLA,
88
                *sparse_backends,
89
90
91
92
            ]
        else:
            return [
                AttentionBackendEnum.FLASH_ATTN_MLA,
93
                AttentionBackendEnum.FLASHMLA,
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
                AttentionBackendEnum.FLASHINFER_MLA,
                AttentionBackendEnum.TRITON_MLA,
                AttentionBackendEnum.FLASHMLA_SPARSE,
            ]
    else:
        if device_capability.major == 10:
            return [
                AttentionBackendEnum.FLASHINFER,
                AttentionBackendEnum.FLASH_ATTN,
                AttentionBackendEnum.TRITON_ATTN,
                AttentionBackendEnum.FLEX_ATTENTION,
            ]
        else:
            return [
                AttentionBackendEnum.FLASH_ATTN,
                AttentionBackendEnum.FLASHINFER,
                AttentionBackendEnum.TRITON_ATTN,
                AttentionBackendEnum.FLEX_ATTENTION,
            ]


115
def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]:
116
    @wraps(fn)
117
    def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
118
119
120
121
122
123
124
125
126
        pynvml.nvmlInit()
        try:
            return fn(*args, **kwargs)
        finally:
            pynvml.nvmlShutdown()

    return wrapper


127
128
class CudaPlatformBase(Platform):
    _enum = PlatformEnum.CUDA
129
    device_name: str = "cuda"
130
131
    device_type: str = "cuda"
    dispatch_key: str = "CUDA"
132
    ray_device_key: str = "GPU"
133
    dist_backend: str = "nccl"
134
    device_control_env_var: str = "CUDA_VISIBLE_DEVICES"
135
136
137
    ray_noset_device_env_vars: list[str] = [
        "RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES",
    ]
138

139
    @property
140
    def supported_dtypes(self) -> list[torch.dtype]:
141
142
143
        if self.has_device_capability(80):
            # Ampere and Hopper or later NVIDIA GPUs.
            return [torch.bfloat16, torch.float16, torch.float32]
144
        if self.has_device_capability(60):
145
146
147
148
149
150
            # Pascal, Volta and Turing NVIDIA GPUs, BF16 is not supported
            return [torch.float16, torch.float32]
        # Kepler and Maxwell NVIDIA GPUs, only FP32 is supported,
        # though vLLM doesn't support these GPUs.
        return [torch.float32]

151
152
153
154
155
    @classmethod
    def set_device(cls, device: torch.device) -> None:
        """
        Set the device for the current platform.
        """
156
        torch.cuda.set_device(device)
157
158
159
160
161
        # With this trick we can force the device to be set eagerly
        # see https://github.com/pytorch/pytorch/issues/155668
        # for why and when it is needed
        _ = torch.zeros(1, device=device)

162
    @classmethod
163
    def get_device_capability(cls, device_id: int = 0) -> DeviceCapability | None:
164
        raise NotImplementedError
165

166
167
168
    @classmethod
    def get_device_name(cls, device_id: int = 0) -> str:
        raise NotImplementedError
169

170
171
172
    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        raise NotImplementedError
173

174
    @classmethod
175
    def is_fully_connected(cls, device_ids: list[int]) -> bool:
176
        raise NotImplementedError
177

178
179
180
    @classmethod
    def log_warnings(cls):
        pass
181

182
    @classmethod
183
    def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
184
        parallel_config = vllm_config.parallel_config
185
        model_config = vllm_config.model_config
186

187
        if parallel_config.worker_cls == "auto":
188
            parallel_config.worker_cls = "vllm.v1.worker.gpu_worker.Worker"
189

190
191
192
193
194
195
196
197
198
199
200
201
202
203
        scheduler_config = vllm_config.scheduler_config
        # Note: model_config may be None during testing
        if (
            model_config is not None
            and model_config.is_mm_prefix_lm
            and scheduler_config.is_multimodal_model
            and not scheduler_config.disable_chunked_mm_input
        ):
            logger.warning(
                "Forcing --disable_chunked_mm_input for models "
                "with multimodal-bidirectional attention."
            )
            scheduler_config.disable_chunked_mm_input = True

204
    @classmethod
205
    def get_current_memory_usage(
206
        cls, device: torch.types.Device | None = None
207
    ) -> float:
208
        torch.cuda.empty_cache()
209
210
211
        torch.cuda.reset_peak_memory_stats(device)
        return torch.cuda.max_memory_allocated(device)

212
    @classmethod
213
    def get_valid_backends(
214
        cls,
215
        device_capability: DeviceCapability,
216
        attn_selector_config: AttentionSelectorConfig,
217
        num_heads: int | None = None,
218
    ) -> tuple[
219
220
        list[tuple[AttentionBackendEnum, int]],
        dict[AttentionBackendEnum, tuple[int, list[str]]],
221
222
    ]:
        valid_backends_priorities = []
223
        invalid_reasons: dict[AttentionBackendEnum, tuple[int, list[str]]] = {}
224

225
        backend_priorities = _get_backend_priorities(
226
227
228
            attn_selector_config.use_mla,
            device_capability,
            num_heads,
229
            attn_selector_config.kv_cache_dtype,
230
        )
231
232
233
234
        for priority, backend in enumerate(backend_priorities):
            try:
                backend_class = backend.get_class()
                invalid_reasons_i = backend_class.validate_configuration(
235
236
                    device_capability=device_capability,
                    **attn_selector_config._asdict(),
237
                )
238
239
240
            except ImportError:
                invalid_reasons_i = ["ImportError"]
            if invalid_reasons_i:
241
                invalid_reasons[backend] = (priority, invalid_reasons_i)
242
243
            else:
                valid_backends_priorities.append((backend, priority))
244

245
        return valid_backends_priorities, invalid_reasons
246

247
    @classmethod
248
    def get_attn_backend_cls(
249
        cls,
250
251
        selected_backend: AttentionBackendEnum | None,
        attn_selector_config: AttentionSelectorConfig,
252
        num_heads: int | None = None,
253
254
255
256
    ) -> str:
        device_capability = cls.get_device_capability()
        assert device_capability is not None

257
258
259
260
        # First try checking just the selected backend, if there is one.
        if selected_backend is not None:
            try:
                backend_class = selected_backend.get_class()
261
                invalid_reasons = backend_class.validate_configuration(
262
263
                    device_capability=device_capability,
                    **attn_selector_config._asdict(),
264
                )
265
            except ImportError:
266
267
268
269
270
271
272
273
274
                invalid_reasons = ["ImportError"]
            if invalid_reasons:
                raise ValueError(
                    f"Selected backend {selected_backend} is not valid for "
                    f"this configuration. Reason: {invalid_reasons}"
                )
            else:
                logger.info("Using %s backend.", selected_backend)
                return selected_backend.get_path()
275

276
277
        # No selected backend or the selected backend is invalid,
        # so we try finding a valid backend.
278
        valid_backends_priorities, all_invalid_reasons = cls.get_valid_backends(
279
280
            device_capability=device_capability,
            attn_selector_config=attn_selector_config,
281
            num_heads=num_heads,
282
        )
283
284
285
286
        reasons_str = (
            "{"
            + ", ".join(
                f"{backend.name}: [{', '.join(reasons)}]"
287
                for backend, (_, reasons) in all_invalid_reasons.items()
288
289
290
291
292
293
294
295
            )
            + "}"
        )
        config_str = attn_selector_config.__repr__()
        logger.debug_once(
            f"Some attention backends are not valid for {cls.device_name} with "
            f"{config_str}. Reasons: {reasons_str}."
        )
296
        if len(valid_backends_priorities) == 0:
297
298
299
300
            raise ValueError(
                f"No valid attention backend found for {cls.device_name} "
                f"with {config_str}. Reasons: {reasons_str}."
            )
301

302
303
304
305
306
307
308
309
        # We have found some valid backends. Select the one with the
        # highest priority.
        sorted_indices = sorted(
            range(len(valid_backends_priorities)),
            key=lambda i: valid_backends_priorities[i][1],
        )
        selected_index = sorted_indices[0]
        selected_backend = valid_backends_priorities[selected_index][0]
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
        selected_priority = valid_backends_priorities[selected_index][1]

        # If the user specified --block-size (but not --attention-backend),
        # check whether that constraint precluded any higher-priority backends.
        if attn_selector_config.block_size is not None:
            excluded = [
                backend
                for backend, (priority, reasons) in all_invalid_reasons.items()
                if priority < selected_priority
                and reasons == ["block_size not supported"]
            ]
            if excluded:
                names = ", ".join(b.name for b in excluded)
                logger.warning(
                    "--block-size %d precluded higher-priority backend(s) "
                    "%s. Using %s instead, which may result in reduced "
                    "performance. Consider removing --block-size to "
                    "auto-select the optimal block size.",
                    attn_selector_config.block_size,
                    names,
                    selected_backend.name,
                )

333
334
335
336
337
        logger.info_once(
            "Using %s attention backend out of potential backends: %s.",
            selected_backend.name,
            "[" + ", ".join(f"'{b[0].name}'" for b in valid_backends_priorities) + "]",
            scope="local",
338
339
        )

340
        return selected_backend.get_path()
341

342
    @classmethod
343
    def get_supported_vit_attn_backends(cls) -> list[AttentionBackendEnum]:
344
345
346
347
348
349
350
351
352
353
354
355
356
357
        if cls.has_device_capability(80):
            return [
                AttentionBackendEnum.FLASH_ATTN,
                AttentionBackendEnum.TRITON_ATTN,
                AttentionBackendEnum.TORCH_SDPA,
                AttentionBackendEnum.FLASHINFER,
            ]
        else:
            return [
                AttentionBackendEnum.FLASH_ATTN,
                AttentionBackendEnum.TORCH_SDPA,
                AttentionBackendEnum.TRITON_ATTN,
                AttentionBackendEnum.FLASHINFER,
            ]
358
359
360
361
362
363

    @classmethod
    def get_vit_attn_backend(
        cls,
        head_size: int,
        dtype: torch.dtype,
364
365
        backend: AttentionBackendEnum | None = None,
    ) -> AttentionBackendEnum:
366
367
368
369
370
371
372
373
        if backend is not None:
            assert backend in cls.get_supported_vit_attn_backends(), (
                f"Backend {backend} is not supported for vit attention. "
                f"Supported backends are: {cls.get_supported_vit_attn_backends()}"
            )
            logger.info_once(f"Using backend {backend} for vit attention")
            return backend

374
375
376
        cc = cls.get_device_capability()
        for vit_attn_backend in cls.get_supported_vit_attn_backends():
            if vit_attn_backend == AttentionBackendEnum.TORCH_SDPA:
377
                return vit_attn_backend
378
            try:
379
380
                backend_class = vit_attn_backend.get_class()
                is_backend_supported = backend_class.supports_head_size(
381
                    head_size
382
383
384
385
386
387
388
389
390
391
392
                ) and backend_class.supports_dtype(dtype)
                if cc is not None:
                    is_backend_supported = (
                        is_backend_supported
                        and backend_class.supports_compute_capability(cc)
                    )
                if is_backend_supported:
                    logger.info_once(
                        f"Using backend {vit_attn_backend} for vit attention"
                    )
                    return vit_attn_backend
393
394
395
396
397
            except ImportError:
                pass

        return AttentionBackendEnum.TORCH_SDPA

398
399
400
401
    @classmethod
    def get_punica_wrapper(cls) -> str:
        return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU"

402
403
    @classmethod
    def get_device_communicator_cls(cls) -> str:
404
405
406
        return (
            "vllm.distributed.device_communicators.cuda_communicator.CudaCommunicator"  # noqa
        )
407

408
409
410
411
    @classmethod
    def supports_fp8(cls) -> bool:
        return cls.has_device_capability(89)

412
413
414
415
    @classmethod
    def use_custom_allreduce(cls) -> bool:
        return True

416
417
418
419
    @classmethod
    def opaque_attention_op(cls) -> bool:
        return True

420
    @classmethod
421
422
    def get_static_graph_wrapper_cls(cls) -> str:
        return "vllm.compilation.cuda_graph.CUDAGraphWrapper"
423

424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
    @classmethod
    def stateless_init_device_torch_dist_pg(
        cls,
        backend: str,
        prefix_store: PrefixStore,
        group_rank: int,
        group_size: int,
        timeout: timedelta,
    ) -> ProcessGroup:
        assert is_nccl_available()
        pg: ProcessGroup = ProcessGroup(
            prefix_store,
            group_rank,
            group_size,
        )
        from torch.distributed.distributed_c10d import ProcessGroupNCCL

        backend_options = ProcessGroupNCCL.Options()
        backend_options._timeout = timeout

        backend_class = ProcessGroupNCCL(
            prefix_store, group_rank, group_size, backend_options
        )
        backend_type = ProcessGroup.BackendType.NCCL
        device = torch.device("cuda")
        pg._set_default_backend(backend_type)
        backend_class._set_sequence_number_for_group()

        pg._register_backend(device, backend_type, backend_class)
        return pg

455
456
457
458
    @classmethod
    def device_count(cls) -> int:
        return cuda_device_count_stateless()

459
    @classmethod
460
461
    def check_if_supports_dtype(cls, dtype: torch.dtype):
        if dtype == torch.bfloat16:  # noqa: SIM102
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
            if not cls.has_device_capability(80):
                capability = cls.get_device_capability()
                gpu_name = cls.get_device_name()

                if capability is None:
                    compute_str = "does not have a compute capability"
                else:
                    version_str = capability.as_version_str()
                    compute_str = f"has compute capability {version_str}"

                raise ValueError(
                    "Bfloat16 is only supported on GPUs "
                    "with compute capability of at least 8.0. "
                    f"Your {gpu_name} GPU {compute_str}. "
                    "You can use float16 instead by explicitly setting the "
477
478
                    "`dtype` flag in CLI, for example: --dtype=half."
                )
479

480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
    @classmethod
    def insert_blocks_to_device(
        cls,
        src_cache: torch.Tensor,
        dst_cache: torch.Tensor,
        src_block_indices: torch.Tensor,
        dst_block_indices: torch.Tensor,
    ) -> None:
        """Copy blocks from src_cache to dst_cache on GPU."""
        _src_cache = src_cache[:, src_block_indices]
        dst_cache[:, dst_block_indices] = _src_cache.to(dst_cache.device)

    @classmethod
    def swap_out_blocks_to_host(
        cls,
        src_cache: torch.Tensor,
        dst_cache: torch.Tensor,
        src_block_indices: torch.Tensor,
        dst_block_indices: torch.Tensor,
    ) -> None:
        """Copy blocks from GPU to host (CPU)."""
        _src_cache = src_cache[:, src_block_indices]
        dst_cache[:, dst_block_indices] = _src_cache.cpu()

504
505
506
507
    @classmethod
    def support_hybrid_kv_cache(cls) -> bool:
        return True

508
509
510
511
    @classmethod
    def support_static_graph_mode(cls) -> bool:
        return True

512
    @classmethod
513
    def num_compute_units(cls, device_id: int = 0) -> int:
514
515
        return torch.cuda.get_device_properties(device_id).multi_processor_count

516
517
518
519
    @classmethod
    def use_custom_op_collectives(cls) -> bool:
        return True

520

521
522
523
524
525
# NVML utils
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
# all the related functions work on real physical device ids.
# the major benefit of using NVML is that it will not initialize CUDA
class NvmlCudaPlatform(CudaPlatformBase):
526
    @classmethod
527
    @cache
528
    @with_nvml_context
529
    def get_device_capability(cls, device_id: int = 0) -> DeviceCapability | None:
530
        try:
531
            physical_device_id = cls.device_id_to_physical_device_id(device_id)
532
533
534
535
536
537
538
539
540
541
            handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
            major, minor = pynvml.nvmlDeviceGetCudaComputeCapability(handle)
            return DeviceCapability(major=major, minor=minor)
        except RuntimeError:
            return None

    @classmethod
    @with_nvml_context
    def has_device_capability(
        cls,
542
        capability: tuple[int, int] | int,
543
544
545
546
547
548
        device_id: int = 0,
    ) -> bool:
        try:
            return super().has_device_capability(capability, device_id)
        except RuntimeError:
            return False
549

550
    @classmethod
551
    @with_nvml_context
552
    def get_device_name(cls, device_id: int = 0) -> str:
553
        physical_device_id = cls.device_id_to_physical_device_id(device_id)
554
        return cls._get_physical_device_name(physical_device_id)
555

556
557
558
    @classmethod
    @with_nvml_context
    def get_device_uuid(cls, device_id: int = 0) -> str:
559
        physical_device_id = cls.device_id_to_physical_device_id(device_id)
560
561
562
        handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
        return pynvml.nvmlDeviceGetUUID(handle)

563
    @classmethod
564
    @with_nvml_context
565
    def get_device_total_memory(cls, device_id: int = 0) -> int:
566
        physical_device_id = cls.device_id_to_physical_device_id(device_id)
567
568
        handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
        return int(pynvml.nvmlDeviceGetMemoryInfo(handle).total)
569

570
    @classmethod
571
    @with_nvml_context
572
    def is_fully_connected(cls, physical_device_ids: list[int]) -> bool:
573
574
575
        """
        query if the set of gpus are fully connected by nvlink (1 hop)
        """
576
        handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in physical_device_ids]
577
578
579
580
581
        for i, handle in enumerate(handles):
            for j, peer_handle in enumerate(handles):
                if i < j:
                    try:
                        p2p_status = pynvml.nvmlDeviceGetP2PStatus(
582
583
584
585
                            handle,
                            peer_handle,
                            pynvml.NVML_P2P_CAPS_INDEX_NVLINK,
                        )
586
587
                        if p2p_status != pynvml.NVML_P2P_STATUS_OK:
                            return False
588
589
                    except pynvml.NVMLError:
                        logger.exception(
590
                            "NVLink detection failed. This is normal if"
591
592
                            " your machine has no NVLink equipped."
                        )
593
594
                        return False
        return True
595
596

    @classmethod
597
598
599
600
601
602
603
604
605
    def _get_physical_device_name(cls, device_id: int = 0) -> str:
        handle = pynvml.nvmlDeviceGetHandleByIndex(device_id)
        return pynvml.nvmlDeviceGetName(handle)

    @classmethod
    @with_nvml_context
    def log_warnings(cls):
        device_ids: int = pynvml.nvmlDeviceGetCount()
        if device_ids > 1:
606
607
608
609
610
            device_names = [cls._get_physical_device_name(i) for i in range(device_ids)]
            if (
                len(set(device_names)) > 1
                and os.environ.get("CUDA_DEVICE_ORDER") != "PCI_BUS_ID"
            ):
611
                logger.warning(
612
                    "Detected different devices in the system: %s. Please"
613
614
                    " make sure to set `CUDA_DEVICE_ORDER=PCI_BUS_ID` to "
                    "avoid unexpected behavior.",
615
                    ", ".join(device_names),
616
617
618
619
620
                )


class NonNvmlCudaPlatform(CudaPlatformBase):
    @classmethod
621
    @cache
622
623
624
625
626
627
628
629
630
631
632
633
634
635
    def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:
        major, minor = torch.cuda.get_device_capability(device_id)
        return DeviceCapability(major=major, minor=minor)

    @classmethod
    def get_device_name(cls, device_id: int = 0) -> str:
        return torch.cuda.get_device_name(device_id)

    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        device_props = torch.cuda.get_device_properties(device_id)
        return device_props.total_memory

    @classmethod
636
    def is_fully_connected(cls, physical_device_ids: list[int]) -> bool:
637
638
        logger.exception(
            "NVLink detection not possible, as context support was"
639
640
            " not found. Assuming no NVLink available."
        )
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
        return False


# Autodetect either NVML-enabled or non-NVML platform
# based on whether NVML is available.
nvml_available = False
try:
    try:
        pynvml.nvmlInit()
        nvml_available = True
    except Exception:
        # On Jetson, NVML is not supported.
        nvml_available = False
finally:
    if nvml_available:
        pynvml.nvmlShutdown()

CudaPlatform = NvmlCudaPlatform if nvml_available else NonNvmlCudaPlatform

660
CudaPlatform.log_warnings()