Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
24e6ad3f
Unverified
Commit
24e6ad3f
authored
Apr 30, 2025
by
Chen Zhang
Committed by
GitHub
Apr 29, 2025
Browse files
[V1] Remove num_input_tokens from attn_metadata (#17193)
Signed-off-by:
Chen Zhang
<
zhangch99@outlook.com
>
parent
2ef5d106
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
14 additions
and
21 deletions
+14
-21
vllm/forward_context.py
vllm/forward_context.py
+7
-9
vllm/v1/attention/backends/flash_attn.py
vllm/v1/attention/backends/flash_attn.py
+0
-3
vllm/v1/attention/backends/flashinfer.py
vllm/v1/attention/backends/flashinfer.py
+0
-3
vllm/v1/attention/backends/mla/common.py
vllm/v1/attention/backends/mla/common.py
+0
-3
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+3
-2
vllm/v1/worker/tpu_model_runner.py
vllm/v1/worker/tpu_model_runner.py
+4
-1
No files found.
vllm/forward_context.py
View file @
24e6ad3f
...
...
@@ -74,15 +74,13 @@ def set_forward_context(attn_metadata: Any,
if
vllm_config
.
parallel_config
.
data_parallel_size
>
1
:
dp_size
=
vllm_config
.
parallel_config
.
data_parallel_size
dp_rank
=
vllm_config
.
parallel_config
.
data_parallel_rank
if
attn_metadata
is
not
None
:
if
hasattr
(
attn_metadata
,
"num_prefill_tokens"
):
if
attn_metadata
is
not
None
and
hasattr
(
attn_metadata
,
"num_prefill_tokens"
):
# for v0 attention backends
batchsize
=
attn_metadata
.
num_prefill_tokens
+
\
attn_metadata
.
num_decode_tokens
else
:
# for v1 attention backends
batchsize
=
attn_metadata
.
num_input_tokens
else
:
# for v1 attention backends or no attn_metadata
batchsize
=
num_tokens
num_tokens_across_dp
=
[
0
]
*
dp_size
num_tokens_across_dp
[
dp_rank
]
=
batchsize
...
...
@@ -124,7 +122,7 @@ def set_forward_context(attn_metadata: Any,
attn_metadata
.
num_decode_tokens
else
:
# for v1 attention backends
batchsize
=
attn_metadata
.
num_input
_tokens
batchsize
=
num
_tokens
# we use synchronous scheduling right now,
# adding a sync point here should not affect
# scheduling of the next batch
...
...
vllm/v1/attention/backends/flash_attn.py
View file @
24e6ad3f
...
...
@@ -94,9 +94,6 @@ class FlashAttentionMetadata:
scheduler_metadata
:
Optional
[
torch
.
Tensor
]
=
None
prefix_scheduler_metadata
:
Optional
[
torch
.
Tensor
]
=
None
# For logging.
num_input_tokens
:
int
=
0
# Number of tokens including padding.
# for local attention
@
dataclass
class
LocalAttentionMetadata
:
...
...
vllm/v1/attention/backends/flashinfer.py
View file @
24e6ad3f
...
...
@@ -183,9 +183,6 @@ class FlashInferMetadata:
decode_wrapper
:
Optional
[
BatchDecodeWithPagedKVCacheWrapper
]
=
None
cascade_wrapper
:
Optional
[
MultiLevelCascadeAttentionWrapper
]
=
None
# For logging.
num_input_tokens
:
int
=
0
# Number of tokens including padding.
@
property
def
query_start_loc
(
self
):
# The GPUModelRunner expects to be able to access this property.
...
...
vllm/v1/attention/backends/mla/common.py
View file @
24e6ad3f
...
...
@@ -312,9 +312,6 @@ class MLACommonMetadata(Generic[D]):
num_decode_tokens
:
int
num_prefills
:
int
# For logging.
num_input_tokens
:
int
=
0
# Number of tokens including padding.
# The dimension of the attention heads
head_dim
:
Optional
[
int
]
=
None
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
24e6ad3f
...
...
@@ -1036,7 +1036,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
num_input_tokens
=
round_up
(
num_scheduled_tokens
,
tp_size
)
else
:
num_input_tokens
=
num_scheduled_tokens
attn_metadata
.
num_input_tokens
=
num_input_tokens
# _prepare_inputs may reorder the batch, so we must gather multi
# modal outputs after that to ensure the correct order
...
...
@@ -1088,7 +1087,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Run the decoder.
# Use persistent buffers for CUDA graphs.
with
set_forward_context
(
attn_metadata
,
self
.
vllm_config
):
with
set_forward_context
(
attn_metadata
,
self
.
vllm_config
,
num_tokens
=
num_input_tokens
):
output
=
self
.
model
(
input_ids
=
input_ids
,
positions
=
positions
,
...
...
vllm/v1/worker/tpu_model_runner.py
View file @
24e6ad3f
...
...
@@ -769,7 +769,10 @@ class TPUModelRunner:
xm
.
mark_step
()
num_reqs
=
self
.
input_batch
.
num_reqs
# Run the decoder
with
set_forward_context
(
attn_metadata
,
self
.
vllm_config
):
with
set_forward_context
(
attn_metadata
,
self
.
vllm_config
,
num_tokens
=
scheduler_output
.
total_num_scheduled_tokens
):
hidden_states
=
self
.
model
(
input_ids
=
input_ids
,
positions
=
self
.
position_ids
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment