Unverified Commit e7e52781 authored by Nick Hill's avatar Nick Hill Committed by GitHub
Browse files

[ModelRunner V2][BugFix] Fix `max_query_len` calculation (#34167)


Signed-off-by: default avatarNick Hill <nickhill123@gmail.com>
parent bb9f9730
......@@ -149,13 +149,13 @@ def build_attn_metadata(
num_tokens: int,
query_start_loc_gpu: torch.Tensor,
query_start_loc_cpu: torch.Tensor,
max_query_len: int,
seq_lens: torch.Tensor,
max_seq_len: int,
block_tables: Sequence[torch.Tensor],
slot_mappings: torch.Tensor,
kv_cache_config: KVCacheConfig,
) -> dict[str, Any]:
max_query_len = int(query_start_loc_cpu.max())
seq_lens = seq_lens[:num_reqs]
attn_metadata: dict[str, Any] = {}
......
......@@ -267,6 +267,7 @@ def prepare_inputs_to_capture(
num_tokens=num_tokens,
query_start_loc_gpu=query_start_loc,
query_start_loc_cpu=query_start_loc_cpu,
max_query_len=num_tokens_per_req,
seq_lens=input_buffers.seq_lens,
max_seq_len=max_model_len,
block_tables=input_block_tables,
......
......@@ -274,6 +274,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
num_tokens=input_batch.num_tokens,
query_start_loc_gpu=input_batch.query_start_loc,
query_start_loc_cpu=torch.from_numpy(input_batch.query_start_loc_np),
max_query_len=input_batch.num_scheduled_tokens.max().item(),
seq_lens=input_batch.seq_lens,
max_seq_len=self.max_model_len,
block_tables=block_tables,
......@@ -561,6 +562,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
query_start_loc_np = query_start_loc_np[: num_reqs + 1]
query_start_loc_cpu = torch.from_numpy(query_start_loc_np)
query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1]
max_query_len = num_scheduled_tokens.max().item()
# Get prefill tokens.
prepare_prefill_inputs(
......@@ -624,6 +626,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
num_tokens=num_tokens,
query_start_loc_gpu=query_start_loc,
query_start_loc_cpu=query_start_loc_cpu,
max_query_len=max_query_len,
seq_lens=self.input_buffers.seq_lens,
max_seq_len=self.max_model_len,
block_tables=block_tables,
......
......@@ -301,6 +301,7 @@ class EagleSpeculator:
num_tokens=num_reqs,
query_start_loc_gpu=query_start_loc,
query_start_loc_cpu=query_start_loc_cpu,
max_query_len=1,
seq_lens=self.input_buffers.seq_lens[:num_reqs],
max_seq_len=self.max_model_len,
block_tables=block_tables,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment