"requirements/common.txt" did not exist on "a6221a144af772fd1a68fe7e627935dc53e81738"
cuda.py 18.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
3
4
5
"""Code inside this file can safely assume cuda platform, e.g. importing
pynvml. However, it should not initialize cuda context.
"""

6
import os
7
from functools import wraps
8
9
from typing import (TYPE_CHECKING, Callable, List, Optional, Tuple, TypeVar,
                    Union)
10

11
import torch
12
from typing_extensions import ParamSpec
13

14
15
# import custom ops, trigger op registration
import vllm._C  # noqa
16
import vllm.envs as envs
17
from vllm.logger import init_logger
18
from vllm.utils import import_pynvml
19

20
from .interface import DeviceCapability, Platform, PlatformEnum, _Backend
21

22
if TYPE_CHECKING:
23
    from vllm.config import ModelConfig, VllmConfig
24

25
26
logger = init_logger(__name__)

27
28
29
_P = ParamSpec("_P")
_R = TypeVar("_R")

30
pynvml = import_pynvml()
31

32
33
34
35
# pytorch 2.5 uses cudnn sdpa by default, which will cause crash on some models
# see https://github.com/huggingface/diffusers/issues/9704 for details
torch.backends.cuda.enable_cudnn_sdp(False)

36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52

def device_id_to_physical_device_id(device_id: int) -> int:
    if "CUDA_VISIBLE_DEVICES" in os.environ:
        device_ids = os.environ["CUDA_VISIBLE_DEVICES"].split(",")
        if device_ids == [""]:
            msg = (
                "CUDA_VISIBLE_DEVICES is set to empty string, which means"
                " GPU support is disabled. If you are using ray, please unset"
                " the environment variable `CUDA_VISIBLE_DEVICES` inside the"
                " worker/actor. "
                "Check https://github.com/vllm-project/vllm/issues/8402 for"
                " more information.")
            raise RuntimeError(msg)
        physical_device_id = device_ids[device_id]
        return int(physical_device_id)
    else:
        return device_id
53

54

55
def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]:
56
57

    @wraps(fn)
58
    def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
59
60
61
62
63
64
65
66
67
        pynvml.nvmlInit()
        try:
            return fn(*args, **kwargs)
        finally:
            pynvml.nvmlShutdown()

    return wrapper


68
69
class CudaPlatformBase(Platform):
    _enum = PlatformEnum.CUDA
70
    device_name: str = "cuda"
71
72
    device_type: str = "cuda"
    dispatch_key: str = "CUDA"
73
    ray_device_key: str = "GPU"
74
    device_control_env_var: str = "CUDA_VISIBLE_DEVICES"
75

76
77
78
79
80
81
82
83
84
85
86
87
88
    @property
    def supported_dtypes(self) -> List[torch.dtype]:
        if self.has_device_capability(80):
            # Ampere and Hopper or later NVIDIA GPUs.
            return [torch.bfloat16, torch.float16, torch.float32]
        elif (not self.has_device_capability(80)
              ) and self.has_device_capability(60):
            # Pascal, Volta and Turing NVIDIA GPUs, BF16 is not supported
            return [torch.float16, torch.float32]
        # Kepler and Maxwell NVIDIA GPUs, only FP32 is supported,
        # though vLLM doesn't support these GPUs.
        return [torch.float32]

89
    @classmethod
90
91
92
    def get_device_capability(cls,
                              device_id: int = 0
                              ) -> Optional[DeviceCapability]:
93
        raise NotImplementedError
94

95
96
97
    @classmethod
    def get_device_name(cls, device_id: int = 0) -> str:
        raise NotImplementedError
98

99
100
101
    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        raise NotImplementedError
102

103
104
105
106
107
108
109
110
111
112
    @classmethod
    def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
        if enforce_eager:
            logger.warning(
                "To see benefits of async output processing, enable CUDA "
                "graph. Since, enforce-eager is enabled, async output "
                "processor cannot be used")
            return False
        return True

113
    @classmethod
114
    def is_fully_connected(cls, device_ids: List[int]) -> bool:
115
        raise NotImplementedError
116

117
118
119
    @classmethod
    def log_warnings(cls):
        pass
120

121
    @classmethod
122
    def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
123
124
        parallel_config = vllm_config.parallel_config
        scheduler_config = vllm_config.scheduler_config
125
        compilation_config = vllm_config.compilation_config
126
        model_config = vllm_config.model_config
127

128
129
        if parallel_config.worker_cls == "auto":
            if scheduler_config.is_multi_step:
130
                if envs.VLLM_USE_V1:
131
132
                    raise NotImplementedError(
                        "Multi-step scheduling is not supported (and not "
133
                        "needed) on vLLM V1. Please launch without "
134
                        "--num-scheduler-steps.")
135
136
137
                else:
                    parallel_config.worker_cls = \
                        "vllm.worker.multi_step_worker.MultiStepWorker"
138
            elif vllm_config.speculative_config:
139
                if envs.VLLM_USE_V1:
140
141
                    parallel_config.worker_cls = \
                            "vllm.v1.worker.gpu_worker.Worker"
142
143
144
145
146
                else:
                    parallel_config.worker_cls = \
                        "vllm.spec_decode.spec_decode_worker.create_spec_worker"
                    parallel_config.sd_worker_cls = \
                        "vllm.worker.worker.Worker"
147
            else:
148
149
150
151
152
                if envs.VLLM_USE_V1:
                    parallel_config.worker_cls = \
                            "vllm.v1.worker.gpu_worker.Worker"
                else:
                    parallel_config.worker_cls = "vllm.worker.worker.Worker"
153

154
155
156
        cache_config = vllm_config.cache_config
        if cache_config and cache_config.block_size is None:
            cache_config.block_size = 16
157

158
        # TODO(lucas): handle this more gracefully
159
160
161
162
163
164
165
        # Note: model_config may be None during testing
        if model_config is not None and model_config.use_mla:
            # if `VLLM_ATTENTION_BACKEND` is not set and we are using MLA, then
            # we default to FlashMLA backend, so we need to force the blocksize
            # here
            use_flashmla = (envs.VLLM_ATTENTION_BACKEND is None \
                or envs.VLLM_ATTENTION_BACKEND == "FLASHMLA")
166
            from vllm.attention.ops.flashmla import is_flashmla_supported
167
168
169
170
171
            if use_flashmla and is_flashmla_supported()[0] \
                and cache_config.block_size != 64:
                cache_config.block_size = 64
                logger.info(
                    "Forcing kv cache block size to 64 for FlashMLA backend.")
172

173
174
175
176
177
178
179
180
        if (parallel_config.data_parallel_size > 1
                and compilation_config.use_cudagraph):
            logger.info(
                "Data Parallel: Forcing enforce eager to be True since DP is "
                "currently not supported with CUDA Graphs.")
            vllm_config.model_config.enforce_eager = True
            compilation_config.use_cudagraph = False

181
182
183
184
185
186
187
    @classmethod
    def get_current_memory_usage(cls,
                                 device: Optional[torch.types.Device] = None
                                 ) -> float:
        torch.cuda.reset_peak_memory_stats(device)
        return torch.cuda.max_memory_allocated(device)

188
189
    @classmethod
    def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
190
191
192
                             kv_cache_dtype, block_size, use_v1,
                             use_mla) -> str:
        if use_mla:
193
194
            # TODO(lucas): refactor to  be more concise
            #  we should probably consider factoring out V1 here
195
196
197
198
199
200
201
202
203
            if selected_backend == _Backend.TRITON_MLA or block_size != 64:
                if use_v1:
                    logger.info_once("Using Triton MLA backend on V1 engine.")
                    return ("vllm.v1.attention.backends.mla."
                            "triton_mla.TritonMLABackend")
                else:
                    logger.info("Using Triton MLA backend.")
                    return "vllm.attention.backends.triton_mla.TritonMLABackend"
            else:
204
205
206
207
208
209
210
211
212
213
214
215
                from vllm.attention.backends.flashmla import (
                    is_flashmla_supported)
                if not is_flashmla_supported()[0]:
                    logger.warning(
                        "FlashMLA backend is not supported due to %s",
                        is_flashmla_supported()[1])
                elif block_size != 64:
                    logger.warning(
                        "FlashMLA backend is not supported for block size %d"
                        " (currently only supports block size 64).",
                        block_size)
                else:
216
                    if use_v1:
217
218
                        logger.info_once(
                            "Using FlashMLA backend on V1 engine.")
219
220
221
222
223
224
225
                        return ("vllm.v1.attention.backends.mla."
                                "flashmla.FlashMLABackend")
                    else:
                        logger.info("Using FlashMLA backend.")
                        return ("vllm.attention.backends."
                                "flashmla.FlashMLABackend")
        if use_v1:
226
227
228
            if selected_backend == _Backend.FLASHINFER:
                logger.info_once("Using FlashInfer backend on V1 engine.")
                return "vllm.v1.attention.backends.flashinfer.FlashInferBackend"
229
230
231
232
233
234
235
236
            if selected_backend == _Backend.TRITON_ATTN_VLLM_V1:
                logger.info_once("Using Triton backend on V1 engine.")
                return ("vllm.v1.attention.backends."
                        "triton_attn.TritonAttentionBackend")
            if cls.has_device_capability(80):
                logger.info_once("Using Flash Attention backend on V1 engine.")
                return ("vllm.v1.attention.backends."
                        "flash_attn.FlashAttentionBackend")
237
238
239
240
241
242
243
244
245
246
        if selected_backend == _Backend.FLASHINFER:
            logger.info("Using FlashInfer backend.")
            return "vllm.attention.backends.flashinfer.FlashInferBackend"
        elif selected_backend == _Backend.XFORMERS:
            logger.info("Using XFormers backend.")
            return "vllm.attention.backends.xformers.XFormersBackend"
        elif selected_backend == _Backend.FLASH_ATTN:
            pass
        elif selected_backend:
            raise ValueError(
247
248
                f"Invalid attention backend for {cls.device_name}, "
                f"with use_v1: {use_v1} use_mla: {use_mla}")
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273

        target_backend = _Backend.FLASH_ATTN
        if not cls.has_device_capability(80):
            # Volta and Turing NVIDIA GPUs.
            logger.info(
                "Cannot use FlashAttention-2 backend for Volta and Turing "
                "GPUs.")
            target_backend = _Backend.XFORMERS
        elif dtype not in (torch.float16, torch.bfloat16):
            logger.info(
                "Cannot use FlashAttention-2 backend for dtype other than "
                "torch.float16 or torch.bfloat16.")
            target_backend = _Backend.XFORMERS
        elif block_size % 16 != 0:
            logger.info(
                "Cannot use FlashAttention-2 backend for block size not "
                "divisible by 16.")
            target_backend = _Backend.XFORMERS

        # FlashAttn is valid for the model, checking if the package is
        # installed.
        if target_backend == _Backend.FLASH_ATTN:
            try:
                import vllm.vllm_flash_attn  # noqa: F401
                from vllm.attention.backends.flash_attn import (  # noqa: F401
274
                    FlashAttentionBackend, flash_attn_supports_fp8)
275
276
277
278
279
280
281
282

                supported_sizes = \
                    FlashAttentionBackend.get_supported_head_sizes()
                if head_size not in supported_sizes:
                    logger.info(
                        "Cannot use FlashAttention-2 backend for head size %d.",
                        head_size)
                    target_backend = _Backend.XFORMERS
283
284
                fp8_kv_cache = (kv_cache_dtype is not None
                                and kv_cache_dtype.startswith("fp8"))
285
                if (fp8_kv_cache and not flash_attn_supports_fp8()):
286
                    logger.info(
287
                        "Cannot use FlashAttention backend for FP8 KV cache.")
288
289
290
291
292
                    logger.warning(
                        "Please use FlashInfer backend with FP8 KV Cache for "
                        "better performance by setting environment variable "
                        "VLLM_ATTENTION_BACKEND=FLASHINFER")
                    target_backend = _Backend.XFORMERS
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
            except ImportError:
                logger.info(
                    "Cannot use FlashAttention-2 backend because the "
                    "vllm.vllm_flash_attn package is not found. "
                    "Make sure that vllm_flash_attn was built and installed "
                    "(on by default).")
                target_backend = _Backend.XFORMERS

        if target_backend == _Backend.XFORMERS:
            logger.info("Using XFormers backend.")
            return "vllm.attention.backends.xformers.XFormersBackend"

        logger.info("Using Flash Attention backend.")
        return "vllm.attention.backends.flash_attn.FlashAttentionBackend"

308
309
310
311
    @classmethod
    def get_punica_wrapper(cls) -> str:
        return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU"

312
313
314
315
    @classmethod
    def get_device_communicator_cls(cls) -> str:
        return "vllm.distributed.device_communicators.cuda_communicator.CudaCommunicator"  # noqa

316
317
318
319
    @classmethod
    def supports_fp8(cls) -> bool:
        return cls.has_device_capability(89)

320
    @classmethod
321
    def supports_v1(cls, model_config: "ModelConfig") -> bool:
322
323
        return True

324
325
326
327
    @classmethod
    def use_custom_allreduce(cls) -> bool:
        return True

328

329
330
331
332
333
# NVML utils
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
# all the related functions work on real physical device ids.
# the major benefit of using NVML is that it will not initialize CUDA
class NvmlCudaPlatform(CudaPlatformBase):
334

335
    @classmethod
336
    @with_nvml_context
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
    def get_device_capability(cls,
                              device_id: int = 0
                              ) -> Optional[DeviceCapability]:
        try:
            physical_device_id = device_id_to_physical_device_id(device_id)
            handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
            major, minor = pynvml.nvmlDeviceGetCudaComputeCapability(handle)
            return DeviceCapability(major=major, minor=minor)
        except RuntimeError:
            return None

    @classmethod
    @with_nvml_context
    def has_device_capability(
        cls,
        capability: Union[Tuple[int, int], int],
        device_id: int = 0,
    ) -> bool:
        try:
            return super().has_device_capability(capability, device_id)
        except RuntimeError:
            return False
359

360
    @classmethod
361
    @with_nvml_context
362
    def get_device_name(cls, device_id: int = 0) -> str:
363
        physical_device_id = device_id_to_physical_device_id(device_id)
364
        return cls._get_physical_device_name(physical_device_id)
365

366
367
368
369
370
371
372
    @classmethod
    @with_nvml_context
    def get_device_uuid(cls, device_id: int = 0) -> str:
        physical_device_id = device_id_to_physical_device_id(device_id)
        handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
        return pynvml.nvmlDeviceGetUUID(handle)

373
    @classmethod
374
    @with_nvml_context
375
376
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        physical_device_id = device_id_to_physical_device_id(device_id)
377
378
        handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
        return int(pynvml.nvmlDeviceGetMemoryInfo(handle).total)
379

380
    @classmethod
381
    @with_nvml_context
382
    def is_fully_connected(cls, physical_device_ids: List[int]) -> bool:
383
384
385
386
387
388
389
390
391
392
393
        """
        query if the set of gpus are fully connected by nvlink (1 hop)
        """
        handles = [
            pynvml.nvmlDeviceGetHandleByIndex(i) for i in physical_device_ids
        ]
        for i, handle in enumerate(handles):
            for j, peer_handle in enumerate(handles):
                if i < j:
                    try:
                        p2p_status = pynvml.nvmlDeviceGetP2PStatus(
394
395
396
397
                            handle,
                            peer_handle,
                            pynvml.NVML_P2P_CAPS_INDEX_NVLINK,
                        )
398
399
                        if p2p_status != pynvml.NVML_P2P_STATUS_OK:
                            return False
400
401
                    except pynvml.NVMLError:
                        logger.exception(
402
403
                            "NVLink detection failed. This is normal if"
                            " your machine has no NVLink equipped.")
404
405
                        return False
        return True
406
407

    @classmethod
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
    def _get_physical_device_name(cls, device_id: int = 0) -> str:
        handle = pynvml.nvmlDeviceGetHandleByIndex(device_id)
        return pynvml.nvmlDeviceGetName(handle)

    @classmethod
    @with_nvml_context
    def log_warnings(cls):
        device_ids: int = pynvml.nvmlDeviceGetCount()
        if device_ids > 1:
            device_names = [
                cls._get_physical_device_name(i) for i in range(device_ids)
            ]
            if (len(set(device_names)) > 1
                    and os.environ.get("CUDA_DEVICE_ORDER") != "PCI_BUS_ID"):
                logger.warning(
423
                    "Detected different devices in the system: %s. Please"
424
425
                    " make sure to set `CUDA_DEVICE_ORDER=PCI_BUS_ID` to "
                    "avoid unexpected behavior.",
426
                    ", ".join(device_names),
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
                )


class NonNvmlCudaPlatform(CudaPlatformBase):

    @classmethod
    def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:
        major, minor = torch.cuda.get_device_capability(device_id)
        return DeviceCapability(major=major, minor=minor)

    @classmethod
    def get_device_name(cls, device_id: int = 0) -> str:
        return torch.cuda.get_device_name(device_id)

    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        device_props = torch.cuda.get_device_properties(device_id)
        return device_props.total_memory

    @classmethod
447
    def is_fully_connected(cls, physical_device_ids: List[int]) -> bool:
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
        logger.exception(
            "NVLink detection not possible, as context support was"
            " not found. Assuming no NVLink available.")
        return False


# Autodetect either NVML-enabled or non-NVML platform
# based on whether NVML is available.
nvml_available = False
try:
    try:
        pynvml.nvmlInit()
        nvml_available = True
    except Exception:
        # On Jetson, NVML is not supported.
        nvml_available = False
finally:
    if nvml_available:
        pynvml.nvmlShutdown()

CudaPlatform = NvmlCudaPlatform if nvml_available else NonNvmlCudaPlatform

470
CudaPlatform.log_warnings()