Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
d55e446d
Unverified
Commit
d55e446d
authored
May 23, 2025
by
qizixi
Committed by
GitHub
May 24, 2025
Browse files
[V1][Spec Decode] Small refactors to improve eagle bookkeeping performance (#18424)
Signed-off-by:
qizixi
<
qizixi@meta.com
>
parent
ec82c3e3
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
21 additions
and
19 deletions
+21
-19
tests/v1/spec_decode/test_eagle.py
tests/v1/spec_decode/test_eagle.py
+5
-1
vllm/v1/spec_decode/eagle.py
vllm/v1/spec_decode/eagle.py
+3
-7
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+13
-11
No files found.
tests/v1/spec_decode/test_eagle.py
View file @
d55e446d
...
...
@@ -100,8 +100,12 @@ def test_prepare_inputs():
dtype
=
torch
.
int32
,
device
=
device
)
# n1 + n2 + n3 - a - b -c
num_tokens
=
cu_target_query_lens
[
-
1
].
item
()
-
num_rejected_tokens
.
sum
(
).
item
()
cu_num_tokens
,
token_indices
=
EagleProposer
.
prepare_inputs
(
cu_target_query_lens
,
num_rejected_tokens
)
cu_target_query_lens
,
num_rejected_tokens
,
num_tokens
)
assert
torch
.
equal
(
cu_num_tokens
,
expected_cu_num_tokens
)
assert
token_indices
.
shape
[
0
]
==
expected_cu_num_tokens
[
-
1
].
item
()
...
...
vllm/v1/spec_decode/eagle.py
View file @
d55e446d
...
...
@@ -271,6 +271,7 @@ class EagleProposer:
cu_target_query_lens
:
torch
.
Tensor
,
# [batch_size]
num_rejected_tokens
:
torch
.
Tensor
,
num_tokens
:
int
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# cu_target_query_lens: [0, a, a + b, a + b + c]
# num_rejected_tokens: [n1, n2, n3]
...
...
@@ -288,18 +289,13 @@ class EagleProposer:
# [a - n1, b - n2, c - n3] ->
# [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3]
cu_num_tokens
=
torch
.
empty
_like
(
cu_target_query_lens
)
cu_num_tokens
=
torch
.
zeros
_like
(
cu_target_query_lens
)
torch
.
cumsum
(
num_tokens_per_req
,
dim
=
0
,
out
=
cu_num_tokens
[
1
:])
cu_num_tokens
[
0
]
=
0
# FIXME(woosuk): Avoid synchronization.
num_tokens
=
cu_num_tokens
[
-
1
].
item
()
token_indices
=
torch
.
empty
(
num_tokens
,
dtype
=
torch
.
int32
,
device
=
cu_
num_tok
ens
.
device
,
device
=
cu_
target_query_l
ens
.
device
,
)
batch_size
=
num_rejected_tokens
.
shape
[
0
]
BLOCK_SIZE
=
1024
prepare_eagle_input_kernel
[(
batch_size
,
)](
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
d55e446d
...
...
@@ -34,8 +34,8 @@ from vllm.multimodal.utils import group_mm_inputs_by_modality
from
vllm.sampling_params
import
SamplingType
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils
import
(
STR_DTYPE_TO_TORCH_DTYPE
,
DeviceMemoryProfiler
,
GiB_bytes
,
LazyLoader
,
cdiv
,
check_use_alibi
,
is_pin_memory_available
)
GiB_bytes
,
LazyLoader
,
async_tensor_h2d
,
cdiv
,
check_use_alibi
,
is_pin_memory_available
)
from
vllm.v1.attention.backends.flash_attn
import
FlashAttentionMetadata
from
vllm.v1.attention.backends.utils
import
CommonAttentionMetadata
from
vllm.v1.core.encoder_cache_manager
import
compute_encoder_budget
...
...
@@ -281,7 +281,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
def
_may_reorder_batch
(
self
,
scheduler_output
:
"SchedulerOutput"
)
->
bool
:
"""
Update the order of requests in the batch based on the attention
backend's needs. For example, some attention backends (namely MLA) may
backend's needs. For example, some attention backends (namely MLA) may
want to separate requests based on if the attention computation will be
compute-bound or memory-bound.
...
...
@@ -1360,9 +1360,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
scheduler_output
.
num_scheduled_tokens
[
req_id
])
next_token_id
=
req_state
.
get_token_id
(
seq_len
)
next_token_ids
.
append
(
next_token_id
)
next_token_ids
=
torch
.
tensor
(
next_token_ids
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
next_token_ids
=
async_tensor_h2d
(
next_token_ids
,
dtype
=
torch
.
int32
,
target_device
=
self
.
device
,
pin_memory
=
True
)
eagle_attn_metadata
=
attn_metadata
[
self
.
drafter
.
attn_layer_name
]
# NOTE: deepseek_mtp uses MLA which does not have `block_table`
...
...
@@ -1390,14 +1391,16 @@ class GPUModelRunner(LoRAModelRunnerMixin):
n
+
1
-
len
(
valid_sampled_token_ids
[
i
])
if
n
>
0
else
0
for
i
,
n
in
enumerate
(
num_draft_tokens
)
]
num_rejected_tokens
=
torch
.
tensor
(
num_rejected_tokens
_tensor
=
async_
tensor
_h2d
(
num_rejected_tokens
,
dtype
=
torch
.
int32
,
device
=
self
.
device
,
)
target_device
=
self
.
device
,
pin_memory
=
True
)
num_tokens
=
num_scheduled_tokens
-
sum
(
num_rejected_tokens
)
cu_num_tokens
,
token_indices
=
self
.
drafter
.
prepare_inputs
(
eagle_attn_metadata
.
query_start_loc
,
num_rejected_tokens
,
num_rejected_tokens_tensor
,
num_tokens
,
)
target_token_ids
=
self
.
input_ids
[
token_indices
]
target_positions
=
positions
[
token_indices
]
...
...
@@ -1408,7 +1411,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
target_hidden_states
=
hidden_states
[
token_indices
]
target_slot_mapping
=
eagle_attn_metadata
.
slot_mapping
[
token_indices
]
draft_token_ids
=
self
.
drafter
.
propose
(
target_token_ids
=
target_token_ids
,
target_positions
=
target_positions
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment