Unverified Commit e9f331d7 authored by Nick Hill's avatar Nick Hill Committed by GitHub
Browse files

[MRV2] Ensure warmup covers prefill path (#40746)


Signed-off-by: default avatarNick Hill <nickhill123@gmail.com>
parent c9bf77df
...@@ -29,13 +29,16 @@ def warmup_kernels( ...@@ -29,13 +29,16 @@ def warmup_kernels(
triton kernels. We must call the provided worker's execute_model for triton kernels. We must call the provided worker's execute_model for
pipeline parallel coordination. pipeline parallel coordination.
The first iteration simulates a prefill with requests of 2 prompt The first iteration simulates a prefill with requests of
tokens each. The second iteration simulates a decode step with all 2 + num_spec_steps prompt tokens each. The second iteration simulates
requests generating 1 token each. a decode step with all requests generating 1 + num_spec_steps tokens.
""" """
prompt_token_ids = [0, 1]
prompt_len = len(prompt_token_ids)
num_spec_steps = model_runner.num_speculative_steps num_spec_steps = model_runner.num_speculative_steps
# Use 1 + num_spec_steps + 1 tokens so the prefill batch's per-request
# query length exceeds decode_query_len (= 1 + num_spec_steps), preventing
# it from being misclassified as a uniform decode batch.
prompt_len = 2 + num_spec_steps
prompt_token_ids = list(range(prompt_len))
# After prefill, decode generates 1 verified + num_spec_steps draft tokens. # After prefill, decode generates 1 verified + num_spec_steps draft tokens.
decode_len = prompt_len + 1 + num_spec_steps decode_len = prompt_len + 1 + num_spec_steps
...@@ -76,7 +79,7 @@ def warmup_kernels( ...@@ -76,7 +79,7 @@ def warmup_kernels(
nonlocal next_block_id nonlocal next_block_id
return list(range(next_block_id, next_block_id := next_block_id + num_blocks)) return list(range(next_block_id, next_block_id := next_block_id + num_blocks))
# Step 1: Prefill all requests with 2 prompt tokens each. # Step 1: Prefill all requests with 2 + num_spec_steps prompt tokens each.
new_reqs = [ new_reqs = [
NewRequestData.from_request( NewRequestData.from_request(
Request(req_ids[i], prompt_token_ids, sampling_params, pooling_params), Request(req_ids[i], prompt_token_ids, sampling_params, pooling_params),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment