Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
c1e4a405
Unverified
Commit
c1e4a405
authored
May 24, 2025
by
qizixi
Committed by
GitHub
May 24, 2025
Browse files
[V1][Spec Decode] Support multi-layer eagle draft model (#18030)
Signed-off-by:
qizixi
<
qizixi@meta.com
>
parent
a8593205
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
45 additions
and
9 deletions
+45
-9
tests/v1/spec_decode/test_eagle.py
tests/v1/spec_decode/test_eagle.py
+3
-0
vllm/v1/spec_decode/eagle.py
vllm/v1/spec_decode/eagle.py
+29
-4
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+13
-5
No files found.
tests/v1/spec_decode/test_eagle.py
View file @
c1e4a405
...
...
@@ -246,6 +246,9 @@ def test_propose(num_speculative_tokens):
# Assign the mock to the proposer
proposer
.
model
=
model_mock
# Assign draft attn_layer_names since load_model is not invoked
proposer
.
attn_layer_names
=
[
"layer.0"
]
# Create input tensors
cu_num_tokens
=
torch
.
tensor
([
0
,
seq_len_1
,
total_tokens
],
dtype
=
torch
.
int32
,
...
...
vllm/v1/spec_decode/eagle.py
View file @
c1e4a405
...
...
@@ -12,6 +12,7 @@ from vllm.model_executor.model_loader import get_model
from
vllm.model_executor.models.llama_eagle3
import
Eagle3LlamaForCausalLM
from
vllm.v1.attention.backends.flash_attn
import
(
CommonAttentionMetadata
,
FlashAttentionMetadata
)
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
from
vllm.v1.sample.metadata
import
SamplingMetadata
from
vllm.v1.spec_decode.utils
import
prepare_eagle_input_kernel
...
...
@@ -150,6 +151,11 @@ class EagleProposer:
else
:
raise
ValueError
(
f
"Unsupported method:
{
self
.
method
}
"
)
# At this moment, we assume all eagle layers belong to the same KV
# cache group, thus using the same attention metadata.
per_layer_attn_metadata
=
{}
for
layer_name
in
self
.
attn_layer_names
:
per_layer_attn_metadata
[
layer_name
]
=
attn_metadata
if
self
.
use_cuda_graph
and
\
num_tokens
<=
self
.
cudagraph_batch_sizes
[
-
1
]:
num_input_tokens
=
self
.
vllm_config
.
pad_for_cudagraph
(
num_tokens
)
...
...
@@ -159,7 +165,7 @@ class EagleProposer:
self
.
positions
[:
num_tokens
]
=
target_positions
self
.
hidden_states
[:
num_tokens
]
=
target_hidden_states
with
set_forward_context
(
attn_metadata
,
with
set_forward_context
(
per_layer_
attn_metadata
,
self
.
vllm_config
,
num_tokens
=
num_input_tokens
):
ret_hidden_states
=
self
.
model
(
...
...
@@ -245,7 +251,7 @@ class EagleProposer:
self
.
hidden_states
[:
batch_size
]
=
hidden_states
# Run the model.
with
set_forward_context
(
attn_metadata
,
with
set_forward_context
(
per_layer_
attn_metadata
,
self
.
vllm_config
,
num_tokens
=
input_batch_size
):
last_hidden_states
,
hidden_states
=
self
.
model
(
...
...
@@ -318,8 +324,8 @@ class EagleProposer:
draft_attn_layer_names
=
(
get_layers_from_vllm_config
(
self
.
vllm_config
,
Attention
).
keys
()
-
target_attn_layer_names
)
assert
len
(
draft_attn_layer_names
)
==
1
self
.
attn_layer_name
=
next
(
iter
(
draft_attn_layer_names
)
)
self
.
attn_layer_name
s
=
list
(
draft_attn_layer_names
)
# share embed_tokens with the target model if needed
if
get_pp_group
().
world_size
==
1
:
...
...
@@ -355,6 +361,25 @@ class EagleProposer:
self
.
hidden_states
[:
num_tokens
],
)
def
validate_same_kv_cache_group
(
self
,
kv_cache_config
:
KVCacheConfig
)
->
None
:
"""
Validate that all eagle layers belong to the same KVCacheGroup.
Need this assumption to ensure all eagle layers can use the
same AttentionMetadata.
May extend to multiple AttentionMetadata in the future.
"""
kv_cache_groups
:
dict
[
str
,
int
]
=
{}
for
id
,
kv_cache_group
in
enumerate
(
kv_cache_config
.
kv_cache_groups
):
for
layer_name
in
kv_cache_group
.
layer_names
:
kv_cache_groups
[
layer_name
]
=
id
assert
len
(
set
([
kv_cache_groups
[
layer_name
]
for
layer_name
in
self
.
attn_layer_names
])
)
==
1
,
"All eagle layers should belong to the same kv cache group"
# NOTE(woosuk): Currently, the below code is not used and we always use argmax
# to sample the draft tokens. We will use this after we find a way to manage
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
c1e4a405
...
...
@@ -1360,11 +1360,13 @@ class GPUModelRunner(LoRAModelRunnerMixin):
scheduler_output
.
num_scheduled_tokens
[
req_id
])
next_token_id
=
req_state
.
get_token_id
(
seq_len
)
next_token_ids
.
append
(
next_token_id
)
next_token_ids
=
async_tensor_h2d
(
next_token_ids
,
dtype
=
torch
.
int32
,
target_device
=
self
.
device
,
pin_memory
=
True
)
eagle_attn_metadata
=
attn_metadata
[
self
.
drafter
.
attn_layer_name
]
next_token_ids
=
torch
.
tensor
(
next_token_ids
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
# At this moment, we assume all eagle layers belong to the same KV
# cache group, thus using the same attention metadata.
eagle_attn_metadata
=
attn_metadata
[
self
.
drafter
.
attn_layer_names
[
0
]]
# NOTE: deepseek_mtp uses MLA which does not have `block_table`
if
hasattr
(
eagle_attn_metadata
,
"block_table"
):
...
...
@@ -2018,6 +2020,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# KV cache specs.
raise
ValueError
(
"Unknown KV cache spec type."
)
if
self
.
speculative_config
and
self
.
speculative_config
.
use_eagle
():
assert
isinstance
(
self
.
drafter
,
EagleProposer
)
# validate all draft model layers belong to the same kv cache
# group
self
.
drafter
.
validate_same_kv_cache_group
(
kv_cache_config
)
bind_kv_cache
(
kv_caches
,
self
.
vllm_config
.
compilation_config
.
static_forward_context
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment