kernel_warmup.py 742 Bytes
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Warmup kernels used during model execution.
This is useful specifically for JIT'ed kernels as we don't want JIT'ing to
happen during model execution.
"""
import torch

import vllm.envs as envs
from vllm.model_executor.warmup.deep_gemm_warmup import deep_gemm_warmup
from vllm.utils.deep_gemm import is_deep_gemm_supported


def kernel_warmup(model: torch.nn.Module, max_tokens: int):
    do_deep_gemm_warmup = (envs.VLLM_USE_DEEP_GEMM
                           and is_deep_gemm_supported()
                           and not envs.VLLM_SKIP_DEEP_GEMM_WARMUP)
    if do_deep_gemm_warmup:
        deep_gemm_warmup(model, max_tokens)