cuda.py 22.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
"""Code inside this file can safely assume cuda platform, e.g. importing
pynvml. However, it should not initialize cuda context.
"""

7
import os
8
from datetime import timedelta
9
from functools import cache, wraps
10
from typing import TYPE_CHECKING, Callable, Optional, TypeVar, Union
11

12
import torch
13
14
from torch.distributed import PrefixStore, ProcessGroup
from torch.distributed.distributed_c10d import is_nccl_available
15
from typing_extensions import ParamSpec
16

17
18
# import custom ops, trigger op registration
import vllm._C  # noqa
19
import vllm.envs as envs
20
from vllm.logger import init_logger
21
from vllm.utils import cuda_device_count_stateless, import_pynvml
22

23
from .interface import DeviceCapability, Platform, PlatformEnum, _Backend
24

25
if TYPE_CHECKING:
26
    from vllm.config import ModelConfig, VllmConfig
27

28
29
logger = init_logger(__name__)

30
31
32
_P = ParamSpec("_P")
_R = TypeVar("_R")

33
pynvml = import_pynvml()
34

35
36
37
38
# pytorch 2.5 uses cudnn sdpa by default, which will cause crash on some models
# see https://github.com/huggingface/diffusers/issues/9704 for details
torch.backends.cuda.enable_cudnn_sdp(False)

39

40
def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]:
41
42

    @wraps(fn)
43
    def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
44
45
46
47
48
49
50
51
52
        pynvml.nvmlInit()
        try:
            return fn(*args, **kwargs)
        finally:
            pynvml.nvmlShutdown()

    return wrapper


53
54
class CudaPlatformBase(Platform):
    _enum = PlatformEnum.CUDA
55
    device_name: str = "cuda"
56
57
    device_type: str = "cuda"
    dispatch_key: str = "CUDA"
58
    ray_device_key: str = "GPU"
59
    device_control_env_var: str = "CUDA_VISIBLE_DEVICES"
60

61
    @property
62
    def supported_dtypes(self) -> list[torch.dtype]:
63
64
65
66
67
68
69
70
71
72
73
        if self.has_device_capability(80):
            # Ampere and Hopper or later NVIDIA GPUs.
            return [torch.bfloat16, torch.float16, torch.float32]
        elif (not self.has_device_capability(80)
              ) and self.has_device_capability(60):
            # Pascal, Volta and Turing NVIDIA GPUs, BF16 is not supported
            return [torch.float16, torch.float32]
        # Kepler and Maxwell NVIDIA GPUs, only FP32 is supported,
        # though vLLM doesn't support these GPUs.
        return [torch.float32]

74
75
76
77
78
79
80
81
82
83
84
    @classmethod
    def set_device(cls, device: torch.device) -> None:
        """
        Set the device for the current platform.
        """
        super().set_device(device)
        # With this trick we can force the device to be set eagerly
        # see https://github.com/pytorch/pytorch/issues/155668
        # for why and when it is needed
        _ = torch.zeros(1, device=device)

85
    @classmethod
86
87
88
    def get_device_capability(cls,
                              device_id: int = 0
                              ) -> Optional[DeviceCapability]:
89
        raise NotImplementedError
90

91
92
93
    @classmethod
    def get_device_name(cls, device_id: int = 0) -> str:
        raise NotImplementedError
94

95
96
97
    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        raise NotImplementedError
98

99
100
101
102
103
104
105
106
107
108
    @classmethod
    def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
        if enforce_eager:
            logger.warning(
                "To see benefits of async output processing, enable CUDA "
                "graph. Since, enforce-eager is enabled, async output "
                "processor cannot be used")
            return False
        return True

109
    @classmethod
110
    def is_fully_connected(cls, device_ids: list[int]) -> bool:
111
        raise NotImplementedError
112

113
114
115
    @classmethod
    def log_warnings(cls):
        pass
116

117
    @classmethod
118
    def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
119
120
        parallel_config = vllm_config.parallel_config
        scheduler_config = vllm_config.scheduler_config
121
        model_config = vllm_config.model_config
122

123
124
        if parallel_config.worker_cls == "auto":
            if scheduler_config.is_multi_step:
125
                if envs.VLLM_USE_V1:
126
127
                    raise NotImplementedError(
                        "Multi-step scheduling is not supported (and not "
128
                        "needed) on vLLM V1. Please launch without "
129
                        "--num-scheduler-steps.")
130
131
132
                else:
                    parallel_config.worker_cls = \
                        "vllm.worker.multi_step_worker.MultiStepWorker"
133
            elif vllm_config.speculative_config:
134
                if envs.VLLM_USE_V1:
135
136
                    parallel_config.worker_cls = \
                            "vllm.v1.worker.gpu_worker.Worker"
137
138
139
140
141
                else:
                    parallel_config.worker_cls = \
                        "vllm.spec_decode.spec_decode_worker.create_spec_worker"
                    parallel_config.sd_worker_cls = \
                        "vllm.worker.worker.Worker"
142
            else:
143
144
145
146
147
                if envs.VLLM_USE_V1:
                    parallel_config.worker_cls = \
                            "vllm.v1.worker.gpu_worker.Worker"
                else:
                    parallel_config.worker_cls = "vllm.worker.worker.Worker"
148

149
150
151
        cache_config = vllm_config.cache_config
        if cache_config and cache_config.block_size is None:
            cache_config.block_size = 16
152

153
        # TODO(lucas): handle this more gracefully
154
155
156
157
158
159
160
        # Note: model_config may be None during testing
        if model_config is not None and model_config.use_mla:
            # if `VLLM_ATTENTION_BACKEND` is not set and we are using MLA, then
            # we default to FlashMLA backend, so we need to force the blocksize
            # here
            use_flashmla = (envs.VLLM_ATTENTION_BACKEND is None \
                or envs.VLLM_ATTENTION_BACKEND == "FLASHMLA")
161
            from vllm.attention.ops.flashmla import is_flashmla_supported
162
163
164
165
166
            if use_flashmla and is_flashmla_supported()[0] \
                and cache_config.block_size != 64:
                cache_config.block_size = 64
                logger.info(
                    "Forcing kv cache block size to 64 for FlashMLA backend.")
167

168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
        if (envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput"
                and parallel_config.data_parallel_size > 1
                and vllm_config.compilation_config.use_cudagraph):
            logger.info(
                "Data Parallel: Forcing enforce eager to be True since DP "
                "with DeepEP high-throughput kernels are not CUDA Graph "
                "compatible. The DeepEP low-latency kernels are CUDA Graph "
                "compatible. Set the all_to_all backend to deepep_low_latency "
                "to use those kernels instead.")
            vllm_config.compilation_config.use_cudagraph = False
            vllm_config.model_config.enforce_eager = True
            # TODO (varun): Turning this ON gives incorrect results for the
            # Deepseek-V2-lite model.
            vllm_config.compilation_config.use_inductor = False

183
184
185
186
    @classmethod
    def get_current_memory_usage(cls,
                                 device: Optional[torch.types.Device] = None
                                 ) -> float:
187
        torch.cuda.empty_cache()
188
189
190
        torch.cuda.reset_peak_memory_stats(device)
        return torch.cuda.max_memory_allocated(device)

191
192
    @classmethod
    def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
193
194
195
                             kv_cache_dtype, block_size, use_v1,
                             use_mla) -> str:
        if use_mla:
196
197
            # TODO(lucas): refactor to  be more concise
            #  we should probably consider factoring out V1 here
198
199
200
201
202
203
204
205
            if selected_backend == _Backend.CUTLASS_MLA_VLLM_V1:
                if use_v1:
                    logger.info_once("Using Cutlass MLA backend on V1 engine.")
                    return ("vllm.v1.attention.backends.mla."
                            "cutlass_mla.CutlassMLABackend")
                else:
                    logger.warning(
                        "Cutlass MLA backend is only supported on V1 engine")
206
207
208
209
210
211
212
213
214
            if selected_backend == _Backend.TRITON_MLA or block_size != 64:
                if use_v1:
                    logger.info_once("Using Triton MLA backend on V1 engine.")
                    return ("vllm.v1.attention.backends.mla."
                            "triton_mla.TritonMLABackend")
                else:
                    logger.info("Using Triton MLA backend.")
                    return "vllm.attention.backends.triton_mla.TritonMLABackend"
            else:
215
216
217
218
219
220
221
222
223
224
225
226
                from vllm.attention.backends.flashmla import (
                    is_flashmla_supported)
                if not is_flashmla_supported()[0]:
                    logger.warning(
                        "FlashMLA backend is not supported due to %s",
                        is_flashmla_supported()[1])
                elif block_size != 64:
                    logger.warning(
                        "FlashMLA backend is not supported for block size %d"
                        " (currently only supports block size 64).",
                        block_size)
                else:
227
                    if use_v1:
228
229
                        logger.info_once(
                            "Using FlashMLA backend on V1 engine.")
230
231
232
233
234
235
236
                        return ("vllm.v1.attention.backends.mla."
                                "flashmla.FlashMLABackend")
                    else:
                        logger.info("Using FlashMLA backend.")
                        return ("vllm.attention.backends."
                                "flashmla.FlashMLABackend")
        if use_v1:
237
238
239
240
241
            FLASHINFER_V1 = "vllm.v1.attention.backends.flashinfer.FlashInferBackend"  # noqa: E501
            FLEX_ATTENTION_V1 = "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend"  # noqa: E501
            TRITON_ATTN_VLLM_V1 = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend"  # noqa: E501
            FLASH_ATTN_V1 = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"  # noqa: E501

242
243
            if selected_backend == _Backend.FLASHINFER:
                logger.info_once("Using FlashInfer backend on V1 engine.")
244
                return FLASHINFER_V1
245
            elif selected_backend == _Backend.FLEX_ATTENTION:
246
247
                logger.info_once("Using FlexAttention backend on V1 engine.")
                return FLEX_ATTENTION_V1
248
            elif selected_backend == _Backend.TRITON_ATTN_VLLM_V1:
249
                logger.info_once("Using Triton backend on V1 engine.")
250
                return TRITON_ATTN_VLLM_V1
251
252
            elif selected_backend == _Backend.FLASH_ATTN:
                logger.info_once("Using Flash Attention backend on V1 engine.")
253
254
255
                return FLASH_ATTN_V1

            from vllm.attention.selector import supports_head_size
256
257

            # Default backends for V1 engine
258
            # FP32 is only supported by FlexAttention
259
260
            if dtype not in (torch.float16, torch.bfloat16):
                logger.info_once(
261
262
263
264
265
266
267
268
                    "Using FlexAttention backend for %s on V1 engine.",
                    dtype,
                )
                return FLEX_ATTENTION_V1

            # Prefer FlashInfer for Blackwell GPUs if installed
            if cls.is_device_capability(100) and \
                supports_head_size(FLASHINFER_V1, head_size):
269
270
271
272
273
                try:
                    import flashinfer  # noqa: F401
                    logger.info_once(
                        "Using FlashInfer backend on V1 engine by default for "
                        "Blackwell (SM 10.0) GPUs.")
274
                    return FLASHINFER_V1
275
276
277
278
279
280
                except ImportError:
                    logger.info_once(
                        "FlashInfer failed to import for V1 engine on "
                        "Blackwell (SM 10.0) GPUs; it is recommended to "
                        "install FlashInfer for better performance.")
                    pass
281
            # FlashAttention is the default for SM 8.0+ GPUs
282
283
            if cls.has_device_capability(80) and \
                supports_head_size(FLASH_ATTN_V1, head_size):
284
                logger.info_once("Using Flash Attention backend on V1 engine.")
285
286
287
288
                return FLASH_ATTN_V1

            logger.info_once("Using FlexAttention backend on V1 engine.")
            return FLEX_ATTENTION_V1
289
290

        # Backends for V0 engine
291
292
293
294
295
296
        if selected_backend == _Backend.FLASHINFER:
            logger.info("Using FlashInfer backend.")
            return "vllm.attention.backends.flashinfer.FlashInferBackend"
        elif selected_backend == _Backend.XFORMERS:
            logger.info("Using XFormers backend.")
            return "vllm.attention.backends.xformers.XFormersBackend"
297
298
299
300
        elif selected_backend == _Backend.DUAL_CHUNK_FLASH_ATTN:
            logger.info("Using DualChunkFlashAttention backend.")
            return ("vllm.attention.backends.dual_chunk_flash_attn."
                    "DualChunkFlashAttentionBackend")
301
302
303
304
        elif selected_backend == _Backend.FLASH_ATTN:
            pass
        elif selected_backend:
            raise ValueError(
305
306
                f"Invalid attention backend for {cls.device_name}, "
                f"with use_v1: {use_v1} use_mla: {use_mla}")
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331

        target_backend = _Backend.FLASH_ATTN
        if not cls.has_device_capability(80):
            # Volta and Turing NVIDIA GPUs.
            logger.info(
                "Cannot use FlashAttention-2 backend for Volta and Turing "
                "GPUs.")
            target_backend = _Backend.XFORMERS
        elif dtype not in (torch.float16, torch.bfloat16):
            logger.info(
                "Cannot use FlashAttention-2 backend for dtype other than "
                "torch.float16 or torch.bfloat16.")
            target_backend = _Backend.XFORMERS
        elif block_size % 16 != 0:
            logger.info(
                "Cannot use FlashAttention-2 backend for block size not "
                "divisible by 16.")
            target_backend = _Backend.XFORMERS

        # FlashAttn is valid for the model, checking if the package is
        # installed.
        if target_backend == _Backend.FLASH_ATTN:
            try:
                import vllm.vllm_flash_attn  # noqa: F401
                from vllm.attention.backends.flash_attn import (  # noqa: F401
332
                    FlashAttentionBackend, flash_attn_supports_fp8)
333
334
335
336
337
338
339
340

                supported_sizes = \
                    FlashAttentionBackend.get_supported_head_sizes()
                if head_size not in supported_sizes:
                    logger.info(
                        "Cannot use FlashAttention-2 backend for head size %d.",
                        head_size)
                    target_backend = _Backend.XFORMERS
341
342
                fp8_kv_cache = (kv_cache_dtype is not None
                                and kv_cache_dtype.startswith("fp8"))
343
                if (fp8_kv_cache and not flash_attn_supports_fp8()):
344
                    logger.info(
345
                        "Cannot use FlashAttention backend for FP8 KV cache.")
346
347
348
349
350
                    logger.warning(
                        "Please use FlashInfer backend with FP8 KV Cache for "
                        "better performance by setting environment variable "
                        "VLLM_ATTENTION_BACKEND=FLASHINFER")
                    target_backend = _Backend.XFORMERS
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
            except ImportError:
                logger.info(
                    "Cannot use FlashAttention-2 backend because the "
                    "vllm.vllm_flash_attn package is not found. "
                    "Make sure that vllm_flash_attn was built and installed "
                    "(on by default).")
                target_backend = _Backend.XFORMERS

        if target_backend == _Backend.XFORMERS:
            logger.info("Using XFormers backend.")
            return "vllm.attention.backends.xformers.XFormersBackend"

        logger.info("Using Flash Attention backend.")
        return "vllm.attention.backends.flash_attn.FlashAttentionBackend"

366
367
368
369
    @classmethod
    def get_punica_wrapper(cls) -> str:
        return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU"

370
371
372
373
    @classmethod
    def get_device_communicator_cls(cls) -> str:
        return "vllm.distributed.device_communicators.cuda_communicator.CudaCommunicator"  # noqa

374
375
376
377
    @classmethod
    def supports_fp8(cls) -> bool:
        return cls.has_device_capability(89)

378
    @classmethod
379
    def supports_v1(cls, model_config: "ModelConfig") -> bool:
380
381
        return True

382
383
384
385
    @classmethod
    def use_custom_allreduce(cls) -> bool:
        return True

386
387
388
389
    @classmethod
    def get_piecewise_backend_cls(cls) -> str:
        return "vllm.compilation.cuda_piecewise_backend.CUDAPiecewiseBackend"  # noqa

390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
    @classmethod
    def stateless_init_device_torch_dist_pg(
        cls,
        backend: str,
        prefix_store: PrefixStore,
        group_rank: int,
        group_size: int,
        timeout: timedelta,
    ) -> ProcessGroup:
        assert is_nccl_available()
        pg: ProcessGroup = ProcessGroup(
            prefix_store,
            group_rank,
            group_size,
        )
        from torch.distributed.distributed_c10d import ProcessGroupNCCL

        backend_options = ProcessGroupNCCL.Options()
        backend_options._timeout = timeout

        backend_class = ProcessGroupNCCL(prefix_store, group_rank, group_size,
                                         backend_options)
        backend_type = ProcessGroup.BackendType.NCCL
        device = torch.device("cuda")
        pg._set_default_backend(backend_type)
        backend_class._set_sequence_number_for_group()

        pg._register_backend(device, backend_type, backend_class)
        return pg

420
421
422
423
    @classmethod
    def device_count(cls) -> int:
        return cuda_device_count_stateless()

424

425
426
427
428
429
# NVML utils
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
# all the related functions work on real physical device ids.
# the major benefit of using NVML is that it will not initialize CUDA
class NvmlCudaPlatform(CudaPlatformBase):
430

431
    @classmethod
432
    @cache
433
    @with_nvml_context
434
435
436
437
    def get_device_capability(cls,
                              device_id: int = 0
                              ) -> Optional[DeviceCapability]:
        try:
438
            physical_device_id = cls.device_id_to_physical_device_id(device_id)
439
440
441
442
443
444
445
446
447
448
            handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
            major, minor = pynvml.nvmlDeviceGetCudaComputeCapability(handle)
            return DeviceCapability(major=major, minor=minor)
        except RuntimeError:
            return None

    @classmethod
    @with_nvml_context
    def has_device_capability(
        cls,
449
        capability: Union[tuple[int, int], int],
450
451
452
453
454
455
        device_id: int = 0,
    ) -> bool:
        try:
            return super().has_device_capability(capability, device_id)
        except RuntimeError:
            return False
456

457
    @classmethod
458
    @with_nvml_context
459
    def get_device_name(cls, device_id: int = 0) -> str:
460
        physical_device_id = cls.device_id_to_physical_device_id(device_id)
461
        return cls._get_physical_device_name(physical_device_id)
462

463
464
465
    @classmethod
    @with_nvml_context
    def get_device_uuid(cls, device_id: int = 0) -> str:
466
        physical_device_id = cls.device_id_to_physical_device_id(device_id)
467
468
469
        handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
        return pynvml.nvmlDeviceGetUUID(handle)

470
    @classmethod
471
    @with_nvml_context
472
    def get_device_total_memory(cls, device_id: int = 0) -> int:
473
        physical_device_id = cls.device_id_to_physical_device_id(device_id)
474
475
        handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
        return int(pynvml.nvmlDeviceGetMemoryInfo(handle).total)
476

477
    @classmethod
478
    @with_nvml_context
479
    def is_fully_connected(cls, physical_device_ids: list[int]) -> bool:
480
481
482
483
484
485
486
487
488
489
490
        """
        query if the set of gpus are fully connected by nvlink (1 hop)
        """
        handles = [
            pynvml.nvmlDeviceGetHandleByIndex(i) for i in physical_device_ids
        ]
        for i, handle in enumerate(handles):
            for j, peer_handle in enumerate(handles):
                if i < j:
                    try:
                        p2p_status = pynvml.nvmlDeviceGetP2PStatus(
491
492
493
494
                            handle,
                            peer_handle,
                            pynvml.NVML_P2P_CAPS_INDEX_NVLINK,
                        )
495
496
                        if p2p_status != pynvml.NVML_P2P_STATUS_OK:
                            return False
497
498
                    except pynvml.NVMLError:
                        logger.exception(
499
500
                            "NVLink detection failed. This is normal if"
                            " your machine has no NVLink equipped.")
501
502
                        return False
        return True
503
504

    @classmethod
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
    def _get_physical_device_name(cls, device_id: int = 0) -> str:
        handle = pynvml.nvmlDeviceGetHandleByIndex(device_id)
        return pynvml.nvmlDeviceGetName(handle)

    @classmethod
    @with_nvml_context
    def log_warnings(cls):
        device_ids: int = pynvml.nvmlDeviceGetCount()
        if device_ids > 1:
            device_names = [
                cls._get_physical_device_name(i) for i in range(device_ids)
            ]
            if (len(set(device_names)) > 1
                    and os.environ.get("CUDA_DEVICE_ORDER") != "PCI_BUS_ID"):
                logger.warning(
520
                    "Detected different devices in the system: %s. Please"
521
522
                    " make sure to set `CUDA_DEVICE_ORDER=PCI_BUS_ID` to "
                    "avoid unexpected behavior.",
523
                    ", ".join(device_names),
524
525
526
527
528
529
                )


class NonNvmlCudaPlatform(CudaPlatformBase):

    @classmethod
530
    @cache
531
532
533
534
535
536
537
538
539
540
541
542
543
544
    def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:
        major, minor = torch.cuda.get_device_capability(device_id)
        return DeviceCapability(major=major, minor=minor)

    @classmethod
    def get_device_name(cls, device_id: int = 0) -> str:
        return torch.cuda.get_device_name(device_id)

    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        device_props = torch.cuda.get_device_properties(device_id)
        return device_props.total_memory

    @classmethod
545
    def is_fully_connected(cls, physical_device_ids: list[int]) -> bool:
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
        logger.exception(
            "NVLink detection not possible, as context support was"
            " not found. Assuming no NVLink available.")
        return False


# Autodetect either NVML-enabled or non-NVML platform
# based on whether NVML is available.
nvml_available = False
try:
    try:
        pynvml.nvmlInit()
        nvml_available = True
    except Exception:
        # On Jetson, NVML is not supported.
        nvml_available = False
finally:
    if nvml_available:
        pynvml.nvmlShutdown()

CudaPlatform = NvmlCudaPlatform if nvml_available else NonNvmlCudaPlatform

568
CudaPlatform.log_warnings()