Unverified Commit ffb5b32b authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

[MRV2] Consider spec decoding in warmup (#37812)


Signed-off-by: default avatarWoosuk Kwon <woosuk@inferact.ai>
Signed-off-by: default avatarNick Hill <nickhill123@gmail.com>
Co-authored-by: default avatarNick Hill <nickhill123@gmail.com>
parent 91fd695b
...@@ -35,7 +35,9 @@ def warmup_kernels( ...@@ -35,7 +35,9 @@ def warmup_kernels(
""" """
prompt_token_ids = [0, 1] prompt_token_ids = [0, 1]
prompt_len = len(prompt_token_ids) prompt_len = len(prompt_token_ids)
decode_len = prompt_len + 1 # After prefill, one decode token is added. num_spec_steps = model_runner.num_speculative_steps
# After prefill, decode generates 1 verified + num_spec_steps draft tokens.
decode_len = prompt_len + 1 + num_spec_steps
kv_cache_groups = model_runner.kv_cache_config.kv_cache_groups kv_cache_groups = model_runner.kv_cache_config.kv_cache_groups
num_kv_cache_groups = len(kv_cache_groups) num_kv_cache_groups = len(kv_cache_groups)
...@@ -51,7 +53,8 @@ def warmup_kernels( ...@@ -51,7 +53,8 @@ def warmup_kernels(
num_reqs = min( num_reqs = min(
model_runner.scheduler_config.max_num_seqs, model_runner.scheduler_config.max_num_seqs,
model_runner.scheduler_config.max_num_batched_tokens // prompt_len, model_runner.scheduler_config.max_num_batched_tokens
// max(prompt_len, 1 + num_spec_steps),
# Reserve block 0 (null block) and ensure we have enough blocks. # Reserve block 0 (null block) and ensure we have enough blocks.
max(1, (model_runner.kv_cache_config.num_blocks - 1) // max_blocks_per_req), max(1, (model_runner.kv_cache_config.num_blocks - 1) // max_blocks_per_req),
) )
...@@ -111,7 +114,7 @@ def warmup_kernels( ...@@ -111,7 +114,7 @@ def warmup_kernels(
worker_sample_tokens(grammar_output) worker_sample_tokens(grammar_output)
# Step 2: Decode all requests with 1 token each. # Step 2: Decode all requests with 1 + num_spec_steps tokens each.
cached_req_data = CachedRequestData.make_empty() cached_req_data = CachedRequestData.make_empty()
cached_req_data.req_ids = list(req_ids) cached_req_data.req_ids = list(req_ids)
cached_req_data.num_computed_tokens = [prompt_len] * num_reqs cached_req_data.num_computed_tokens = [prompt_len] * num_reqs
...@@ -124,8 +127,16 @@ def warmup_kernels( ...@@ -124,8 +127,16 @@ def warmup_kernels(
decode_output = SchedulerOutput.make_empty() decode_output = SchedulerOutput.make_empty()
decode_output.scheduled_cached_reqs = cached_req_data decode_output.scheduled_cached_reqs = cached_req_data
decode_output.num_scheduled_tokens = {rid: 1 for rid in req_ids} decode_output.num_scheduled_tokens = {
decode_output.total_num_scheduled_tokens = num_reqs req_id: 1 + num_spec_steps for req_id in req_ids
}
if num_spec_steps > 0:
decode_output.scheduled_spec_decode_tokens = {
req_id: [0] * num_spec_steps for req_id in req_ids
}
decode_output.total_num_scheduled_tokens = sum(
decode_output.num_scheduled_tokens.values()
)
decode_output.num_common_prefix_blocks = [0] * num_kv_cache_groups decode_output.num_common_prefix_blocks = [0] * num_kv_cache_groups
worker_execute_model(decode_output) worker_execute_model(decode_output)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment