kernel_warmup.py 2.21 KB
Newer Older
1
2
3
4
5
6
7
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Warmup kernels used during model execution.
This is useful specifically for JIT'ed kernels as we don't want JIT'ing to
happen during model execution.
"""
8
9
from typing import TYPE_CHECKING

10
11
12
13
import torch

import vllm.envs as envs
from vllm.model_executor.warmup.deep_gemm_warmup import deep_gemm_warmup
14
from vllm.platforms import current_platform
15
from vllm.utils.deep_gemm import is_deep_gemm_supported
16
17
18
19
20
from vllm.utils.flashinfer import has_flashinfer

if TYPE_CHECKING:
    from vllm.v1.worker.gpu_model_runner import GPUModelRunner
    from vllm.v1.worker.gpu_worker import Worker
21
22


23
24
def kernel_warmup(worker: "Worker"):
    # Deep GEMM warmup
25
26
27
28
    do_deep_gemm_warmup = (envs.VLLM_USE_DEEP_GEMM
                           and is_deep_gemm_supported()
                           and not envs.VLLM_SKIP_DEEP_GEMM_WARMUP)
    if do_deep_gemm_warmup:
29
30
        model = worker.get_model()
        max_tokens = worker.scheduler_config.max_num_batched_tokens
31
        deep_gemm_warmup(model, max_tokens)
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57

    # FlashInfer autotune for Blackwell (SM 10.0) GPUs
    if has_flashinfer() and current_platform.is_device_capability(100):
        flashinfer_autotune(worker.model_runner)


def flashinfer_autotune(runner: "GPUModelRunner") -> None:
    """
    Autotune FlashInfer operations.
    FlashInfer have many implementations for the same operation,
    autotuning runs benchmarks for each implementation and stores
    the results. The results are cached transparently and
    future calls to FlashInfer will use the best implementation.
    Without autotuning, FlashInfer will rely on heuristics, which may
    be significantly slower.
    """
    from vllm.utils.flashinfer import autotune

    with torch.inference_mode(), autotune():
        # We skip EPLB here since we don't want to record dummy metrics
        # When autotuning with number of tokens m, flashinfer will autotune
        # operations for all number of tokens up to m.
        # So we only need to run with the max number of tokens.
        runner._dummy_run(runner.scheduler_config.max_num_batched_tokens,
                          skip_eplb=True,
                          is_profile=True)