cuda.py 22.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
"""Code inside this file can safely assume cuda platform, e.g. importing
pynvml. However, it should not initialize cuda context.
"""

7
import os
8
from collections.abc import Callable
9
from functools import cache, wraps
10
from typing import TYPE_CHECKING, TypeVar
11

12
import torch
13
from typing_extensions import ParamSpec
14

15
16
# import custom ops, trigger op registration
import vllm._C  # noqa
17
from vllm.logger import init_logger
18
from vllm.utils.import_utils import import_pynvml
19
from vllm.utils.torch_utils import cuda_device_count_stateless
20
from vllm.v1.attention.backends.registry import AttentionBackendEnum
21

22
from .interface import DeviceCapability, Platform, PlatformEnum
23

24
if TYPE_CHECKING:
25
    from vllm.config import VllmConfig
26
    from vllm.config.cache import CacheDType
27
    from vllm.v1.attention.selector import AttentionSelectorConfig
28
else:
29
30
    VllmConfig = None
    CacheDType = None
31

32
33
logger = init_logger(__name__)

34
35
36
_P = ParamSpec("_P")
_R = TypeVar("_R")

37
pynvml = import_pynvml()
38

39
40
41
42
# pytorch 2.5 uses cudnn sdpa by default, which will cause crash on some models
# see https://github.com/huggingface/diffusers/issues/9704 for details
torch.backends.cuda.enable_cudnn_sdp(False)

43

44
45
46
47
48
49
50
51
52
@cache
def _get_backend_priorities(
    use_mla: bool,
    device_capability: DeviceCapability,
) -> list[AttentionBackendEnum]:
    """Get backend priorities with lazy import to avoid circular dependency."""
    if use_mla:
        if device_capability.major == 10:
            return [
53
                AttentionBackendEnum.FLASHINFER_MLA,
54
                AttentionBackendEnum.CUTLASS_MLA,
55
                AttentionBackendEnum.FLASH_ATTN_MLA,
56
                AttentionBackendEnum.FLASHMLA,
57
58
59
60
61
62
                AttentionBackendEnum.TRITON_MLA,
                AttentionBackendEnum.FLASHMLA_SPARSE,
            ]
        else:
            return [
                AttentionBackendEnum.FLASH_ATTN_MLA,
63
                AttentionBackendEnum.FLASHMLA,
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
                AttentionBackendEnum.FLASHINFER_MLA,
                AttentionBackendEnum.TRITON_MLA,
                AttentionBackendEnum.FLASHMLA_SPARSE,
            ]
    else:
        if device_capability.major == 10:
            return [
                AttentionBackendEnum.FLASHINFER,
                AttentionBackendEnum.FLASH_ATTN,
                AttentionBackendEnum.TRITON_ATTN,
                AttentionBackendEnum.FLEX_ATTENTION,
            ]
        else:
            return [
                AttentionBackendEnum.FLASH_ATTN,
                AttentionBackendEnum.FLASHINFER,
                AttentionBackendEnum.TRITON_ATTN,
                AttentionBackendEnum.FLEX_ATTENTION,
            ]


85
def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]:
86
    @wraps(fn)
87
    def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
88
89
90
91
92
93
94
95
96
        pynvml.nvmlInit()
        try:
            return fn(*args, **kwargs)
        finally:
            pynvml.nvmlShutdown()

    return wrapper


97
98
class CudaPlatformBase(Platform):
    _enum = PlatformEnum.CUDA
99
    device_name: str = "cuda"
100
101
    device_type: str = "cuda"
    dispatch_key: str = "CUDA"
102
    ray_device_key: str = "GPU"
103
    dist_backend: str = "nccl"
104
    device_control_env_var: str = "CUDA_VISIBLE_DEVICES"
105
106
107
    ray_noset_device_env_vars: list[str] = [
        "RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES",
    ]
108

109
    @property
110
    def supported_dtypes(self) -> list[torch.dtype]:
111
112
113
        if self.has_device_capability(80):
            # Ampere and Hopper or later NVIDIA GPUs.
            return [torch.bfloat16, torch.float16, torch.float32]
114
        if self.has_device_capability(60):
115
116
117
118
119
120
            # Pascal, Volta and Turing NVIDIA GPUs, BF16 is not supported
            return [torch.float16, torch.float32]
        # Kepler and Maxwell NVIDIA GPUs, only FP32 is supported,
        # though vLLM doesn't support these GPUs.
        return [torch.float32]

121
122
123
124
125
    @classmethod
    def set_device(cls, device: torch.device) -> None:
        """
        Set the device for the current platform.
        """
126
        torch.cuda.set_device(device)
127
128
129
130
131
        # With this trick we can force the device to be set eagerly
        # see https://github.com/pytorch/pytorch/issues/155668
        # for why and when it is needed
        _ = torch.zeros(1, device=device)

132
    @classmethod
133
    def get_device_capability(cls, device_id: int = 0) -> DeviceCapability | None:
134
        raise NotImplementedError
135

136
137
138
    @classmethod
    def get_device_name(cls, device_id: int = 0) -> str:
        raise NotImplementedError
139

140
141
142
    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        raise NotImplementedError
143

144
    @classmethod
145
    def is_fully_connected(cls, device_ids: list[int]) -> bool:
146
        raise NotImplementedError
147

148
149
150
    @classmethod
    def log_warnings(cls):
        pass
151

152
    @classmethod
153
    def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
154
        from vllm.v1.attention.backends.registry import AttentionBackendEnum
155

156
        parallel_config = vllm_config.parallel_config
157
        model_config = vllm_config.model_config
158

159
        if parallel_config.worker_cls == "auto":
160
            parallel_config.worker_cls = "vllm.v1.worker.gpu_worker.Worker"
161

162
163
164
        cache_config = vllm_config.cache_config
        if cache_config and cache_config.block_size is None:
            cache_config.block_size = 16
165

166
        # TODO(lucas): handle this more gracefully
167
        # Note: model_config may be None during testing
168
169
170
171
172
173
174
175
176
        # Note: block_size is initialized in
        # HybridAttentionMambaModelConfig.verify_and_update_config
        # for models with both attention and mamba,
        # and doesn't need to be reinitialized here
        if (
            model_config is not None
            and model_config.use_mla
            and cache_config.block_size is not None
        ):
177
            use_sparse = hasattr(vllm_config.model_config.hf_config, "index_topk")
178
            # If `--attention-config.backend` is not set and we are using MLA,
179
180
181
182
183
            # then we default to FlashMLA backend for non-blackwell GPUs,
            # else we default to CutlassMLA. For each case, we force the
            # required block_size.
            use_flashmla = False
            use_cutlass_mla = False
184
            use_flashinfer_mla = False
185

186
187
            from vllm.v1.attention.ops.flashmla import is_flashmla_dense_supported

188
            if vllm_config.attention_config.backend is None:
189
                # Default case
190
191
192
193
194
195
196
197
198
                hf_text_config = model_config.hf_text_config
                qk_nope_head_dim = getattr(hf_text_config, "qk_nope_head_dim", 1)
                if (
                    cls.is_device_capability_family(100)
                    and not use_sparse
                    and qk_nope_head_dim == 128
                ):
                    # Blackwell => Force FlashInfer MLA (unless sparse, i.e. DSv3.2)
                    # and only if qk_nope_head_dim == 128 (kernel constraint)
199
                    use_flashinfer_mla = True
200
201
202
                    # Set the backend in AttentionConfig so it's used during
                    # backend selection
                    vllm_config.attention_config.backend = (
203
                        AttentionBackendEnum.FLASHINFER_MLA
204
                    )
205
206
207
208
209
                elif cls.is_device_capability_family(100) and not use_sparse:
                    # Fall back to CUTLASS_MLA as 2nd priority on Blackwell
                    use_cutlass_mla = True
                elif is_flashmla_dense_supported()[0]:
                    # Non-Blackwell with FlashMLA support
210
                    use_flashmla = True
211
212
213
                else:
                    # Fallback: will use Triton MLA or other compatible backend
                    pass
214
215
            else:
                # Forced case
216
217
218
219
                backend = vllm_config.attention_config.backend
                use_flashmla = backend == AttentionBackendEnum.FLASHMLA
                use_cutlass_mla = backend == AttentionBackendEnum.CUTLASS_MLA
                use_flashinfer_mla = backend == AttentionBackendEnum.FLASHINFER_MLA
220

221
222
            if (
                use_flashmla
223
                and is_flashmla_dense_supported()[0]
224
                and cache_config.block_size % 64 != 0
225
            ):
226
                cache_config.block_size = 64
227
                logger.info("Forcing kv cache block size to 64 for FlashMLA backend.")
228

229
            if use_cutlass_mla and cache_config.block_size % 128 != 0:
230
                cache_config.block_size = 128
231
232
233
                logger.info(
                    "Forcing kv cache block size to 128 for CUTLASS_MLA backend."
                )
234

235
236
237
238
239
            if (
                use_flashinfer_mla
                and cache_config.block_size != 32
                and cache_config.block_size % 64 != 0
            ):
240
241
                cache_config.block_size = 64
                logger.info(
242
243
                    "Forcing kv cache block size to 64 for FlashInferMLA backend."
                )
244

245
246
247
248
            # TODO(Chen): remove this hacky code
            if use_sparse and cache_config.block_size != 64:
                cache_config.block_size = 64
                logger.info(
249
250
                    "Forcing kv cache block size to 64 for FlashMLASparse backend."
                )
251

252
253
254
255
256
257
258
259
260
261
262
263
264
265
        scheduler_config = vllm_config.scheduler_config
        # Note: model_config may be None during testing
        if (
            model_config is not None
            and model_config.is_mm_prefix_lm
            and scheduler_config.is_multimodal_model
            and not scheduler_config.disable_chunked_mm_input
        ):
            logger.warning(
                "Forcing --disable_chunked_mm_input for models "
                "with multimodal-bidirectional attention."
            )
            scheduler_config.disable_chunked_mm_input = True

266
    @classmethod
267
    def get_current_memory_usage(
268
        cls, device: torch.types.Device | None = None
269
    ) -> float:
270
        torch.cuda.empty_cache()
271
272
273
        torch.cuda.reset_peak_memory_stats(device)
        return torch.cuda.max_memory_allocated(device)

274
    @classmethod
275
    def get_valid_backends(
276
        cls,
277
278
        device_capability: DeviceCapability,
        attn_selector_config: "AttentionSelectorConfig",
279
280
281
282
283
284
285
    ) -> tuple[
        list[tuple["AttentionBackendEnum", int]],
        dict["AttentionBackendEnum", list[str]],
    ]:
        valid_backends_priorities = []
        invalid_reasons = {}

286
287
288
        backend_priorities = _get_backend_priorities(
            attn_selector_config.use_mla, device_capability
        )
289
290
291
292
        for priority, backend in enumerate(backend_priorities):
            try:
                backend_class = backend.get_class()
                invalid_reasons_i = backend_class.validate_configuration(
293
294
                    device_capability=device_capability,
                    **attn_selector_config._asdict(),
295
                )
296
297
298
299
300
301
            except ImportError:
                invalid_reasons_i = ["ImportError"]
            if invalid_reasons_i:
                invalid_reasons[backend] = invalid_reasons_i
            else:
                valid_backends_priorities.append((backend, priority))
302

303
        return valid_backends_priorities, invalid_reasons
304

305
306
307
308
    @classmethod
    def get_attn_backend_cls(
        cls,
        selected_backend: "AttentionBackendEnum",
309
        attn_selector_config: "AttentionSelectorConfig",
310
311
312
313
    ) -> str:
        device_capability = cls.get_device_capability()
        assert device_capability is not None

314
        attn_selector_config = attn_selector_config._replace(block_size=None)
315
316
317
318
319
        # First try checking just the selected backend, if there is one.
        if selected_backend is not None:
            try:
                backend_class = selected_backend.get_class()
                invalid_reasons = backend_class.validate_configuration(
320
321
                    device_capability=device_capability,
                    **attn_selector_config._asdict(),
322
                )
323
324
325
326
327
328
            except ImportError:
                invalid_reasons = ["ImportError"]
            if invalid_reasons:
                raise ValueError(
                    f"Selected backend {selected_backend} is not valid for "
                    f"this configuration. Reason: {invalid_reasons}"
329
                )
330
331
332
333
334
335
336
            else:
                logger.info("Using %s backend.", selected_backend)
                return selected_backend.get_path()

        # No selected backend or the selected backend is invalid,
        # so we try finding a valid backend.
        valid_backends_priorities, invalid_reasons = cls.get_valid_backends(
337
338
            device_capability=device_capability,
            attn_selector_config=attn_selector_config,
339
        )
340
341
342
343
344
345
346
        reasons_str = (
            "{"
            + ", ".join(
                f"{backend.name}: [{', '.join(reasons)}]"
                for backend, reasons in invalid_reasons.items()
            )
            + "}"
347
        )
348
        config_str = attn_selector_config.__repr__()
349
350
351
352
353
354
355
356
357
        logger.debug_once(
            f"Some attention backends are not valid for {cls.device_name} with "
            f"{config_str}. Reasons: {reasons_str}."
        )
        if len(valid_backends_priorities) == 0:
            raise ValueError(
                f"No valid attention backend found for {cls.device_name} "
                f"with {config_str}. Reasons: {reasons_str}."
            )
358

359
360
361
362
363
364
365
366
        # We have found some valid backends. Select the one with the
        # highest priority.
        sorted_indices = sorted(
            range(len(valid_backends_priorities)),
            key=lambda i: valid_backends_priorities[i][1],
        )
        selected_index = sorted_indices[0]
        selected_backend = valid_backends_priorities[selected_index][0]
367
        logger.info_once(
368
            "Using %s attention backend out of potential backends: %s.",
369
            selected_backend.name,
370
            "[" + ", ".join(f"'{b[0].name}'" for b in valid_backends_priorities) + "]",
371
            scope="local",
372
373
374
        )

        return selected_backend.get_path()
375

376
377
378
379
380
381
382
383
384
385
386
387
    @classmethod
    def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]:
        return [
            AttentionBackendEnum.TORCH_SDPA,
            AttentionBackendEnum.FLASH_ATTN,
        ]

    @classmethod
    def get_vit_attn_backend(
        cls,
        head_size: int,
        dtype: torch.dtype,
388
        backend: "AttentionBackendEnum | None" = None,
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
    ) -> "AttentionBackendEnum":
        if backend is not None:
            assert backend in cls.get_supported_vit_attn_backends(), (
                f"Backend {backend} is not supported for vit attention. "
                f"Supported backends are: {cls.get_supported_vit_attn_backends()}"
            )
            logger.info_once(f"Using backend {backend} for vit attention")
            return backend

        # Try FlashAttention first
        if (cc := cls.get_device_capability()) and cc.major >= 8:
            try:
                backend_class = AttentionBackendEnum.FLASH_ATTN.get_class()
                if backend_class.supports_head_size(
                    head_size
                ) and backend_class.supports_dtype(dtype):
                    return AttentionBackendEnum.FLASH_ATTN
            except ImportError:
                pass

        return AttentionBackendEnum.TORCH_SDPA

411
412
413
414
    @classmethod
    def get_punica_wrapper(cls) -> str:
        return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU"

415
416
    @classmethod
    def get_device_communicator_cls(cls) -> str:
417
418
419
        return (
            "vllm.distributed.device_communicators.cuda_communicator.CudaCommunicator"  # noqa
        )
420

421
422
423
424
    @classmethod
    def supports_fp8(cls) -> bool:
        return cls.has_device_capability(89)

425
426
427
428
    @classmethod
    def use_custom_allreduce(cls) -> bool:
        return True

429
430
431
432
    @classmethod
    def opaque_attention_op(cls) -> bool:
        return True

433
    @classmethod
434
435
    def get_static_graph_wrapper_cls(cls) -> str:
        return "vllm.compilation.cuda_graph.CUDAGraphWrapper"
436

437
438
439
440
    @classmethod
    def device_count(cls) -> int:
        return cuda_device_count_stateless()

441
    @classmethod
442
443
    def check_if_supports_dtype(cls, dtype: torch.dtype):
        if dtype == torch.bfloat16:  # noqa: SIM102
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
            if not cls.has_device_capability(80):
                capability = cls.get_device_capability()
                gpu_name = cls.get_device_name()

                if capability is None:
                    compute_str = "does not have a compute capability"
                else:
                    version_str = capability.as_version_str()
                    compute_str = f"has compute capability {version_str}"

                raise ValueError(
                    "Bfloat16 is only supported on GPUs "
                    "with compute capability of at least 8.0. "
                    f"Your {gpu_name} GPU {compute_str}. "
                    "You can use float16 instead by explicitly setting the "
459
460
                    "`dtype` flag in CLI, for example: --dtype=half."
                )
461

462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
    @classmethod
    def insert_blocks_to_device(
        cls,
        src_cache: torch.Tensor,
        dst_cache: torch.Tensor,
        src_block_indices: torch.Tensor,
        dst_block_indices: torch.Tensor,
    ) -> None:
        """Copy blocks from src_cache to dst_cache on GPU."""
        _src_cache = src_cache[:, src_block_indices]
        dst_cache[:, dst_block_indices] = _src_cache.to(dst_cache.device)

    @classmethod
    def swap_out_blocks_to_host(
        cls,
        src_cache: torch.Tensor,
        dst_cache: torch.Tensor,
        src_block_indices: torch.Tensor,
        dst_block_indices: torch.Tensor,
    ) -> None:
        """Copy blocks from GPU to host (CPU)."""
        _src_cache = src_cache[:, src_block_indices]
        dst_cache[:, dst_block_indices] = _src_cache.cpu()

486
487
488
489
    @classmethod
    def support_hybrid_kv_cache(cls) -> bool:
        return True

490
491
492
493
    @classmethod
    def support_static_graph_mode(cls) -> bool:
        return True

494

495
496
497
498
499
# NVML utils
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
# all the related functions work on real physical device ids.
# the major benefit of using NVML is that it will not initialize CUDA
class NvmlCudaPlatform(CudaPlatformBase):
500
    @classmethod
501
    @cache
502
    @with_nvml_context
503
    def get_device_capability(cls, device_id: int = 0) -> DeviceCapability | None:
504
        try:
505
            physical_device_id = cls.device_id_to_physical_device_id(device_id)
506
507
508
509
510
511
512
513
514
515
            handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
            major, minor = pynvml.nvmlDeviceGetCudaComputeCapability(handle)
            return DeviceCapability(major=major, minor=minor)
        except RuntimeError:
            return None

    @classmethod
    @with_nvml_context
    def has_device_capability(
        cls,
516
        capability: tuple[int, int] | int,
517
518
519
520
521
522
        device_id: int = 0,
    ) -> bool:
        try:
            return super().has_device_capability(capability, device_id)
        except RuntimeError:
            return False
523

524
    @classmethod
525
    @with_nvml_context
526
    def get_device_name(cls, device_id: int = 0) -> str:
527
        physical_device_id = cls.device_id_to_physical_device_id(device_id)
528
        return cls._get_physical_device_name(physical_device_id)
529

530
531
532
    @classmethod
    @with_nvml_context
    def get_device_uuid(cls, device_id: int = 0) -> str:
533
        physical_device_id = cls.device_id_to_physical_device_id(device_id)
534
535
536
        handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
        return pynvml.nvmlDeviceGetUUID(handle)

537
    @classmethod
538
    @with_nvml_context
539
    def get_device_total_memory(cls, device_id: int = 0) -> int:
540
        physical_device_id = cls.device_id_to_physical_device_id(device_id)
541
542
        handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
        return int(pynvml.nvmlDeviceGetMemoryInfo(handle).total)
543

544
    @classmethod
545
    @with_nvml_context
546
    def is_fully_connected(cls, physical_device_ids: list[int]) -> bool:
547
548
549
        """
        query if the set of gpus are fully connected by nvlink (1 hop)
        """
550
        handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in physical_device_ids]
551
552
553
554
555
        for i, handle in enumerate(handles):
            for j, peer_handle in enumerate(handles):
                if i < j:
                    try:
                        p2p_status = pynvml.nvmlDeviceGetP2PStatus(
556
557
558
559
                            handle,
                            peer_handle,
                            pynvml.NVML_P2P_CAPS_INDEX_NVLINK,
                        )
560
561
                        if p2p_status != pynvml.NVML_P2P_STATUS_OK:
                            return False
562
563
                    except pynvml.NVMLError:
                        logger.exception(
564
                            "NVLink detection failed. This is normal if"
565
566
                            " your machine has no NVLink equipped."
                        )
567
568
                        return False
        return True
569
570

    @classmethod
571
572
573
574
575
576
577
578
579
    def _get_physical_device_name(cls, device_id: int = 0) -> str:
        handle = pynvml.nvmlDeviceGetHandleByIndex(device_id)
        return pynvml.nvmlDeviceGetName(handle)

    @classmethod
    @with_nvml_context
    def log_warnings(cls):
        device_ids: int = pynvml.nvmlDeviceGetCount()
        if device_ids > 1:
580
581
582
583
584
            device_names = [cls._get_physical_device_name(i) for i in range(device_ids)]
            if (
                len(set(device_names)) > 1
                and os.environ.get("CUDA_DEVICE_ORDER") != "PCI_BUS_ID"
            ):
585
                logger.warning(
586
                    "Detected different devices in the system: %s. Please"
587
588
                    " make sure to set `CUDA_DEVICE_ORDER=PCI_BUS_ID` to "
                    "avoid unexpected behavior.",
589
                    ", ".join(device_names),
590
591
592
593
594
                )


class NonNvmlCudaPlatform(CudaPlatformBase):
    @classmethod
595
    @cache
596
597
598
599
600
601
602
603
604
605
606
607
608
609
    def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:
        major, minor = torch.cuda.get_device_capability(device_id)
        return DeviceCapability(major=major, minor=minor)

    @classmethod
    def get_device_name(cls, device_id: int = 0) -> str:
        return torch.cuda.get_device_name(device_id)

    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        device_props = torch.cuda.get_device_properties(device_id)
        return device_props.total_memory

    @classmethod
610
    def is_fully_connected(cls, physical_device_ids: list[int]) -> bool:
611
612
        logger.exception(
            "NVLink detection not possible, as context support was"
613
614
            " not found. Assuming no NVLink available."
        )
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
        return False


# Autodetect either NVML-enabled or non-NVML platform
# based on whether NVML is available.
nvml_available = False
try:
    try:
        pynvml.nvmlInit()
        nvml_available = True
    except Exception:
        # On Jetson, NVML is not supported.
        nvml_available = False
finally:
    if nvml_available:
        pynvml.nvmlShutdown()

CudaPlatform = NvmlCudaPlatform if nvml_available else NonNvmlCudaPlatform

634
CudaPlatform.log_warnings()