Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
d55e446d
Unverified
Commit
d55e446d
authored
May 23, 2025
by
qizixi
Committed by
GitHub
May 24, 2025
Browse files
[V1][Spec Decode] Small refactors to improve eagle bookkeeping performance (#18424)
Signed-off-by:
qizixi
<
qizixi@meta.com
>
parent
ec82c3e3
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
21 additions
and
19 deletions
+21
-19
tests/v1/spec_decode/test_eagle.py
tests/v1/spec_decode/test_eagle.py
+5
-1
vllm/v1/spec_decode/eagle.py
vllm/v1/spec_decode/eagle.py
+3
-7
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+13
-11
No files found.
tests/v1/spec_decode/test_eagle.py
View file @
d55e446d
...
...
@@ -100,8 +100,12 @@ def test_prepare_inputs():
dtype
=
torch
.
int32
,
device
=
device
)
# n1 + n2 + n3 - a - b -c
num_tokens
=
cu_target_query_lens
[
-
1
].
item
()
-
num_rejected_tokens
.
sum
(
).
item
()
cu_num_tokens
,
token_indices
=
EagleProposer
.
prepare_inputs
(
cu_target_query_lens
,
num_rejected_tokens
)
cu_target_query_lens
,
num_rejected_tokens
,
num_tokens
)
assert
torch
.
equal
(
cu_num_tokens
,
expected_cu_num_tokens
)
assert
token_indices
.
shape
[
0
]
==
expected_cu_num_tokens
[
-
1
].
item
()
...
...
vllm/v1/spec_decode/eagle.py
View file @
d55e446d
...
...
@@ -271,6 +271,7 @@ class EagleProposer:
cu_target_query_lens
:
torch
.
Tensor
,
# [batch_size]
num_rejected_tokens
:
torch
.
Tensor
,
num_tokens
:
int
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# cu_target_query_lens: [0, a, a + b, a + b + c]
# num_rejected_tokens: [n1, n2, n3]
...
...
@@ -288,18 +289,13 @@ class EagleProposer:
# [a - n1, b - n2, c - n3] ->
# [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3]
cu_num_tokens
=
torch
.
empty
_like
(
cu_target_query_lens
)
cu_num_tokens
=
torch
.
zeros
_like
(
cu_target_query_lens
)
torch
.
cumsum
(
num_tokens_per_req
,
dim
=
0
,
out
=
cu_num_tokens
[
1
:])
cu_num_tokens
[
0
]
=
0
# FIXME(woosuk): Avoid synchronization.
num_tokens
=
cu_num_tokens
[
-
1
].
item
()
token_indices
=
torch
.
empty
(
num_tokens
,
dtype
=
torch
.
int32
,
device
=
cu_
num_tok
ens
.
device
,
device
=
cu_
target_query_l
ens
.
device
,
)
batch_size
=
num_rejected_tokens
.
shape
[
0
]
BLOCK_SIZE
=
1024
prepare_eagle_input_kernel
[(
batch_size
,
)](
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
d55e446d
...
...
@@ -34,8 +34,8 @@ from vllm.multimodal.utils import group_mm_inputs_by_modality
from
vllm.sampling_params
import
SamplingType
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils
import
(
STR_DTYPE_TO_TORCH_DTYPE
,
DeviceMemoryProfiler
,
GiB_bytes
,
LazyLoader
,
cdiv
,
check_use_alibi
,
is_pin_memory_available
)
GiB_bytes
,
LazyLoader
,
async_tensor_h2d
,
cdiv
,
check_use_alibi
,
is_pin_memory_available
)
from
vllm.v1.attention.backends.flash_attn
import
FlashAttentionMetadata
from
vllm.v1.attention.backends.utils
import
CommonAttentionMetadata
from
vllm.v1.core.encoder_cache_manager
import
compute_encoder_budget
...
...
@@ -1360,9 +1360,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
scheduler_output
.
num_scheduled_tokens
[
req_id
])
next_token_id
=
req_state
.
get_token_id
(
seq_len
)
next_token_ids
.
append
(
next_token_id
)
next_token_ids
=
torch
.
tensor
(
next_token_ids
,
next_token_ids
=
async_
tensor
_h2d
(
next_token_ids
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
target_device
=
self
.
device
,
pin_memory
=
True
)
eagle_attn_metadata
=
attn_metadata
[
self
.
drafter
.
attn_layer_name
]
# NOTE: deepseek_mtp uses MLA which does not have `block_table`
...
...
@@ -1390,14 +1391,16 @@ class GPUModelRunner(LoRAModelRunnerMixin):
n
+
1
-
len
(
valid_sampled_token_ids
[
i
])
if
n
>
0
else
0
for
i
,
n
in
enumerate
(
num_draft_tokens
)
]
num_rejected_tokens
=
torch
.
tensor
(
num_rejected_tokens
_tensor
=
async_
tensor
_h2d
(
num_rejected_tokens
,
dtype
=
torch
.
int32
,
device
=
self
.
device
,
)
target_device
=
self
.
device
,
pin_memory
=
True
)
num_tokens
=
num_scheduled_tokens
-
sum
(
num_rejected_tokens
)
cu_num_tokens
,
token_indices
=
self
.
drafter
.
prepare_inputs
(
eagle_attn_metadata
.
query_start_loc
,
num_rejected_tokens
,
num_rejected_tokens_tensor
,
num_tokens
,
)
target_token_ids
=
self
.
input_ids
[
token_indices
]
target_positions
=
positions
[
token_indices
]
...
...
@@ -1408,7 +1411,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
target_hidden_states
=
hidden_states
[
token_indices
]
target_slot_mapping
=
eagle_attn_metadata
.
slot_mapping
[
token_indices
]
draft_token_ids
=
self
.
drafter
.
propose
(
target_token_ids
=
target_token_ids
,
target_positions
=
target_positions
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment