Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
d6d13bd4
Unverified
Commit
d6d13bd4
authored
Aug 20, 2025
by
Woosuk Kwon
Committed by
GitHub
Aug 20, 2025
Browse files
[Misc] Add max_seq_len to CommonAttentionMetadata (#23216)
Signed-off-by:
Woosuk Kwon
<
woosuk.kwon@berkeley.edu
>
parent
5efd6905
Changes
12
Show whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
22 additions
and
7 deletions
+22
-7
tests/v1/attention/utils.py
tests/v1/attention/utils.py
+2
-0
tests/v1/spec_decode/test_tree_attention.py
tests/v1/spec_decode/test_tree_attention.py
+2
-0
vllm/v1/attention/backends/flash_attn.py
vllm/v1/attention/backends/flash_attn.py
+1
-1
vllm/v1/attention/backends/flashinfer.py
vllm/v1/attention/backends/flashinfer.py
+1
-1
vllm/v1/attention/backends/flex_attention.py
vllm/v1/attention/backends/flex_attention.py
+1
-1
vllm/v1/attention/backends/rocm_aiter_fa.py
vllm/v1/attention/backends/rocm_aiter_fa.py
+1
-1
vllm/v1/attention/backends/tree_attn.py
vllm/v1/attention/backends/tree_attn.py
+1
-1
vllm/v1/attention/backends/triton_attn.py
vllm/v1/attention/backends/triton_attn.py
+1
-1
vllm/v1/attention/backends/utils.py
vllm/v1/attention/backends/utils.py
+6
-0
vllm/v1/attention/backends/xformers.py
vllm/v1/attention/backends/xformers.py
+1
-1
vllm/v1/spec_decode/eagle.py
vllm/v1/spec_decode/eagle.py
+1
-0
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+4
-0
No files found.
tests/v1/attention/utils.py
View file @
d6d13bd4
...
@@ -58,6 +58,7 @@ def create_common_attn_metadata(
...
@@ -58,6 +58,7 @@ def create_common_attn_metadata(
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
device
)
device
=
device
)
seq_lens_cpu
=
seq_lens
.
cpu
()
seq_lens_cpu
=
seq_lens
.
cpu
()
max_seq_len
=
int
(
seq_lens_cpu
.
max
())
# Create computed tokens (context length for each sequence)
# Create computed tokens (context length for each sequence)
context_lens
=
[
context_lens
=
[
...
@@ -101,6 +102,7 @@ def create_common_attn_metadata(
...
@@ -101,6 +102,7 @@ def create_common_attn_metadata(
num_reqs
=
batch_spec
.
batch_size
,
num_reqs
=
batch_spec
.
batch_size
,
num_actual_tokens
=
num_tokens
,
num_actual_tokens
=
num_tokens
,
max_query_len
=
max_query_len
,
max_query_len
=
max_query_len
,
max_seq_len
=
max_seq_len
,
block_table_tensor
=
block_table_tensor
,
block_table_tensor
=
block_table_tensor
,
slot_mapping
=
slot_mapping
,
slot_mapping
=
slot_mapping
,
causal
=
True
,
causal
=
True
,
...
...
tests/v1/spec_decode/test_tree_attention.py
View file @
d6d13bd4
...
@@ -50,6 +50,7 @@ def forward_attention(
...
@@ -50,6 +50,7 @@ def forward_attention(
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
)
)
context_lens
=
seq_lens
-
query_lens
context_lens
=
seq_lens
-
query_lens
max_seq_len
=
int
(
seq_lens
.
max
())
max_query_len
=
q_len
max_query_len
=
q_len
num_actual_tokens
=
query_start_loc
[
-
1
]
num_actual_tokens
=
query_start_loc
[
-
1
]
...
@@ -81,6 +82,7 @@ def forward_attention(
...
@@ -81,6 +82,7 @@ def forward_attention(
num_reqs
=
batch_size
,
num_reqs
=
batch_size
,
num_actual_tokens
=
num_actual_tokens
,
num_actual_tokens
=
num_actual_tokens
,
max_query_len
=
max_query_len
,
max_query_len
=
max_query_len
,
max_seq_len
=
max_seq_len
,
block_table_tensor
=
block_table
,
block_table_tensor
=
block_table
,
slot_mapping
=
slot_mapping
,
slot_mapping
=
slot_mapping
,
)
)
...
...
vllm/v1/attention/backends/flash_attn.py
View file @
d6d13bd4
...
@@ -233,7 +233,7 @@ class FlashAttentionMetadataBuilder(
...
@@ -233,7 +233,7 @@ class FlashAttentionMetadataBuilder(
num_reqs
=
common_attn_metadata
.
num_reqs
num_reqs
=
common_attn_metadata
.
num_reqs
num_actual_tokens
=
common_attn_metadata
.
num_actual_tokens
num_actual_tokens
=
common_attn_metadata
.
num_actual_tokens
max_query_len
=
common_attn_metadata
.
max_query_len
max_query_len
=
common_attn_metadata
.
max_query_len
max_seq_len
=
int
(
common_attn_metadata
.
seq_len
s_cpu
.
max
())
max_seq_len
=
common_attn_metadata
.
max_
seq_len
query_start_loc
=
common_attn_metadata
.
query_start_loc
query_start_loc
=
common_attn_metadata
.
query_start_loc
seq_lens
=
common_attn_metadata
.
seq_lens
seq_lens
=
common_attn_metadata
.
seq_lens
seq_lens_cpu
=
common_attn_metadata
.
seq_lens_cpu
seq_lens_cpu
=
common_attn_metadata
.
seq_lens_cpu
...
...
vllm/v1/attention/backends/flashinfer.py
View file @
d6d13bd4
...
@@ -463,7 +463,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
...
@@ -463,7 +463,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
page_size
=
self
.
page_size
page_size
=
self
.
page_size
max_q_len
=
common_attn_metadata
.
max_query_len
max_q_len
=
common_attn_metadata
.
max_query_len
max_seq_len
=
common_attn_metadata
.
seq_len
s_cpu
.
max
().
item
()
max_seq_len
=
common_attn_metadata
.
max_
seq_len
seq_lens
=
common_attn_metadata
.
seq_lens
seq_lens
=
common_attn_metadata
.
seq_lens
seq_lens_cpu
=
common_attn_metadata
.
seq_lens_cpu
seq_lens_cpu
=
common_attn_metadata
.
seq_lens_cpu
block_table_tensor
=
common_attn_metadata
.
block_table_tensor
block_table_tensor
=
common_attn_metadata
.
block_table_tensor
...
...
vllm/v1/attention/backends/flex_attention.py
View file @
d6d13bd4
...
@@ -305,7 +305,7 @@ class FlexAttentionMetadataBuilder(
...
@@ -305,7 +305,7 @@ class FlexAttentionMetadataBuilder(
num_actual_tokens
=
common_attn_metadata
.
num_actual_tokens
num_actual_tokens
=
common_attn_metadata
.
num_actual_tokens
max_query_len
=
common_attn_metadata
.
max_query_len
max_query_len
=
common_attn_metadata
.
max_query_len
max_seq_len
=
int
(
common_attn_metadata
.
seq_len
s_cpu
.
max
())
max_seq_len
=
common_attn_metadata
.
max_
seq_len
query_start_loc
=
common_attn_metadata
.
query_start_loc
query_start_loc
=
common_attn_metadata
.
query_start_loc
seq_lens
=
common_attn_metadata
.
seq_lens
seq_lens
=
common_attn_metadata
.
seq_lens
block_table_tensor
=
common_attn_metadata
.
block_table_tensor
block_table_tensor
=
common_attn_metadata
.
block_table_tensor
...
...
vllm/v1/attention/backends/rocm_aiter_fa.py
View file @
d6d13bd4
...
@@ -270,7 +270,7 @@ class AiterFlashAttentionMetadataBuilder(
...
@@ -270,7 +270,7 @@ class AiterFlashAttentionMetadataBuilder(
num_actual_tokens
=
common_attn_metadata
.
num_actual_tokens
num_actual_tokens
=
common_attn_metadata
.
num_actual_tokens
max_query_len
=
common_attn_metadata
.
max_query_len
max_query_len
=
common_attn_metadata
.
max_query_len
max_seq_len
=
int
(
common_attn_metadata
.
seq_len
s_cpu
.
max
())
max_seq_len
=
common_attn_metadata
.
max_
seq_len
query_start_loc
=
common_attn_metadata
.
query_start_loc
query_start_loc
=
common_attn_metadata
.
query_start_loc
seq_lens
=
common_attn_metadata
.
seq_lens
seq_lens
=
common_attn_metadata
.
seq_lens
block_table_tensor
=
common_attn_metadata
.
block_table_tensor
block_table_tensor
=
common_attn_metadata
.
block_table_tensor
...
...
vllm/v1/attention/backends/tree_attn.py
View file @
d6d13bd4
...
@@ -205,7 +205,7 @@ class TreeAttentionMetadataBuilder(
...
@@ -205,7 +205,7 @@ class TreeAttentionMetadataBuilder(
q_start_loc
=
common_attn_metadata
.
query_start_loc
q_start_loc
=
common_attn_metadata
.
query_start_loc
max_query_len
=
common_attn_metadata
.
max_query_len
max_query_len
=
common_attn_metadata
.
max_query_len
kv_seqlens
=
common_attn_metadata
.
seq_lens
kv_seqlens
=
common_attn_metadata
.
seq_lens
max_seq_len
=
int
(
common_attn_metadata
.
seq_len
s_cpu
.
max
())
max_seq_len
=
common_attn_metadata
.
max_
seq_len
block_table
=
common_attn_metadata
.
block_table_tensor
block_table
=
common_attn_metadata
.
block_table_tensor
slot_mapping
=
common_attn_metadata
.
slot_mapping
slot_mapping
=
common_attn_metadata
.
slot_mapping
...
...
vllm/v1/attention/backends/triton_attn.py
View file @
d6d13bd4
...
@@ -90,7 +90,7 @@ class TritonAttentionMetadataBuilder(
...
@@ -90,7 +90,7 @@ class TritonAttentionMetadataBuilder(
num_actual_tokens
=
common_attn_metadata
.
num_actual_tokens
num_actual_tokens
=
common_attn_metadata
.
num_actual_tokens
max_query_len
=
common_attn_metadata
.
max_query_len
max_query_len
=
common_attn_metadata
.
max_query_len
max_seq_len
=
int
(
common_attn_metadata
.
seq_len
s_cpu
.
max
())
max_seq_len
=
common_attn_metadata
.
max_
seq_len
query_start_loc
=
common_attn_metadata
.
query_start_loc
query_start_loc
=
common_attn_metadata
.
query_start_loc
seq_lens
=
common_attn_metadata
.
seq_lens
seq_lens
=
common_attn_metadata
.
seq_lens
block_table_tensor
=
common_attn_metadata
.
block_table_tensor
block_table_tensor
=
common_attn_metadata
.
block_table_tensor
...
...
vllm/v1/attention/backends/utils.py
View file @
d6d13bd4
...
@@ -58,6 +58,8 @@ class CommonAttentionMetadata:
...
@@ -58,6 +58,8 @@ class CommonAttentionMetadata:
"""Total number of tokens in batch"""
"""Total number of tokens in batch"""
max_query_len
:
int
max_query_len
:
int
"""Longest query in batch"""
"""Longest query in batch"""
max_seq_len
:
int
"""Longest context length in batch"""
block_table_tensor
:
torch
.
Tensor
block_table_tensor
:
torch
.
Tensor
slot_mapping
:
torch
.
Tensor
slot_mapping
:
torch
.
Tensor
...
@@ -107,6 +109,7 @@ def _make_metadata_with_slice(
...
@@ -107,6 +109,7 @@ def _make_metadata_with_slice(
seq_lens
=
attn_metadata
.
seq_lens
[
request_slice
]
seq_lens
=
attn_metadata
.
seq_lens
[
request_slice
]
seq_lens_cpu
=
attn_metadata
.
seq_lens_cpu
[
request_slice
]
seq_lens_cpu
=
attn_metadata
.
seq_lens_cpu
[
request_slice
]
max_seq_len
=
int
(
seq_lens_cpu
.
max
())
num_computed_tokens_cpu
=
attn_metadata
.
num_computed_tokens_cpu
[
num_computed_tokens_cpu
=
attn_metadata
.
num_computed_tokens_cpu
[
request_slice
]
request_slice
]
...
@@ -128,6 +131,7 @@ def _make_metadata_with_slice(
...
@@ -128,6 +131,7 @@ def _make_metadata_with_slice(
num_reqs
=
num_requests
,
num_reqs
=
num_requests
,
num_actual_tokens
=
num_actual_tokens
,
num_actual_tokens
=
num_actual_tokens
,
max_query_len
=
max_query_len
,
max_query_len
=
max_query_len
,
max_seq_len
=
max_seq_len
,
block_table_tensor
=
block_table_tensor
,
block_table_tensor
=
block_table_tensor
,
slot_mapping
=
slot_mapping
,
slot_mapping
=
slot_mapping
,
)
)
...
@@ -520,6 +524,7 @@ def make_local_attention_virtual_batches(
...
@@ -520,6 +524,7 @@ def make_local_attention_virtual_batches(
query_start_loc_cpu
=
torch
.
from_numpy
(
cu_seqlens_q_local
)
query_start_loc_cpu
=
torch
.
from_numpy
(
cu_seqlens_q_local
)
seq_lens_cpu
=
torch
.
from_numpy
(
seqlens_k_local
)
seq_lens_cpu
=
torch
.
from_numpy
(
seqlens_k_local
)
max_seq_len
=
int
(
seq_lens_cpu
.
max
())
return
CommonAttentionMetadata
(
return
CommonAttentionMetadata
(
query_start_loc_cpu
=
query_start_loc_cpu
,
query_start_loc_cpu
=
query_start_loc_cpu
,
...
@@ -531,6 +536,7 @@ def make_local_attention_virtual_batches(
...
@@ -531,6 +536,7 @@ def make_local_attention_virtual_batches(
num_reqs
=
len
(
seq_lens_cpu
),
num_reqs
=
len
(
seq_lens_cpu
),
num_actual_tokens
=
common_attn_metadata
.
num_actual_tokens
,
num_actual_tokens
=
common_attn_metadata
.
num_actual_tokens
,
max_query_len
=
seqlens_q_local
.
max
(),
max_query_len
=
seqlens_q_local
.
max
(),
max_seq_len
=
max_seq_len
,
block_table_tensor
=
block_table_local
,
block_table_tensor
=
block_table_local
,
slot_mapping
=
common_attn_metadata
.
slot_mapping
,
slot_mapping
=
common_attn_metadata
.
slot_mapping
,
causal
=
True
,
causal
=
True
,
...
...
vllm/v1/attention/backends/xformers.py
View file @
d6d13bd4
...
@@ -231,7 +231,7 @@ class XFormersAttentionMetadataBuilder(
...
@@ -231,7 +231,7 @@ class XFormersAttentionMetadataBuilder(
q_seqlens
=
torch
.
diff
(
q_start_loc
)
q_seqlens
=
torch
.
diff
(
q_start_loc
)
max_query_len
=
common_attn_metadata
.
max_query_len
max_query_len
=
common_attn_metadata
.
max_query_len
kv_seqlens
=
common_attn_metadata
.
seq_lens
kv_seqlens
=
common_attn_metadata
.
seq_lens
max_seq_len
=
int
(
common_attn_metadata
.
seq_len
s_cpu
.
max
())
max_seq_len
=
common_attn_metadata
.
max_
seq_len
block_table
=
common_attn_metadata
.
block_table_tensor
block_table
=
common_attn_metadata
.
block_table_tensor
slot_mapping
=
common_attn_metadata
.
slot_mapping
slot_mapping
=
common_attn_metadata
.
slot_mapping
...
...
vllm/v1/spec_decode/eagle.py
View file @
d6d13bd4
...
@@ -582,6 +582,7 @@ class EagleProposer:
...
@@ -582,6 +582,7 @@ class EagleProposer:
num_reqs
=
common_attn_metadata
.
num_reqs
,
num_reqs
=
common_attn_metadata
.
num_reqs
,
num_actual_tokens
=
total_num_tokens
,
num_actual_tokens
=
total_num_tokens
,
max_query_len
=
new_query_len_per_req
.
max
().
item
(),
max_query_len
=
new_query_len_per_req
.
max
().
item
(),
max_seq_len
=
new_seq_lens_cpu
.
max
().
item
(),
block_table_tensor
=
common_attn_metadata
.
block_table_tensor
,
block_table_tensor
=
common_attn_metadata
.
block_table_tensor
,
slot_mapping
=
common_attn_metadata
.
slot_mapping
[
token_indices
],
slot_mapping
=
common_attn_metadata
.
slot_mapping
[
token_indices
],
causal
=
True
,
causal
=
True
,
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
d6d13bd4
...
@@ -774,6 +774,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -774,6 +774,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self
.
seq_lens_np
[
num_reqs
:].
fill
(
0
)
self
.
seq_lens_np
[
num_reqs
:].
fill
(
0
)
self
.
seq_lens
.
copy_
(
self
.
seq_lens_cpu
,
non_blocking
=
True
)
self
.
seq_lens
.
copy_
(
self
.
seq_lens_cpu
,
non_blocking
=
True
)
seq_lens
=
self
.
seq_lens
[:
num_reqs
]
seq_lens
=
self
.
seq_lens
[:
num_reqs
]
max_seq_len
=
self
.
seq_lens_np
[:
num_reqs
].
max
().
item
()
# Copy the tensors to the GPU.
# Copy the tensors to the GPU.
self
.
input_ids
[:
total_num_scheduled_tokens
].
copy_
(
self
.
input_ids
[:
total_num_scheduled_tokens
].
copy_
(
...
@@ -886,6 +887,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -886,6 +887,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
num_reqs
=
num_reqs
,
num_reqs
=
num_reqs
,
num_actual_tokens
=
total_num_scheduled_tokens
,
num_actual_tokens
=
total_num_scheduled_tokens
,
max_query_len
=
max_num_scheduled_tokens
,
max_query_len
=
max_num_scheduled_tokens
,
max_seq_len
=
max_seq_len
,
block_table_tensor
=
blk_table_tensor
,
block_table_tensor
=
blk_table_tensor
,
slot_mapping
=
slot_mapping
,
slot_mapping
=
slot_mapping
,
causal
=
True
,
causal
=
True
,
...
@@ -2338,6 +2340,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -2338,6 +2340,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
num_reqs
=
num_reqs
,
num_reqs
=
num_reqs
,
num_actual_tokens
=
num_tokens
,
num_actual_tokens
=
num_tokens
,
max_query_len
=
max_query_len
,
max_query_len
=
max_query_len
,
max_seq_len
=
self
.
max_model_len
,
block_table_tensor
=
self
.
input_batch
.
block_table
[
block_table_tensor
=
self
.
input_batch
.
block_table
[
kv_cache_group_id
].
get_device_tensor
()[:
num_reqs
],
kv_cache_group_id
].
get_device_tensor
()[:
num_reqs
],
slot_mapping
=
self
.
input_batch
.
slot_mapping
=
self
.
input_batch
.
...
@@ -3343,6 +3346,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -3343,6 +3346,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
num_reqs
=
num_reqs
,
num_reqs
=
num_reqs
,
num_actual_tokens
=
total_num_scheduled_tokens
,
num_actual_tokens
=
total_num_scheduled_tokens
,
max_query_len
=
max_num_scheduled_tokens
,
max_query_len
=
max_num_scheduled_tokens
,
max_seq_len
=
self
.
seq_lens_cpu
[:
num_reqs
].
max
().
item
(),
block_table_tensor
=
dummy_block_table
,
block_table_tensor
=
dummy_block_table
,
slot_mapping
=
dummy_slot_mapping
,
slot_mapping
=
dummy_slot_mapping
,
causal
=
False
,
causal
=
False
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment