Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
4874e3e0
Commit
4874e3e0
authored
Aug 13, 2025
by
zhuwenwen
Browse files
[feat]优化mtp/eagle的计算逻辑,减少第1层并行解码的计算重复(num_accepted_tokens_tensor修改暂未合入)
parent
295dfac8
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
38 additions
and
16 deletions
+38
-16
vllm/v1/attention/backends/mla/common.py
vllm/v1/attention/backends/mla/common.py
+15
-9
vllm/v1/attention/backends/utils.py
vllm/v1/attention/backends/utils.py
+1
-2
vllm/v1/spec_decode/eagle.py
vllm/v1/spec_decode/eagle.py
+8
-3
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+14
-2
No files found.
vllm/v1/attention/backends/mla/common.py
View file @
4874e3e0
...
@@ -754,12 +754,8 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
...
@@ -754,12 +754,8 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
decode_metadata
=
None
decode_metadata
=
None
if
num_decodes
>
0
:
if
num_decodes
>
0
:
if
self
.
use_spec_decode
:
if
self
.
use_spec_decode
and
not
common_attn_metadata
.
spec_layer_decoding
:
query_lens
=
self
.
num_scheduled_tokens_np
[:
num_decodes
]
query_lens
=
self
.
num_scheduled_tokens_np
[:
num_decodes
]
if
common_attn_metadata
.
num_rejected_tokens
is
not
None
:
num_rejected_tokens
=
common_attn_metadata
.
num_rejected_tokens
[:
num_decodes
]
query_lens
=
query_lens
-
np
.
array
(
num_rejected_tokens
,
dtype
=
np
.
int32
)
self
.
_num_decode_tokens
-=
sum
(
num_rejected_tokens
)
cu_num_blocks
=
np
.
cumsum
(
query_lens
)
cu_num_blocks
=
np
.
cumsum
(
query_lens
)
virtual_batches
=
cu_num_blocks
[
-
1
]
virtual_batches
=
cu_num_blocks
[
-
1
]
block_offsets
=
np
.
repeat
(
cu_num_blocks
-
query_lens
,
query_lens
)
block_offsets
=
np
.
repeat
(
cu_num_blocks
-
query_lens
,
query_lens
)
...
@@ -789,10 +785,20 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
...
@@ -789,10 +785,20 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
block_table_tensor
=
decode_block_table_tensor
,
block_table_tensor
=
decode_block_table_tensor
,
seq_lens
=
decode_seq_lens
,
seq_lens
=
decode_seq_lens
,
)
)
else
:
self
.
_num_decode_tokens
=
num_decodes
if
self
.
use_spec_decode
and
self
.
spec_decode_block_table_tensor
is
not
None
:
self
.
spec_decode_block_table_tensor
[:
self
.
_num_decode_tokens
].
copy_
(
block_table_tensor
[:
self
.
_num_decode_tokens
,
...])
self
.
spec_decode_seq_lens
[:
self
.
_num_decode_tokens
].
copy_
(
seq_lens
[:
self
.
_num_decode_tokens
])
decode_metadata
=
self
.
_build_decode
(
block_table_tensor
=
self
.
spec_decode_block_table_tensor
[:
self
.
_num_decode_tokens
,
...],
seq_lens
=
self
.
spec_decode_seq_lens
[:
self
.
_num_decode_tokens
],
)
else
:
else
:
decode_metadata
=
self
.
_build_decode
(
decode_metadata
=
self
.
_build_decode
(
block_table_tensor
=
block_table_tensor
[:
num_decodes
,
...],
block_table_tensor
=
block_table_tensor
[:
self
.
_
num_decode
_token
s
,
...],
seq_lens
=
seq_lens
[:
num_decodes
],
seq_lens
=
seq_lens
[:
self
.
_
num_decode
_token
s
],
)
)
attn_metadata
=
self
.
metadata_cls
(
attn_metadata
=
self
.
metadata_cls
(
...
...
vllm/v1/attention/backends/utils.py
View file @
4874e3e0
...
@@ -58,12 +58,11 @@ class CommonAttentionMetadata:
...
@@ -58,12 +58,11 @@ class CommonAttentionMetadata:
block_table_tensor
:
torch
.
Tensor
block_table_tensor
:
torch
.
Tensor
num_rejected_tokens
:
list
[
int
]
=
None
"""(batch_size,), record the rejected tokens number in cpu and gpu"""
num_speculative_tokens
:
int
=
0
num_speculative_tokens
:
int
=
0
"""Number of speculative tokens"""
"""Number of speculative tokens"""
slot_mapping
:
torch
.
Tensor
=
None
slot_mapping
:
torch
.
Tensor
=
None
"""(batch_size, seq_len), slot mapping"""
"""(batch_size, seq_len), slot mapping"""
spec_layer_decoding
:
bool
=
False
M
=
TypeVar
(
"M"
)
M
=
TypeVar
(
"M"
)
...
...
vllm/v1/spec_decode/eagle.py
View file @
4874e3e0
...
@@ -98,6 +98,7 @@ class EagleProposer:
...
@@ -98,6 +98,7 @@ class EagleProposer:
next_token_ids
:
torch
.
Tensor
,
next_token_ids
:
torch
.
Tensor
,
common_attn_metadata
:
CommonAttentionMetadata
,
common_attn_metadata
:
CommonAttentionMetadata
,
sampling_metadata
:
SamplingMetadata
,
sampling_metadata
:
SamplingMetadata
,
decoding
:
bool
=
False
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
num_tokens
=
target_token_ids
.
shape
[
0
]
num_tokens
=
target_token_ids
.
shape
[
0
]
batch_size
=
next_token_ids
.
shape
[
0
]
batch_size
=
next_token_ids
.
shape
[
0
]
...
@@ -141,7 +142,7 @@ class EagleProposer:
...
@@ -141,7 +142,7 @@ class EagleProposer:
self
.
positions
[:
num_tokens
]
=
target_positions
self
.
positions
[:
num_tokens
]
=
target_positions
self
.
hidden_states
[:
num_tokens
]
=
target_hidden_states
self
.
hidden_states
[:
num_tokens
]
=
target_hidden_states
if
(
self
.
use_full_cuda_graph
if
(
decoding
and
self
.
use_full_cuda_graph
and
num_tokens
<=
self
.
cudagraph_batch_sizes
[
-
1
]):
and
num_tokens
<=
self
.
cudagraph_batch_sizes
[
-
1
]):
assert
self
.
attn_metadata_cudagraph
assert
self
.
attn_metadata_cudagraph
if
self
.
method
==
"deepseek_mtp"
:
if
self
.
method
==
"deepseek_mtp"
:
...
@@ -166,7 +167,8 @@ class EagleProposer:
...
@@ -166,7 +167,8 @@ class EagleProposer:
with
set_forward_context
(
per_layer_attn_metadata
,
with
set_forward_context
(
per_layer_attn_metadata
,
self
.
vllm_config
,
self
.
vllm_config
,
num_tokens
=
num_input_tokens
):
num_tokens
=
num_input_tokens
,
skip_cuda_graphs
=
not
decoding
):
ret_hidden_states
=
self
.
model
(
ret_hidden_states
=
self
.
model
(
self
.
input_ids
[:
num_input_tokens
],
self
.
input_ids
[:
num_input_tokens
],
self
.
positions
[:
num_input_tokens
],
self
.
positions
[:
num_input_tokens
],
...
@@ -329,9 +331,11 @@ class EagleProposer:
...
@@ -329,9 +331,11 @@ class EagleProposer:
def
prepare_inputs
(
def
prepare_inputs
(
self
,
self
,
# cu_target_query_lens: torch.Tensor,
common_attn_metadata
:
CommonAttentionMetadata
,
common_attn_metadata
:
CommonAttentionMetadata
,
# [batch_size]
# [batch_size]
num_rejected_tokens
:
torch
.
Tensor
num_rejected_tokens
:
torch
.
Tensor
,
# num_accepted_tokens_tensor: torch.Tensor,
)
->
tuple
[
CommonAttentionMetadata
,
torch
.
Tensor
]:
)
->
tuple
[
CommonAttentionMetadata
,
torch
.
Tensor
]:
"""
"""
This function is used to prepare the inputs for the spec decode.
This function is used to prepare the inputs for the spec decode.
...
@@ -403,6 +407,7 @@ class EagleProposer:
...
@@ -403,6 +407,7 @@ class EagleProposer:
token_indices_np
=
token_offests
+
old_query_start_locs_expanded
token_indices_np
=
token_offests
+
old_query_start_locs_expanded
token_indices
=
torch
.
from_numpy
(
token_indices_np
).
to
(
token_indices
=
torch
.
from_numpy
(
token_indices_np
).
to
(
device
,
non_blocking
=
True
)
device
,
non_blocking
=
True
)
# token_indices = num_accepted_tokens_tensor + cu_target_query_lens[:-1]
spec_common_attn_metadata
=
CommonAttentionMetadata
(
spec_common_attn_metadata
=
CommonAttentionMetadata
(
query_start_loc
=
new_query_start_loc_cpu
.
to
(
device
,
query_start_loc
=
new_query_start_loc_cpu
.
to
(
device
,
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
4874e3e0
...
@@ -1732,7 +1732,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1732,7 +1732,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
device
=
self
.
device
)
num_rejected_tokens
=
None
if
spec_decode_metadata
is
None
:
if
spec_decode_metadata
is
None
:
# input_ids can be None for multimodal models.
# input_ids can be None for multimodal models.
target_token_ids
=
self
.
input_ids
[:
num_scheduled_tokens
]
target_token_ids
=
self
.
input_ids
[:
num_scheduled_tokens
]
...
@@ -1757,6 +1756,19 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1757,6 +1756,19 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self
.
drafter
.
prepare_inputs
(
self
.
drafter
.
prepare_inputs
(
common_attn_metadata
,
num_rejected_tokens_cpu
)
common_attn_metadata
,
num_rejected_tokens_cpu
)
# num_accepted_tokens = [len(s) - 1 for s in sampled_token_ids]
# num_accepted_tokens_tensor = async_tensor_h2d(
# num_accepted_tokens,
# dtype=torch.int32,
# target_device=self.device,
# pin_memory=True)
# num_accepted_tokens_cpu = torch.tensor(num_accepted_tokens,
# dtype=torch.int32)
# common_attn_metadata, token_indices =\
# self.drafter.prepare_inputs(
# common_attn_metadata, num_accepted_tokens_cpu)
target_token_ids
=
self
.
input_ids
[
token_indices
]
target_token_ids
=
self
.
input_ids
[
token_indices
]
# TODO(woosuk): Support M-RoPE.
# TODO(woosuk): Support M-RoPE.
target_positions
=
self
.
positions
[
token_indices
]
target_positions
=
self
.
positions
[
token_indices
]
...
@@ -1772,7 +1784,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1772,7 +1784,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
next_token_ids
=
next_token_ids
,
next_token_ids
=
next_token_ids
,
sampling_metadata
=
sampling_metadata
,
sampling_metadata
=
sampling_metadata
,
common_attn_metadata
=
common_attn_metadata
,
common_attn_metadata
=
common_attn_metadata
,
num_rejected_tokens
=
num_rejected_tokens
decoding
=
spec_decode_metadata
is
not
None
)
)
spec_token_ids
=
draft_token_ids
.
tolist
()
spec_token_ids
=
draft_token_ids
.
tolist
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment