Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
70788bdb
Unverified
Commit
70788bdb
authored
Apr 29, 2025
by
Bryan Lu
Committed by
GitHub
Apr 29, 2025
Browse files
[V1][Spec Decode] Apply torch.compile & cudagraph to EAGLE (#17211)
Signed-off-by:
Bryan Lu
<
yuzhelu@amazon.com
>
parent
c9c1b59e
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
152 additions
and
53 deletions
+152
-53
examples/offline_inference/eagle.py
examples/offline_inference/eagle.py
+12
-2
vllm/compilation/backends.py
vllm/compilation/backends.py
+12
-3
vllm/model_executor/models/llama_eagle.py
vllm/model_executor/models/llama_eagle.py
+14
-11
vllm/model_executor/models/llama_eagle3.py
vllm/model_executor/models/llama_eagle3.py
+3
-2
vllm/v1/spec_decode/eagle.py
vllm/v1/spec_decode/eagle.py
+100
-22
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+11
-13
No files found.
examples/offline_inference/eagle.py
View file @
70788bdb
...
@@ -36,6 +36,10 @@ def parse_args():
...
@@ -36,6 +36,10 @@ def parse_args():
help
=
"downloaded from the eagle repo "
\
help
=
"downloaded from the eagle repo "
\
"https://github.com/SafeAILab/EAGLE/blob/main/eagle/data/"
"https://github.com/SafeAILab/EAGLE/blob/main/eagle/data/"
)
)
parser
.
add_argument
(
"--method"
,
type
=
str
,
default
=
'eagle'
,
choices
=
[
'eagle'
,
'eagle3'
])
parser
.
add_argument
(
"--max_num_seqs"
,
type
=
int
,
default
=
8
)
parser
.
add_argument
(
"--max_num_seqs"
,
type
=
int
,
default
=
8
)
parser
.
add_argument
(
"--num_prompts"
,
type
=
int
,
default
=
80
)
parser
.
add_argument
(
"--num_prompts"
,
type
=
int
,
default
=
80
)
parser
.
add_argument
(
"--num_spec_tokens"
,
type
=
int
,
default
=
2
)
parser
.
add_argument
(
"--num_spec_tokens"
,
type
=
int
,
default
=
2
)
...
@@ -53,7 +57,13 @@ def main():
...
@@ -53,7 +57,13 @@ def main():
args
=
parse_args
()
args
=
parse_args
()
model_dir
=
"meta-llama/Llama-3.1-8B-Instruct"
model_dir
=
"meta-llama/Llama-3.1-8B-Instruct"
if
args
.
method
==
'eagle'
:
eagle_dir
=
"yuhuili/EAGLE-LLaMA3.1-Instruct-8B"
elif
args
.
method
==
'eagle3'
:
eagle_dir
=
"yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
eagle_dir
=
"yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
else
:
raise
ValueError
(
f
"unknown method:
{
args
.
method
}
"
)
max_model_len
=
2048
max_model_len
=
2048
...
@@ -81,7 +91,7 @@ def main():
...
@@ -81,7 +91,7 @@ def main():
max_num_seqs
=
args
.
max_num_seqs
,
max_num_seqs
=
args
.
max_num_seqs
,
gpu_memory_utilization
=
0.8
,
gpu_memory_utilization
=
0.8
,
speculative_config
=
{
speculative_config
=
{
"method"
:
"eagle3"
if
"eagle3"
in
eagle_dir
.
lower
()
else
"eagle"
,
"method"
:
args
.
method
,
"model"
:
eagle_dir
,
"model"
:
eagle_dir
,
"num_speculative_tokens"
:
args
.
num_spec_tokens
,
"num_speculative_tokens"
:
args
.
num_spec_tokens
,
"draft_tensor_parallel_size"
:
args
.
draft_tp
,
"draft_tensor_parallel_size"
:
args
.
draft_tp
,
...
...
vllm/compilation/backends.py
View file @
70788bdb
...
@@ -347,6 +347,10 @@ class VllmBackend:
...
@@ -347,6 +347,10 @@ class VllmBackend:
PASS_KEY
=
"post_grad_custom_post_pass"
PASS_KEY
=
"post_grad_custom_post_pass"
if
PASS_KEY
in
inductor_config
:
if
PASS_KEY
in
inductor_config
:
# Config should automatically wrap all inductor passes
# Config should automatically wrap all inductor passes
if
isinstance
(
inductor_config
[
PASS_KEY
],
PostGradPassManager
):
assert
(
inductor_config
[
PASS_KEY
].
uuid
()
==
self
.
post_grad_pass_manager
.
uuid
())
else
:
assert
isinstance
(
inductor_config
[
PASS_KEY
],
InductorPass
)
assert
isinstance
(
inductor_config
[
PASS_KEY
],
InductorPass
)
self
.
post_grad_pass_manager
.
add
(
inductor_config
[
PASS_KEY
])
self
.
post_grad_pass_manager
.
add
(
inductor_config
[
PASS_KEY
])
inductor_config
[
PASS_KEY
]
=
self
.
post_grad_pass_manager
inductor_config
[
PASS_KEY
]
=
self
.
post_grad_pass_manager
...
@@ -408,8 +412,13 @@ class VllmBackend:
...
@@ -408,8 +412,13 @@ class VllmBackend:
)
)
self
.
compilation_config
.
cache_dir
=
cache_dir
self
.
compilation_config
.
cache_dir
=
cache_dir
if
compilation_counter
.
num_graphs_seen
>
0
:
cache_dir
=
self
.
compilation_config
.
cache_dir
+
\
f
'-
{
compilation_counter
.
num_graphs_seen
}
'
else
:
cache_dir
=
self
.
compilation_config
.
cache_dir
cache_dir
=
self
.
compilation_config
.
cache_dir
os
.
makedirs
(
cache_dir
,
exist_ok
=
True
)
os
.
makedirs
(
cache_dir
,
exist_ok
=
True
)
self
.
compilation_config
.
cache_dir
=
cache_dir
rank
=
vllm_config
.
parallel_config
.
rank
rank
=
vllm_config
.
parallel_config
.
rank
dp_rank
=
vllm_config
.
parallel_config
.
data_parallel_rank
dp_rank
=
vllm_config
.
parallel_config
.
data_parallel_rank
local_cache_dir
=
os
.
path
.
join
(
cache_dir
,
f
"rank_
{
rank
}
_
{
dp_rank
}
"
)
local_cache_dir
=
os
.
path
.
join
(
cache_dir
,
f
"rank_
{
rank
}
_
{
dp_rank
}
"
)
...
...
vllm/model_executor/models/llama_eagle.py
View file @
70788bdb
...
@@ -6,7 +6,8 @@ import torch
...
@@ -6,7 +6,8 @@ import torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
transformers
import
LlamaConfig
from
transformers
import
LlamaConfig
from
vllm.config
import
ModelConfig
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
...
@@ -37,17 +38,19 @@ class LlamaDecoderLayer(LlamaDecoderLayer):
...
@@ -37,17 +38,19 @@ class LlamaDecoderLayer(LlamaDecoderLayer):
self
.
input_layernorm
=
nn
.
Identity
()
self
.
input_layernorm
=
nn
.
Identity
()
@
support_torch_compile
class
LlamaModel
(
nn
.
Module
):
class
LlamaModel
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
self
,
*
,
*
,
model_config
:
ModelConfig
,
vllm_config
:
VllmConfig
,
start_layer_id
:
int
=
0
,
prefix
:
str
=
""
,
prefix
:
str
=
""
,
start_layer_id
:
int
=
0
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
config
=
model_config
.
hf_config
self
.
config
=
vllm_config
.
\
speculative_config
.
draft_model_config
.
hf_config
self
.
vocab_size
=
self
.
config
.
vocab_size
self
.
vocab_size
=
self
.
config
.
vocab_size
self
.
embed_tokens
=
VocabParallelEmbedding
(
self
.
embed_tokens
=
VocabParallelEmbedding
(
self
.
config
.
vocab_size
,
self
.
config
.
vocab_size
,
...
@@ -75,8 +78,7 @@ class LlamaModel(nn.Module):
...
@@ -75,8 +78,7 @@ class LlamaModel(nn.Module):
hidden_states
=
self
.
fc
(
hidden_states
=
self
.
fc
(
torch
.
cat
((
input_embeds
,
hidden_states
),
dim
=-
1
))
torch
.
cat
((
input_embeds
,
hidden_states
),
dim
=-
1
))
residual
=
None
residual
=
None
for
i
in
range
(
len
(
self
.
layers
)):
for
layer
in
self
.
layers
:
layer
=
self
.
layers
[
i
]
hidden_states
,
residual
=
layer
(
hidden_states
,
residual
=
layer
(
positions
,
positions
,
hidden_states
,
hidden_states
,
...
@@ -117,12 +119,13 @@ class LlamaModel(nn.Module):
...
@@ -117,12 +119,13 @@ class LlamaModel(nn.Module):
class
EagleLlamaForCausalLM
(
LlamaForCausalLM
):
class
EagleLlamaForCausalLM
(
LlamaForCausalLM
):
def
__init__
(
self
,
*
,
model
_config
:
Model
Config
,
start_layer_id
:
int
=
0
):
def
__init__
(
self
,
*
,
vllm
_config
:
Vllm
Config
,
start_layer_id
:
int
=
0
):
nn
.
Module
.
__init__
(
self
)
nn
.
Module
.
__init__
(
self
)
self
.
config
=
model_config
.
hf_config
self
.
config
=
vllm_config
.
\
self
.
model
=
LlamaModel
(
model_config
=
model_config
,
speculative_config
.
draft_model_config
.
hf_config
start_layer_id
=
start_layer_id
,
self
.
model
=
LlamaModel
(
vllm_config
=
vllm_config
,
prefix
=
"model"
)
prefix
=
"model"
,
start_layer_id
=
start_layer_id
)
logit_scale
=
getattr
(
self
.
config
,
"logit_scale"
,
1.0
)
logit_scale
=
getattr
(
self
.
config
,
"logit_scale"
,
1.0
)
self
.
logits_processor
=
LogitsProcessor
(
self
.
config
.
vocab_size
,
self
.
logits_processor
=
LogitsProcessor
(
self
.
config
.
vocab_size
,
...
...
vllm/model_executor/models/llama_eagle3.py
View file @
70788bdb
...
@@ -6,7 +6,7 @@ import torch
...
@@ -6,7 +6,7 @@ import torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
transformers
import
LlamaConfig
from
transformers
import
LlamaConfig
from
vllm.config
import
ModelConfig
from
vllm.config
import
ModelConfig
,
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
QKVParallelLinear
from
vllm.model_executor.layers.linear
import
QKVParallelLinear
...
@@ -167,8 +167,9 @@ class LlamaModel(nn.Module):
...
@@ -167,8 +167,9 @@ class LlamaModel(nn.Module):
class
Eagle3LlamaForCausalLM
(
LlamaForCausalLM
):
class
Eagle3LlamaForCausalLM
(
LlamaForCausalLM
):
def
__init__
(
self
,
*
,
model
_config
:
Model
Config
,
start_layer_id
:
int
=
0
):
def
__init__
(
self
,
*
,
vllm
_config
:
Vllm
Config
,
start_layer_id
:
int
=
0
):
nn
.
Module
.
__init__
(
self
)
nn
.
Module
.
__init__
(
self
)
model_config
=
vllm_config
.
speculative_config
.
draft_model_config
self
.
config
=
model_config
.
hf_config
self
.
config
=
model_config
.
hf_config
self
.
model
=
LlamaModel
(
model_config
=
model_config
,
self
.
model
=
LlamaModel
(
model_config
=
model_config
,
start_layer_id
=
start_layer_id
,
start_layer_id
=
start_layer_id
,
...
...
vllm/v1/spec_decode/eagle.py
View file @
70788bdb
...
@@ -4,7 +4,7 @@ import torch.nn as nn
...
@@ -4,7 +4,7 @@ import torch.nn as nn
import
triton
import
triton
import
triton.language
as
tl
import
triton.language
as
tl
from
vllm.config
import
VllmConfig
,
set_current_vllm_config
from
vllm.config
import
CompilationLevel
,
VllmConfig
,
set_current_vllm_config
from
vllm.forward_context
import
set_forward_context
from
vllm.forward_context
import
set_forward_context
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.model_loader.loader
import
get_model_loader
from
vllm.model_executor.model_loader.loader
import
get_model_loader
...
@@ -26,10 +26,41 @@ class EagleProposer:
...
@@ -26,10 +26,41 @@ class EagleProposer:
device
:
torch
.
device
,
device
:
torch
.
device
,
):
):
self
.
vllm_config
=
vllm_config
self
.
vllm_config
=
vllm_config
self
.
method
=
self
.
vllm_config
.
speculative_config
.
method
self
.
num_speculative_tokens
=
(
self
.
num_speculative_tokens
=
(
vllm_config
.
speculative_config
.
num_speculative_tokens
)
vllm_config
.
speculative_config
.
num_speculative_tokens
)
self
.
max_model_len
=
vllm_config
.
model_config
.
max_model_len
self
.
max_model_len
=
vllm_config
.
model_config
.
max_model_len
self
.
block_size
=
vllm_config
.
cache_config
.
block_size
self
.
block_size
=
vllm_config
.
cache_config
.
block_size
self
.
dtype
=
vllm_config
.
model_config
.
dtype
self
.
max_num_tokens
=
vllm_config
.
scheduler_config
\
.
max_num_batched_tokens
self
.
hidden_size
=
vllm_config
.
model_config
.
get_hidden_size
()
# TODO: make eagle3 compatible with cudagraph
self
.
use_cuda_graph
=
self
.
method
!=
'eagle3'
and
\
(
self
.
vllm_config
.
compilation_config
.
level
==
CompilationLevel
.
PIECEWISE
and
not
self
.
vllm_config
.
model_config
.
enforce_eager
)
self
.
cudagraph_batch_sizes
=
list
(
reversed
(
self
.
vllm_config
.
compilation_config
.
cudagraph_capture_sizes
))
# persistent buffers for cuda graph
self
.
input_ids
=
torch
.
zeros
(
self
.
max_num_tokens
,
dtype
=
torch
.
int32
,
device
=
device
)
self
.
positions
=
torch
.
zeros
(
self
.
max_num_tokens
,
dtype
=
torch
.
int64
,
device
=
device
)
self
.
hidden_states
=
torch
.
zeros
(
(
self
.
max_num_tokens
,
self
.
hidden_size
),
dtype
=
self
.
dtype
,
device
=
device
)
# We need +1 here because the arange is used to set query_start_loc,
# We need +1 here because the arange is used to set query_start_loc,
# which has one more element than batch_size.
# which has one more element than batch_size.
self
.
arange
=
torch
.
arange
(
vllm_config
.
scheduler_config
.
max_num_seqs
+
self
.
arange
=
torch
.
arange
(
vllm_config
.
scheduler_config
.
max_num_seqs
+
...
@@ -59,13 +90,12 @@ class EagleProposer:
...
@@ -59,13 +90,12 @@ class EagleProposer:
batch_size
=
next_token_ids
.
shape
[
0
]
batch_size
=
next_token_ids
.
shape
[
0
]
last_token_indices
=
cu_num_tokens
[
1
:]
-
1
last_token_indices
=
cu_num_tokens
[
1
:]
-
1
input_ids
=
torch
.
empty_like
(
target_token_ids
)
# Shift the input ids by one token.
# Shift the input ids by one token.
# E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
# E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
input_ids
[:
-
1
]
=
target_token_ids
[
1
:]
self
.
input_ids
[:
num_tokens
-
1
]
=
target_token_ids
[
1
:]
# Replace the last token with the next token.
# Replace the last token with the next token.
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
input_ids
[
last_token_indices
]
=
next_token_ids
self
.
input_ids
[
last_token_indices
]
=
next_token_ids
# FA requires seq_len to have dtype int32.
# FA requires seq_len to have dtype int32.
seq_lens
=
(
target_positions
[
last_token_indices
]
+
1
).
int
()
seq_lens
=
(
target_positions
[
last_token_indices
]
+
1
).
int
()
...
@@ -88,14 +118,30 @@ class EagleProposer:
...
@@ -88,14 +118,30 @@ class EagleProposer:
prefix_kv_lens
=
None
,
prefix_kv_lens
=
None
,
suffix_kv_lens
=
None
,
suffix_kv_lens
=
None
,
)
)
if
self
.
use_cuda_graph
and
\
num_tokens
<=
self
.
cudagraph_batch_sizes
[
-
1
]:
num_input_tokens
=
self
.
vllm_config
.
pad_for_cudagraph
(
num_tokens
)
else
:
num_input_tokens
=
num_tokens
# copy inputs to buffer for cudagraph
self
.
positions
[:
num_tokens
]
=
target_positions
with
set_forward_context
(
attn_metadata
,
self
.
vllm_config
):
if
self
.
method
==
'eagle'
:
hidden_states_logits
,
hidden_states_fwd
=
self
.
model
(
self
.
hidden_states
[:
num_tokens
]
=
target_hidden_states
input_ids
=
input_ids
,
hidden_states
=
self
.
hidden_states
hidden_states
=
target_hidden_states
,
else
:
positions
=
target_positions
,
# TODO: make eagle3 compatible with cuda graph
hidden_states
=
target_hidden_states
with
set_forward_context
(
attn_metadata
,
self
.
vllm_config
,
num_tokens
=
num_input_tokens
):
last_hidden_states
,
hidden_states
=
self
.
model
(
input_ids
=
self
.
input_ids
[:
num_input_tokens
],
positions
=
self
.
positions
[:
num_input_tokens
],
hidden_states
=
hidden_states
[:
num_input_tokens
],
)
)
sample_hidden_states
=
hidden_states
_logits
[
last_token_indices
]
sample_hidden_states
=
last_
hidden_states
[
last_token_indices
]
logits
=
self
.
model
.
compute_logits
(
sample_hidden_states
,
None
)
logits
=
self
.
model
.
compute_logits
(
sample_hidden_states
,
None
)
draft_token_ids
=
logits
.
argmax
(
dim
=-
1
)
draft_token_ids
=
logits
.
argmax
(
dim
=-
1
)
...
@@ -108,13 +154,20 @@ class EagleProposer:
...
@@ -108,13 +154,20 @@ class EagleProposer:
draft_token_ids_list
=
[
draft_token_ids
]
draft_token_ids_list
=
[
draft_token_ids
]
positions
=
target_positions
[
last_token_indices
]
positions
=
target_positions
[
last_token_indices
]
hidden_states
=
hidden_states_fwd
[
last_token_indices
]
hidden_states
=
hidden_states
[
last_token_indices
]
if
self
.
use_cuda_graph
and
\
batch_size
<=
self
.
cudagraph_batch_sizes
[
-
1
]:
input_batch_size
=
self
.
vllm_config
.
pad_for_cudagraph
(
batch_size
)
else
:
input_batch_size
=
batch_size
attn_metadata
.
num_actual_tokens
=
batch_size
attn_metadata
.
num_actual_tokens
=
batch_size
attn_metadata
.
max_query_len
=
1
attn_metadata
.
max_query_len
=
1
attn_metadata
.
query_start_loc
=
self
.
arange
[:
batch_size
+
1
]
attn_metadata
.
query_start_loc
=
self
.
arange
[:
batch_size
+
1
]
for
_
in
range
(
self
.
num_speculative_tokens
-
1
):
for
_
in
range
(
self
.
num_speculative_tokens
-
1
):
# Update the inputs.
# Update the inputs.
input_ids
=
draft_token_ids_list
[
-
1
]
# cast to int32 is crucial when eagle model is compiled.
# tensor.argmax() returns int64 by default.
input_ids
=
draft_token_ids_list
[
-
1
].
int
()
positions
+=
1
positions
+=
1
# NOTE(woosuk): We should handle the case where the draft model
# NOTE(woosuk): We should handle the case where the draft model
...
@@ -152,14 +205,27 @@ class EagleProposer:
...
@@ -152,14 +205,27 @@ class EagleProposer:
attn_metadata
.
slot_mapping
.
masked_fill_
(
exceeds_max_model_len
,
attn_metadata
.
slot_mapping
.
masked_fill_
(
exceeds_max_model_len
,
PADDING_SLOT_ID
)
PADDING_SLOT_ID
)
# copy inputs to buffer for cudagraph
self
.
input_ids
[:
batch_size
]
=
input_ids
self
.
positions
[:
batch_size
]
=
clamped_positions
if
self
.
method
==
'eagle'
:
# TODO: make eagle3 compatible with cudagraph.
self
.
hidden_states
[:
batch_size
]
=
hidden_states
hidden_states
=
self
.
hidden_states
# Run the model.
# Run the model.
with
set_forward_context
(
attn_metadata
,
self
.
vllm_config
):
with
set_forward_context
(
attn_metadata
,
hidden_states_logits
,
hidden_states
=
self
.
model
(
self
.
vllm_config
,
input_ids
=
input_ids
,
num_tokens
=
input_batch_size
):
hidden_states
=
hidden_states
,
last_hidden_states
,
hidden_states
=
self
.
model
(
positions
=
clamped_positions
,
input_ids
=
self
.
input_ids
[:
input_batch_size
],
positions
=
self
.
positions
[:
input_batch_size
],
hidden_states
=
hidden_states
[:
input_batch_size
],
)
)
logits
=
self
.
model
.
compute_logits
(
hidden_states_logits
,
None
)
hidden_states
=
hidden_states
[:
batch_size
]
logits
=
self
.
model
.
compute_logits
(
last_hidden_states
[:
batch_size
],
None
)
draft_token_ids
=
logits
.
argmax
(
dim
=-
1
)
draft_token_ids
=
logits
.
argmax
(
dim
=-
1
)
draft_token_ids_list
.
append
(
draft_token_ids
)
draft_token_ids_list
.
append
(
draft_token_ids
)
...
@@ -227,13 +293,11 @@ class EagleProposer:
...
@@ -227,13 +293,11 @@ class EagleProposer:
draft_model_cls
,
arch
=
ModelRegistry
.
resolve_model_cls
(
draft_model_cls
,
arch
=
ModelRegistry
.
resolve_model_cls
(
draft_model_config
.
architectures
)
draft_model_config
.
architectures
)
self
.
model
=
draft_model_cls
(
self
.
model
=
draft_model_cls
(
model
_config
=
draft_model
_config
,
vllm
_config
=
self
.
vllm
_config
,
start_layer_id
=
target_layer_num
).
to
(
target_device
)
start_layer_id
=
target_layer_num
).
to
(
target_device
)
loaded_weights
=
self
.
model
.
load_weights
(
loaded_weights
=
self
.
model
.
load_weights
(
loader
.
get_all_weights
(
loader
.
get_all_weights
(
draft_model_config
,
self
.
model
))
self
.
vllm_config
.
speculative_config
.
draft_model_config
,
self
.
model
))
if
self
.
vllm_config
.
speculative_config
.
method
==
"eagle3"
:
if
self
.
vllm_config
.
speculative_config
.
method
==
"eagle3"
:
if
"model.embed_tokens.weight"
not
in
loaded_weights
:
if
"model.embed_tokens.weight"
not
in
loaded_weights
:
logger
.
info
(
logger
.
info
(
...
@@ -243,6 +307,20 @@ class EagleProposer:
...
@@ -243,6 +307,20 @@ class EagleProposer:
logger
.
info
(
"Loading EAGLE LM head weights from the target model."
)
logger
.
info
(
"Loading EAGLE LM head weights from the target model."
)
self
.
model
.
lm_head
=
target_model
.
lm_head
self
.
model
.
lm_head
=
target_model
.
lm_head
@
torch
.
inference_mode
()
def
dummy_run
(
self
,
num_tokens
:
int
,
)
->
None
:
with
set_forward_context
(
None
,
self
.
vllm_config
,
num_tokens
=
num_tokens
):
if
self
.
method
==
'eagle'
:
self
.
model
(
input_ids
=
self
.
input_ids
[:
num_tokens
],
positions
=
self
.
positions
[:
num_tokens
],
hidden_states
=
self
.
hidden_states
[:
num_tokens
],
)
# NOTE(woosuk): Currently, the below code is not used and we always use argmax
# NOTE(woosuk): Currently, the below code is not used and we always use argmax
# to sample the draft tokens. We will use this after we find a way to manage
# to sample the draft tokens. We will use this after we find a way to manage
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
70788bdb
...
@@ -1106,7 +1106,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1106,7 +1106,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# For mid-pipeline stages, return the hidden states.
# For mid-pipeline stages, return the hidden states.
return
hidden_states
return
hidden_states
hidden_states
=
hidden_states
[:
num_scheduled_tokens
]
sample_hidden_states
=
hidden_states
[
logits_indices
]
sample_hidden_states
=
hidden_states
[
logits_indices
]
logits
=
self
.
model
.
compute_logits
(
sample_hidden_states
,
None
)
logits
=
self
.
model
.
compute_logits
(
sample_hidden_states
,
None
)
...
@@ -1172,7 +1171,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1172,7 +1171,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Compute prompt logprobs if needed.
# Compute prompt logprobs if needed.
prompt_logprobs_dict
=
self
.
_get_prompt_logprobs_dict
(
prompt_logprobs_dict
=
self
.
_get_prompt_logprobs_dict
(
hidden_states
,
hidden_states
[:
num_scheduled_tokens
]
,
scheduler_output
,
scheduler_output
,
)
)
...
@@ -1222,15 +1221,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1222,15 +1221,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if
spec_decode_metadata
is
None
:
if
spec_decode_metadata
is
None
:
# input_ids can be None for multimodal models.
# input_ids can be None for multimodal models.
# We need to slice token_ids, positions, and hidden_states
# because the eagle head does not use cuda graph and should
# not include padding.
target_token_ids
=
self
.
input_ids
[:
num_scheduled_tokens
]
target_token_ids
=
self
.
input_ids
[:
num_scheduled_tokens
]
target_positions
=
positions
[:
num_scheduled_tokens
]
target_positions
=
positions
[:
num_scheduled_tokens
]
if
self
.
use_aux_hidden_state_outputs
:
if
self
.
use_aux_hidden_state_outputs
:
target_hidden_states
=
[
target_hidden_states
=
torch
.
cat
(
h
[:
num_scheduled_tokens
]
for
h
in
aux_hidden_states
[
h
[:
num_scheduled_tokens
]
for
h
in
aux_hidden_states
],
]
dim
=-
1
)
else
:
else
:
target_hidden_states
=
hidden_states
[:
num_scheduled_tokens
]
target_hidden_states
=
hidden_states
[:
num_scheduled_tokens
]
target_slot_mapping
=
attn_metadata
.
slot_mapping
target_slot_mapping
=
attn_metadata
.
slot_mapping
...
@@ -1254,15 +1250,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1254,15 +1250,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
target_token_ids
=
self
.
input_ids
[
token_indices
]
target_token_ids
=
self
.
input_ids
[
token_indices
]
target_positions
=
positions
[
token_indices
]
target_positions
=
positions
[
token_indices
]
if
self
.
use_aux_hidden_state_outputs
:
if
self
.
use_aux_hidden_state_outputs
:
target_hidden_states
=
[
target_hidden_states
=
torch
.
cat
(
h
[
token_indices
]
for
h
in
aux_hidden_states
[
h
[
token_indices
]
for
h
in
aux_hidden_states
],
dim
=-
1
)
]
else
:
else
:
target_hidden_states
=
hidden_states
[
token_indices
]
target_hidden_states
=
hidden_states
[
token_indices
]
target_slot_mapping
=
attn_metadata
.
slot_mapping
[
token_indices
]
target_slot_mapping
=
attn_metadata
.
slot_mapping
[
token_indices
]
if
self
.
use_aux_hidden_state_outputs
:
target_hidden_states
=
torch
.
cat
(
target_hidden_states
,
dim
=-
1
)
draft_token_ids
=
self
.
drafter
.
propose
(
draft_token_ids
=
self
.
drafter
.
propose
(
target_token_ids
=
target_token_ids
,
target_token_ids
=
target_token_ids
,
target_positions
=
target_positions
,
target_positions
=
target_positions
,
...
@@ -1506,6 +1499,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1506,6 +1499,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
else
:
else
:
hidden_states
=
outputs
hidden_states
=
outputs
if
self
.
use_spec_decode
and
\
self
.
speculative_config
.
method
in
(
'eagle'
,
'eagle3'
):
assert
isinstance
(
self
.
drafter
,
EagleProposer
)
self
.
drafter
.
dummy_run
(
num_tokens
)
logit_indices
=
np
.
cumsum
(
num_scheduled_tokens
)
-
1
logit_indices
=
np
.
cumsum
(
num_scheduled_tokens
)
-
1
return
hidden_states
[
logit_indices
]
return
hidden_states
[
logit_indices
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment