Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
619eb032
Commit
619eb032
authored
Aug 13, 2025
by
王敏
Browse files
[feat]优化mtp/eagle的计算逻辑,减少第1层并行解码的计算重复
parent
a6bf968b
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
73 additions
and
64 deletions
+73
-64
vllm/v1/attention/backends/mla/common.py
vllm/v1/attention/backends/mla/common.py
+15
-9
vllm/v1/attention/backends/utils.py
vllm/v1/attention/backends/utils.py
+1
-2
vllm/v1/spec_decode/eagle.py
vllm/v1/spec_decode/eagle.py
+52
-41
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+5
-12
No files found.
vllm/v1/attention/backends/mla/common.py
View file @
619eb032
...
@@ -637,12 +637,8 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
...
@@ -637,12 +637,8 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
decode_metadata
=
None
decode_metadata
=
None
if
self
.
_num_decodes
>
0
:
if
self
.
_num_decodes
>
0
:
if
self
.
use_spec_decode
:
if
self
.
use_spec_decode
and
not
common_attn_metadata
.
spec_layer_decoding
:
query_lens
=
self
.
num_scheduled_tokens_np
[:
self
.
_num_decodes
]
query_lens
=
self
.
num_scheduled_tokens_np
[:
self
.
_num_decodes
]
if
common_attn_metadata
.
num_rejected_tokens
is
not
None
:
num_rejected_tokens
=
common_attn_metadata
.
num_rejected_tokens
[:
self
.
_num_decodes
]
query_lens
=
query_lens
-
np
.
array
(
num_rejected_tokens
,
dtype
=
np
.
int32
)
self
.
_num_decode_tokens
-=
sum
(
num_rejected_tokens
)
cu_num_blocks
=
np
.
cumsum
(
query_lens
)
cu_num_blocks
=
np
.
cumsum
(
query_lens
)
virtual_batches
=
cu_num_blocks
[
-
1
]
virtual_batches
=
cu_num_blocks
[
-
1
]
block_offsets
=
np
.
repeat
(
cu_num_blocks
-
query_lens
,
query_lens
)
block_offsets
=
np
.
repeat
(
cu_num_blocks
-
query_lens
,
query_lens
)
...
@@ -673,10 +669,20 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
...
@@ -673,10 +669,20 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
seq_lens
=
decode_seq_lens
,
seq_lens
=
decode_seq_lens
,
)
)
else
:
else
:
decode_metadata
=
self
.
_build_decode
(
self
.
_num_decode_tokens
=
self
.
_num_decodes
block_table_tensor
=
block_table_tensor
[:
self
.
_num_decodes
,
...],
if
self
.
use_spec_decode
and
self
.
spec_decode_block_table_tensor
is
not
None
:
seq_lens
=
seq_lens
[:
self
.
_num_decodes
],
self
.
spec_decode_block_table_tensor
[:
self
.
_num_decode_tokens
].
copy_
(
block_table_tensor
[:
self
.
_num_decode_tokens
,
...])
)
self
.
spec_decode_seq_lens
[:
self
.
_num_decode_tokens
].
copy_
(
seq_lens
[:
self
.
_num_decode_tokens
])
decode_metadata
=
self
.
_build_decode
(
block_table_tensor
=
self
.
spec_decode_block_table_tensor
[:
self
.
_num_decode_tokens
,
...],
seq_lens
=
self
.
spec_decode_seq_lens
[:
self
.
_num_decode_tokens
],
)
else
:
decode_metadata
=
self
.
_build_decode
(
block_table_tensor
=
block_table_tensor
[:
self
.
_num_decode_tokens
,
...],
seq_lens
=
seq_lens
[:
self
.
_num_decode_tokens
],
)
return
self
.
metadata_cls
(
return
self
.
metadata_cls
(
num_actual_tokens
=
num_actual_tokens
,
num_actual_tokens
=
num_actual_tokens
,
...
...
vllm/v1/attention/backends/utils.py
View file @
619eb032
...
@@ -41,12 +41,11 @@ class CommonAttentionMetadata:
...
@@ -41,12 +41,11 @@ class CommonAttentionMetadata:
"""Total number of tokens in batch"""
"""Total number of tokens in batch"""
max_query_len
:
int
max_query_len
:
int
"""Longest query in batch"""
"""Longest query in batch"""
num_rejected_tokens
:
list
[
int
]
=
None
"""(batch_size,), record the rejected tokens number in cpu and gpu"""
num_speculative_tokens
:
int
=
0
num_speculative_tokens
:
int
=
0
"""Number of speculative tokens"""
"""Number of speculative tokens"""
slot_mapping
:
torch
.
Tensor
=
None
slot_mapping
:
torch
.
Tensor
=
None
"""(batch_size, seq_len), slot mapping"""
"""(batch_size, seq_len), slot mapping"""
spec_layer_decoding
:
bool
=
False
M
=
TypeVar
(
"M"
)
M
=
TypeVar
(
"M"
)
...
...
vllm/v1/spec_decode/eagle.py
View file @
619eb032
...
@@ -104,9 +104,8 @@ class EagleProposer:
...
@@ -104,9 +104,8 @@ class EagleProposer:
# [batch_size, max_num_blocks_per_req]
# [batch_size, max_num_blocks_per_req]
block_table
:
torch
.
Tensor
,
block_table
:
torch
.
Tensor
,
# [batch_size]
# [batch_size]
num_rejected_tokens
:
list
[
int
],
sampling_metadata
:
SamplingMetadata
,
# [batch_size]
decoding
:
bool
=
False
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
num_tokens
=
target_token_ids
.
shape
[
0
]
num_tokens
=
target_token_ids
.
shape
[
0
]
batch_size
=
next_token_ids
.
shape
[
0
]
batch_size
=
next_token_ids
.
shape
[
0
]
...
@@ -158,8 +157,8 @@ class EagleProposer:
...
@@ -158,8 +157,8 @@ class EagleProposer:
num_reqs
=
batch_size
,
num_reqs
=
batch_size
,
num_actual_tokens
=
num_tokens
,
num_actual_tokens
=
num_tokens
,
max_query_len
=
max_query_len
,
max_query_len
=
max_query_len
,
num_rejected_tokens
=
num_rejected_tokens
,
slot_mapping
=
target_slot_mapping
,
s
lot_mapping
=
target_slot_mapp
ing
s
pec_layer_decoding
=
decod
ing
)
)
assert
self
.
runner
is
not
None
assert
self
.
runner
is
not
None
...
@@ -186,7 +185,7 @@ class EagleProposer:
...
@@ -186,7 +185,7 @@ class EagleProposer:
self
.
positions
[:
num_tokens
]
=
target_positions
self
.
positions
[:
num_tokens
]
=
target_positions
self
.
hidden_states
[:
num_tokens
]
=
target_hidden_states
self
.
hidden_states
[:
num_tokens
]
=
target_hidden_states
if
(
self
.
use_full_cuda_graph
if
(
decoding
and
self
.
use_full_cuda_graph
and
num_tokens
<=
self
.
cudagraph_batch_sizes
[
-
1
]):
and
num_tokens
<=
self
.
cudagraph_batch_sizes
[
-
1
]):
assert
self
.
attn_metadata_cudagraph
assert
self
.
attn_metadata_cudagraph
if
self
.
method
in
[
"eagle"
,
"eagle3"
]:
if
self
.
method
in
[
"eagle"
,
"eagle3"
]:
...
@@ -214,13 +213,14 @@ class EagleProposer:
...
@@ -214,13 +213,14 @@ class EagleProposer:
if
attn_metadata
.
decode
is
not
None
:
if
attn_metadata
.
decode
is
not
None
:
self
.
attn_metadata_cudagraph
.
decode
.
block_table
[:
attn_metadata
.
num_decode_tokens
]
=
(
self
.
attn_metadata_cudagraph
.
decode
.
block_table
[:
attn_metadata
.
num_decode_tokens
]
=
(
attn_metadata
.
decode
.
block_table
)
attn_metadata
.
decode
.
block_table
)
self
.
attn_metadata_cudagraph
.
decode
.
seq_lens
[:
attn_metadata
.
num_decode_tokens
]
=
(
self
.
attn_metadata_cudagraph
.
decode
.
seq_lens
[:
attn_metadata
.
num_decode_tokens
]
=
(
attn_metadata
.
decode
.
seq_lens
)
attn_metadata
.
decode
.
seq_lens
)
with
set_forward_context
(
per_layer_attn_metadata
,
with
set_forward_context
(
per_layer_attn_metadata
,
self
.
vllm_config
,
self
.
vllm_config
,
num_tokens
=
num_input_tokens
):
num_tokens
=
num_input_tokens
,
skip_cuda_graphs
=
not
decoding
):
ret_hidden_states
=
self
.
model
(
ret_hidden_states
=
self
.
model
(
self
.
input_ids
[:
num_input_tokens
],
self
.
input_ids
[:
num_input_tokens
],
self
.
positions
[:
num_input_tokens
],
self
.
positions
[:
num_input_tokens
],
...
@@ -390,45 +390,56 @@ class EagleProposer:
...
@@ -390,45 +390,56 @@ class EagleProposer:
return
draft_token_ids
return
draft_token_ids
# @staticmethod
# def prepare_inputs(
# # [batch_size + 1]
# cu_target_query_lens: torch.Tensor,
# # [batch_size]
# num_rejected_tokens: torch.Tensor,
# num_tokens: int,
# ) -> tuple[torch.Tensor, torch.Tensor]:
# # cu_target_query_lens: [0, a, a + b, a + b + c]
# # num_rejected_tokens: [n1, n2, n3]
# # num_tokens_per_req: [a - n1, b - n2, c - n3]
# # cu_num_tokens: [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3]
# # token_indices: [0, 1, ..., a - n1 - 1,
# # a, a + 1, ..., a + b - n2 - 1,
# # a + b, a + b + 1, ..., a + b + c - n3 - 1]
# # [0, a, a + b, a + b + c] -> [a, b, c]
# query_len_per_req = (cu_target_query_lens[1:] -
# cu_target_query_lens[:-1])
# # [a, b, c] -> [a - n1, b - n2, c - n3]
# num_tokens_per_req = query_len_per_req - num_rejected_tokens
# # [a - n1, b - n2, c - n3] ->
# # [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3]
# cu_num_tokens = torch.zeros_like(cu_target_query_lens)
# torch.cumsum(num_tokens_per_req, dim=0, out=cu_num_tokens[1:])
# token_indices = torch.empty(
# num_tokens,
# dtype=torch.int32,
# device=cu_target_query_lens.device,
# )
# batch_size = num_rejected_tokens.shape[0]
# BLOCK_SIZE = 1024
# prepare_eagle_input_kernel[(batch_size, )](
# token_indices,
# cu_target_query_lens,
# cu_num_tokens,
# BLOCK_SIZE=BLOCK_SIZE,
# )
# return cu_num_tokens, token_indices
@
staticmethod
@
staticmethod
def
prepare_inputs
(
def
prepare_inputs
(
# [batch_size + 1]
# [batch_size + 1]
cu_target_query_lens
:
torch
.
Tensor
,
cu_target_query_lens
:
torch
.
Tensor
,
# [batch_size]
# [batch_size]
num_rejected_tokens
:
torch
.
Tensor
,
num_accepted_tokens_tensor
:
torch
.
Tensor
,
num_tokens
:
int
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# cu_target_query_lens: [0, a, a + b, a + b + c]
cu_num_tokens
=
torch
.
arange
(
cu_target_query_lens
.
shape
[
0
],
device
=
cu_target_query_lens
.
device
)
# num_rejected_tokens: [n1, n2, n3]
token_indices
=
num_accepted_tokens_tensor
+
cu_target_query_lens
[:
-
1
]
# num_tokens_per_req: [a - n1, b - n2, c - n3]
# cu_num_tokens: [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3]
# token_indices: [0, 1, ..., a - n1 - 1,
# a, a + 1, ..., a + b - n2 - 1,
# a + b, a + b + 1, ..., a + b + c - n3 - 1]
# [0, a, a + b, a + b + c] -> [a, b, c]
query_len_per_req
=
(
cu_target_query_lens
[
1
:]
-
cu_target_query_lens
[:
-
1
])
# [a, b, c] -> [a - n1, b - n2, c - n3]
num_tokens_per_req
=
query_len_per_req
-
num_rejected_tokens
# [a - n1, b - n2, c - n3] ->
# [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3]
cu_num_tokens
=
torch
.
zeros_like
(
cu_target_query_lens
)
torch
.
cumsum
(
num_tokens_per_req
,
dim
=
0
,
out
=
cu_num_tokens
[
1
:])
token_indices
=
torch
.
empty
(
num_tokens
,
dtype
=
torch
.
int32
,
device
=
cu_target_query_lens
.
device
,
)
batch_size
=
num_rejected_tokens
.
shape
[
0
]
BLOCK_SIZE
=
1024
prepare_eagle_input_kernel
[(
batch_size
,
)](
token_indices
,
cu_target_query_lens
,
cu_num_tokens
,
BLOCK_SIZE
=
BLOCK_SIZE
,
)
return
cu_num_tokens
,
token_indices
return
cu_num_tokens
,
token_indices
def
load_model
(
self
,
target_model
:
nn
.
Module
)
->
None
:
def
load_model
(
self
,
target_model
:
nn
.
Module
)
->
None
:
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
619eb032
...
@@ -1659,7 +1659,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1659,7 +1659,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
else
:
else
:
block_table
=
None
block_table
=
None
num_rejected_tokens
=
None
if
spec_decode_metadata
is
None
:
if
spec_decode_metadata
is
None
:
# input_ids can be None for multimodal models.
# input_ids can be None for multimodal models.
target_token_ids
=
self
.
input_ids
[:
num_scheduled_tokens
]
target_token_ids
=
self
.
input_ids
[:
num_scheduled_tokens
]
...
@@ -1675,21 +1674,15 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1675,21 +1674,15 @@ class GPUModelRunner(LoRAModelRunnerMixin):
cu_num_tokens
=
eagle_attn_metadata
.
query_start_loc
cu_num_tokens
=
eagle_attn_metadata
.
query_start_loc
else
:
else
:
# TODO(woosuk): Refactor this.
# TODO(woosuk): Refactor this.
num_draft_tokens
=
spec_decode_metadata
.
num_draft_tokens
num_accepted_tokens
=
[
len
(
s
)
-
1
for
s
in
sampled_token_ids
]
num_rejected_tokens
=
[
num_accepted_tokens_tensor
=
async_tensor_h2d
(
n
+
1
-
len
(
sampled_token_ids
[
i
])
if
n
>
0
else
0
num_accepted_tokens
,
for
i
,
n
in
enumerate
(
num_draft_tokens
)
]
num_rejected_tokens_tensor
=
async_tensor_h2d
(
num_rejected_tokens
,
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
target_device
=
self
.
device
,
target_device
=
self
.
device
,
pin_memory
=
True
)
pin_memory
=
True
)
num_tokens
=
num_scheduled_tokens
-
sum
(
num_rejected_tokens
)
cu_num_tokens
,
token_indices
=
self
.
drafter
.
prepare_inputs
(
cu_num_tokens
,
token_indices
=
self
.
drafter
.
prepare_inputs
(
eagle_attn_metadata
.
query_start_loc
,
eagle_attn_metadata
.
query_start_loc
,
num_rejected_tokens_tensor
,
num_accepted_tokens_tensor
,
num_tokens
,
)
)
target_token_ids
=
self
.
input_ids
[
token_indices
]
target_token_ids
=
self
.
input_ids
[
token_indices
]
# TODO(woosuk): Support M-RoPE.
# TODO(woosuk): Support M-RoPE.
...
@@ -1710,7 +1703,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1710,7 +1703,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
cu_num_tokens
=
cu_num_tokens
,
cu_num_tokens
=
cu_num_tokens
,
block_table
=
block_table
,
block_table
=
block_table
,
sampling_metadata
=
sampling_metadata
,
sampling_metadata
=
sampling_metadata
,
num_rejected_tokens
=
num_rejected_tokens
decoding
=
spec_decode_metadata
is
not
None
)
)
spec_token_ids
=
draft_token_ids
.
tolist
()
spec_token_ids
=
draft_token_ids
.
tolist
()
return
spec_token_ids
return
spec_token_ids
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment