Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
fe85a92e
Unverified
Commit
fe85a92e
authored
Apr 23, 2026
by
Nick Hill
Committed by
GitHub
Apr 24, 2026
Browse files
[Core] Avoid seq_lens_cpu GPU->CPU sync (#40654)
Signed-off-by:
Nick Hill
<
nickhill123@gmail.com
>
parent
62b1bbe4
Changes
19
Hide whitespace changes
Inline
Side-by-side
Showing
19 changed files
with
142 additions
and
26 deletions
+142
-26
tests/v1/attention/utils.py
tests/v1/attention/utils.py
+1
-0
tests/v1/spec_decode/test_tree_attention.py
tests/v1/spec_decode/test_tree_attention.py
+3
-1
vllm/model_executor/layers/attention/cross_attention.py
vllm/model_executor/layers/attention/cross_attention.py
+12
-4
vllm/model_executor/layers/attention/mla_attention.py
vllm/model_executor/layers/attention/mla_attention.py
+10
-5
vllm/v1/attention/backend.py
vllm/v1/attention/backend.py
+6
-0
vllm/v1/attention/backends/flex_attention.py
vllm/v1/attention/backends/flex_attention.py
+5
-4
vllm/v1/attention/backends/mla/flashmla_sparse.py
vllm/v1/attention/backends/mla/flashmla_sparse.py
+4
-1
vllm/v1/attention/backends/mla/indexer.py
vllm/v1/attention/backends/mla/indexer.py
+6
-2
vllm/v1/attention/backends/utils.py
vllm/v1/attention/backends/utils.py
+7
-1
vllm/v1/spec_decode/dflash.py
vllm/v1/spec_decode/dflash.py
+7
-0
vllm/v1/spec_decode/llm_base_proposer.py
vllm/v1/spec_decode/llm_base_proposer.py
+9
-1
vllm/v1/worker/gpu/attn_utils.py
vllm/v1/worker/gpu/attn_utils.py
+4
-0
vllm/v1/worker/gpu/input_batch.py
vllm/v1/worker/gpu/input_batch.py
+5
-0
vllm/v1/worker/gpu/model_runner.py
vllm/v1/worker/gpu/model_runner.py
+18
-0
vllm/v1/worker/gpu/model_states/default.py
vllm/v1/worker/gpu/model_states/default.py
+8
-1
vllm/v1/worker/gpu/model_states/whisper.py
vllm/v1/worker/gpu/model_states/whisper.py
+7
-1
vllm/v1/worker/gpu/states.py
vllm/v1/worker/gpu/states.py
+3
-0
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+2
-0
vllm/v1/worker/ubatch_utils.py
vllm/v1/worker/ubatch_utils.py
+25
-5
No files found.
tests/v1/attention/utils.py
View file @
fe85a92e
...
@@ -107,6 +107,7 @@ def create_common_attn_metadata(
...
@@ -107,6 +107,7 @@ def create_common_attn_metadata(
query_start_loc
=
query_start_loc
,
query_start_loc
=
query_start_loc
,
query_start_loc_cpu
=
query_start_loc_cpu
,
query_start_loc_cpu
=
query_start_loc_cpu
,
seq_lens
=
seq_lens
,
seq_lens
=
seq_lens
,
seq_lens_cpu_upper_bound
=
seq_lens_cpu
,
_seq_lens_cpu
=
seq_lens_cpu
,
_seq_lens_cpu
=
seq_lens_cpu
,
_num_computed_tokens_cpu
=
num_computed_tokens_cpu
,
_num_computed_tokens_cpu
=
num_computed_tokens_cpu
,
num_reqs
=
batch_spec
.
batch_size
,
num_reqs
=
batch_spec
.
batch_size
,
...
...
tests/v1/spec_decode/test_tree_attention.py
View file @
fe85a92e
...
@@ -241,11 +241,13 @@ def forward_attention(
...
@@ -241,11 +241,13 @@ def forward_attention(
)
)
kv_cache_spec
=
create_standard_kv_cache_spec
(
vllm_config
)
kv_cache_spec
=
create_standard_kv_cache_spec
(
vllm_config
)
builder
=
builder_cls
(
kv_cache_spec
,
[],
vllm_config
,
q
.
device
)
builder
=
builder_cls
(
kv_cache_spec
,
[],
vllm_config
,
q
.
device
)
seq_lens_cpu
=
seq_lens
.
cpu
()
common_attn_metadata
=
CommonAttentionMetadata
(
common_attn_metadata
=
CommonAttentionMetadata
(
query_start_loc
=
query_start_loc
,
query_start_loc
=
query_start_loc
,
query_start_loc_cpu
=
query_start_loc
.
cpu
(),
query_start_loc_cpu
=
query_start_loc
.
cpu
(),
seq_lens
=
seq_lens
,
seq_lens
=
seq_lens
,
_seq_lens_cpu
=
seq_lens
.
cpu
(),
seq_lens_cpu_upper_bound
=
seq_lens_cpu
,
_seq_lens_cpu
=
seq_lens_cpu
,
_num_computed_tokens_cpu
=
context_lens
.
cpu
(),
_num_computed_tokens_cpu
=
context_lens
.
cpu
(),
num_reqs
=
batch_size
,
num_reqs
=
batch_size
,
num_actual_tokens
=
num_actual_tokens
,
num_actual_tokens
=
num_actual_tokens
,
...
...
vllm/model_executor/layers/attention/cross_attention.py
View file @
fe85a92e
...
@@ -90,15 +90,23 @@ def create_cross_attention_backend(
...
@@ -90,15 +90,23 @@ def create_cross_attention_backend(
assert
new_metadata
.
encoder_seq_lens_cpu
is
not
None
assert
new_metadata
.
encoder_seq_lens_cpu
is
not
None
max_encoder_len
=
int
(
new_metadata
.
encoder_seq_lens_cpu
.
max
())
max_encoder_len
=
int
(
new_metadata
.
encoder_seq_lens_cpu
.
max
())
new_metadata
.
max_seq_len
=
max_encoder_len
new_metadata
.
max_seq_len
=
max_encoder_len
# Any computed tokens indicated decode step>1 (no chunked prefill)
# Any computed tokens indicates decode step>1 (no chunked prefill).
num_cache_decodes
=
(
# The upper bound is exact for this `> 0` test - prefill rows have
(
common_attn_metadata
.
num_computed_tokens_cpu
>
0
).
sum
().
item
()
# num_computed == 0 and decode rows have num_computed > 0.
query_lens_cpu
=
(
common_attn_metadata
.
query_start_loc_cpu
[
1
:]
-
common_attn_metadata
.
query_start_loc_cpu
[:
-
1
]
)
)
assert
common_attn_metadata
.
seq_lens_cpu_upper_bound
is
not
None
num_computed_tokens_cpu
=
(
common_attn_metadata
.
seq_lens_cpu_upper_bound
-
query_lens_cpu
)
num_cache_decodes
=
(
num_computed_tokens_cpu
>
0
).
sum
().
item
()
if
num_cache_decodes
>
0
:
if
num_cache_decodes
>
0
:
# CrossAttn KV cache has already been populated on first decoder step,
# CrossAttn KV cache has already been populated on first decoder step,
# skip slot_mapping calculation for requests that do not need
# skip slot_mapping calculation for requests that do not need
# reshape_and_cache.
# reshape_and_cache.
num_tokens
=
common_attn_metadata
.
num_computed_tokens_cpu
.
numpy
()
num_tokens
=
num_computed_tokens_cpu
.
numpy
()
new_metadata
.
encoder_seq_lens_cpu
=
np
.
where
(
new_metadata
.
encoder_seq_lens_cpu
=
np
.
where
(
num_tokens
>
0
,
0
,
new_metadata
.
encoder_seq_lens_cpu
num_tokens
>
0
,
0
,
new_metadata
.
encoder_seq_lens_cpu
)
)
...
...
vllm/model_executor/layers/attention/mla_attention.py
View file @
fe85a92e
...
@@ -1822,13 +1822,18 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
...
@@ -1822,13 +1822,18 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
prefill_metadata
=
None
prefill_metadata
=
None
if
num_prefills
>
0
:
if
num_prefills
>
0
:
num_computed_tokens_cpu
=
(
common_attn_metadata
.
compute_num_computed_tokens
().
cpu
()
)
reqs_start
=
num_decodes
# prefill_start
reqs_start
=
num_decodes
# prefill_start
context_lens_cpu
=
num_computed_tokens_cpu
[
reqs_start
:
num_reqs
]
# Upper bound is exact for prefill rows (no D2H sync).
seq_lens_cpu
=
common_attn_metadata
.
seq_lens_cpu_upper_bound
assert
seq_lens_cpu
is
not
None
prefill_query_lens_cpu
=
(
query_start_loc_cpu
[
reqs_start
+
1
:
num_reqs
+
1
]
-
query_start_loc_cpu
[
reqs_start
:
num_reqs
]
)
context_lens_cpu
=
(
seq_lens_cpu
[
reqs_start
:
num_reqs
]
-
prefill_query_lens_cpu
)
max_context_len_cpu
=
context_lens_cpu
.
max
().
item
()
max_context_len_cpu
=
context_lens_cpu
.
max
().
item
()
num_prefills_with_context_cpu
=
(
context_lens_cpu
>
0
).
sum
().
item
()
num_prefills_with_context_cpu
=
(
context_lens_cpu
>
0
).
sum
().
item
()
prefill_query_start_loc
=
(
prefill_query_start_loc
=
(
...
...
vllm/v1/attention/backend.py
View file @
fe85a92e
...
@@ -397,6 +397,12 @@ class CommonAttentionMetadata:
...
@@ -397,6 +397,12 @@ class CommonAttentionMetadata:
(num_computed_tokens < num_prompt_tokens). Used by some backends to
(num_computed_tokens < num_prompt_tokens). Used by some backends to
distinguish actual decodes from short extends."""
distinguish actual decodes from short extends."""
seq_lens_cpu_upper_bound
:
torch
.
Tensor
|
None
=
None
"""(batch_size,) CPU upper bound on seq_lens. Precise for prefill rows
and for all rows outside async spec decode; optimistic for async-spec
decode rows (assumes every draft was accepted). Not safe for kernels
that need exact per-row context lengths on decode rows."""
# WARNING: Deprecated fields. Will be removed in a future release (v0.15.0)
# WARNING: Deprecated fields. Will be removed in a future release (v0.15.0)
_seq_lens_cpu
:
torch
.
Tensor
|
None
=
None
_seq_lens_cpu
:
torch
.
Tensor
|
None
=
None
_num_computed_tokens_cpu
:
torch
.
Tensor
|
None
=
None
_num_computed_tokens_cpu
:
torch
.
Tensor
|
None
=
None
...
...
vllm/v1/attention/backends/flex_attention.py
View file @
fe85a92e
...
@@ -782,10 +782,11 @@ class FlexAttentionMetadataBuilder(AttentionMetadataBuilder[FlexAttentionMetadat
...
@@ -782,10 +782,11 @@ class FlexAttentionMetadataBuilder(AttentionMetadataBuilder[FlexAttentionMetadat
def
build_for_cudagraph_capture
(
def
build_for_cudagraph_capture
(
self
,
common_attn_metadata
:
CommonAttentionMetadata
self
,
common_attn_metadata
:
CommonAttentionMetadata
)
->
FlexAttentionMetadata
:
)
->
FlexAttentionMetadata
:
# Use actual max_seq_len instead of max_model_len to avoid
# Use actual max_seq_len (not max_model_len) to avoid torch.compile
# torch.compile recompilation during CUDA graph capture.
# recompilation during CUDA graph capture.
common_attn_metadata
.
max_seq_len
=
(
assert
common_attn_metadata
.
seq_lens_cpu_upper_bound
is
not
None
common_attn_metadata
.
seq_lens_cpu
.
max
().
item
()
common_attn_metadata
.
max_seq_len
=
int
(
common_attn_metadata
.
seq_lens_cpu_upper_bound
.
max
().
item
()
)
)
return
self
.
build
(
return
self
.
build
(
common_prefix_len
=
0
,
common_attn_metadata
=
common_attn_metadata
common_prefix_len
=
0
,
common_attn_metadata
=
common_attn_metadata
...
...
vllm/v1/attention/backends/mla/flashmla_sparse.py
View file @
fe85a92e
...
@@ -364,7 +364,10 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad
...
@@ -364,7 +364,10 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad
# For pure decode batches, prefill_request_id will be None
# For pure decode batches, prefill_request_id will be None
# For mixed batches, it will have -1 for decode and request_id for prefill
# For mixed batches, it will have -1 for decode and request_id for prefill
if
num_prefills
>
0
:
if
num_prefills
>
0
:
seq_lens_cpu
=
common_attn_metadata
.
seq_lens
.
cpu
()
# Upper bound is exact for prefill rows (the `[num_decodes:]`
# slice below), so no D2H sync is needed.
seq_lens_cpu
=
common_attn_metadata
.
seq_lens_cpu_upper_bound
assert
seq_lens_cpu
is
not
None
seq_lens
=
common_attn_metadata
.
seq_lens
seq_lens
=
common_attn_metadata
.
seq_lens
query_start_loc_cpu
=
common_attn_metadata
.
query_start_loc_cpu
query_start_loc_cpu
=
common_attn_metadata
.
query_start_loc_cpu
...
...
vllm/v1/attention/backends/mla/indexer.py
View file @
fe85a92e
...
@@ -554,8 +554,12 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
...
@@ -554,8 +554,12 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
query_start_loc_cpu
[
num_decodes
:
num_decodes
+
num_prefills
+
1
]
query_start_loc_cpu
[
num_decodes
:
num_decodes
+
num_prefills
+
1
]
)
)
max_logits_bytes
=
envs
.
VLLM_SPARSE_INDEXER_MAX_LOGITS_MB
*
1024
*
1024
max_logits_bytes
=
envs
.
VLLM_SPARSE_INDEXER_MAX_LOGITS_MB
*
1024
*
1024
# Upper bound is exact for prefill rows (the `[num_decodes:]`
# slice below).
assert
common_attn_metadata
.
seq_lens_cpu_upper_bound
is
not
None
seq_lens_cpu
=
common_attn_metadata
.
seq_lens_cpu_upper_bound
chunk_specs
=
split_indexer_prefill_chunks
(
chunk_specs
=
split_indexer_prefill_chunks
(
common_attn_metadata
.
seq_lens_cpu
[
num_decodes
:],
seq_lens_cpu
[
num_decodes
:],
prefill_query_lens_cpu
,
prefill_query_lens_cpu
,
self
.
max_prefill_buffer_size
,
self
.
max_prefill_buffer_size
,
max_logits_bytes
,
max_logits_bytes
,
...
@@ -566,7 +570,7 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
...
@@ -566,7 +570,7 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
req_slice
,
req_slice
,
query_slice
,
query_slice
,
query_start_loc_cpu
,
query_start_loc_cpu
,
common_attn_metadata
.
seq_lens_cpu
,
seq_lens_cpu
,
common_attn_metadata
.
block_table_tensor
,
common_attn_metadata
.
block_table_tensor
,
skip_kv_gather
=
query_slice
.
start
>
0
,
skip_kv_gather
=
query_slice
.
start
>
0
,
)
)
...
...
vllm/v1/attention/backends/utils.py
View file @
fe85a92e
...
@@ -356,6 +356,7 @@ def make_local_attention_virtual_batches(
...
@@ -356,6 +356,7 @@ def make_local_attention_virtual_batches(
block_table_tensor
=
block_table_local
,
block_table_tensor
=
block_table_local
,
slot_mapping
=
common_attn_metadata
.
slot_mapping
,
slot_mapping
=
common_attn_metadata
.
slot_mapping
,
causal
=
True
,
causal
=
True
,
seq_lens_cpu_upper_bound
=
seq_lens_cpu
,
_seq_lens_cpu
=
seq_lens_cpu
,
_seq_lens_cpu
=
seq_lens_cpu
,
_num_computed_tokens_cpu
=
torch
.
from_numpy
(
num_computed_tokens_local
),
_num_computed_tokens_cpu
=
torch
.
from_numpy
(
num_computed_tokens_local
),
),
make_block_table
),
make_block_table
...
@@ -414,6 +415,7 @@ def make_kv_sharing_fast_prefill_common_attn_metadata(
...
@@ -414,6 +415,7 @@ def make_kv_sharing_fast_prefill_common_attn_metadata(
block_table_tensor
=
common_attn_metadata
.
block_table_tensor
,
block_table_tensor
=
common_attn_metadata
.
block_table_tensor
,
slot_mapping
=
common_attn_metadata
.
slot_mapping
,
slot_mapping
=
common_attn_metadata
.
slot_mapping
,
causal
=
True
,
causal
=
True
,
seq_lens_cpu_upper_bound
=
common_attn_metadata
.
seq_lens_cpu_upper_bound
,
_seq_lens_cpu
=
common_attn_metadata
.
_seq_lens_cpu
,
_seq_lens_cpu
=
common_attn_metadata
.
_seq_lens_cpu
,
_num_computed_tokens_cpu
=
common_attn_metadata
.
_num_computed_tokens_cpu
,
_num_computed_tokens_cpu
=
common_attn_metadata
.
_num_computed_tokens_cpu
,
)
)
...
@@ -445,7 +447,11 @@ def split_decodes_prefills_and_extends(
...
@@ -445,7 +447,11 @@ def split_decodes_prefills_and_extends(
num_reqs
=
common_attn_metadata
.
num_reqs
num_reqs
=
common_attn_metadata
.
num_reqs
num_tokens
=
common_attn_metadata
.
num_actual_tokens
num_tokens
=
common_attn_metadata
.
num_actual_tokens
query_start_loc
=
common_attn_metadata
.
query_start_loc_cpu
query_start_loc
=
common_attn_metadata
.
query_start_loc_cpu
seq_lens
=
common_attn_metadata
.
seq_lens_cpu
# Upper bound is exact for prefill rows; decode rows still satisfy
# seq_len > query_len under the optimistic bound, so `seq_lens ==
# query_lens` identifies prefills correctly either way.
assert
common_attn_metadata
.
seq_lens_cpu_upper_bound
is
not
None
seq_lens
=
common_attn_metadata
.
seq_lens_cpu_upper_bound
if
max_query_len
<=
decode_threshold
:
if
max_query_len
<=
decode_threshold
:
return
num_reqs
,
0
,
0
,
num_tokens
,
0
,
0
return
num_reqs
,
0
,
0
,
num_tokens
,
0
,
0
...
...
vllm/v1/spec_decode/dflash.py
View file @
fe85a92e
...
@@ -151,6 +151,12 @@ class DFlashProposer(SpecDecodeBaseProposer):
...
@@ -151,6 +151,12 @@ class DFlashProposer(SpecDecodeBaseProposer):
if
has_num_rejected
:
if
has_num_rejected
:
effective_seq_lens
=
effective_seq_lens
-
num_rejected_tokens_gpu
effective_seq_lens
=
effective_seq_lens
-
num_rejected_tokens_gpu
# Skip num_rejected_tokens (GPU-only); overestimating is fine here.
new_seq_lens_cpu_upper_bound
=
(
cad
.
seq_lens_cpu_upper_bound
+
num_query_per_req
if
cad
.
seq_lens_cpu_upper_bound
is
not
None
else
None
)
new_cad
=
CommonAttentionMetadata
(
new_cad
=
CommonAttentionMetadata
(
query_start_loc
=
new_query_start_loc
,
query_start_loc
=
new_query_start_loc
,
seq_lens
=
effective_seq_lens
+
num_query_per_req
,
seq_lens
=
effective_seq_lens
+
num_query_per_req
,
...
@@ -160,6 +166,7 @@ class DFlashProposer(SpecDecodeBaseProposer):
...
@@ -160,6 +166,7 @@ class DFlashProposer(SpecDecodeBaseProposer):
),
),
_seq_lens_cpu
=
None
,
_seq_lens_cpu
=
None
,
_num_computed_tokens_cpu
=
None
,
_num_computed_tokens_cpu
=
None
,
seq_lens_cpu_upper_bound
=
new_seq_lens_cpu_upper_bound
,
num_reqs
=
cad
.
num_reqs
,
num_reqs
=
cad
.
num_reqs
,
num_actual_tokens
=
num_query_total
,
num_actual_tokens
=
num_query_total
,
max_query_len
=
num_query_per_req
,
max_query_len
=
num_query_per_req
,
...
...
vllm/v1/spec_decode/llm_base_proposer.py
View file @
fe85a92e
...
@@ -593,6 +593,8 @@ class SpecDecodeBaseProposer:
...
@@ -593,6 +593,8 @@ class SpecDecodeBaseProposer:
common_attn_metadata
.
_seq_lens_cpu
+=
1
common_attn_metadata
.
_seq_lens_cpu
+=
1
if
common_attn_metadata
.
_num_computed_tokens_cpu
is
not
None
:
if
common_attn_metadata
.
_num_computed_tokens_cpu
is
not
None
:
common_attn_metadata
.
_num_computed_tokens_cpu
+=
1
common_attn_metadata
.
_num_computed_tokens_cpu
+=
1
if
common_attn_metadata
.
seq_lens_cpu_upper_bound
is
not
None
:
common_attn_metadata
.
seq_lens_cpu_upper_bound
+=
1
# Rebuild attention metadata
# Rebuild attention metadata
_
,
per_layer_attn_metadata
=
self
.
build_per_group_and_layer_attn_metadata
(
_
,
per_layer_attn_metadata
=
self
.
build_per_group_and_layer_attn_metadata
(
...
@@ -959,6 +961,7 @@ class SpecDecodeBaseProposer:
...
@@ -959,6 +961,7 @@ class SpecDecodeBaseProposer:
query_start_loc_cpu
=
query_start_loc_cpu
,
query_start_loc_cpu
=
query_start_loc_cpu
,
_seq_lens_cpu
=
common_attn_metadata
.
_seq_lens_cpu
,
_seq_lens_cpu
=
common_attn_metadata
.
_seq_lens_cpu
,
_num_computed_tokens_cpu
=
common_attn_metadata
.
_num_computed_tokens_cpu
,
_num_computed_tokens_cpu
=
common_attn_metadata
.
_num_computed_tokens_cpu
,
seq_lens_cpu_upper_bound
=
common_attn_metadata
.
seq_lens_cpu_upper_bound
,
num_reqs
=
common_attn_metadata
.
num_reqs
,
num_reqs
=
common_attn_metadata
.
num_reqs
,
num_actual_tokens
=
total_num_tokens
,
num_actual_tokens
=
total_num_tokens
,
max_query_len
=
new_query_len_per_req
.
max
().
item
(),
max_query_len
=
new_query_len_per_req
.
max
().
item
(),
...
@@ -1183,7 +1186,11 @@ class SpecDecodeBaseProposer:
...
@@ -1183,7 +1186,11 @@ class SpecDecodeBaseProposer:
device
=
common_attn_metadata
.
query_start_loc
.
device
device
=
common_attn_metadata
.
query_start_loc
.
device
query_start_loc_cpu
=
common_attn_metadata
.
query_start_loc_cpu
query_start_loc_cpu
=
common_attn_metadata
.
query_start_loc_cpu
new_seq_lens_cpu
=
common_attn_metadata
.
seq_lens_cpu
-
num_rejected_tokens
# upper_bound - rejected = actual post-rejection seq_lens (no D2H sync).
assert
common_attn_metadata
.
seq_lens_cpu_upper_bound
is
not
None
new_seq_lens_cpu
=
(
common_attn_metadata
.
seq_lens_cpu_upper_bound
-
num_rejected_tokens
)
# [0, q1, q1 + q2, q1 + q2 + q3] -> [q1, q2, q3]
# [0, q1, q1 + q2, q1 + q2 + q3] -> [q1, q2, q3]
new_query_len_per_req
=
query_start_loc_cpu
[
1
:]
-
query_start_loc_cpu
[:
-
1
]
new_query_len_per_req
=
query_start_loc_cpu
[
1
:]
-
query_start_loc_cpu
[:
-
1
]
...
@@ -1237,6 +1244,7 @@ class SpecDecodeBaseProposer:
...
@@ -1237,6 +1244,7 @@ class SpecDecodeBaseProposer:
query_start_loc_cpu
=
new_query_start_loc_cpu
,
query_start_loc_cpu
=
new_query_start_loc_cpu
,
_seq_lens_cpu
=
new_seq_lens_cpu
,
_seq_lens_cpu
=
new_seq_lens_cpu
,
_num_computed_tokens_cpu
=
common_attn_metadata
.
_num_computed_tokens_cpu
,
_num_computed_tokens_cpu
=
common_attn_metadata
.
_num_computed_tokens_cpu
,
seq_lens_cpu_upper_bound
=
new_seq_lens_cpu
,
num_reqs
=
common_attn_metadata
.
num_reqs
,
num_reqs
=
common_attn_metadata
.
num_reqs
,
num_actual_tokens
=
total_num_tokens
,
num_actual_tokens
=
total_num_tokens
,
max_query_len
=
new_query_len_per_req
.
max
().
item
(),
max_query_len
=
new_query_len_per_req
.
max
().
item
(),
...
...
vllm/v1/worker/gpu/attn_utils.py
View file @
fe85a92e
...
@@ -227,12 +227,15 @@ def build_attn_metadata(
...
@@ -227,12 +227,15 @@ def build_attn_metadata(
block_tables
:
Sequence
[
torch
.
Tensor
],
block_tables
:
Sequence
[
torch
.
Tensor
],
slot_mappings
:
torch
.
Tensor
,
slot_mappings
:
torch
.
Tensor
,
kv_cache_config
:
KVCacheConfig
,
kv_cache_config
:
KVCacheConfig
,
seq_lens_cpu_upper_bound
:
torch
.
Tensor
|
None
=
None
,
dcp_local_seq_lens
:
torch
.
Tensor
|
None
=
None
,
dcp_local_seq_lens
:
torch
.
Tensor
|
None
=
None
,
encoder_seq_lens
:
dict
[
int
,
tuple
[
torch
.
Tensor
,
np
.
ndarray
]]
|
None
=
None
,
encoder_seq_lens
:
dict
[
int
,
tuple
[
torch
.
Tensor
,
np
.
ndarray
]]
|
None
=
None
,
)
->
dict
[
str
,
Any
]:
)
->
dict
[
str
,
Any
]:
seq_lens
=
seq_lens
[:
num_reqs
]
seq_lens
=
seq_lens
[:
num_reqs
]
if
dcp_local_seq_lens
is
not
None
:
if
dcp_local_seq_lens
is
not
None
:
dcp_local_seq_lens
=
dcp_local_seq_lens
[:
num_reqs
]
dcp_local_seq_lens
=
dcp_local_seq_lens
[:
num_reqs
]
if
seq_lens_cpu_upper_bound
is
not
None
:
seq_lens_cpu_upper_bound
=
seq_lens_cpu_upper_bound
[:
num_reqs
]
attn_metadata
:
dict
[
str
,
Any
]
=
{}
attn_metadata
:
dict
[
str
,
Any
]
=
{}
num_kv_cache_groups
=
len
(
kv_cache_config
.
kv_cache_groups
)
num_kv_cache_groups
=
len
(
kv_cache_config
.
kv_cache_groups
)
...
@@ -244,6 +247,7 @@ def build_attn_metadata(
...
@@ -244,6 +247,7 @@ def build_attn_metadata(
query_start_loc
=
query_start_loc_gpu
,
query_start_loc
=
query_start_loc_gpu
,
query_start_loc_cpu
=
query_start_loc_cpu
,
query_start_loc_cpu
=
query_start_loc_cpu
,
seq_lens
=
seq_lens
,
seq_lens
=
seq_lens
,
seq_lens_cpu_upper_bound
=
seq_lens_cpu_upper_bound
,
max_seq_len
=
max_seq_len
,
max_seq_len
=
max_seq_len
,
num_reqs
=
num_reqs
,
num_reqs
=
num_reqs
,
num_actual_tokens
=
num_tokens
,
num_actual_tokens
=
num_tokens
,
...
...
vllm/v1/worker/gpu/input_batch.py
View file @
fe85a92e
...
@@ -60,6 +60,8 @@ class InputBatch:
...
@@ -60,6 +60,8 @@ class InputBatch:
query_start_loc_np
:
np
.
ndarray
query_start_loc_np
:
np
.
ndarray
# [num_reqs]
# [num_reqs]
seq_lens
:
torch
.
Tensor
seq_lens
:
torch
.
Tensor
# [num_reqs] CPU upper bound on seq_lens (see CommonAttentionMetadata).
seq_lens_cpu_upper_bound
:
torch
.
Tensor
# [num_reqs]
# [num_reqs]
dcp_local_seq_lens
:
torch
.
Tensor
|
None
dcp_local_seq_lens
:
torch
.
Tensor
|
None
...
@@ -121,6 +123,8 @@ class InputBatch:
...
@@ -121,6 +123,8 @@ class InputBatch:
logits_indices
=
query_start_loc
[
1
:]
-
1
logits_indices
=
query_start_loc
[
1
:]
-
1
cu_num_logits
=
torch
.
arange
(
num_reqs
+
1
,
device
=
device
,
dtype
=
torch
.
int32
)
cu_num_logits
=
torch
.
arange
(
num_reqs
+
1
,
device
=
device
,
dtype
=
torch
.
int32
)
cu_num_logits_np
=
np
.
arange
(
num_reqs
+
1
,
dtype
=
np
.
int32
)
cu_num_logits_np
=
np
.
arange
(
num_reqs
+
1
,
dtype
=
np
.
int32
)
# Dummy: seq_len == query_len (fresh-prefill shape).
seq_lens_cpu_upper_bound
=
torch
.
from_numpy
(
num_scheduled_tokens
.
copy
())
return
cls
(
return
cls
(
req_ids
=
req_ids
,
req_ids
=
req_ids
,
num_reqs
=
num_reqs
,
num_reqs
=
num_reqs
,
...
@@ -136,6 +140,7 @@ class InputBatch:
...
@@ -136,6 +140,7 @@ class InputBatch:
query_start_loc
=
query_start_loc
,
query_start_loc
=
query_start_loc
,
query_start_loc_np
=
query_start_loc_np
,
query_start_loc_np
=
query_start_loc_np
,
seq_lens
=
seq_lens
,
seq_lens
=
seq_lens
,
seq_lens_cpu_upper_bound
=
seq_lens_cpu_upper_bound
,
dcp_local_seq_lens
=
None
,
dcp_local_seq_lens
=
None
,
input_ids
=
input_ids
,
input_ids
=
input_ids
,
positions
=
positions
,
positions
=
positions
,
...
...
vllm/v1/worker/gpu/model_runner.py
View file @
fe85a92e
...
@@ -799,6 +799,15 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -799,6 +799,15 @@ class GPUModelRunner(LoRAModelRunnerMixin):
total_num_logits
,
total_num_logits
,
)
)
# CPU upper bound on seq_lens; padded entries left at zero.
seq_lens_cpu_upper_bound_np
=
np
.
zeros
(
num_reqs_padded
,
dtype
=
np
.
int32
)
np
.
add
(
self
.
req_states
.
num_computed_tokens_np
[
idx_mapping_np
],
num_scheduled_tokens
,
out
=
seq_lens_cpu_upper_bound_np
[:
num_reqs
],
)
seq_lens_cpu_upper_bound
=
torch
.
from_numpy
(
seq_lens_cpu_upper_bound_np
)
return
InputBatch
(
return
InputBatch
(
req_ids
=
req_ids
,
req_ids
=
req_ids
,
num_reqs
=
num_reqs
,
num_reqs
=
num_reqs
,
...
@@ -814,6 +823,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -814,6 +823,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
query_start_loc
=
query_start_loc
,
query_start_loc
=
query_start_loc
,
query_start_loc_np
=
query_start_loc_np
,
query_start_loc_np
=
query_start_loc_np
,
seq_lens
=
seq_lens
,
seq_lens
=
seq_lens
,
seq_lens_cpu_upper_bound
=
seq_lens_cpu_upper_bound
,
dcp_local_seq_lens
=
dcp_local_seq_lens
,
dcp_local_seq_lens
=
dcp_local_seq_lens
,
input_ids
=
self
.
input_buffers
.
input_ids
[:
num_tokens_after_padding
],
input_ids
=
self
.
input_buffers
.
input_ids
[:
num_tokens_after_padding
],
positions
=
self
.
input_buffers
.
positions
[:
num_tokens_after_padding
],
positions
=
self
.
input_buffers
.
positions
[:
num_tokens_after_padding
],
...
@@ -927,6 +937,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -927,6 +937,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
np
.
minimum
(
np
.
minimum
(
computed_prefill
,
self
.
req_states
.
prefill_len
.
np
,
out
=
computed_prefill
computed_prefill
,
self
.
req_states
.
prefill_len
.
np
,
out
=
computed_prefill
)
)
# Advance the CPU mirror optimistically (assume all scheduled accepted).
self
.
req_states
.
num_computed_tokens_np
[
idx_mapping_np
]
+=
(
input_batch
.
num_scheduled_tokens
)
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
execute_model
(
def
execute_model
(
...
@@ -1297,6 +1311,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1297,6 +1311,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
np
.
minimum
(
np
.
minimum
(
computed_prefill
,
self
.
req_states
.
prefill_len
.
np
,
out
=
computed_prefill
computed_prefill
,
self
.
req_states
.
prefill_len
.
np
,
out
=
computed_prefill
)
)
# Advance the CPU mirror optimistically (assume all scheduled accepted).
self
.
req_states
.
num_computed_tokens_np
[
idx_mapping_np
]
+=
(
input_batch
.
num_scheduled_tokens
)
########### EPLB methods start ###########
########### EPLB methods start ###########
@
property
@
property
...
...
vllm/v1/worker/gpu/model_states/default.py
View file @
fe85a92e
...
@@ -173,6 +173,12 @@ class DefaultModelState(ModelState):
...
@@ -173,6 +173,12 @@ class DefaultModelState(ModelState):
num_tokens
=
input_batch
.
num_tokens
num_tokens
=
input_batch
.
num_tokens
query_start_loc_cpu
=
torch
.
from_numpy
(
input_batch
.
query_start_loc_np
)
query_start_loc_cpu
=
torch
.
from_numpy
(
input_batch
.
query_start_loc_np
)
max_query_len
=
input_batch
.
num_scheduled_tokens
.
max
().
item
()
max_query_len
=
input_batch
.
num_scheduled_tokens
.
max
().
item
()
seq_lens_cpu_upper_bound
=
input_batch
.
seq_lens_cpu_upper_bound
if
for_capture
:
# Capture with worst-case max_seq_len so the graph is valid at any replay.
max_seq_len
=
self
.
max_model_len
else
:
max_seq_len
=
int
(
seq_lens_cpu_upper_bound
[:
num_reqs
].
max
().
item
())
attn_metadata
=
build_attn_metadata
(
attn_metadata
=
build_attn_metadata
(
attn_groups
=
attn_groups
,
attn_groups
=
attn_groups
,
num_reqs
=
num_reqs
,
num_reqs
=
num_reqs
,
...
@@ -181,10 +187,11 @@ class DefaultModelState(ModelState):
...
@@ -181,10 +187,11 @@ class DefaultModelState(ModelState):
query_start_loc_cpu
=
query_start_loc_cpu
,
query_start_loc_cpu
=
query_start_loc_cpu
,
max_query_len
=
max_query_len
,
max_query_len
=
max_query_len
,
seq_lens
=
input_batch
.
seq_lens
,
seq_lens
=
input_batch
.
seq_lens
,
max_seq_len
=
self
.
max_model
_len
,
max_seq_len
=
max_seq
_len
,
block_tables
=
block_tables
,
block_tables
=
block_tables
,
slot_mappings
=
slot_mappings
,
slot_mappings
=
slot_mappings
,
kv_cache_config
=
kv_cache_config
,
kv_cache_config
=
kv_cache_config
,
seq_lens_cpu_upper_bound
=
seq_lens_cpu_upper_bound
,
dcp_local_seq_lens
=
input_batch
.
dcp_local_seq_lens
,
dcp_local_seq_lens
=
input_batch
.
dcp_local_seq_lens
,
)
)
return
attn_metadata
return
attn_metadata
vllm/v1/worker/gpu/model_states/whisper.py
View file @
fe85a92e
...
@@ -117,6 +117,11 @@ class WhisperModelState(ModelState):
...
@@ -117,6 +117,11 @@ class WhisperModelState(ModelState):
query_start_loc_cpu
=
torch
.
from_numpy
(
input_batch
.
query_start_loc_np
)
query_start_loc_cpu
=
torch
.
from_numpy
(
input_batch
.
query_start_loc_np
)
max_query_len
=
input_batch
.
num_scheduled_tokens
.
max
().
item
()
max_query_len
=
input_batch
.
num_scheduled_tokens
.
max
().
item
()
seq_lens_cpu_upper_bound
=
input_batch
.
seq_lens_cpu_upper_bound
if
for_capture
:
max_seq_len
=
self
.
max_model_len
else
:
max_seq_len
=
int
(
seq_lens_cpu_upper_bound
[:
num_reqs
].
max
().
item
())
attn_metadata
=
build_attn_metadata
(
attn_metadata
=
build_attn_metadata
(
attn_groups
=
attn_groups
,
attn_groups
=
attn_groups
,
num_reqs
=
num_reqs
,
num_reqs
=
num_reqs
,
...
@@ -125,10 +130,11 @@ class WhisperModelState(ModelState):
...
@@ -125,10 +130,11 @@ class WhisperModelState(ModelState):
query_start_loc_cpu
=
query_start_loc_cpu
,
query_start_loc_cpu
=
query_start_loc_cpu
,
max_query_len
=
max_query_len
,
max_query_len
=
max_query_len
,
seq_lens
=
input_batch
.
seq_lens
,
seq_lens
=
input_batch
.
seq_lens
,
max_seq_len
=
self
.
max_model
_len
,
max_seq_len
=
max_seq
_len
,
block_tables
=
block_tables
,
block_tables
=
block_tables
,
slot_mappings
=
slot_mappings
,
slot_mappings
=
slot_mappings
,
kv_cache_config
=
kv_cache_config
,
kv_cache_config
=
kv_cache_config
,
seq_lens_cpu_upper_bound
=
seq_lens_cpu_upper_bound
,
dcp_local_seq_lens
=
input_batch
.
dcp_local_seq_lens
,
dcp_local_seq_lens
=
input_batch
.
dcp_local_seq_lens
,
encoder_seq_lens
=
encoder_seq_lens
,
encoder_seq_lens
=
encoder_seq_lens
,
)
)
...
...
vllm/v1/worker/gpu/states.py
View file @
fe85a92e
...
@@ -57,6 +57,8 @@ class RequestState:
...
@@ -57,6 +57,8 @@ class RequestState:
self
.
num_computed_tokens
=
StagedWriteTensor
(
self
.
num_computed_tokens
=
StagedWriteTensor
(
self
.
max_num_reqs
,
dtype
=
torch
.
int32
,
device
=
device
self
.
max_num_reqs
,
dtype
=
torch
.
int32
,
device
=
device
)
)
# Optimistic CPU mirror of num_computed_tokens (upper bound on GPU value).
self
.
num_computed_tokens_np
=
np
.
zeros
(
self
.
max_num_reqs
,
dtype
=
np
.
int32
)
# Last sampled tokens.
# Last sampled tokens.
self
.
last_sampled_tokens
=
torch
.
zeros
(
self
.
last_sampled_tokens
=
torch
.
zeros
(
...
@@ -100,6 +102,7 @@ class RequestState:
...
@@ -100,6 +102,7 @@ class RequestState:
self
.
total_len
.
stage_write_elem
(
req_idx
,
prefill_len
)
self
.
total_len
.
stage_write_elem
(
req_idx
,
prefill_len
)
self
.
all_token_ids
.
stage_write
(
req_idx
,
0
,
all_token_ids
)
self
.
all_token_ids
.
stage_write
(
req_idx
,
0
,
all_token_ids
)
self
.
num_computed_prefill_tokens
[
req_idx
]
=
num_computed_tokens
self
.
num_computed_prefill_tokens
[
req_idx
]
=
num_computed_tokens
self
.
num_computed_tokens_np
[
req_idx
]
=
num_computed_tokens
self
.
num_computed_tokens
.
stage_write_elem
(
req_idx
,
num_computed_tokens
)
self
.
num_computed_tokens
.
stage_write_elem
(
req_idx
,
num_computed_tokens
)
if
num_computed_tokens
>
0
and
num_computed_tokens
<=
prefill_len
:
if
num_computed_tokens
>
0
and
num_computed_tokens
<=
prefill_len
:
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
fe85a92e
...
@@ -2155,6 +2155,7 @@ class GPUModelRunner(
...
@@ -2155,6 +2155,7 @@ class GPUModelRunner(
:
num_reqs_padded
:
num_reqs_padded
]
]
seq_lens_cpu
=
self
.
optimistic_seq_lens_cpu
[:
num_reqs_padded
]
seq_lens_cpu
=
self
.
optimistic_seq_lens_cpu
[:
num_reqs_padded
]
seq_lens_cpu_upper_bound
=
seq_lens_cpu
# is_prefilling: True if request is still in prefill phase.
# is_prefilling: True if request is still in prefill phase.
# Used by mamba backends to distinguish actual decodes from
# Used by mamba backends to distinguish actual decodes from
...
@@ -2172,6 +2173,7 @@ class GPUModelRunner(
...
@@ -2172,6 +2173,7 @@ class GPUModelRunner(
seq_lens
=
self
.
seq_lens
[:
num_reqs_padded
],
seq_lens
=
self
.
seq_lens
[:
num_reqs_padded
],
_seq_lens_cpu
=
seq_lens_cpu
,
_seq_lens_cpu
=
seq_lens_cpu
,
_num_computed_tokens_cpu
=
num_computed_tokens_cpu
,
_num_computed_tokens_cpu
=
num_computed_tokens_cpu
,
seq_lens_cpu_upper_bound
=
seq_lens_cpu_upper_bound
,
num_reqs
=
num_reqs_padded
,
num_reqs
=
num_reqs_padded
,
num_actual_tokens
=
num_tokens_padded
,
num_actual_tokens
=
num_tokens_padded
,
max_query_len
=
max_query_len
,
max_query_len
=
max_query_len
,
...
...
vllm/v1/worker/ubatch_utils.py
View file @
fe85a92e
...
@@ -177,7 +177,22 @@ def _make_metadata_with_slice(
...
@@ -177,7 +177,22 @@ def _make_metadata_with_slice(
query_start_loc
[
1
:]
-=
tokens_skipped
query_start_loc
[
1
:]
-=
tokens_skipped
query_start_loc_cpu
[
1
:]
-=
tokens_skipped
query_start_loc_cpu
[
1
:]
-=
tokens_skipped
seq_lens
=
attn_metadata
.
seq_lens
[
request_slice
]
seq_lens
=
attn_metadata
.
seq_lens
[
request_slice
]
seq_lens_cpu
=
attn_metadata
.
seq_lens_cpu
[
request_slice
]
# Read raw fields to avoid triggering the deprecated D2H-syncing properties.
seq_lens_cpu
=
(
attn_metadata
.
_seq_lens_cpu
[
request_slice
]
if
attn_metadata
.
_seq_lens_cpu
is
not
None
else
None
)
seq_lens_cpu_upper_bound
=
(
attn_metadata
.
seq_lens_cpu_upper_bound
[
request_slice
]
if
attn_metadata
.
seq_lens_cpu_upper_bound
is
not
None
else
None
)
num_computed_tokens_cpu
=
(
attn_metadata
.
_num_computed_tokens_cpu
[
request_slice
]
if
attn_metadata
.
_num_computed_tokens_cpu
is
not
None
else
None
)
if
splits_last_request
:
if
splits_last_request
:
# NOTE: We use start_locs (the original query_start_loc_cpu) to calculate
# NOTE: We use start_locs (the original query_start_loc_cpu) to calculate
...
@@ -190,12 +205,16 @@ def _make_metadata_with_slice(
...
@@ -190,12 +205,16 @@ def _make_metadata_with_slice(
# Make sure we don't modify the seq_lens tensors
# Make sure we don't modify the seq_lens tensors
# (not cudagraph compatible)
# (not cudagraph compatible)
seq_lens
=
seq_lens
.
clone
()
seq_lens
=
seq_lens
.
clone
()
seq_lens_cpu
=
seq_lens_cpu
.
clone
()
seq_lens
[
-
1
]
-=
tokens_skipped
seq_lens
[
-
1
]
-=
tokens_skipped
seq_lens_cpu
[
-
1
]
-=
tokens_skipped
if
seq_lens_cpu
is
not
None
:
seq_lens_cpu
=
seq_lens_cpu
.
clone
()
seq_lens_cpu
[
-
1
]
-=
tokens_skipped
if
seq_lens_cpu_upper_bound
is
not
None
:
seq_lens_cpu_upper_bound
=
seq_lens_cpu_upper_bound
.
clone
()
seq_lens_cpu_upper_bound
[
-
1
]
-=
tokens_skipped
max_seq_len
=
int
(
seq_lens_cpu
.
max
())
assert
seq_lens_cpu_upper_bound
is
not
None
num_computed_tokens_cpu
=
attn_metadata
.
num_computed_tokens_cpu
[
request_slice
]
max_seq_len
=
int
(
seq_lens_cpu_upper_bound
.
max
())
num_requests
=
request_slice
.
stop
-
request_slice
.
start
num_requests
=
request_slice
.
stop
-
request_slice
.
start
num_actual_tokens
=
token_slice
.
stop
-
token_slice
.
start
num_actual_tokens
=
token_slice
.
stop
-
token_slice
.
start
...
@@ -221,6 +240,7 @@ def _make_metadata_with_slice(
...
@@ -221,6 +240,7 @@ def _make_metadata_with_slice(
max_seq_len
=
max_seq_len
,
max_seq_len
=
max_seq_len
,
block_table_tensor
=
block_table_tensor
,
block_table_tensor
=
block_table_tensor
,
slot_mapping
=
slot_mapping
,
slot_mapping
=
slot_mapping
,
seq_lens_cpu_upper_bound
=
seq_lens_cpu_upper_bound
,
_seq_lens_cpu
=
seq_lens_cpu
,
_seq_lens_cpu
=
seq_lens_cpu
,
_num_computed_tokens_cpu
=
num_computed_tokens_cpu
,
_num_computed_tokens_cpu
=
num_computed_tokens_cpu
,
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment