cuda.py 21.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
"""Code inside this file can safely assume cuda platform, e.g. importing
pynvml. However, it should not initialize cuda context.
"""

7
import os
8
from collections.abc import Callable
9
from functools import cache, wraps
10
from typing import TYPE_CHECKING, TypeVar
11

12
import torch
13
from typing_extensions import ParamSpec
14

15
16
# import custom ops, trigger op registration
import vllm._C  # noqa
17
18
from vllm.attention.backends.abstract import AttentionType
from vllm.attention.backends.registry import AttentionBackendEnum
19
from vllm.logger import init_logger
20
from vllm.utils.import_utils import import_pynvml
21
from vllm.utils.torch_utils import cuda_device_count_stateless
22

23
from .interface import DeviceCapability, Platform, PlatformEnum
24

25
if TYPE_CHECKING:
26
    from vllm.config import VllmConfig
27
    from vllm.config.cache import CacheDType
28
else:
29
30
    VllmConfig = None
    CacheDType = None
31

32
33
logger = init_logger(__name__)

34
35
36
_P = ParamSpec("_P")
_R = TypeVar("_R")

37
pynvml = import_pynvml()
38

39
40
41
42
# pytorch 2.5 uses cudnn sdpa by default, which will cause crash on some models
# see https://github.com/huggingface/diffusers/issues/9704 for details
torch.backends.cuda.enable_cudnn_sdp(False)

43

44
45
46
47
48
49
50
51
52
53
54
55
@cache
def _get_backend_priorities(
    use_mla: bool,
    device_capability: DeviceCapability,
) -> list[AttentionBackendEnum]:
    """Get backend priorities with lazy import to avoid circular dependency."""
    if use_mla:
        if device_capability.major == 10:
            return [
                AttentionBackendEnum.CUTLASS_MLA,
                AttentionBackendEnum.FLASHINFER_MLA,
                AttentionBackendEnum.FLASH_ATTN_MLA,
56
                AttentionBackendEnum.FLASHMLA,
57
58
59
60
61
62
                AttentionBackendEnum.TRITON_MLA,
                AttentionBackendEnum.FLASHMLA_SPARSE,
            ]
        else:
            return [
                AttentionBackendEnum.FLASH_ATTN_MLA,
63
                AttentionBackendEnum.FLASHMLA,
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
                AttentionBackendEnum.FLASHINFER_MLA,
                AttentionBackendEnum.TRITON_MLA,
                AttentionBackendEnum.FLASHMLA_SPARSE,
            ]
    else:
        if device_capability.major == 10:
            return [
                AttentionBackendEnum.FLASHINFER,
                AttentionBackendEnum.FLASH_ATTN,
                AttentionBackendEnum.TRITON_ATTN,
                AttentionBackendEnum.FLEX_ATTENTION,
            ]
        else:
            return [
                AttentionBackendEnum.FLASH_ATTN,
                AttentionBackendEnum.FLASHINFER,
                AttentionBackendEnum.TRITON_ATTN,
                AttentionBackendEnum.FLEX_ATTENTION,
            ]


85
def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]:
86
    @wraps(fn)
87
    def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
88
89
90
91
92
93
94
95
96
        pynvml.nvmlInit()
        try:
            return fn(*args, **kwargs)
        finally:
            pynvml.nvmlShutdown()

    return wrapper


97
98
class CudaPlatformBase(Platform):
    _enum = PlatformEnum.CUDA
99
    device_name: str = "cuda"
100
101
    device_type: str = "cuda"
    dispatch_key: str = "CUDA"
102
    ray_device_key: str = "GPU"
103
    dist_backend: str = "nccl"
104
    device_control_env_var: str = "CUDA_VISIBLE_DEVICES"
105

106
    @property
107
    def supported_dtypes(self) -> list[torch.dtype]:
108
109
110
        if self.has_device_capability(80):
            # Ampere and Hopper or later NVIDIA GPUs.
            return [torch.bfloat16, torch.float16, torch.float32]
111
        if self.has_device_capability(60):
112
113
114
115
116
117
            # Pascal, Volta and Turing NVIDIA GPUs, BF16 is not supported
            return [torch.float16, torch.float32]
        # Kepler and Maxwell NVIDIA GPUs, only FP32 is supported,
        # though vLLM doesn't support these GPUs.
        return [torch.float32]

118
119
120
121
122
    @classmethod
    def set_device(cls, device: torch.device) -> None:
        """
        Set the device for the current platform.
        """
123
        torch.cuda.set_device(device)
124
125
126
127
128
        # With this trick we can force the device to be set eagerly
        # see https://github.com/pytorch/pytorch/issues/155668
        # for why and when it is needed
        _ = torch.zeros(1, device=device)

129
    @classmethod
130
    def get_device_capability(cls, device_id: int = 0) -> DeviceCapability | None:
131
        raise NotImplementedError
132

133
134
135
    @classmethod
    def get_device_name(cls, device_id: int = 0) -> str:
        raise NotImplementedError
136

137
138
139
    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        raise NotImplementedError
140

141
    @classmethod
142
    def is_fully_connected(cls, device_ids: list[int]) -> bool:
143
        raise NotImplementedError
144

145
146
147
    @classmethod
    def log_warnings(cls):
        pass
148

149
    @classmethod
150
    def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
151
152
        from vllm.attention.backends.registry import AttentionBackendEnum

153
        parallel_config = vllm_config.parallel_config
154
        model_config = vllm_config.model_config
155

156
        if parallel_config.worker_cls == "auto":
157
            parallel_config.worker_cls = "vllm.v1.worker.gpu_worker.Worker"
158

159
160
161
        cache_config = vllm_config.cache_config
        if cache_config and cache_config.block_size is None:
            cache_config.block_size = 16
162

163
        # TODO(lucas): handle this more gracefully
164
        # Note: model_config may be None during testing
165
166
167
168
169
170
171
172
173
        # Note: block_size is initialized in
        # HybridAttentionMambaModelConfig.verify_and_update_config
        # for models with both attention and mamba,
        # and doesn't need to be reinitialized here
        if (
            model_config is not None
            and model_config.use_mla
            and cache_config.block_size is not None
        ):
174
            use_sparse = hasattr(vllm_config.model_config.hf_config, "index_topk")
175
            # If `--attention-config.backend` is not set and we are using MLA,
176
177
178
179
180
            # then we default to FlashMLA backend for non-blackwell GPUs,
            # else we default to CutlassMLA. For each case, we force the
            # required block_size.
            use_flashmla = False
            use_cutlass_mla = False
181
            use_flashinfer_mla = False
182

183
            if vllm_config.attention_config.backend is None:
184
185
186
187
                # Default case
                if cls.is_device_capability(100):
                    # Blackwell => Force CutlassMLA.
                    use_cutlass_mla = True
188
189
190
191
192
                    # Set the backend in AttentionConfig so it's used during
                    # backend selection
                    vllm_config.attention_config.backend = (
                        AttentionBackendEnum.CUTLASS_MLA
                    )
193
194
195
196
197
                else:
                    # Not Blackwell
                    use_flashmla = True
            else:
                # Forced case
198
199
200
201
                backend = vllm_config.attention_config.backend
                use_flashmla = backend == AttentionBackendEnum.FLASHMLA
                use_cutlass_mla = backend == AttentionBackendEnum.CUTLASS_MLA
                use_flashinfer_mla = backend == AttentionBackendEnum.FLASHINFER_MLA
202

203
            from vllm.attention.ops.flashmla import is_flashmla_dense_supported
204
205
206

            if (
                use_flashmla
207
                and is_flashmla_dense_supported()[0]
208
                and cache_config.block_size % 64 != 0
209
            ):
210
                cache_config.block_size = 64
211
                logger.info("Forcing kv cache block size to 64 for FlashMLA backend.")
212

213
            if use_cutlass_mla and cache_config.block_size % 128 != 0:
214
                cache_config.block_size = 128
215
216
217
                logger.info(
                    "Forcing kv cache block size to 128 for CUTLASS_MLA backend."
                )
218

219
220
221
222
223
            if (
                use_flashinfer_mla
                and cache_config.block_size != 32
                and cache_config.block_size % 64 != 0
            ):
224
225
                cache_config.block_size = 64
                logger.info(
226
227
                    "Forcing kv cache block size to 64 for FlashInferMLA backend."
                )
228

229
230
231
232
            # TODO(Chen): remove this hacky code
            if use_sparse and cache_config.block_size != 64:
                cache_config.block_size = 64
                logger.info(
233
234
                    "Forcing kv cache block size to 64 for FlashMLASparse backend."
                )
235

236
    @classmethod
237
    def get_current_memory_usage(
238
        cls, device: torch.types.Device | None = None
239
    ) -> float:
240
        torch.cuda.empty_cache()
241
242
243
        torch.cuda.reset_peak_memory_stats(device)
        return torch.cuda.max_memory_allocated(device)

244
    @classmethod
245
246
247
    def get_vit_attn_backend(
        cls, head_size: int, dtype: torch.dtype
    ) -> "AttentionBackendEnum":
248
        # Try FlashAttention first
249
250
251
252
253
254
255
256
257
        if (cc := cls.get_device_capability()) and cc.major >= 8:
            try:
                backend_class = AttentionBackendEnum.FLASH_ATTN.get_class()
                if backend_class.supports_head_size(
                    head_size
                ) and backend_class.supports_dtype(dtype):
                    return AttentionBackendEnum.FLASH_ATTN
            except ImportError:
                pass
258

259
        return AttentionBackendEnum.TORCH_SDPA
260

261
    @classmethod
262
    def get_valid_backends(
263
264
265
266
267
268
269
270
        cls,
        head_size,
        dtype,
        kv_cache_dtype,
        block_size,
        use_mla,
        has_sink,
        use_sparse,
271
        device_capability,
272
        attn_type,
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
    ) -> tuple[
        list[tuple["AttentionBackendEnum", int]],
        dict["AttentionBackendEnum", list[str]],
    ]:
        valid_backends_priorities = []
        invalid_reasons = {}

        backend_priorities = _get_backend_priorities(use_mla, device_capability)
        for priority, backend in enumerate(backend_priorities):
            try:
                backend_class = backend.get_class()
                invalid_reasons_i = backend_class.validate_configuration(
                    head_size,
                    dtype,
                    kv_cache_dtype,
                    block_size,
                    use_mla,
                    has_sink,
                    use_sparse,
                    device_capability,
293
                    attn_type,
294
                )
295
296
297
298
299
300
            except ImportError:
                invalid_reasons_i = ["ImportError"]
            if invalid_reasons_i:
                invalid_reasons[backend] = invalid_reasons_i
            else:
                valid_backends_priorities.append((backend, priority))
301

302
        return valid_backends_priorities, invalid_reasons
303

304
305
306
307
308
309
310
311
312
313
314
    @classmethod
    def get_attn_backend_cls(
        cls,
        selected_backend: "AttentionBackendEnum",
        head_size: int,
        dtype: torch.dtype,
        kv_cache_dtype: "CacheDType | None",
        block_size: int | None,
        use_mla: bool,
        has_sink: bool,
        use_sparse: bool,
315
        attn_type: str | None = None,
316
    ) -> str:
317
318
319
        if attn_type is None:
            attn_type = AttentionType.DECODER

320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
        device_capability = cls.get_device_capability()
        assert device_capability is not None

        # First try checking just the selected backend, if there is one.
        if selected_backend is not None:
            try:
                backend_class = selected_backend.get_class()
                invalid_reasons = backend_class.validate_configuration(
                    head_size,
                    dtype,
                    kv_cache_dtype,
                    None,
                    use_mla,
                    has_sink,
                    use_sparse,
                    device_capability,
336
                    attn_type,
337
                )
338
339
340
341
342
343
            except ImportError:
                invalid_reasons = ["ImportError"]
            if invalid_reasons:
                raise ValueError(
                    f"Selected backend {selected_backend} is not valid for "
                    f"this configuration. Reason: {invalid_reasons}"
344
                )
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
            else:
                logger.info("Using %s backend.", selected_backend)
                return selected_backend.get_path()

        # No selected backend or the selected backend is invalid,
        # so we try finding a valid backend.
        valid_backends_priorities, invalid_reasons = cls.get_valid_backends(
            head_size,
            dtype,
            kv_cache_dtype,
            None,
            use_mla,
            has_sink,
            use_sparse,
            device_capability,
360
            attn_type,
361
        )
362
363
364
365
366
367
368
        reasons_str = (
            "{"
            + ", ".join(
                f"{backend.name}: [{', '.join(reasons)}]"
                for backend, reasons in invalid_reasons.items()
            )
            + "}"
369
        )
370
371
372
373
374
375
376
377
378
379
380
381
382
383
        config_str = (
            f"head_size: {head_size}, dtype: {dtype}, "
            f"kv_cache_dtype: {kv_cache_dtype}, block_size: {block_size}, "
            f"use_mla: {use_mla}, has_sink: {has_sink}, use_sparse: {use_sparse}"
        )
        logger.debug_once(
            f"Some attention backends are not valid for {cls.device_name} with "
            f"{config_str}. Reasons: {reasons_str}."
        )
        if len(valid_backends_priorities) == 0:
            raise ValueError(
                f"No valid attention backend found for {cls.device_name} "
                f"with {config_str}. Reasons: {reasons_str}."
            )
384

385
386
387
388
389
390
391
392
393
        # We have found some valid backends. Select the one with the
        # highest priority.
        sorted_indices = sorted(
            range(len(valid_backends_priorities)),
            key=lambda i: valid_backends_priorities[i][1],
        )
        selected_index = sorted_indices[0]
        selected_backend = valid_backends_priorities[selected_index][0]
        logger.info(
394
            "Using %s attention backend out of potential backends: %s",
395
            selected_backend.name,
396
            [b[0].name for b in valid_backends_priorities],
397
398
399
        )

        return selected_backend.get_path()
400

401
402
403
404
    @classmethod
    def get_punica_wrapper(cls) -> str:
        return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU"

405
406
    @classmethod
    def get_device_communicator_cls(cls) -> str:
407
408
409
        return (
            "vllm.distributed.device_communicators.cuda_communicator.CudaCommunicator"  # noqa
        )
410

411
412
413
414
    @classmethod
    def supports_fp8(cls) -> bool:
        return cls.has_device_capability(89)

415
416
417
418
    @classmethod
    def use_custom_allreduce(cls) -> bool:
        return True

419
420
421
422
    @classmethod
    def opaque_attention_op(cls) -> bool:
        return True

423
    @classmethod
424
425
    def get_static_graph_wrapper_cls(cls) -> str:
        return "vllm.compilation.cuda_graph.CUDAGraphWrapper"
426

427
428
429
430
    @classmethod
    def device_count(cls) -> int:
        return cuda_device_count_stateless()

431
    @classmethod
432
433
    def check_if_supports_dtype(cls, dtype: torch.dtype):
        if dtype == torch.bfloat16:  # noqa: SIM102
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
            if not cls.has_device_capability(80):
                capability = cls.get_device_capability()
                gpu_name = cls.get_device_name()

                if capability is None:
                    compute_str = "does not have a compute capability"
                else:
                    version_str = capability.as_version_str()
                    compute_str = f"has compute capability {version_str}"

                raise ValueError(
                    "Bfloat16 is only supported on GPUs "
                    "with compute capability of at least 8.0. "
                    f"Your {gpu_name} GPU {compute_str}. "
                    "You can use float16 instead by explicitly setting the "
449
450
                    "`dtype` flag in CLI, for example: --dtype=half."
                )
451

452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
    @classmethod
    def insert_blocks_to_device(
        cls,
        src_cache: torch.Tensor,
        dst_cache: torch.Tensor,
        src_block_indices: torch.Tensor,
        dst_block_indices: torch.Tensor,
    ) -> None:
        """Copy blocks from src_cache to dst_cache on GPU."""
        _src_cache = src_cache[:, src_block_indices]
        dst_cache[:, dst_block_indices] = _src_cache.to(dst_cache.device)

    @classmethod
    def swap_out_blocks_to_host(
        cls,
        src_cache: torch.Tensor,
        dst_cache: torch.Tensor,
        src_block_indices: torch.Tensor,
        dst_block_indices: torch.Tensor,
    ) -> None:
        """Copy blocks from GPU to host (CPU)."""
        _src_cache = src_cache[:, src_block_indices]
        dst_cache[:, dst_block_indices] = _src_cache.cpu()

476
477
478
479
    @classmethod
    def support_hybrid_kv_cache(cls) -> bool:
        return True

480
481
482
483
    @classmethod
    def support_static_graph_mode(cls) -> bool:
        return True

484

485
486
487
488
489
# NVML utils
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
# all the related functions work on real physical device ids.
# the major benefit of using NVML is that it will not initialize CUDA
class NvmlCudaPlatform(CudaPlatformBase):
490
    @classmethod
491
    @cache
492
    @with_nvml_context
493
    def get_device_capability(cls, device_id: int = 0) -> DeviceCapability | None:
494
        try:
495
            physical_device_id = cls.device_id_to_physical_device_id(device_id)
496
497
498
499
500
501
502
503
504
505
            handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
            major, minor = pynvml.nvmlDeviceGetCudaComputeCapability(handle)
            return DeviceCapability(major=major, minor=minor)
        except RuntimeError:
            return None

    @classmethod
    @with_nvml_context
    def has_device_capability(
        cls,
506
        capability: tuple[int, int] | int,
507
508
509
510
511
512
        device_id: int = 0,
    ) -> bool:
        try:
            return super().has_device_capability(capability, device_id)
        except RuntimeError:
            return False
513

514
    @classmethod
515
    @with_nvml_context
516
    def get_device_name(cls, device_id: int = 0) -> str:
517
        physical_device_id = cls.device_id_to_physical_device_id(device_id)
518
        return cls._get_physical_device_name(physical_device_id)
519

520
521
522
    @classmethod
    @with_nvml_context
    def get_device_uuid(cls, device_id: int = 0) -> str:
523
        physical_device_id = cls.device_id_to_physical_device_id(device_id)
524
525
526
        handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
        return pynvml.nvmlDeviceGetUUID(handle)

527
    @classmethod
528
    @with_nvml_context
529
    def get_device_total_memory(cls, device_id: int = 0) -> int:
530
        physical_device_id = cls.device_id_to_physical_device_id(device_id)
531
532
        handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
        return int(pynvml.nvmlDeviceGetMemoryInfo(handle).total)
533

534
    @classmethod
535
    @with_nvml_context
536
    def is_fully_connected(cls, physical_device_ids: list[int]) -> bool:
537
538
539
        """
        query if the set of gpus are fully connected by nvlink (1 hop)
        """
540
        handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in physical_device_ids]
541
542
543
544
545
        for i, handle in enumerate(handles):
            for j, peer_handle in enumerate(handles):
                if i < j:
                    try:
                        p2p_status = pynvml.nvmlDeviceGetP2PStatus(
546
547
548
549
                            handle,
                            peer_handle,
                            pynvml.NVML_P2P_CAPS_INDEX_NVLINK,
                        )
550
551
                        if p2p_status != pynvml.NVML_P2P_STATUS_OK:
                            return False
552
553
                    except pynvml.NVMLError:
                        logger.exception(
554
                            "NVLink detection failed. This is normal if"
555
556
                            " your machine has no NVLink equipped."
                        )
557
558
                        return False
        return True
559
560

    @classmethod
561
562
563
564
565
566
567
568
569
    def _get_physical_device_name(cls, device_id: int = 0) -> str:
        handle = pynvml.nvmlDeviceGetHandleByIndex(device_id)
        return pynvml.nvmlDeviceGetName(handle)

    @classmethod
    @with_nvml_context
    def log_warnings(cls):
        device_ids: int = pynvml.nvmlDeviceGetCount()
        if device_ids > 1:
570
571
572
573
574
            device_names = [cls._get_physical_device_name(i) for i in range(device_ids)]
            if (
                len(set(device_names)) > 1
                and os.environ.get("CUDA_DEVICE_ORDER") != "PCI_BUS_ID"
            ):
575
                logger.warning(
576
                    "Detected different devices in the system: %s. Please"
577
578
                    " make sure to set `CUDA_DEVICE_ORDER=PCI_BUS_ID` to "
                    "avoid unexpected behavior.",
579
                    ", ".join(device_names),
580
581
582
583
584
                )


class NonNvmlCudaPlatform(CudaPlatformBase):
    @classmethod
585
    @cache
586
587
588
589
590
591
592
593
594
595
596
597
598
599
    def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:
        major, minor = torch.cuda.get_device_capability(device_id)
        return DeviceCapability(major=major, minor=minor)

    @classmethod
    def get_device_name(cls, device_id: int = 0) -> str:
        return torch.cuda.get_device_name(device_id)

    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        device_props = torch.cuda.get_device_properties(device_id)
        return device_props.total_memory

    @classmethod
600
    def is_fully_connected(cls, physical_device_ids: list[int]) -> bool:
601
602
        logger.exception(
            "NVLink detection not possible, as context support was"
603
604
            " not found. Assuming no NVLink available."
        )
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
        return False


# Autodetect either NVML-enabled or non-NVML platform
# based on whether NVML is available.
nvml_available = False
try:
    try:
        pynvml.nvmlInit()
        nvml_available = True
    except Exception:
        # On Jetson, NVML is not supported.
        nvml_available = False
finally:
    if nvml_available:
        pynvml.nvmlShutdown()

CudaPlatform = NvmlCudaPlatform if nvml_available else NonNvmlCudaPlatform

624
CudaPlatform.log_warnings()