cuda.py 22.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
"""Code inside this file can safely assume cuda platform, e.g. importing
pynvml. However, it should not initialize cuda context.
"""

7
import os
8
from collections.abc import Callable
9
from datetime import timedelta
10
from functools import cache, wraps
11
from typing import TYPE_CHECKING, TypeVar
12

13
import torch
14
15
from torch.distributed import PrefixStore, ProcessGroup
from torch.distributed.distributed_c10d import is_nccl_available
16
from typing_extensions import ParamSpec
17

18
19
# import custom ops, trigger op registration
import vllm._C  # noqa
20
from vllm.logger import init_logger
21
from vllm.utils.import_utils import import_pynvml
22
from vllm.utils.torch_utils import cuda_device_count_stateless
23
from vllm.v1.attention.backends.registry import AttentionBackendEnum
24

25
from .interface import DeviceCapability, Platform, PlatformEnum
26

27
if TYPE_CHECKING:
28
    from vllm.config import VllmConfig
29
    from vllm.config.cache import CacheDType
30
    from vllm.v1.attention.selector import AttentionSelectorConfig
31
else:
32
33
    VllmConfig = None
    CacheDType = None
34

35
36
logger = init_logger(__name__)

37
38
39
_P = ParamSpec("_P")
_R = TypeVar("_R")

40
pynvml = import_pynvml()
41

42
43
44
45
# pytorch 2.5 uses cudnn sdpa by default, which will cause crash on some models
# see https://github.com/huggingface/diffusers/issues/9704 for details
torch.backends.cuda.enable_cudnn_sdp(False)

46

47
48
49
50
@cache
def _get_backend_priorities(
    use_mla: bool,
    device_capability: DeviceCapability,
51
    num_heads: int | None = None,
52
53
54
55
) -> list[AttentionBackendEnum]:
    """Get backend priorities with lazy import to avoid circular dependency."""
    if use_mla:
        if device_capability.major == 10:
56
57
58
59
60
61
62
63
64
65
66
            # Prefer FlashInfer at low head counts (FlashMLA uses padding)
            if num_heads is not None and num_heads <= 16:
                sparse_backends = [
                    AttentionBackendEnum.FLASHINFER_MLA_SPARSE,
                    AttentionBackendEnum.FLASHMLA_SPARSE,
                ]
            else:
                sparse_backends = [
                    AttentionBackendEnum.FLASHMLA_SPARSE,
                    AttentionBackendEnum.FLASHINFER_MLA_SPARSE,
                ]
67
            return [
68
                AttentionBackendEnum.FLASHINFER_MLA,
69
                AttentionBackendEnum.CUTLASS_MLA,
70
                AttentionBackendEnum.FLASH_ATTN_MLA,
71
                AttentionBackendEnum.FLASHMLA,
72
                AttentionBackendEnum.TRITON_MLA,
73
                *sparse_backends,
74
75
76
77
            ]
        else:
            return [
                AttentionBackendEnum.FLASH_ATTN_MLA,
78
                AttentionBackendEnum.FLASHMLA,
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
                AttentionBackendEnum.FLASHINFER_MLA,
                AttentionBackendEnum.TRITON_MLA,
                AttentionBackendEnum.FLASHMLA_SPARSE,
            ]
    else:
        if device_capability.major == 10:
            return [
                AttentionBackendEnum.FLASHINFER,
                AttentionBackendEnum.FLASH_ATTN,
                AttentionBackendEnum.TRITON_ATTN,
                AttentionBackendEnum.FLEX_ATTENTION,
            ]
        else:
            return [
                AttentionBackendEnum.FLASH_ATTN,
                AttentionBackendEnum.FLASHINFER,
                AttentionBackendEnum.TRITON_ATTN,
                AttentionBackendEnum.FLEX_ATTENTION,
            ]


100
def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]:
101
    @wraps(fn)
102
    def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
103
104
105
106
107
108
109
110
111
        pynvml.nvmlInit()
        try:
            return fn(*args, **kwargs)
        finally:
            pynvml.nvmlShutdown()

    return wrapper


112
113
class CudaPlatformBase(Platform):
    _enum = PlatformEnum.CUDA
114
    device_name: str = "cuda"
115
116
    device_type: str = "cuda"
    dispatch_key: str = "CUDA"
117
    ray_device_key: str = "GPU"
118
    dist_backend: str = "nccl"
119
    device_control_env_var: str = "CUDA_VISIBLE_DEVICES"
120
121
122
    ray_noset_device_env_vars: list[str] = [
        "RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES",
    ]
123

124
    @property
125
    def supported_dtypes(self) -> list[torch.dtype]:
126
127
128
        if self.has_device_capability(80):
            # Ampere and Hopper or later NVIDIA GPUs.
            return [torch.bfloat16, torch.float16, torch.float32]
129
        if self.has_device_capability(60):
130
131
132
133
134
135
            # Pascal, Volta and Turing NVIDIA GPUs, BF16 is not supported
            return [torch.float16, torch.float32]
        # Kepler and Maxwell NVIDIA GPUs, only FP32 is supported,
        # though vLLM doesn't support these GPUs.
        return [torch.float32]

136
137
138
139
140
    @classmethod
    def set_device(cls, device: torch.device) -> None:
        """
        Set the device for the current platform.
        """
141
        torch.cuda.set_device(device)
142
143
144
145
146
        # With this trick we can force the device to be set eagerly
        # see https://github.com/pytorch/pytorch/issues/155668
        # for why and when it is needed
        _ = torch.zeros(1, device=device)

147
    @classmethod
148
    def get_device_capability(cls, device_id: int = 0) -> DeviceCapability | None:
149
        raise NotImplementedError
150

151
152
153
    @classmethod
    def get_device_name(cls, device_id: int = 0) -> str:
        raise NotImplementedError
154

155
156
157
    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        raise NotImplementedError
158

159
    @classmethod
160
    def is_fully_connected(cls, device_ids: list[int]) -> bool:
161
        raise NotImplementedError
162

163
164
165
    @classmethod
    def log_warnings(cls):
        pass
166

167
    @classmethod
168
    def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
169
        parallel_config = vllm_config.parallel_config
170
        model_config = vllm_config.model_config
171

172
        if parallel_config.worker_cls == "auto":
173
            parallel_config.worker_cls = "vllm.v1.worker.gpu_worker.Worker"
174

175
176
177
178
179
180
181
182
183
184
185
186
187
188
        scheduler_config = vllm_config.scheduler_config
        # Note: model_config may be None during testing
        if (
            model_config is not None
            and model_config.is_mm_prefix_lm
            and scheduler_config.is_multimodal_model
            and not scheduler_config.disable_chunked_mm_input
        ):
            logger.warning(
                "Forcing --disable_chunked_mm_input for models "
                "with multimodal-bidirectional attention."
            )
            scheduler_config.disable_chunked_mm_input = True

189
    @classmethod
190
    def get_current_memory_usage(
191
        cls, device: torch.types.Device | None = None
192
    ) -> float:
193
        torch.cuda.empty_cache()
194
195
196
        torch.cuda.reset_peak_memory_stats(device)
        return torch.cuda.max_memory_allocated(device)

197
    @classmethod
198
    def get_valid_backends(
199
        cls,
200
201
        device_capability: DeviceCapability,
        attn_selector_config: "AttentionSelectorConfig",
202
        num_heads: int | None = None,
203
204
    ) -> tuple[
        list[tuple["AttentionBackendEnum", int]],
205
        dict["AttentionBackendEnum", tuple[int, list[str]]],
206
207
    ]:
        valid_backends_priorities = []
208
        invalid_reasons: dict[AttentionBackendEnum, tuple[int, list[str]]] = {}
209

210
        backend_priorities = _get_backend_priorities(
211
212
213
            attn_selector_config.use_mla,
            device_capability,
            num_heads,
214
        )
215
216
217
218
        for priority, backend in enumerate(backend_priorities):
            try:
                backend_class = backend.get_class()
                invalid_reasons_i = backend_class.validate_configuration(
219
220
                    device_capability=device_capability,
                    **attn_selector_config._asdict(),
221
                )
222
223
224
            except ImportError:
                invalid_reasons_i = ["ImportError"]
            if invalid_reasons_i:
225
                invalid_reasons[backend] = (priority, invalid_reasons_i)
226
227
            else:
                valid_backends_priorities.append((backend, priority))
228

229
        return valid_backends_priorities, invalid_reasons
230

231
    @classmethod
232
    def get_attn_backend_cls(
233
        cls,
234
        selected_backend: "AttentionBackendEnum | None",
235
        attn_selector_config: "AttentionSelectorConfig",
236
        num_heads: int | None = None,
237
238
239
240
    ) -> str:
        device_capability = cls.get_device_capability()
        assert device_capability is not None

241
242
243
244
        # First try checking just the selected backend, if there is one.
        if selected_backend is not None:
            try:
                backend_class = selected_backend.get_class()
245
                invalid_reasons = backend_class.validate_configuration(
246
247
                    device_capability=device_capability,
                    **attn_selector_config._asdict(),
248
                )
249
            except ImportError:
250
251
252
253
254
255
256
257
258
                invalid_reasons = ["ImportError"]
            if invalid_reasons:
                raise ValueError(
                    f"Selected backend {selected_backend} is not valid for "
                    f"this configuration. Reason: {invalid_reasons}"
                )
            else:
                logger.info("Using %s backend.", selected_backend)
                return selected_backend.get_path()
259

260
261
        # No selected backend or the selected backend is invalid,
        # so we try finding a valid backend.
262
        valid_backends_priorities, all_invalid_reasons = cls.get_valid_backends(
263
264
            device_capability=device_capability,
            attn_selector_config=attn_selector_config,
265
            num_heads=num_heads,
266
        )
267
268
269
270
        reasons_str = (
            "{"
            + ", ".join(
                f"{backend.name}: [{', '.join(reasons)}]"
271
                for backend, (_, reasons) in all_invalid_reasons.items()
272
273
274
275
276
277
278
279
            )
            + "}"
        )
        config_str = attn_selector_config.__repr__()
        logger.debug_once(
            f"Some attention backends are not valid for {cls.device_name} with "
            f"{config_str}. Reasons: {reasons_str}."
        )
280
        if len(valid_backends_priorities) == 0:
281
282
283
284
            raise ValueError(
                f"No valid attention backend found for {cls.device_name} "
                f"with {config_str}. Reasons: {reasons_str}."
            )
285

286
287
288
289
290
291
292
293
        # We have found some valid backends. Select the one with the
        # highest priority.
        sorted_indices = sorted(
            range(len(valid_backends_priorities)),
            key=lambda i: valid_backends_priorities[i][1],
        )
        selected_index = sorted_indices[0]
        selected_backend = valid_backends_priorities[selected_index][0]
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
        selected_priority = valid_backends_priorities[selected_index][1]

        # If the user specified --block-size (but not --attention-backend),
        # check whether that constraint precluded any higher-priority backends.
        if attn_selector_config.block_size is not None:
            excluded = [
                backend
                for backend, (priority, reasons) in all_invalid_reasons.items()
                if priority < selected_priority
                and reasons == ["block_size not supported"]
            ]
            if excluded:
                names = ", ".join(b.name for b in excluded)
                logger.warning(
                    "--block-size %d precluded higher-priority backend(s) "
                    "%s. Using %s instead, which may result in reduced "
                    "performance. Consider removing --block-size to "
                    "auto-select the optimal block size.",
                    attn_selector_config.block_size,
                    names,
                    selected_backend.name,
                )

317
318
319
320
321
        logger.info_once(
            "Using %s attention backend out of potential backends: %s.",
            selected_backend.name,
            "[" + ", ".join(f"'{b[0].name}'" for b in valid_backends_priorities) + "]",
            scope="local",
322
323
        )

324
        return selected_backend.get_path()
325

326
327
    @classmethod
    def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]:
328
329
330
331
332
333
334
335
336
337
338
339
340
341
        if cls.has_device_capability(80):
            return [
                AttentionBackendEnum.FLASH_ATTN,
                AttentionBackendEnum.TRITON_ATTN,
                AttentionBackendEnum.TORCH_SDPA,
                AttentionBackendEnum.FLASHINFER,
            ]
        else:
            return [
                AttentionBackendEnum.FLASH_ATTN,
                AttentionBackendEnum.TORCH_SDPA,
                AttentionBackendEnum.TRITON_ATTN,
                AttentionBackendEnum.FLASHINFER,
            ]
342
343
344
345
346
347

    @classmethod
    def get_vit_attn_backend(
        cls,
        head_size: int,
        dtype: torch.dtype,
348
        backend: "AttentionBackendEnum | None" = None,
349
350
351
352
353
354
355
356
357
    ) -> "AttentionBackendEnum":
        if backend is not None:
            assert backend in cls.get_supported_vit_attn_backends(), (
                f"Backend {backend} is not supported for vit attention. "
                f"Supported backends are: {cls.get_supported_vit_attn_backends()}"
            )
            logger.info_once(f"Using backend {backend} for vit attention")
            return backend

358
359
360
        cc = cls.get_device_capability()
        for vit_attn_backend in cls.get_supported_vit_attn_backends():
            if vit_attn_backend == AttentionBackendEnum.TORCH_SDPA:
361
                return vit_attn_backend
362
            try:
363
364
                backend_class = vit_attn_backend.get_class()
                is_backend_supported = backend_class.supports_head_size(
365
                    head_size
366
367
368
369
370
371
372
373
374
375
376
                ) and backend_class.supports_dtype(dtype)
                if cc is not None:
                    is_backend_supported = (
                        is_backend_supported
                        and backend_class.supports_compute_capability(cc)
                    )
                if is_backend_supported:
                    logger.info_once(
                        f"Using backend {vit_attn_backend} for vit attention"
                    )
                    return vit_attn_backend
377
378
379
380
381
            except ImportError:
                pass

        return AttentionBackendEnum.TORCH_SDPA

382
383
384
385
    @classmethod
    def get_punica_wrapper(cls) -> str:
        return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU"

386
387
    @classmethod
    def get_device_communicator_cls(cls) -> str:
388
389
390
        return (
            "vllm.distributed.device_communicators.cuda_communicator.CudaCommunicator"  # noqa
        )
391

392
393
394
395
    @classmethod
    def supports_fp8(cls) -> bool:
        return cls.has_device_capability(89)

396
397
398
399
    @classmethod
    def use_custom_allreduce(cls) -> bool:
        return True

400
401
402
403
    @classmethod
    def opaque_attention_op(cls) -> bool:
        return True

404
    @classmethod
405
406
    def get_static_graph_wrapper_cls(cls) -> str:
        return "vllm.compilation.cuda_graph.CUDAGraphWrapper"
407

408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
    @classmethod
    def stateless_init_device_torch_dist_pg(
        cls,
        backend: str,
        prefix_store: PrefixStore,
        group_rank: int,
        group_size: int,
        timeout: timedelta,
    ) -> ProcessGroup:
        assert is_nccl_available()
        pg: ProcessGroup = ProcessGroup(
            prefix_store,
            group_rank,
            group_size,
        )
        from torch.distributed.distributed_c10d import ProcessGroupNCCL

        backend_options = ProcessGroupNCCL.Options()
        backend_options._timeout = timeout

        backend_class = ProcessGroupNCCL(
            prefix_store, group_rank, group_size, backend_options
        )
        backend_type = ProcessGroup.BackendType.NCCL
        device = torch.device("cuda")
        pg._set_default_backend(backend_type)
        backend_class._set_sequence_number_for_group()

        pg._register_backend(device, backend_type, backend_class)
        return pg

439
440
441
442
    @classmethod
    def device_count(cls) -> int:
        return cuda_device_count_stateless()

443
    @classmethod
444
445
    def check_if_supports_dtype(cls, dtype: torch.dtype):
        if dtype == torch.bfloat16:  # noqa: SIM102
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
            if not cls.has_device_capability(80):
                capability = cls.get_device_capability()
                gpu_name = cls.get_device_name()

                if capability is None:
                    compute_str = "does not have a compute capability"
                else:
                    version_str = capability.as_version_str()
                    compute_str = f"has compute capability {version_str}"

                raise ValueError(
                    "Bfloat16 is only supported on GPUs "
                    "with compute capability of at least 8.0. "
                    f"Your {gpu_name} GPU {compute_str}. "
                    "You can use float16 instead by explicitly setting the "
461
462
                    "`dtype` flag in CLI, for example: --dtype=half."
                )
463

464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
    @classmethod
    def insert_blocks_to_device(
        cls,
        src_cache: torch.Tensor,
        dst_cache: torch.Tensor,
        src_block_indices: torch.Tensor,
        dst_block_indices: torch.Tensor,
    ) -> None:
        """Copy blocks from src_cache to dst_cache on GPU."""
        _src_cache = src_cache[:, src_block_indices]
        dst_cache[:, dst_block_indices] = _src_cache.to(dst_cache.device)

    @classmethod
    def swap_out_blocks_to_host(
        cls,
        src_cache: torch.Tensor,
        dst_cache: torch.Tensor,
        src_block_indices: torch.Tensor,
        dst_block_indices: torch.Tensor,
    ) -> None:
        """Copy blocks from GPU to host (CPU)."""
        _src_cache = src_cache[:, src_block_indices]
        dst_cache[:, dst_block_indices] = _src_cache.cpu()

488
489
490
491
    @classmethod
    def support_hybrid_kv_cache(cls) -> bool:
        return True

492
493
494
495
    @classmethod
    def support_static_graph_mode(cls) -> bool:
        return True

496
    @classmethod
497
    def num_compute_units(cls, device_id: int = 0) -> int:
498
499
        return torch.cuda.get_device_properties(device_id).multi_processor_count

500
501
502
503
    @classmethod
    def use_custom_op_collectives(cls) -> bool:
        return True

504

505
506
507
508
509
# NVML utils
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
# all the related functions work on real physical device ids.
# the major benefit of using NVML is that it will not initialize CUDA
class NvmlCudaPlatform(CudaPlatformBase):
510
    @classmethod
511
    @cache
512
    @with_nvml_context
513
    def get_device_capability(cls, device_id: int = 0) -> DeviceCapability | None:
514
        try:
515
            physical_device_id = cls.device_id_to_physical_device_id(device_id)
516
517
518
519
520
521
522
523
524
525
            handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
            major, minor = pynvml.nvmlDeviceGetCudaComputeCapability(handle)
            return DeviceCapability(major=major, minor=minor)
        except RuntimeError:
            return None

    @classmethod
    @with_nvml_context
    def has_device_capability(
        cls,
526
        capability: tuple[int, int] | int,
527
528
529
530
531
532
        device_id: int = 0,
    ) -> bool:
        try:
            return super().has_device_capability(capability, device_id)
        except RuntimeError:
            return False
533

534
    @classmethod
535
    @with_nvml_context
536
    def get_device_name(cls, device_id: int = 0) -> str:
537
        physical_device_id = cls.device_id_to_physical_device_id(device_id)
538
        return cls._get_physical_device_name(physical_device_id)
539

540
541
542
    @classmethod
    @with_nvml_context
    def get_device_uuid(cls, device_id: int = 0) -> str:
543
        physical_device_id = cls.device_id_to_physical_device_id(device_id)
544
545
546
        handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
        return pynvml.nvmlDeviceGetUUID(handle)

547
    @classmethod
548
    @with_nvml_context
549
    def get_device_total_memory(cls, device_id: int = 0) -> int:
550
        physical_device_id = cls.device_id_to_physical_device_id(device_id)
551
552
        handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
        return int(pynvml.nvmlDeviceGetMemoryInfo(handle).total)
553

554
    @classmethod
555
    @with_nvml_context
556
    def is_fully_connected(cls, physical_device_ids: list[int]) -> bool:
557
558
559
        """
        query if the set of gpus are fully connected by nvlink (1 hop)
        """
560
        handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in physical_device_ids]
561
562
563
564
565
        for i, handle in enumerate(handles):
            for j, peer_handle in enumerate(handles):
                if i < j:
                    try:
                        p2p_status = pynvml.nvmlDeviceGetP2PStatus(
566
567
568
569
                            handle,
                            peer_handle,
                            pynvml.NVML_P2P_CAPS_INDEX_NVLINK,
                        )
570
571
                        if p2p_status != pynvml.NVML_P2P_STATUS_OK:
                            return False
572
573
                    except pynvml.NVMLError:
                        logger.exception(
574
                            "NVLink detection failed. This is normal if"
575
576
                            " your machine has no NVLink equipped."
                        )
577
578
                        return False
        return True
579
580

    @classmethod
581
582
583
584
585
586
587
588
589
    def _get_physical_device_name(cls, device_id: int = 0) -> str:
        handle = pynvml.nvmlDeviceGetHandleByIndex(device_id)
        return pynvml.nvmlDeviceGetName(handle)

    @classmethod
    @with_nvml_context
    def log_warnings(cls):
        device_ids: int = pynvml.nvmlDeviceGetCount()
        if device_ids > 1:
590
591
592
593
594
            device_names = [cls._get_physical_device_name(i) for i in range(device_ids)]
            if (
                len(set(device_names)) > 1
                and os.environ.get("CUDA_DEVICE_ORDER") != "PCI_BUS_ID"
            ):
595
                logger.warning(
596
                    "Detected different devices in the system: %s. Please"
597
598
                    " make sure to set `CUDA_DEVICE_ORDER=PCI_BUS_ID` to "
                    "avoid unexpected behavior.",
599
                    ", ".join(device_names),
600
601
602
603
604
                )


class NonNvmlCudaPlatform(CudaPlatformBase):
    @classmethod
605
    @cache
606
607
608
609
610
611
612
613
614
615
616
617
618
619
    def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:
        major, minor = torch.cuda.get_device_capability(device_id)
        return DeviceCapability(major=major, minor=minor)

    @classmethod
    def get_device_name(cls, device_id: int = 0) -> str:
        return torch.cuda.get_device_name(device_id)

    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        device_props = torch.cuda.get_device_properties(device_id)
        return device_props.total_memory

    @classmethod
620
    def is_fully_connected(cls, physical_device_ids: list[int]) -> bool:
621
622
        logger.exception(
            "NVLink detection not possible, as context support was"
623
624
            " not found. Assuming no NVLink available."
        )
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
        return False


# Autodetect either NVML-enabled or non-NVML platform
# based on whether NVML is available.
nvml_available = False
try:
    try:
        pynvml.nvmlInit()
        nvml_available = True
    except Exception:
        # On Jetson, NVML is not supported.
        nvml_available = False
finally:
    if nvml_available:
        pynvml.nvmlShutdown()

CudaPlatform = NvmlCudaPlatform if nvml_available else NonNvmlCudaPlatform

644
CudaPlatform.log_warnings()