cuda.py 18.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
3
4
5
"""Code inside this file can safely assume cuda platform, e.g. importing
pynvml. However, it should not initialize cuda context.
"""

6
import os
7
from datetime import timedelta
8
from functools import wraps
9
from typing import TYPE_CHECKING, Callable, Optional, TypeVar, Union
10

11
import torch
12
13
from torch.distributed import PrefixStore, ProcessGroup
from torch.distributed.distributed_c10d import is_nccl_available
14
from typing_extensions import ParamSpec
15

16
17
# import custom ops, trigger op registration
import vllm._C  # noqa
18
import vllm.envs as envs
19
from vllm.logger import init_logger
20
from vllm.utils import import_pynvml
21

22
from .interface import DeviceCapability, Platform, PlatformEnum, _Backend
23

24
if TYPE_CHECKING:
25
    from vllm.config import ModelConfig, VllmConfig
26

27
28
logger = init_logger(__name__)

29
30
31
_P = ParamSpec("_P")
_R = TypeVar("_R")

32
pynvml = import_pynvml()
33

34
35
36
37
# pytorch 2.5 uses cudnn sdpa by default, which will cause crash on some models
# see https://github.com/huggingface/diffusers/issues/9704 for details
torch.backends.cuda.enable_cudnn_sdp(False)

38

39
def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]:
40
41

    @wraps(fn)
42
    def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
43
44
45
46
47
48
49
50
51
        pynvml.nvmlInit()
        try:
            return fn(*args, **kwargs)
        finally:
            pynvml.nvmlShutdown()

    return wrapper


52
53
class CudaPlatformBase(Platform):
    _enum = PlatformEnum.CUDA
54
    device_name: str = "cuda"
55
56
    device_type: str = "cuda"
    dispatch_key: str = "CUDA"
57
    ray_device_key: str = "GPU"
58
    device_control_env_var: str = "CUDA_VISIBLE_DEVICES"
59

60
    @property
61
    def supported_dtypes(self) -> list[torch.dtype]:
62
63
64
65
66
67
68
69
70
71
72
        if self.has_device_capability(80):
            # Ampere and Hopper or later NVIDIA GPUs.
            return [torch.bfloat16, torch.float16, torch.float32]
        elif (not self.has_device_capability(80)
              ) and self.has_device_capability(60):
            # Pascal, Volta and Turing NVIDIA GPUs, BF16 is not supported
            return [torch.float16, torch.float32]
        # Kepler and Maxwell NVIDIA GPUs, only FP32 is supported,
        # though vLLM doesn't support these GPUs.
        return [torch.float32]

73
    @classmethod
74
75
76
    def get_device_capability(cls,
                              device_id: int = 0
                              ) -> Optional[DeviceCapability]:
77
        raise NotImplementedError
78

79
80
81
    @classmethod
    def get_device_name(cls, device_id: int = 0) -> str:
        raise NotImplementedError
82

83
84
85
    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        raise NotImplementedError
86

87
88
89
90
91
92
93
94
95
96
    @classmethod
    def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
        if enforce_eager:
            logger.warning(
                "To see benefits of async output processing, enable CUDA "
                "graph. Since, enforce-eager is enabled, async output "
                "processor cannot be used")
            return False
        return True

97
    @classmethod
98
    def is_fully_connected(cls, device_ids: list[int]) -> bool:
99
        raise NotImplementedError
100

101
102
103
    @classmethod
    def log_warnings(cls):
        pass
104

105
    @classmethod
106
    def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
107
108
        parallel_config = vllm_config.parallel_config
        scheduler_config = vllm_config.scheduler_config
109
        model_config = vllm_config.model_config
110

111
112
        if parallel_config.worker_cls == "auto":
            if scheduler_config.is_multi_step:
113
                if envs.VLLM_USE_V1:
114
115
                    raise NotImplementedError(
                        "Multi-step scheduling is not supported (and not "
116
                        "needed) on vLLM V1. Please launch without "
117
                        "--num-scheduler-steps.")
118
119
120
                else:
                    parallel_config.worker_cls = \
                        "vllm.worker.multi_step_worker.MultiStepWorker"
121
            elif vllm_config.speculative_config:
122
                if envs.VLLM_USE_V1:
123
124
                    parallel_config.worker_cls = \
                            "vllm.v1.worker.gpu_worker.Worker"
125
126
127
128
129
                else:
                    parallel_config.worker_cls = \
                        "vllm.spec_decode.spec_decode_worker.create_spec_worker"
                    parallel_config.sd_worker_cls = \
                        "vllm.worker.worker.Worker"
130
            else:
131
132
133
134
135
                if envs.VLLM_USE_V1:
                    parallel_config.worker_cls = \
                            "vllm.v1.worker.gpu_worker.Worker"
                else:
                    parallel_config.worker_cls = "vllm.worker.worker.Worker"
136

137
138
139
        cache_config = vllm_config.cache_config
        if cache_config and cache_config.block_size is None:
            cache_config.block_size = 16
140

141
        # TODO(lucas): handle this more gracefully
142
143
144
145
146
147
148
        # Note: model_config may be None during testing
        if model_config is not None and model_config.use_mla:
            # if `VLLM_ATTENTION_BACKEND` is not set and we are using MLA, then
            # we default to FlashMLA backend, so we need to force the blocksize
            # here
            use_flashmla = (envs.VLLM_ATTENTION_BACKEND is None \
                or envs.VLLM_ATTENTION_BACKEND == "FLASHMLA")
149
            from vllm.attention.ops.flashmla import is_flashmla_supported
150
151
152
153
154
            if use_flashmla and is_flashmla_supported()[0] \
                and cache_config.block_size != 64:
                cache_config.block_size = 64
                logger.info(
                    "Forcing kv cache block size to 64 for FlashMLA backend.")
155

156
157
158
159
160
161
162
    @classmethod
    def get_current_memory_usage(cls,
                                 device: Optional[torch.types.Device] = None
                                 ) -> float:
        torch.cuda.reset_peak_memory_stats(device)
        return torch.cuda.max_memory_allocated(device)

163
164
    @classmethod
    def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
165
166
167
                             kv_cache_dtype, block_size, use_v1,
                             use_mla) -> str:
        if use_mla:
168
169
            # TODO(lucas): refactor to  be more concise
            #  we should probably consider factoring out V1 here
170
171
172
173
174
175
176
177
178
            if selected_backend == _Backend.TRITON_MLA or block_size != 64:
                if use_v1:
                    logger.info_once("Using Triton MLA backend on V1 engine.")
                    return ("vllm.v1.attention.backends.mla."
                            "triton_mla.TritonMLABackend")
                else:
                    logger.info("Using Triton MLA backend.")
                    return "vllm.attention.backends.triton_mla.TritonMLABackend"
            else:
179
180
181
182
183
184
185
186
187
188
189
190
                from vllm.attention.backends.flashmla import (
                    is_flashmla_supported)
                if not is_flashmla_supported()[0]:
                    logger.warning(
                        "FlashMLA backend is not supported due to %s",
                        is_flashmla_supported()[1])
                elif block_size != 64:
                    logger.warning(
                        "FlashMLA backend is not supported for block size %d"
                        " (currently only supports block size 64).",
                        block_size)
                else:
191
                    if use_v1:
192
193
                        logger.info_once(
                            "Using FlashMLA backend on V1 engine.")
194
195
196
197
198
199
200
                        return ("vllm.v1.attention.backends.mla."
                                "flashmla.FlashMLABackend")
                    else:
                        logger.info("Using FlashMLA backend.")
                        return ("vllm.attention.backends."
                                "flashmla.FlashMLABackend")
        if use_v1:
201
202
203
            if selected_backend == _Backend.FLASHINFER:
                logger.info_once("Using FlashInfer backend on V1 engine.")
                return "vllm.v1.attention.backends.flashinfer.FlashInferBackend"
204
205
206
207
208
209
210
211
            if selected_backend == _Backend.TRITON_ATTN_VLLM_V1:
                logger.info_once("Using Triton backend on V1 engine.")
                return ("vllm.v1.attention.backends."
                        "triton_attn.TritonAttentionBackend")
            if cls.has_device_capability(80):
                logger.info_once("Using Flash Attention backend on V1 engine.")
                return ("vllm.v1.attention.backends."
                        "flash_attn.FlashAttentionBackend")
212
213
214
215
216
217
        if selected_backend == _Backend.FLASHINFER:
            logger.info("Using FlashInfer backend.")
            return "vllm.attention.backends.flashinfer.FlashInferBackend"
        elif selected_backend == _Backend.XFORMERS:
            logger.info("Using XFormers backend.")
            return "vllm.attention.backends.xformers.XFormersBackend"
218
219
220
221
        elif selected_backend == _Backend.DUAL_CHUNK_FLASH_ATTN:
            logger.info("Using DualChunkFlashAttention backend.")
            return ("vllm.attention.backends.dual_chunk_flash_attn."
                    "DualChunkFlashAttentionBackend")
222
223
224
225
        elif selected_backend == _Backend.FLASH_ATTN:
            pass
        elif selected_backend:
            raise ValueError(
226
227
                f"Invalid attention backend for {cls.device_name}, "
                f"with use_v1: {use_v1} use_mla: {use_mla}")
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252

        target_backend = _Backend.FLASH_ATTN
        if not cls.has_device_capability(80):
            # Volta and Turing NVIDIA GPUs.
            logger.info(
                "Cannot use FlashAttention-2 backend for Volta and Turing "
                "GPUs.")
            target_backend = _Backend.XFORMERS
        elif dtype not in (torch.float16, torch.bfloat16):
            logger.info(
                "Cannot use FlashAttention-2 backend for dtype other than "
                "torch.float16 or torch.bfloat16.")
            target_backend = _Backend.XFORMERS
        elif block_size % 16 != 0:
            logger.info(
                "Cannot use FlashAttention-2 backend for block size not "
                "divisible by 16.")
            target_backend = _Backend.XFORMERS

        # FlashAttn is valid for the model, checking if the package is
        # installed.
        if target_backend == _Backend.FLASH_ATTN:
            try:
                import vllm.vllm_flash_attn  # noqa: F401
                from vllm.attention.backends.flash_attn import (  # noqa: F401
253
                    FlashAttentionBackend, flash_attn_supports_fp8)
254
255
256
257
258
259
260
261

                supported_sizes = \
                    FlashAttentionBackend.get_supported_head_sizes()
                if head_size not in supported_sizes:
                    logger.info(
                        "Cannot use FlashAttention-2 backend for head size %d.",
                        head_size)
                    target_backend = _Backend.XFORMERS
262
263
                fp8_kv_cache = (kv_cache_dtype is not None
                                and kv_cache_dtype.startswith("fp8"))
264
                if (fp8_kv_cache and not flash_attn_supports_fp8()):
265
                    logger.info(
266
                        "Cannot use FlashAttention backend for FP8 KV cache.")
267
268
269
270
271
                    logger.warning(
                        "Please use FlashInfer backend with FP8 KV Cache for "
                        "better performance by setting environment variable "
                        "VLLM_ATTENTION_BACKEND=FLASHINFER")
                    target_backend = _Backend.XFORMERS
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
            except ImportError:
                logger.info(
                    "Cannot use FlashAttention-2 backend because the "
                    "vllm.vllm_flash_attn package is not found. "
                    "Make sure that vllm_flash_attn was built and installed "
                    "(on by default).")
                target_backend = _Backend.XFORMERS

        if target_backend == _Backend.XFORMERS:
            logger.info("Using XFormers backend.")
            return "vllm.attention.backends.xformers.XFormersBackend"

        logger.info("Using Flash Attention backend.")
        return "vllm.attention.backends.flash_attn.FlashAttentionBackend"

287
288
289
290
    @classmethod
    def get_punica_wrapper(cls) -> str:
        return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU"

291
292
293
294
    @classmethod
    def get_device_communicator_cls(cls) -> str:
        return "vllm.distributed.device_communicators.cuda_communicator.CudaCommunicator"  # noqa

295
296
297
298
    @classmethod
    def supports_fp8(cls) -> bool:
        return cls.has_device_capability(89)

299
    @classmethod
300
    def supports_v1(cls, model_config: "ModelConfig") -> bool:
301
302
        return True

303
304
305
306
    @classmethod
    def use_custom_allreduce(cls) -> bool:
        return True

307
308
309
310
    @classmethod
    def get_piecewise_backend_cls(cls) -> str:
        return "vllm.compilation.cuda_piecewise_backend.CUDAPiecewiseBackend"  # noqa

311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
    @classmethod
    def stateless_init_device_torch_dist_pg(
        cls,
        backend: str,
        prefix_store: PrefixStore,
        group_rank: int,
        group_size: int,
        timeout: timedelta,
    ) -> ProcessGroup:
        assert is_nccl_available()
        pg: ProcessGroup = ProcessGroup(
            prefix_store,
            group_rank,
            group_size,
        )
        from torch.distributed.distributed_c10d import ProcessGroupNCCL

        backend_options = ProcessGroupNCCL.Options()
        backend_options._timeout = timeout

        backend_class = ProcessGroupNCCL(prefix_store, group_rank, group_size,
                                         backend_options)
        backend_type = ProcessGroup.BackendType.NCCL
        device = torch.device("cuda")
        pg._set_default_backend(backend_type)
        backend_class._set_sequence_number_for_group()

        pg._register_backend(device, backend_type, backend_class)
        return pg

341

342
343
344
345
346
# NVML utils
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
# all the related functions work on real physical device ids.
# the major benefit of using NVML is that it will not initialize CUDA
class NvmlCudaPlatform(CudaPlatformBase):
347

348
    @classmethod
349
    @with_nvml_context
350
351
352
353
    def get_device_capability(cls,
                              device_id: int = 0
                              ) -> Optional[DeviceCapability]:
        try:
354
            physical_device_id = cls.device_id_to_physical_device_id(device_id)
355
356
357
358
359
360
361
362
363
364
            handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
            major, minor = pynvml.nvmlDeviceGetCudaComputeCapability(handle)
            return DeviceCapability(major=major, minor=minor)
        except RuntimeError:
            return None

    @classmethod
    @with_nvml_context
    def has_device_capability(
        cls,
365
        capability: Union[tuple[int, int], int],
366
367
368
369
370
371
        device_id: int = 0,
    ) -> bool:
        try:
            return super().has_device_capability(capability, device_id)
        except RuntimeError:
            return False
372

373
    @classmethod
374
    @with_nvml_context
375
    def get_device_name(cls, device_id: int = 0) -> str:
376
        physical_device_id = cls.device_id_to_physical_device_id(device_id)
377
        return cls._get_physical_device_name(physical_device_id)
378

379
380
381
    @classmethod
    @with_nvml_context
    def get_device_uuid(cls, device_id: int = 0) -> str:
382
        physical_device_id = cls.device_id_to_physical_device_id(device_id)
383
384
385
        handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
        return pynvml.nvmlDeviceGetUUID(handle)

386
    @classmethod
387
    @with_nvml_context
388
    def get_device_total_memory(cls, device_id: int = 0) -> int:
389
        physical_device_id = cls.device_id_to_physical_device_id(device_id)
390
391
        handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
        return int(pynvml.nvmlDeviceGetMemoryInfo(handle).total)
392

393
    @classmethod
394
    @with_nvml_context
395
    def is_fully_connected(cls, physical_device_ids: list[int]) -> bool:
396
397
398
399
400
401
402
403
404
405
406
        """
        query if the set of gpus are fully connected by nvlink (1 hop)
        """
        handles = [
            pynvml.nvmlDeviceGetHandleByIndex(i) for i in physical_device_ids
        ]
        for i, handle in enumerate(handles):
            for j, peer_handle in enumerate(handles):
                if i < j:
                    try:
                        p2p_status = pynvml.nvmlDeviceGetP2PStatus(
407
408
409
410
                            handle,
                            peer_handle,
                            pynvml.NVML_P2P_CAPS_INDEX_NVLINK,
                        )
411
412
                        if p2p_status != pynvml.NVML_P2P_STATUS_OK:
                            return False
413
414
                    except pynvml.NVMLError:
                        logger.exception(
415
416
                            "NVLink detection failed. This is normal if"
                            " your machine has no NVLink equipped.")
417
418
                        return False
        return True
419
420

    @classmethod
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
    def _get_physical_device_name(cls, device_id: int = 0) -> str:
        handle = pynvml.nvmlDeviceGetHandleByIndex(device_id)
        return pynvml.nvmlDeviceGetName(handle)

    @classmethod
    @with_nvml_context
    def log_warnings(cls):
        device_ids: int = pynvml.nvmlDeviceGetCount()
        if device_ids > 1:
            device_names = [
                cls._get_physical_device_name(i) for i in range(device_ids)
            ]
            if (len(set(device_names)) > 1
                    and os.environ.get("CUDA_DEVICE_ORDER") != "PCI_BUS_ID"):
                logger.warning(
436
                    "Detected different devices in the system: %s. Please"
437
438
                    " make sure to set `CUDA_DEVICE_ORDER=PCI_BUS_ID` to "
                    "avoid unexpected behavior.",
439
                    ", ".join(device_names),
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
                )


class NonNvmlCudaPlatform(CudaPlatformBase):

    @classmethod
    def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:
        major, minor = torch.cuda.get_device_capability(device_id)
        return DeviceCapability(major=major, minor=minor)

    @classmethod
    def get_device_name(cls, device_id: int = 0) -> str:
        return torch.cuda.get_device_name(device_id)

    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        device_props = torch.cuda.get_device_properties(device_id)
        return device_props.total_memory

    @classmethod
460
    def is_fully_connected(cls, physical_device_ids: list[int]) -> bool:
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
        logger.exception(
            "NVLink detection not possible, as context support was"
            " not found. Assuming no NVLink available.")
        return False


# Autodetect either NVML-enabled or non-NVML platform
# based on whether NVML is available.
nvml_available = False
try:
    try:
        pynvml.nvmlInit()
        nvml_available = True
    except Exception:
        # On Jetson, NVML is not supported.
        nvml_available = False
finally:
    if nvml_available:
        pynvml.nvmlShutdown()

CudaPlatform = NvmlCudaPlatform if nvml_available else NonNvmlCudaPlatform

483
CudaPlatform.log_warnings()