Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e7e52781
Unverified
Commit
e7e52781
authored
Feb 09, 2026
by
Nick Hill
Committed by
GitHub
Feb 09, 2026
Browse files
[ModelRunner V2][BugFix] Fix `max_query_len` calculation (#34167)
Signed-off-by:
Nick Hill
<
nickhill123@gmail.com
>
parent
bb9f9730
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
6 additions
and
1 deletion
+6
-1
vllm/v1/worker/gpu/attn_utils.py
vllm/v1/worker/gpu/attn_utils.py
+1
-1
vllm/v1/worker/gpu/cudagraph_utils.py
vllm/v1/worker/gpu/cudagraph_utils.py
+1
-0
vllm/v1/worker/gpu/model_runner.py
vllm/v1/worker/gpu/model_runner.py
+3
-0
vllm/v1/worker/gpu/spec_decode/eagle.py
vllm/v1/worker/gpu/spec_decode/eagle.py
+1
-0
No files found.
vllm/v1/worker/gpu/attn_utils.py
View file @
e7e52781
...
...
@@ -149,13 +149,13 @@ def build_attn_metadata(
num_tokens
:
int
,
query_start_loc_gpu
:
torch
.
Tensor
,
query_start_loc_cpu
:
torch
.
Tensor
,
max_query_len
:
int
,
seq_lens
:
torch
.
Tensor
,
max_seq_len
:
int
,
block_tables
:
Sequence
[
torch
.
Tensor
],
slot_mappings
:
torch
.
Tensor
,
kv_cache_config
:
KVCacheConfig
,
)
->
dict
[
str
,
Any
]:
max_query_len
=
int
(
query_start_loc_cpu
.
max
())
seq_lens
=
seq_lens
[:
num_reqs
]
attn_metadata
:
dict
[
str
,
Any
]
=
{}
...
...
vllm/v1/worker/gpu/cudagraph_utils.py
View file @
e7e52781
...
...
@@ -267,6 +267,7 @@ def prepare_inputs_to_capture(
num_tokens
=
num_tokens
,
query_start_loc_gpu
=
query_start_loc
,
query_start_loc_cpu
=
query_start_loc_cpu
,
max_query_len
=
num_tokens_per_req
,
seq_lens
=
input_buffers
.
seq_lens
,
max_seq_len
=
max_model_len
,
block_tables
=
input_block_tables
,
...
...
vllm/v1/worker/gpu/model_runner.py
View file @
e7e52781
...
...
@@ -274,6 +274,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
num_tokens
=
input_batch
.
num_tokens
,
query_start_loc_gpu
=
input_batch
.
query_start_loc
,
query_start_loc_cpu
=
torch
.
from_numpy
(
input_batch
.
query_start_loc_np
),
max_query_len
=
input_batch
.
num_scheduled_tokens
.
max
().
item
(),
seq_lens
=
input_batch
.
seq_lens
,
max_seq_len
=
self
.
max_model_len
,
block_tables
=
block_tables
,
...
...
@@ -561,6 +562,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
query_start_loc_np
=
query_start_loc_np
[:
num_reqs
+
1
]
query_start_loc_cpu
=
torch
.
from_numpy
(
query_start_loc_np
)
query_start_loc
=
self
.
input_buffers
.
query_start_loc
[:
num_reqs
+
1
]
max_query_len
=
num_scheduled_tokens
.
max
().
item
()
# Get prefill tokens.
prepare_prefill_inputs
(
...
...
@@ -624,6 +626,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
num_tokens
=
num_tokens
,
query_start_loc_gpu
=
query_start_loc
,
query_start_loc_cpu
=
query_start_loc_cpu
,
max_query_len
=
max_query_len
,
seq_lens
=
self
.
input_buffers
.
seq_lens
,
max_seq_len
=
self
.
max_model_len
,
block_tables
=
block_tables
,
...
...
vllm/v1/worker/gpu/spec_decode/eagle.py
View file @
e7e52781
...
...
@@ -301,6 +301,7 @@ class EagleSpeculator:
num_tokens
=
num_reqs
,
query_start_loc_gpu
=
query_start_loc
,
query_start_loc_cpu
=
query_start_loc_cpu
,
max_query_len
=
1
,
seq_lens
=
self
.
input_buffers
.
seq_lens
[:
num_reqs
],
max_seq_len
=
self
.
max_model_len
,
block_tables
=
block_tables
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment