Unverified Commit b1d9f537 authored by Nick Hill's avatar Nick Hill Committed by GitHub
Browse files

[Model Runner V2] Warmup kernels (#35172)


Signed-off-by: default avatarNick Hill <nickhill123@gmail.com>
parent fd6de37f
...@@ -840,6 +840,24 @@ class SamplingParams( ...@@ -840,6 +840,24 @@ class SamplingParams(
f"extra_args={self.extra_args})" f"extra_args={self.extra_args})"
) )
@staticmethod
def for_sampler_warmup() -> "SamplingParams":
"""Set parameters to exercise all sampler logic."""
return SamplingParams(
temperature=0.9,
top_p=0.9,
top_k=50,
min_p=0.1,
frequency_penalty=0.5,
presence_penalty=0.5,
repetition_penalty=1.2,
min_tokens=2,
logit_bias={0: -1.0, 1: 0.5},
_bad_words_token_ids=[[0], [1, 2]],
logprobs=5,
prompt_logprobs=1,
)
class BeamSearchParams( class BeamSearchParams(
msgspec.Struct, msgspec.Struct,
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import numpy as np
import torch
from vllm import PoolingParams, SamplingParams
from vllm.v1.core.sched.output import (
CachedRequestData,
GrammarOutput,
NewRequestData,
SchedulerOutput,
)
from vllm.v1.request import Request
from vllm.v1.worker.gpu.model_runner import GPUModelRunner
@torch.inference_mode()
def warmup_kernels(model_runner: GPUModelRunner) -> None:
"""Run two execute_model + sample_tokens iterations to JIT compile
triton kernels.
The first iteration simulates a prefill with requests of 2 prompt
tokens each. The second iteration simulates a decode step with all
requests generating 1 token each.
"""
prompt_token_ids = [0, 1]
prompt_len = len(prompt_token_ids)
num_reqs = min(
model_runner.scheduler_config.max_num_seqs,
model_runner.scheduler_config.max_num_batched_tokens // prompt_len,
)
num_kv_cache_groups = len(model_runner.kv_cache_config.kv_cache_groups)
req_ids = [f"_warmup_{i}_" for i in range(num_reqs)]
# SamplingParams exercising all sampling features.
if model_runner.is_pooling_model:
sampling_params = None
pooling_params = PoolingParams()
else:
sampling_params = SamplingParams.for_sampler_warmup()
pooling_params = None
# Step 1: Prefill all requests with 2 prompt tokens each.
new_reqs = [
NewRequestData.from_request(
Request(req_ids[i], prompt_token_ids, sampling_params, pooling_params),
# Each request uses a distinct block per KV cache group.
block_ids=tuple([i] for _ in range(num_kv_cache_groups)),
prefill_token_ids=prompt_token_ids,
)
for i in range(num_reqs)
]
prefill_output = SchedulerOutput.make_empty()
prefill_output.scheduled_new_reqs = new_reqs
prefill_output.num_scheduled_tokens = {rid: prompt_len for rid in req_ids}
prefill_output.total_num_scheduled_tokens = prompt_len * num_reqs
prefill_output.num_common_prefix_blocks = [0] * num_kv_cache_groups
# Disable KV connector for warmup run.
model_runner.kv_connector.set_disabled(True)
model_runner.execute_model(prefill_output)
if not model_runner.is_pooling_model:
# Warm up sampler and perform a decode step for non-pooling models.
grammar_output = None
if model_runner.is_last_pp_rank:
# Build a GrammarOutput to exercise the structured output bitmask
# kernel during the prefill step.
vocab_size = model_runner.model_config.get_vocab_size()
bitmask_width = (vocab_size + 31) // 32
grammar_bitmask = np.full(
(len(req_ids), bitmask_width), fill_value=-1, dtype=np.int32
)
grammar_output = GrammarOutput(
structured_output_request_ids=req_ids, grammar_bitmask=grammar_bitmask
)
model_runner.sample_tokens(grammar_output)
# Step 2: Decode all requests with 1 token each.
cached_req_data = CachedRequestData.make_empty()
cached_req_data.req_ids = list(req_ids)
cached_req_data.new_block_ids = [None] * num_reqs
cached_req_data.num_computed_tokens = [prompt_len] * num_reqs
cached_req_data.num_output_tokens = [1] * num_reqs
decode_output = SchedulerOutput.make_empty()
decode_output.scheduled_cached_reqs = cached_req_data
decode_output.num_scheduled_tokens = {rid: 1 for rid in req_ids}
decode_output.total_num_scheduled_tokens = num_reqs
decode_output.num_common_prefix_blocks = [0] * num_kv_cache_groups
model_runner.execute_model(decode_output)
model_runner.sample_tokens(None)
# Clean up - process finish_req_ids.
cleanup_output = SchedulerOutput.make_empty()
cleanup_output.finished_req_ids = set(req_ids)
model_runner.execute_model(cleanup_output)
model_runner.kv_connector.set_disabled(False)
torch.cuda.synchronize()
...@@ -61,6 +61,7 @@ from vllm.v1.worker.utils import is_residual_scattered_for_sp ...@@ -61,6 +61,7 @@ from vllm.v1.worker.utils import is_residual_scattered_for_sp
from vllm.v1.worker.worker_base import WorkerBase from vllm.v1.worker.worker_base import WorkerBase
from vllm.v1.worker.workspace import init_workspace_manager from vllm.v1.worker.workspace import init_workspace_manager
from .gpu.warmup import warmup_kernels
from .utils import request_memory from .utils import request_memory
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -558,12 +559,15 @@ class Worker(WorkerBase): ...@@ -558,12 +559,15 @@ class Worker(WorkerBase):
logger.debug(msg) logger.debug(msg)
# Warm up sampler and preallocate memory buffer for logits and other if self.use_v2_model_runner:
# sampling related tensors of max possible shape to avoid memory # V2: Run full execute_model + sample_tokens to JIT compile triton kernels.
# fragmentation issue. warmup_kernels(self.model_runner)
# NOTE: This is called after `capture_model` on purpose to prevent elif get_pp_group().is_last_rank:
# memory buffers from being cleared by `torch.cuda.empty_cache`. # V1: Warm up sampler and preallocate memory buffer for logits and other
if get_pp_group().is_last_rank: # sampling related tensors of max possible shape to avoid memory
# fragmentation issue.
# NOTE: This is called after `capture_model` on purpose to prevent
# memory buffers from being cleared by `torch.cuda.empty_cache`.
max_num_reqs = min( max_num_reqs = min(
self.scheduler_config.max_num_seqs, self.scheduler_config.max_num_seqs,
self.scheduler_config.max_num_batched_tokens, self.scheduler_config.max_num_batched_tokens,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment