cuda.py 17.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
3
4
5
"""Code inside this file can safely assume cuda platform, e.g. importing
pynvml. However, it should not initialize cuda context.
"""

6
import os
7
from functools import wraps
8
9
from typing import (TYPE_CHECKING, Callable, List, Optional, Tuple, TypeVar,
                    Union)
10

11
import torch
12
from typing_extensions import ParamSpec
13

14
15
# import custom ops, trigger op registration
import vllm._C  # noqa
16
import vllm.envs as envs
17
from vllm.fa_utils import get_flash_attn_version
18
from vllm.logger import init_logger
19
from vllm.utils import import_pynvml
20

21
from .interface import DeviceCapability, Platform, PlatformEnum, _Backend
22

23
24
25
26
27
if TYPE_CHECKING:
    from vllm.config import VllmConfig
else:
    VllmConfig = None

28
29
logger = init_logger(__name__)

30
31
32
_P = ParamSpec("_P")
_R = TypeVar("_R")

33
pynvml = import_pynvml()
34

35
36
37
38
# pytorch 2.5 uses cudnn sdpa by default, which will cause crash on some models
# see https://github.com/huggingface/diffusers/issues/9704 for details
torch.backends.cuda.enable_cudnn_sdp(False)

39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55

def device_id_to_physical_device_id(device_id: int) -> int:
    if "CUDA_VISIBLE_DEVICES" in os.environ:
        device_ids = os.environ["CUDA_VISIBLE_DEVICES"].split(",")
        if device_ids == [""]:
            msg = (
                "CUDA_VISIBLE_DEVICES is set to empty string, which means"
                " GPU support is disabled. If you are using ray, please unset"
                " the environment variable `CUDA_VISIBLE_DEVICES` inside the"
                " worker/actor. "
                "Check https://github.com/vllm-project/vllm/issues/8402 for"
                " more information.")
            raise RuntimeError(msg)
        physical_device_id = device_ids[device_id]
        return int(physical_device_id)
    else:
        return device_id
56

57

58
def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]:
59
60

    @wraps(fn)
61
    def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
62
63
64
65
66
67
68
69
70
        pynvml.nvmlInit()
        try:
            return fn(*args, **kwargs)
        finally:
            pynvml.nvmlShutdown()

    return wrapper


71
72
class CudaPlatformBase(Platform):
    _enum = PlatformEnum.CUDA
73
    device_name: str = "cuda"
74
75
    device_type: str = "cuda"
    dispatch_key: str = "CUDA"
76
    ray_device_key: str = "GPU"
77
    device_control_env_var: str = "CUDA_VISIBLE_DEVICES"
78

79
    @classmethod
80
81
82
    def get_device_capability(cls,
                              device_id: int = 0
                              ) -> Optional[DeviceCapability]:
83
        raise NotImplementedError
84

85
86
87
    @classmethod
    def get_device_name(cls, device_id: int = 0) -> str:
        raise NotImplementedError
88

89
90
91
    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        raise NotImplementedError
92

93
94
95
96
97
98
99
100
101
102
    @classmethod
    def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
        if enforce_eager:
            logger.warning(
                "To see benefits of async output processing, enable CUDA "
                "graph. Since, enforce-eager is enabled, async output "
                "processor cannot be used")
            return False
        return True

103
104
105
    @classmethod
    def is_full_nvlink(cls, device_ids: List[int]) -> bool:
        raise NotImplementedError
106

107
108
109
    @classmethod
    def log_warnings(cls):
        pass
110

111
112
113
114
    @classmethod
    def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
        parallel_config = vllm_config.parallel_config
        scheduler_config = vllm_config.scheduler_config
115
        compilation_config = vllm_config.compilation_config
116
        model_config = vllm_config.model_config
117

118
119
        if parallel_config.worker_cls == "auto":
            if scheduler_config.is_multi_step:
120
                if envs.VLLM_USE_V1:
121
122
                    raise NotImplementedError(
                        "Multi-step scheduling is not supported (and not "
123
                        "needed) on vLLM V1. Please launch without "
124
                        "--num-scheduler-steps.")
125
126
127
                else:
                    parallel_config.worker_cls = \
                        "vllm.worker.multi_step_worker.MultiStepWorker"
128
            elif vllm_config.speculative_config:
129
                if envs.VLLM_USE_V1:
130
131
                    parallel_config.worker_cls = \
                            "vllm.v1.worker.gpu_worker.Worker"
132
133
134
135
136
                else:
                    parallel_config.worker_cls = \
                        "vllm.spec_decode.spec_decode_worker.create_spec_worker"
                    parallel_config.sd_worker_cls = \
                        "vllm.worker.worker.Worker"
137
            else:
138
139
140
141
142
                if envs.VLLM_USE_V1:
                    parallel_config.worker_cls = \
                            "vllm.v1.worker.gpu_worker.Worker"
                else:
                    parallel_config.worker_cls = "vllm.worker.worker.Worker"
143

144
145
146
        cache_config = vllm_config.cache_config
        if cache_config and cache_config.block_size is None:
            cache_config.block_size = 16
147

148
        # TODO(lucas): handle this more gracefully
149
150
151
152
153
154
155
        # Note: model_config may be None during testing
        if model_config is not None and model_config.use_mla:
            # if `VLLM_ATTENTION_BACKEND` is not set and we are using MLA, then
            # we default to FlashMLA backend, so we need to force the blocksize
            # here
            use_flashmla = (envs.VLLM_ATTENTION_BACKEND is None \
                or envs.VLLM_ATTENTION_BACKEND == "FLASHMLA")
156
            from vllm.attention.ops.flashmla import is_flashmla_supported
157
158
159
160
161
            if use_flashmla and is_flashmla_supported()[0] \
                and cache_config.block_size != 64:
                cache_config.block_size = 64
                logger.info(
                    "Forcing kv cache block size to 64 for FlashMLA backend.")
162

163
164
165
166
167
168
169
170
        if (parallel_config.data_parallel_size > 1
                and compilation_config.use_cudagraph):
            logger.info(
                "Data Parallel: Forcing enforce eager to be True since DP is "
                "currently not supported with CUDA Graphs.")
            vllm_config.model_config.enforce_eager = True
            compilation_config.use_cudagraph = False

171
172
173
174
175
176
177
    @classmethod
    def get_current_memory_usage(cls,
                                 device: Optional[torch.types.Device] = None
                                 ) -> float:
        torch.cuda.reset_peak_memory_stats(device)
        return torch.cuda.max_memory_allocated(device)

178
179
    @classmethod
    def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
180
181
182
                             kv_cache_dtype, block_size, use_v1,
                             use_mla) -> str:
        if use_mla:
183
184
            # TODO(lucas): refactor to  be more concise
            #  we should probably consider factoring out V1 here
185
186
187
188
189
190
191
192
193
            if selected_backend == _Backend.TRITON_MLA or block_size != 64:
                if use_v1:
                    logger.info_once("Using Triton MLA backend on V1 engine.")
                    return ("vllm.v1.attention.backends.mla."
                            "triton_mla.TritonMLABackend")
                else:
                    logger.info("Using Triton MLA backend.")
                    return "vllm.attention.backends.triton_mla.TritonMLABackend"
            else:
194
195
196
197
198
199
200
201
202
203
204
205
                from vllm.attention.backends.flashmla import (
                    is_flashmla_supported)
                if not is_flashmla_supported()[0]:
                    logger.warning(
                        "FlashMLA backend is not supported due to %s",
                        is_flashmla_supported()[1])
                elif block_size != 64:
                    logger.warning(
                        "FlashMLA backend is not supported for block size %d"
                        " (currently only supports block size 64).",
                        block_size)
                else:
206
                    if use_v1:
207
208
                        logger.info_once(
                            "Using FlashMLA backend on V1 engine.")
209
210
211
212
213
214
215
                        return ("vllm.v1.attention.backends.mla."
                                "flashmla.FlashMLABackend")
                    else:
                        logger.info("Using FlashMLA backend.")
                        return ("vllm.attention.backends."
                                "flashmla.FlashMLABackend")
        if use_v1:
216
217
218
219
220
221
222
223
            if selected_backend == _Backend.TRITON_ATTN_VLLM_V1:
                logger.info_once("Using Triton backend on V1 engine.")
                return ("vllm.v1.attention.backends."
                        "triton_attn.TritonAttentionBackend")
            if cls.has_device_capability(80):
                logger.info_once("Using Flash Attention backend on V1 engine.")
                return ("vllm.v1.attention.backends."
                        "flash_attn.FlashAttentionBackend")
224
225
226
227
228
229
230
231
232
233
        if selected_backend == _Backend.FLASHINFER:
            logger.info("Using FlashInfer backend.")
            return "vllm.attention.backends.flashinfer.FlashInferBackend"
        elif selected_backend == _Backend.XFORMERS:
            logger.info("Using XFormers backend.")
            return "vllm.attention.backends.xformers.XFormersBackend"
        elif selected_backend == _Backend.FLASH_ATTN:
            pass
        elif selected_backend:
            raise ValueError(
234
235
                f"Invalid attention backend for {cls.device_name}, "
                f"with use_v1: {use_v1} use_mla: {use_mla}")
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269

        target_backend = _Backend.FLASH_ATTN
        if not cls.has_device_capability(80):
            # Volta and Turing NVIDIA GPUs.
            logger.info(
                "Cannot use FlashAttention-2 backend for Volta and Turing "
                "GPUs.")
            target_backend = _Backend.XFORMERS
        elif dtype not in (torch.float16, torch.bfloat16):
            logger.info(
                "Cannot use FlashAttention-2 backend for dtype other than "
                "torch.float16 or torch.bfloat16.")
            target_backend = _Backend.XFORMERS
        elif block_size % 16 != 0:
            logger.info(
                "Cannot use FlashAttention-2 backend for block size not "
                "divisible by 16.")
            target_backend = _Backend.XFORMERS

        # FlashAttn is valid for the model, checking if the package is
        # installed.
        if target_backend == _Backend.FLASH_ATTN:
            try:
                import vllm.vllm_flash_attn  # noqa: F401
                from vllm.attention.backends.flash_attn import (  # noqa: F401
                    FlashAttentionBackend)

                supported_sizes = \
                    FlashAttentionBackend.get_supported_head_sizes()
                if head_size not in supported_sizes:
                    logger.info(
                        "Cannot use FlashAttention-2 backend for head size %d.",
                        head_size)
                    target_backend = _Backend.XFORMERS
270
271
272
273
274
275
276
277
278
279
280
                fp8_kv_cache = (kv_cache_dtype is not None
                                and kv_cache_dtype.startswith("fp8"))
                if (fp8_kv_cache and get_flash_attn_version() != 3):
                    logger.info(
                        "Cannot use FlashAttention-2 backend for FP8 KV cache."
                    )
                    logger.warning(
                        "Please use FlashInfer backend with FP8 KV Cache for "
                        "better performance by setting environment variable "
                        "VLLM_ATTENTION_BACKEND=FLASHINFER")
                    target_backend = _Backend.XFORMERS
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
            except ImportError:
                logger.info(
                    "Cannot use FlashAttention-2 backend because the "
                    "vllm.vllm_flash_attn package is not found. "
                    "Make sure that vllm_flash_attn was built and installed "
                    "(on by default).")
                target_backend = _Backend.XFORMERS

        if target_backend == _Backend.XFORMERS:
            logger.info("Using XFormers backend.")
            return "vllm.attention.backends.xformers.XFormersBackend"

        logger.info("Using Flash Attention backend.")
        return "vllm.attention.backends.flash_attn.FlashAttentionBackend"

296
297
298
299
    @classmethod
    def get_punica_wrapper(cls) -> str:
        return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU"

300
301
302
303
    @classmethod
    def get_device_communicator_cls(cls) -> str:
        return "vllm.distributed.device_communicators.cuda_communicator.CudaCommunicator"  # noqa

304
305
306
307
    @classmethod
    def supports_fp8(cls) -> bool:
        return cls.has_device_capability(89)

308

309
310
311
312
313
# NVML utils
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
# all the related functions work on real physical device ids.
# the major benefit of using NVML is that it will not initialize CUDA
class NvmlCudaPlatform(CudaPlatformBase):
314

315
    @classmethod
316
    @with_nvml_context
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
    def get_device_capability(cls,
                              device_id: int = 0
                              ) -> Optional[DeviceCapability]:
        try:
            physical_device_id = device_id_to_physical_device_id(device_id)
            handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
            major, minor = pynvml.nvmlDeviceGetCudaComputeCapability(handle)
            return DeviceCapability(major=major, minor=minor)
        except RuntimeError:
            return None

    @classmethod
    @with_nvml_context
    def has_device_capability(
        cls,
        capability: Union[Tuple[int, int], int],
        device_id: int = 0,
    ) -> bool:
        try:
            return super().has_device_capability(capability, device_id)
        except RuntimeError:
            return False
339

340
    @classmethod
341
    @with_nvml_context
342
    def get_device_name(cls, device_id: int = 0) -> str:
343
        physical_device_id = device_id_to_physical_device_id(device_id)
344
        return cls._get_physical_device_name(physical_device_id)
345

346
347
348
349
350
351
352
    @classmethod
    @with_nvml_context
    def get_device_uuid(cls, device_id: int = 0) -> str:
        physical_device_id = device_id_to_physical_device_id(device_id)
        handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
        return pynvml.nvmlDeviceGetUUID(handle)

353
    @classmethod
354
    @with_nvml_context
355
356
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        physical_device_id = device_id_to_physical_device_id(device_id)
357
358
        handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
        return int(pynvml.nvmlDeviceGetMemoryInfo(handle).total)
359

360
    @classmethod
361
    @with_nvml_context
362
    def is_full_nvlink(cls, physical_device_ids: List[int]) -> bool:
363
364
365
366
367
368
369
370
371
372
373
        """
        query if the set of gpus are fully connected by nvlink (1 hop)
        """
        handles = [
            pynvml.nvmlDeviceGetHandleByIndex(i) for i in physical_device_ids
        ]
        for i, handle in enumerate(handles):
            for j, peer_handle in enumerate(handles):
                if i < j:
                    try:
                        p2p_status = pynvml.nvmlDeviceGetP2PStatus(
374
375
376
377
                            handle,
                            peer_handle,
                            pynvml.NVML_P2P_CAPS_INDEX_NVLINK,
                        )
378
379
                        if p2p_status != pynvml.NVML_P2P_STATUS_OK:
                            return False
380
381
                    except pynvml.NVMLError:
                        logger.exception(
382
383
                            "NVLink detection failed. This is normal if"
                            " your machine has no NVLink equipped.")
384
385
                        return False
        return True
386
387

    @classmethod
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
    def _get_physical_device_name(cls, device_id: int = 0) -> str:
        handle = pynvml.nvmlDeviceGetHandleByIndex(device_id)
        return pynvml.nvmlDeviceGetName(handle)

    @classmethod
    @with_nvml_context
    def log_warnings(cls):
        device_ids: int = pynvml.nvmlDeviceGetCount()
        if device_ids > 1:
            device_names = [
                cls._get_physical_device_name(i) for i in range(device_ids)
            ]
            if (len(set(device_names)) > 1
                    and os.environ.get("CUDA_DEVICE_ORDER") != "PCI_BUS_ID"):
                logger.warning(
403
                    "Detected different devices in the system: %s. Please"
404
405
                    " make sure to set `CUDA_DEVICE_ORDER=PCI_BUS_ID` to "
                    "avoid unexpected behavior.",
406
                    ", ".join(device_names),
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
                )


class NonNvmlCudaPlatform(CudaPlatformBase):

    @classmethod
    def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:
        major, minor = torch.cuda.get_device_capability(device_id)
        return DeviceCapability(major=major, minor=minor)

    @classmethod
    def get_device_name(cls, device_id: int = 0) -> str:
        return torch.cuda.get_device_name(device_id)

    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        device_props = torch.cuda.get_device_properties(device_id)
        return device_props.total_memory

    @classmethod
    def is_full_nvlink(cls, physical_device_ids: List[int]) -> bool:
        logger.exception(
            "NVLink detection not possible, as context support was"
            " not found. Assuming no NVLink available.")
        return False


# Autodetect either NVML-enabled or non-NVML platform
# based on whether NVML is available.
nvml_available = False
try:
    try:
        pynvml.nvmlInit()
        nvml_available = True
    except Exception:
        # On Jetson, NVML is not supported.
        nvml_available = False
finally:
    if nvml_available:
        pynvml.nvmlShutdown()

CudaPlatform = NvmlCudaPlatform if nvml_available else NonNvmlCudaPlatform

try:
    from sphinx.ext.autodoc.mock import _MockModule

    if not isinstance(pynvml, _MockModule):
        CudaPlatform.log_warnings()
except ModuleNotFoundError:
456
    CudaPlatform.log_warnings()