cuda.py 23 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
"""Code inside this file can safely assume cuda platform, e.g. importing
pynvml. However, it should not initialize cuda context.
"""

7
import os
8
from collections.abc import Callable
9
from functools import cache, wraps
10
from typing import TYPE_CHECKING, TypeVar
11

12
import torch
13
from typing_extensions import ParamSpec
14

15
16
# import custom ops, trigger op registration
import vllm._C  # noqa
17
import vllm.envs as envs
18
from vllm.logger import init_logger
19
from vllm.utils.import_utils import import_pynvml
20
from vllm.utils.torch_utils import cuda_device_count_stateless
21

22
from .interface import DeviceCapability, Platform, PlatformEnum
23

24
if TYPE_CHECKING:
25
    from vllm.attention.backends.registry import AttentionBackendEnum
26
    from vllm.config import VllmConfig
27
    from vllm.config.cache import CacheDType
28
else:
29
30
31
    AttentionBackendEnum = None
    VllmConfig = None
    CacheDType = None
32

33
34
logger = init_logger(__name__)

35
36
37
_P = ParamSpec("_P")
_R = TypeVar("_R")

38
pynvml = import_pynvml()
39

40
41
42
43
# pytorch 2.5 uses cudnn sdpa by default, which will cause crash on some models
# see https://github.com/huggingface/diffusers/issues/9704 for details
torch.backends.cuda.enable_cudnn_sdp(False)

44

45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
@cache
def _get_backend_priorities(
    use_mla: bool,
    device_capability: DeviceCapability,
) -> list[AttentionBackendEnum]:
    """Get backend priorities with lazy import to avoid circular dependency."""
    from vllm.attention.backends.registry import AttentionBackendEnum

    if use_mla:
        if device_capability.major == 10:
            return [
                AttentionBackendEnum.CUTLASS_MLA,
                AttentionBackendEnum.FLASHINFER_MLA,
                AttentionBackendEnum.FLASHMLA,
                AttentionBackendEnum.FLASH_ATTN_MLA,
                AttentionBackendEnum.TRITON_MLA,
                AttentionBackendEnum.FLASHMLA_SPARSE,
            ]
        else:
            return [
                AttentionBackendEnum.FLASHMLA,
                AttentionBackendEnum.FLASH_ATTN_MLA,
                AttentionBackendEnum.FLASHINFER_MLA,
                AttentionBackendEnum.TRITON_MLA,
                AttentionBackendEnum.FLASHMLA_SPARSE,
            ]
    else:
        if device_capability.major == 10:
            return [
                AttentionBackendEnum.FLASHINFER,
                AttentionBackendEnum.FLASH_ATTN,
                AttentionBackendEnum.TRITON_ATTN,
                AttentionBackendEnum.FLEX_ATTENTION,
            ]
        else:
            return [
                AttentionBackendEnum.FLASH_ATTN,
                AttentionBackendEnum.FLASHINFER,
                AttentionBackendEnum.TRITON_ATTN,
                AttentionBackendEnum.FLEX_ATTENTION,
            ]


88
def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]:
89
    @wraps(fn)
90
    def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
91
92
93
94
95
96
97
98
99
        pynvml.nvmlInit()
        try:
            return fn(*args, **kwargs)
        finally:
            pynvml.nvmlShutdown()

    return wrapper


100
101
class CudaPlatformBase(Platform):
    _enum = PlatformEnum.CUDA
102
    device_name: str = "cuda"
103
104
    device_type: str = "cuda"
    dispatch_key: str = "CUDA"
105
    ray_device_key: str = "GPU"
106
    dist_backend: str = "nccl"
107
    device_control_env_var: str = "CUDA_VISIBLE_DEVICES"
108

109
    @property
110
    def supported_dtypes(self) -> list[torch.dtype]:
111
112
113
        if self.has_device_capability(80):
            # Ampere and Hopper or later NVIDIA GPUs.
            return [torch.bfloat16, torch.float16, torch.float32]
114
        if self.has_device_capability(60):
115
116
117
118
119
120
            # Pascal, Volta and Turing NVIDIA GPUs, BF16 is not supported
            return [torch.float16, torch.float32]
        # Kepler and Maxwell NVIDIA GPUs, only FP32 is supported,
        # though vLLM doesn't support these GPUs.
        return [torch.float32]

121
122
123
124
125
    @classmethod
    def set_device(cls, device: torch.device) -> None:
        """
        Set the device for the current platform.
        """
126
        torch.cuda.set_device(device)
127
128
129
130
131
        # With this trick we can force the device to be set eagerly
        # see https://github.com/pytorch/pytorch/issues/155668
        # for why and when it is needed
        _ = torch.zeros(1, device=device)

132
    @classmethod
133
    def get_device_capability(cls, device_id: int = 0) -> DeviceCapability | None:
134
        raise NotImplementedError
135

136
137
138
    @classmethod
    def get_device_name(cls, device_id: int = 0) -> str:
        raise NotImplementedError
139

140
141
142
    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        raise NotImplementedError
143

144
    @classmethod
145
    def is_fully_connected(cls, device_ids: list[int]) -> bool:
146
        raise NotImplementedError
147

148
149
150
    @classmethod
    def log_warnings(cls):
        pass
151

152
    @classmethod
153
    def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
154
        parallel_config = vllm_config.parallel_config
155
        model_config = vllm_config.model_config
156

157
        if parallel_config.worker_cls == "auto":
158
            parallel_config.worker_cls = "vllm.v1.worker.gpu_worker.Worker"
159

160
161
162
        cache_config = vllm_config.cache_config
        if cache_config and cache_config.block_size is None:
            cache_config.block_size = 16
163

164
        # TODO(lucas): handle this more gracefully
165
        # Note: model_config may be None during testing
166
167
168
169
170
171
172
173
174
        # Note: block_size is initialized in
        # HybridAttentionMambaModelConfig.verify_and_update_config
        # for models with both attention and mamba,
        # and doesn't need to be reinitialized here
        if (
            model_config is not None
            and model_config.use_mla
            and cache_config.block_size is not None
        ):
175
            use_sparse = hasattr(vllm_config.model_config.hf_config, "index_topk")
176
177
178
179
180
181
            # If `VLLM_ATTENTION_BACKEND` is not set and we are using MLA,
            # then we default to FlashMLA backend for non-blackwell GPUs,
            # else we default to CutlassMLA. For each case, we force the
            # required block_size.
            use_flashmla = False
            use_cutlass_mla = False
182
            use_flashinfer_mla = False
183
184
185
186
187
188

            if envs.VLLM_ATTENTION_BACKEND is None:
                # Default case
                if cls.is_device_capability(100):
                    # Blackwell => Force CutlassMLA.
                    use_cutlass_mla = True
189
190
191
                    # TODO: This does not work, because the
                    # global_force_attn_backend_context_manager is not set.
                    # See vllm/attention/selector.py:_cached_get_attn_backend
192
                    envs.VLLM_ATTENTION_BACKEND = "CUTLASS_MLA"
193
194
195
196
197
                else:
                    # Not Blackwell
                    use_flashmla = True
            else:
                # Forced case
198
199
200
                use_flashmla = envs.VLLM_ATTENTION_BACKEND == "FLASHMLA"
                use_cutlass_mla = envs.VLLM_ATTENTION_BACKEND == "CUTLASS_MLA"
                use_flashinfer_mla = envs.VLLM_ATTENTION_BACKEND == "FLASHINFER_MLA"
201

202
            from vllm.attention.ops.flashmla import is_flashmla_dense_supported
203
204
205

            if (
                use_flashmla
206
                and is_flashmla_dense_supported()[0]
207
                and cache_config.block_size % 64 != 0
208
            ):
209
                cache_config.block_size = 64
210
                logger.info("Forcing kv cache block size to 64 for FlashMLA backend.")
211

212
            if use_cutlass_mla and cache_config.block_size % 128 != 0:
213
                cache_config.block_size = 128
214
215
216
                logger.info(
                    "Forcing kv cache block size to 128 for CUTLASS_MLA backend."
                )
217

218
219
220
221
222
            if (
                use_flashinfer_mla
                and cache_config.block_size != 32
                and cache_config.block_size % 64 != 0
            ):
223
224
                cache_config.block_size = 64
                logger.info(
225
226
                    "Forcing kv cache block size to 64 for FlashInferMLA backend."
                )
227

228
229
230
231
            # TODO(Chen): remove this hacky code
            if use_sparse and cache_config.block_size != 64:
                cache_config.block_size = 64
                logger.info(
232
233
                    "Forcing kv cache block size to 64 for FlashMLASparse backend."
                )
234
235
236
        # lazy import to avoid circular import
        from vllm.config import CUDAGraphMode

237
        compilation_config = vllm_config.compilation_config
238
        if (
239
            parallel_config.all2all_backend == "deepep_high_throughput"
240
241
242
            and parallel_config.data_parallel_size > 1
            and compilation_config.cudagraph_mode != CUDAGraphMode.NONE
        ):
243
244
245
            # TODO: Piecewise Cuda graph might be enabled
            # if torch compile cache key issue fixed
            # See https://github.com/vllm-project/vllm/pull/25093
246
            logger.info(
247
248
249
250
                "WideEP: Disabling CUDA Graphs since DeepEP high-throughput "
                "kernels are optimized for prefill and are incompatible with "
                "CUDA Graphs. "
                "In order to use CUDA Graphs for decode-optimized workloads, "
251
                "use --all2all-backend with another option, such as "
252
253
                "deepep_low_latency, pplx, or allgather_reducescatter."
            )
254
            compilation_config.cudagraph_mode = CUDAGraphMode.NONE
255

256
    @classmethod
257
    def get_current_memory_usage(
258
        cls, device: torch.types.Device | None = None
259
    ) -> float:
260
        torch.cuda.empty_cache()
261
262
263
        torch.cuda.reset_peak_memory_stats(device)
        return torch.cuda.max_memory_allocated(device)

264
    @classmethod
265
266
267
268
    def get_vit_attn_backend(
        cls, head_size: int, dtype: torch.dtype
    ) -> "AttentionBackendEnum":
        from vllm.attention.backends.registry import AttentionBackendEnum
269
270
271
272

        # For Blackwell GPUs, force TORCH_SDPA for now.
        # See https://github.com/facebookresearch/xformers/issues/1317#issuecomment-3199392579 # noqa: E501
        if cls.has_device_capability(100):
273
            return AttentionBackendEnum.TORCH_SDPA
274

275
        if dtype not in (torch.float16, torch.bfloat16):
276
            return AttentionBackendEnum.XFORMERS
277
278

        if cls.has_device_capability(80):
279
280
281
282
283
            backend_class = AttentionBackendEnum.FLASH_ATTN.get_class()
            if backend_class.supports_head_size(
                head_size
            ) and backend_class.supports_dtype(dtype):
                return AttentionBackendEnum.FLASH_ATTN
284
            else:
285
                return AttentionBackendEnum.XFORMERS
286
287
        else:
            # Fallback for Volta/Turing GPUs or FA not supported
288
            return AttentionBackendEnum.XFORMERS
289

290
    @classmethod
291
    def get_valid_backends(
292
293
294
295
296
297
298
299
        cls,
        head_size,
        dtype,
        kv_cache_dtype,
        block_size,
        use_mla,
        has_sink,
        use_sparse,
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
        device_capability,
    ) -> tuple[
        list[tuple["AttentionBackendEnum", int]],
        dict["AttentionBackendEnum", list[str]],
    ]:
        valid_backends_priorities = []
        invalid_reasons = {}

        backend_priorities = _get_backend_priorities(use_mla, device_capability)
        for priority, backend in enumerate(backend_priorities):
            try:
                backend_class = backend.get_class()
                invalid_reasons_i = backend_class.validate_configuration(
                    head_size,
                    dtype,
                    kv_cache_dtype,
                    block_size,
                    use_mla,
                    has_sink,
                    use_sparse,
                    device_capability,
321
                )
322
323
324
325
326
327
            except ImportError:
                invalid_reasons_i = ["ImportError"]
            if invalid_reasons_i:
                invalid_reasons[backend] = invalid_reasons_i
            else:
                valid_backends_priorities.append((backend, priority))
328

329
        return valid_backends_priorities, invalid_reasons
330

331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
    @classmethod
    def get_attn_backend_cls(
        cls,
        selected_backend: "AttentionBackendEnum",
        head_size: int,
        dtype: torch.dtype,
        kv_cache_dtype: "CacheDType | None",
        block_size: int | None,
        use_v1: bool,
        use_mla: bool,
        has_sink: bool,
        use_sparse: bool,
    ) -> str:
        if not use_v1:
            raise RuntimeError(
                "V0 attention backends have been removed. Set VLLM_USE_V1=1 "
                "to select a supported backend."
348
            )
349

350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
        device_capability = cls.get_device_capability()
        assert device_capability is not None

        # First try checking just the selected backend, if there is one.
        if selected_backend is not None:
            try:
                backend_class = selected_backend.get_class()
                invalid_reasons = backend_class.validate_configuration(
                    head_size,
                    dtype,
                    kv_cache_dtype,
                    None,
                    use_mla,
                    has_sink,
                    use_sparse,
                    device_capability,
366
                )
367
368
369
370
371
372
            except ImportError:
                invalid_reasons = ["ImportError"]
            if invalid_reasons:
                raise ValueError(
                    f"Selected backend {selected_backend} is not valid for "
                    f"this configuration. Reason: {invalid_reasons}"
373
                )
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
            else:
                logger.info("Using %s backend.", selected_backend)
                return selected_backend.get_path()

        # No selected backend or the selected backend is invalid,
        # so we try finding a valid backend.
        valid_backends_priorities, invalid_reasons = cls.get_valid_backends(
            head_size,
            dtype,
            kv_cache_dtype,
            None,
            use_mla,
            has_sink,
            use_sparse,
            device_capability,
389
        )
390
391
392
393
394
395
396
        reasons_str = (
            "{"
            + ", ".join(
                f"{backend.name}: [{', '.join(reasons)}]"
                for backend, reasons in invalid_reasons.items()
            )
            + "}"
397
        )
398
399
400
401
402
403
404
405
406
407
408
409
410
411
        config_str = (
            f"head_size: {head_size}, dtype: {dtype}, "
            f"kv_cache_dtype: {kv_cache_dtype}, block_size: {block_size}, "
            f"use_mla: {use_mla}, has_sink: {has_sink}, use_sparse: {use_sparse}"
        )
        logger.debug_once(
            f"Some attention backends are not valid for {cls.device_name} with "
            f"{config_str}. Reasons: {reasons_str}."
        )
        if len(valid_backends_priorities) == 0:
            raise ValueError(
                f"No valid attention backend found for {cls.device_name} "
                f"with {config_str}. Reasons: {reasons_str}."
            )
412

413
414
415
416
        # We have found some valid backends. Select the one with the
        # highest priority.
        logger.info(
            "Valid backends: %s", [b[0].name for b in valid_backends_priorities]
417
        )
418
419
420
421
422
423
424
425
426
427
428
429
        sorted_indices = sorted(
            range(len(valid_backends_priorities)),
            key=lambda i: valid_backends_priorities[i][1],
        )
        selected_index = sorted_indices[0]
        selected_backend = valid_backends_priorities[selected_index][0]
        logger.info(
            "Using %s backend.",
            selected_backend.name,
        )

        return selected_backend.get_path()
430

431
432
433
434
    @classmethod
    def get_punica_wrapper(cls) -> str:
        return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU"

435
436
    @classmethod
    def get_device_communicator_cls(cls) -> str:
437
438
439
        return (
            "vllm.distributed.device_communicators.cuda_communicator.CudaCommunicator"  # noqa
        )
440

441
442
443
444
    @classmethod
    def supports_fp8(cls) -> bool:
        return cls.has_device_capability(89)

445
446
447
448
    @classmethod
    def use_custom_allreduce(cls) -> bool:
        return True

449
450
451
452
    @classmethod
    def opaque_attention_op(cls) -> bool:
        return True

453
    @classmethod
454
455
    def get_static_graph_wrapper_cls(cls) -> str:
        return "vllm.compilation.cuda_graph.CUDAGraphWrapper"
456

457
458
459
460
    @classmethod
    def device_count(cls) -> int:
        return cuda_device_count_stateless()

461
    @classmethod
462
463
    def check_if_supports_dtype(cls, dtype: torch.dtype):
        if dtype == torch.bfloat16:  # noqa: SIM102
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
            if not cls.has_device_capability(80):
                capability = cls.get_device_capability()
                gpu_name = cls.get_device_name()

                if capability is None:
                    compute_str = "does not have a compute capability"
                else:
                    version_str = capability.as_version_str()
                    compute_str = f"has compute capability {version_str}"

                raise ValueError(
                    "Bfloat16 is only supported on GPUs "
                    "with compute capability of at least 8.0. "
                    f"Your {gpu_name} GPU {compute_str}. "
                    "You can use float16 instead by explicitly setting the "
479
480
                    "`dtype` flag in CLI, for example: --dtype=half."
                )
481

482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
    @classmethod
    def insert_blocks_to_device(
        cls,
        src_cache: torch.Tensor,
        dst_cache: torch.Tensor,
        src_block_indices: torch.Tensor,
        dst_block_indices: torch.Tensor,
    ) -> None:
        """Copy blocks from src_cache to dst_cache on GPU."""
        _src_cache = src_cache[:, src_block_indices]
        dst_cache[:, dst_block_indices] = _src_cache.to(dst_cache.device)

    @classmethod
    def swap_out_blocks_to_host(
        cls,
        src_cache: torch.Tensor,
        dst_cache: torch.Tensor,
        src_block_indices: torch.Tensor,
        dst_block_indices: torch.Tensor,
    ) -> None:
        """Copy blocks from GPU to host (CPU)."""
        _src_cache = src_cache[:, src_block_indices]
        dst_cache[:, dst_block_indices] = _src_cache.cpu()

506
507
508
509
    @classmethod
    def support_hybrid_kv_cache(cls) -> bool:
        return True

510
511
512
513
    @classmethod
    def support_static_graph_mode(cls) -> bool:
        return True

514

515
516
517
518
519
# NVML utils
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
# all the related functions work on real physical device ids.
# the major benefit of using NVML is that it will not initialize CUDA
class NvmlCudaPlatform(CudaPlatformBase):
520
    @classmethod
521
    @cache
522
    @with_nvml_context
523
    def get_device_capability(cls, device_id: int = 0) -> DeviceCapability | None:
524
        try:
525
            physical_device_id = cls.device_id_to_physical_device_id(device_id)
526
527
528
529
530
531
532
533
534
535
            handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
            major, minor = pynvml.nvmlDeviceGetCudaComputeCapability(handle)
            return DeviceCapability(major=major, minor=minor)
        except RuntimeError:
            return None

    @classmethod
    @with_nvml_context
    def has_device_capability(
        cls,
536
        capability: tuple[int, int] | int,
537
538
539
540
541
542
        device_id: int = 0,
    ) -> bool:
        try:
            return super().has_device_capability(capability, device_id)
        except RuntimeError:
            return False
543

544
    @classmethod
545
    @with_nvml_context
546
    def get_device_name(cls, device_id: int = 0) -> str:
547
        physical_device_id = cls.device_id_to_physical_device_id(device_id)
548
        return cls._get_physical_device_name(physical_device_id)
549

550
551
552
    @classmethod
    @with_nvml_context
    def get_device_uuid(cls, device_id: int = 0) -> str:
553
        physical_device_id = cls.device_id_to_physical_device_id(device_id)
554
555
556
        handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
        return pynvml.nvmlDeviceGetUUID(handle)

557
    @classmethod
558
    @with_nvml_context
559
    def get_device_total_memory(cls, device_id: int = 0) -> int:
560
        physical_device_id = cls.device_id_to_physical_device_id(device_id)
561
562
        handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
        return int(pynvml.nvmlDeviceGetMemoryInfo(handle).total)
563

564
    @classmethod
565
    @with_nvml_context
566
    def is_fully_connected(cls, physical_device_ids: list[int]) -> bool:
567
568
569
        """
        query if the set of gpus are fully connected by nvlink (1 hop)
        """
570
        handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in physical_device_ids]
571
572
573
574
575
        for i, handle in enumerate(handles):
            for j, peer_handle in enumerate(handles):
                if i < j:
                    try:
                        p2p_status = pynvml.nvmlDeviceGetP2PStatus(
576
577
578
579
                            handle,
                            peer_handle,
                            pynvml.NVML_P2P_CAPS_INDEX_NVLINK,
                        )
580
581
                        if p2p_status != pynvml.NVML_P2P_STATUS_OK:
                            return False
582
583
                    except pynvml.NVMLError:
                        logger.exception(
584
                            "NVLink detection failed. This is normal if"
585
586
                            " your machine has no NVLink equipped."
                        )
587
588
                        return False
        return True
589
590

    @classmethod
591
592
593
594
595
596
597
598
599
    def _get_physical_device_name(cls, device_id: int = 0) -> str:
        handle = pynvml.nvmlDeviceGetHandleByIndex(device_id)
        return pynvml.nvmlDeviceGetName(handle)

    @classmethod
    @with_nvml_context
    def log_warnings(cls):
        device_ids: int = pynvml.nvmlDeviceGetCount()
        if device_ids > 1:
600
601
602
603
604
            device_names = [cls._get_physical_device_name(i) for i in range(device_ids)]
            if (
                len(set(device_names)) > 1
                and os.environ.get("CUDA_DEVICE_ORDER") != "PCI_BUS_ID"
            ):
605
                logger.warning(
606
                    "Detected different devices in the system: %s. Please"
607
608
                    " make sure to set `CUDA_DEVICE_ORDER=PCI_BUS_ID` to "
                    "avoid unexpected behavior.",
609
                    ", ".join(device_names),
610
611
612
613
614
                )


class NonNvmlCudaPlatform(CudaPlatformBase):
    @classmethod
615
    @cache
616
617
618
619
620
621
622
623
624
625
626
627
628
629
    def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:
        major, minor = torch.cuda.get_device_capability(device_id)
        return DeviceCapability(major=major, minor=minor)

    @classmethod
    def get_device_name(cls, device_id: int = 0) -> str:
        return torch.cuda.get_device_name(device_id)

    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        device_props = torch.cuda.get_device_properties(device_id)
        return device_props.total_memory

    @classmethod
630
    def is_fully_connected(cls, physical_device_ids: list[int]) -> bool:
631
632
        logger.exception(
            "NVLink detection not possible, as context support was"
633
634
            " not found. Assuming no NVLink available."
        )
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
        return False


# Autodetect either NVML-enabled or non-NVML platform
# based on whether NVML is available.
nvml_available = False
try:
    try:
        pynvml.nvmlInit()
        nvml_available = True
    except Exception:
        # On Jetson, NVML is not supported.
        nvml_available = False
finally:
    if nvml_available:
        pynvml.nvmlShutdown()

CudaPlatform = NvmlCudaPlatform if nvml_available else NonNvmlCudaPlatform

654
CudaPlatform.log_warnings()