Unverified Commit e9f331d7 authored by Nick Hill's avatar Nick Hill Committed by GitHub
Browse files

[MRV2] Ensure warmup covers prefill path (#40746)


Signed-off-by: default avatarNick Hill <nickhill123@gmail.com>
parent c9bf77df
......@@ -29,13 +29,16 @@ def warmup_kernels(
triton kernels. We must call the provided worker's execute_model for
pipeline parallel coordination.
The first iteration simulates a prefill with requests of 2 prompt
tokens each. The second iteration simulates a decode step with all
requests generating 1 token each.
The first iteration simulates a prefill with requests of
2 + num_spec_steps prompt tokens each. The second iteration simulates
a decode step with all requests generating 1 + num_spec_steps tokens.
"""
prompt_token_ids = [0, 1]
prompt_len = len(prompt_token_ids)
num_spec_steps = model_runner.num_speculative_steps
# Use 1 + num_spec_steps + 1 tokens so the prefill batch's per-request
# query length exceeds decode_query_len (= 1 + num_spec_steps), preventing
# it from being misclassified as a uniform decode batch.
prompt_len = 2 + num_spec_steps
prompt_token_ids = list(range(prompt_len))
# After prefill, decode generates 1 verified + num_spec_steps draft tokens.
decode_len = prompt_len + 1 + num_spec_steps
......@@ -76,7 +79,7 @@ def warmup_kernels(
nonlocal next_block_id
return list(range(next_block_id, next_block_id := next_block_id + num_blocks))
# Step 1: Prefill all requests with 2 prompt tokens each.
# Step 1: Prefill all requests with 2 + num_spec_steps prompt tokens each.
new_reqs = [
NewRequestData.from_request(
Request(req_ids[i], prompt_token_ids, sampling_params, pooling_params),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment