kernel_warmup.py 4.13 KB
Newer Older
1
2
3
4
5
6
7
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Warmup kernels used during model execution.
This is useful specifically for JIT'ed kernels as we don't want JIT'ing to
happen during model execution.
"""
8

9
10
from typing import TYPE_CHECKING

11
12
13
import torch

import vllm.envs as envs
14
from vllm.config import CUDAGraphMode, VllmConfig
15
from vllm.logger import init_logger
16
from vllm.model_executor.warmup.deep_gemm_warmup import deep_gemm_warmup
17
from vllm.platforms import current_platform
18
from vllm.utils.deep_gemm import is_deep_gemm_supported
19
20
21
22
23
from vllm.utils.flashinfer import has_flashinfer

if TYPE_CHECKING:
    from vllm.v1.worker.gpu_model_runner import GPUModelRunner
    from vllm.v1.worker.gpu_worker import Worker
24

25
26
logger = init_logger(__name__)

27

28
29
30
31
32
def flashinfer_autotune_supported(vllm_config: VllmConfig) -> bool:
    """
    Record known issues with vllm + flashinfer autotune here. Return True if
    and only if flashinfer autotune will run through without issues.
    """
33
34
    is_tp_or_dp = (vllm_config.parallel_config.data_parallel_size > 1) or (
        vllm_config.parallel_config.tensor_parallel_size > 1
35
    )
36
37
38
39
40
41
42
43
44
45
    is_fi_mxfp4_backend = (
        envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
        or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16
        or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS
    ) or (
        current_platform.is_cuda() and current_platform.is_device_capability(100)
    )  # on >=sm100, default mxfp4 backend is flashinfer
    is_eager = vllm_config.compilation_config.cudagraph_mode == CUDAGraphMode.NONE

    return not (is_tp_or_dp and is_fi_mxfp4_backend and is_eager)
46
47


48
49
def kernel_warmup(worker: "Worker"):
    # Deep GEMM warmup
50
51
52
    do_deep_gemm_warmup = (
        envs.VLLM_USE_DEEP_GEMM
        and is_deep_gemm_supported()
53
        and envs.VLLM_DEEP_GEMM_WARMUP != "skip"
54
    )
55
    if do_deep_gemm_warmup:
56
57
        model = worker.get_model()
        max_tokens = worker.scheduler_config.max_num_batched_tokens
58
        deep_gemm_warmup(model, max_tokens)
59

60
    # FlashInfer autotune for Hopper (SM 9.0) and Blackwell (SM 10.0) GPUs
61
62
63
64
65
    if (
        has_flashinfer()
        and current_platform.has_device_capability(90)
        and flashinfer_autotune_supported(worker.vllm_config)
    ):
66
67
        flashinfer_autotune(worker.model_runner)

68
69
70
71
72
    # FlashInfer attention warmup
    # Only warmup if the model has FlashInfer attention groups
    # and is not a pooling model
    def _is_flashinfer_backend(backend):
        try:
73
            return backend.get_name() == "FLASHINFER"
74
75
76
77
        except NotImplementedError:
            return False

    if not worker.model_runner.is_pooling_model and all(
78
79
80
81
        _is_flashinfer_backend(group.backend)
        for groups in worker.model_runner.attn_groups
        for group in groups
    ):
82
83
84
85
86
87
88
89
90
91
92
        logger.info("Warming up FlashInfer attention.")
        # Warmup with mixed batch containing both prefill and decode tokens
        # This is to warm up both prefill and decode attention kernels
        worker.model_runner._dummy_run(
            num_tokens=16,
            skip_eplb=True,
            is_profile=True,
            force_attention=True,
            create_mixed_batch=True,
        )

93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110

def flashinfer_autotune(runner: "GPUModelRunner") -> None:
    """
    Autotune FlashInfer operations.
    FlashInfer have many implementations for the same operation,
    autotuning runs benchmarks for each implementation and stores
    the results. The results are cached transparently and
    future calls to FlashInfer will use the best implementation.
    Without autotuning, FlashInfer will rely on heuristics, which may
    be significantly slower.
    """
    from vllm.utils.flashinfer import autotune

    with torch.inference_mode(), autotune():
        # We skip EPLB here since we don't want to record dummy metrics
        # When autotuning with number of tokens m, flashinfer will autotune
        # operations for all number of tokens up to m.
        # So we only need to run with the max number of tokens.
111
112
113
114
115
        runner._dummy_run(
            runner.scheduler_config.max_num_batched_tokens,
            skip_eplb=True,
            is_profile=True,
        )