Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
39c0813a
Unverified
Commit
39c0813a
authored
May 01, 2025
by
qizixi
Committed by
GitHub
May 01, 2025
Browse files
[V1][Spec Decode] Apply torch.compile & cudagraph to EAGLE3 (#17504)
Signed-off-by:
qizixi
<
qizixi@meta.com
>
parent
9b70e2b4
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
36 additions
and
31 deletions
+36
-31
vllm/model_executor/models/llama_eagle3.py
vllm/model_executor/models/llama_eagle3.py
+17
-8
vllm/v1/spec_decode/eagle.py
vllm/v1/spec_decode/eagle.py
+19
-23
No files found.
vllm/model_executor/models/llama_eagle3.py
View file @
39c0813a
...
@@ -6,7 +6,8 @@ import torch
...
@@ -6,7 +6,8 @@ import torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
transformers
import
LlamaConfig
from
transformers
import
LlamaConfig
from
vllm.config
import
ModelConfig
,
VllmConfig
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
QKVParallelLinear
from
vllm.model_executor.layers.linear
import
QKVParallelLinear
...
@@ -76,17 +77,19 @@ class LlamaDecoderLayer(LlamaDecoderLayer):
...
@@ -76,17 +77,19 @@ class LlamaDecoderLayer(LlamaDecoderLayer):
return
hidden_states
,
residual
return
hidden_states
,
residual
@
support_torch_compile
class
LlamaModel
(
nn
.
Module
):
class
LlamaModel
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
self
,
*
,
*
,
model
_config
:
Model
Config
,
vllm
_config
:
Vllm
Config
,
start_layer_id
:
int
=
0
,
start_layer_id
:
int
=
0
,
prefix
:
str
=
""
,
prefix
:
str
=
""
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
config
=
model_config
.
hf_config
self
.
config
=
vllm_config
.
\
speculative_config
.
draft_model_config
.
hf_config
self
.
vocab_size
=
self
.
config
.
vocab_size
self
.
vocab_size
=
self
.
config
.
vocab_size
self
.
embed_tokens
=
VocabParallelEmbedding
(
self
.
embed_tokens
=
VocabParallelEmbedding
(
self
.
config
.
vocab_size
,
self
.
config
.
vocab_size
,
...
@@ -119,8 +122,7 @@ class LlamaModel(nn.Module):
...
@@ -119,8 +122,7 @@ class LlamaModel(nn.Module):
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
input_embeds
=
self
.
embed_tokens
(
input_ids
)
input_embeds
=
self
.
embed_tokens
(
input_ids
)
if
(
hidden_states
.
shape
[
-
1
]
!=
input_embeds
.
shape
[
-
1
]):
assert
hidden_states
.
shape
[
-
1
]
==
input_embeds
.
shape
[
-
1
]
hidden_states
=
self
.
fc
(
hidden_states
)
residual
=
None
residual
=
None
hidden_states
,
residual
=
self
.
layers
[
0
](
hidden_states
,
residual
=
self
.
layers
[
0
](
...
@@ -169,9 +171,9 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM):
...
@@ -169,9 +171,9 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
start_layer_id
:
int
=
0
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
start_layer_id
:
int
=
0
):
nn
.
Module
.
__init__
(
self
)
nn
.
Module
.
__init__
(
self
)
model_
config
=
vllm_config
.
speculative_config
.
draft_model_config
self
.
config
=
vllm_config
.
\
self
.
config
=
model_config
.
hf_config
speculative_config
.
draft_
model_config
.
hf_config
self
.
model
=
LlamaModel
(
model
_config
=
model
_config
,
self
.
model
=
LlamaModel
(
vllm
_config
=
vllm
_config
,
start_layer_id
=
start_layer_id
,
start_layer_id
=
start_layer_id
,
prefix
=
"model"
)
prefix
=
"model"
)
...
@@ -214,6 +216,13 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM):
...
@@ -214,6 +216,13 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM):
logits_new
[:,
targets
]
=
logits
logits_new
[:,
targets
]
=
logits
return
logits_new
return
logits_new
def
combine_hidden_states
(
self
,
hidden_states
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
# combine multiple auxiliary hidden states returned by eagle3
return
self
.
model
.
fc
(
hidden_states
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
loader
=
AutoWeightsLoader
(
loader
=
AutoWeightsLoader
(
self
,
self
,
...
...
vllm/v1/spec_decode/eagle.py
View file @
39c0813a
...
@@ -10,6 +10,7 @@ from vllm.logger import init_logger
...
@@ -10,6 +10,7 @@ from vllm.logger import init_logger
from
vllm.model_executor.model_loader.loader
import
get_model_loader
from
vllm.model_executor.model_loader.loader
import
get_model_loader
from
vllm.model_executor.model_loader.utils
import
set_default_torch_dtype
from
vllm.model_executor.model_loader.utils
import
set_default_torch_dtype
from
vllm.model_executor.models
import
ModelRegistry
from
vllm.model_executor.models
import
ModelRegistry
from
vllm.model_executor.models.llama_eagle3
import
Eagle3LlamaForCausalLM
from
vllm.v1.attention.backends.flash_attn
import
FlashAttentionMetadata
from
vllm.v1.attention.backends.flash_attn
import
FlashAttentionMetadata
from
vllm.v1.sample.metadata
import
SamplingMetadata
from
vllm.v1.sample.metadata
import
SamplingMetadata
...
@@ -39,9 +40,7 @@ class EagleProposer:
...
@@ -39,9 +40,7 @@ class EagleProposer:
self
.
hidden_size
=
vllm_config
.
model_config
.
get_hidden_size
()
self
.
hidden_size
=
vllm_config
.
model_config
.
get_hidden_size
()
# TODO: make eagle3 compatible with cudagraph
self
.
use_cuda_graph
=
(
self
.
vllm_config
.
compilation_config
.
level
self
.
use_cuda_graph
=
self
.
method
!=
'eagle3'
and
\
(
self
.
vllm_config
.
compilation_config
.
level
==
CompilationLevel
.
PIECEWISE
and
==
CompilationLevel
.
PIECEWISE
and
not
self
.
vllm_config
.
model_config
.
enforce_eager
)
not
self
.
vllm_config
.
model_config
.
enforce_eager
)
...
@@ -90,6 +89,12 @@ class EagleProposer:
...
@@ -90,6 +89,12 @@ class EagleProposer:
batch_size
=
next_token_ids
.
shape
[
0
]
batch_size
=
next_token_ids
.
shape
[
0
]
last_token_indices
=
cu_num_tokens
[
1
:]
-
1
last_token_indices
=
cu_num_tokens
[
1
:]
-
1
if
self
.
method
==
"eagle3"
:
assert
isinstance
(
self
.
model
,
Eagle3LlamaForCausalLM
)
target_hidden_states
=
self
.
model
.
combine_hidden_states
(
target_hidden_states
)
assert
target_hidden_states
.
shape
[
-
1
]
==
self
.
hidden_size
# Shift the input ids by one token.
# Shift the input ids by one token.
# E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
# E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
self
.
input_ids
[:
num_tokens
-
1
]
=
target_token_ids
[
1
:]
self
.
input_ids
[:
num_tokens
-
1
]
=
target_token_ids
[
1
:]
...
@@ -126,12 +131,7 @@ class EagleProposer:
...
@@ -126,12 +131,7 @@ class EagleProposer:
# copy inputs to buffer for cudagraph
# copy inputs to buffer for cudagraph
self
.
positions
[:
num_tokens
]
=
target_positions
self
.
positions
[:
num_tokens
]
=
target_positions
if
self
.
method
==
'eagle'
:
self
.
hidden_states
[:
num_tokens
]
=
target_hidden_states
self
.
hidden_states
[:
num_tokens
]
=
target_hidden_states
hidden_states
=
self
.
hidden_states
else
:
# TODO: make eagle3 compatible with cuda graph
hidden_states
=
target_hidden_states
with
set_forward_context
(
attn_metadata
,
with
set_forward_context
(
attn_metadata
,
self
.
vllm_config
,
self
.
vllm_config
,
...
@@ -139,7 +139,7 @@ class EagleProposer:
...
@@ -139,7 +139,7 @@ class EagleProposer:
last_hidden_states
,
hidden_states
=
self
.
model
(
last_hidden_states
,
hidden_states
=
self
.
model
(
input_ids
=
self
.
input_ids
[:
num_input_tokens
],
input_ids
=
self
.
input_ids
[:
num_input_tokens
],
positions
=
self
.
positions
[:
num_input_tokens
],
positions
=
self
.
positions
[:
num_input_tokens
],
hidden_states
=
hidden_states
[:
num_input_tokens
],
hidden_states
=
self
.
hidden_states
[:
num_input_tokens
],
)
)
sample_hidden_states
=
last_hidden_states
[
last_token_indices
]
sample_hidden_states
=
last_hidden_states
[
last_token_indices
]
logits
=
self
.
model
.
compute_logits
(
sample_hidden_states
,
None
)
logits
=
self
.
model
.
compute_logits
(
sample_hidden_states
,
None
)
...
@@ -209,10 +209,7 @@ class EagleProposer:
...
@@ -209,10 +209,7 @@ class EagleProposer:
self
.
input_ids
[:
batch_size
]
=
input_ids
self
.
input_ids
[:
batch_size
]
=
input_ids
self
.
positions
[:
batch_size
]
=
clamped_positions
self
.
positions
[:
batch_size
]
=
clamped_positions
if
self
.
method
==
'eagle'
:
# TODO: make eagle3 compatible with cudagraph.
self
.
hidden_states
[:
batch_size
]
=
hidden_states
self
.
hidden_states
[:
batch_size
]
=
hidden_states
hidden_states
=
self
.
hidden_states
# Run the model.
# Run the model.
with
set_forward_context
(
attn_metadata
,
with
set_forward_context
(
attn_metadata
,
...
@@ -221,7 +218,7 @@ class EagleProposer:
...
@@ -221,7 +218,7 @@ class EagleProposer:
last_hidden_states
,
hidden_states
=
self
.
model
(
last_hidden_states
,
hidden_states
=
self
.
model
(
input_ids
=
self
.
input_ids
[:
input_batch_size
],
input_ids
=
self
.
input_ids
[:
input_batch_size
],
positions
=
self
.
positions
[:
input_batch_size
],
positions
=
self
.
positions
[:
input_batch_size
],
hidden_states
=
hidden_states
[:
input_batch_size
],
hidden_states
=
self
.
hidden_states
[:
input_batch_size
],
)
)
hidden_states
=
hidden_states
[:
batch_size
]
hidden_states
=
hidden_states
[:
batch_size
]
logits
=
self
.
model
.
compute_logits
(
last_hidden_states
[:
batch_size
],
logits
=
self
.
model
.
compute_logits
(
last_hidden_states
[:
batch_size
],
...
@@ -314,7 +311,6 @@ class EagleProposer:
...
@@ -314,7 +311,6 @@ class EagleProposer:
)
->
None
:
)
->
None
:
with
set_forward_context
(
None
,
self
.
vllm_config
,
with
set_forward_context
(
None
,
self
.
vllm_config
,
num_tokens
=
num_tokens
):
num_tokens
=
num_tokens
):
if
self
.
method
==
'eagle'
:
self
.
model
(
self
.
model
(
input_ids
=
self
.
input_ids
[:
num_tokens
],
input_ids
=
self
.
input_ids
[:
num_tokens
],
positions
=
self
.
positions
[:
num_tokens
],
positions
=
self
.
positions
[:
num_tokens
],
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment