Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
83e69a09
Unverified
Commit
83e69a09
authored
Aug 20, 2025
by
Xin Yang
Committed by
GitHub
Aug 20, 2025
Browse files
[Model] Support deepseek with eagle (#21086)
Signed-off-by:
Xin Yang
<
xyangx@amazon.com
>
parent
3aa8c100
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
255 additions
and
1 deletion
+255
-1
tests/models/registry.py
tests/models/registry.py
+3
-0
tests/v1/e2e/test_spec_decode.py
tests/v1/e2e/test_spec_decode.py
+5
-1
vllm/model_executor/models/deepseek_eagle.py
vllm/model_executor/models/deepseek_eagle.py
+246
-0
vllm/model_executor/models/registry.py
vllm/model_executor/models/registry.py
+1
-0
No files found.
tests/models/registry.py
View file @
83e69a09
...
...
@@ -530,6 +530,9 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
"DeepSeekMTPModel"
:
_HfExamplesInfo
(
"luccafong/deepseek_mtp_main_random"
,
speculative_model
=
"luccafong/deepseek_mtp_draft_random"
,
# noqa: E501
trust_remote_code
=
True
),
"EagleDeepSeekMTPModel"
:
_HfExamplesInfo
(
"eagle618/deepseek-v3-random"
,
speculative_model
=
"eagle618/eagle-deepseek-v3-random"
,
# noqa: E501
trust_remote_code
=
True
),
"EagleLlamaForCausalLM"
:
_HfExamplesInfo
(
"yuhuili/EAGLE-LLaMA3-Instruct-8B"
,
trust_remote_code
=
True
,
speculative_model
=
"yuhuili/EAGLE-LLaMA3-Instruct-8B"
,
...
...
tests/v1/e2e/test_spec_decode.py
View file @
83e69a09
...
...
@@ -144,6 +144,8 @@ def test_ngram_correctness(
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct"
,
4
),
True
,
marks
=
pytest
.
mark
.
skip
(
reason
=
"Skipping due to CI OOM issues"
)),
((
"eagle"
,
"eagle618/deepseek-v3-random"
,
"eagle618/eagle-deepseek-v3-random"
,
1
),
False
),
],
ids
=
[
# TODO: Re-enable this once tests/models/test_initialization.py is fixed, see PR #22333 #22611 # noqa: E501
...
...
@@ -151,7 +153,8 @@ def test_ngram_correctness(
"llama3_eagle"
,
"llama3_eagle3"
,
"llama4_eagle"
,
"llama4_eagle_mm"
"llama4_eagle_mm"
,
"deepseek_eagle"
])
@
pytest
.
mark
.
parametrize
(
"attn_backend"
,
get_attn_backend_list_based_on_platform
())
...
...
@@ -177,6 +180,7 @@ def test_eagle_correctness(
'''
with
monkeypatch
.
context
()
as
m
:
m
.
setenv
(
"VLLM_USE_V1"
,
"1"
)
m
.
setenv
(
"VLLM_MLA_DISABLE"
,
"1"
)
m
.
setenv
(
"VLLM_ATTENTION_BACKEND"
,
attn_backend
)
if
(
attn_backend
==
"TRITON_ATTN_VLLM_V1"
...
...
vllm/model_executor/models/deepseek_eagle.py
0 → 100644
View file @
83e69a09
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
collections.abc
import
Iterable
from
typing
import
Optional
import
torch
import
torch.nn
as
nn
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
VllmConfig
from
vllm.distributed.parallel_state
import
get_pp_group
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
(
default_weight_loader
,
maybe_remap_kv_scale_name
)
from
vllm.model_executor.models.deepseek_v2
import
(
DeepseekV2DecoderLayer
,
DeepseekV3ForCausalLM
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
.utils
import
AutoWeightsLoader
,
maybe_prefix
@
support_torch_compile
class
DeepseekV2Model
(
nn
.
Module
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
,
start_layer_id
:
int
=
0
,
)
->
None
:
super
().
__init__
()
self
.
config
=
vllm_config
.
\
speculative_config
.
draft_model_config
.
hf_config
model_config
=
vllm_config
.
model_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
self
.
vocab_size
=
self
.
config
.
vocab_size
self
.
embed_tokens
=
VocabParallelEmbedding
(
self
.
config
.
vocab_size
,
self
.
config
.
hidden_size
,
quant_config
=
quant_config
,
prefix
=
maybe_prefix
(
prefix
,
"embed_tokens"
),
)
self
.
layers
=
nn
.
ModuleList
([
DeepseekV2DecoderLayer
(
self
.
config
,
prefix
=
maybe_prefix
(
prefix
,
f
"layers.
{
i
+
start_layer_id
}
"
),
model_config
=
model_config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
)
for
i
in
range
(
self
.
config
.
num_hidden_layers
)
])
self
.
fc
=
nn
.
Linear
(
self
.
config
.
model
.
hidden_size
*
2
,
self
.
config
.
model
.
hidden_size
,
bias
=
False
,
)
self
.
enorm
=
RMSNorm
(
self
.
config
.
hidden_size
,
eps
=
self
.
config
.
rms_norm_eps
)
self
.
hnorm
=
RMSNorm
(
self
.
config
.
hidden_size
,
eps
=
self
.
config
.
rms_norm_eps
)
self
.
norm
=
RMSNorm
(
self
.
config
.
hidden_size
,
eps
=
self
.
config
.
rms_norm_eps
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
input_embeds
=
self
.
embed_tokens
(
input_ids
)
inputs
=
torch
.
cat
(
[
self
.
enorm
(
input_embeds
),
self
.
hnorm
(
hidden_states
)],
dim
=-
1
)
hidden_states
=
self
.
fc
(
inputs
)
residual
=
None
for
layer
in
self
.
layers
:
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
residual
,
)
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
return
hidden_states
,
hidden_states
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
(
"fused_qkv_a_proj"
,
"q_a_proj"
,
0
),
(
"fused_qkv_a_proj"
,
"kv_a_proj_with_mqa"
,
1
),
]
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping
=
FusedMoE
.
make_expert_params_mapping
(
ckpt_gate_proj_name
=
"gate_proj"
,
ckpt_down_proj_name
=
"down_proj"
,
ckpt_up_proj_name
=
"up_proj"
,
num_experts
=
self
.
config
.
n_routed_experts
)
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
if
"rotary_emb.inv_freq"
in
name
:
continue
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
# Skip non-stacked layers and experts (experts handled below).
if
weight_name
not
in
name
:
continue
# We have mlp.experts[0].gate_proj in the checkpoint.
# Since we handle the experts below in expert_params_mapping,
# we need to skip here BEFORE we update the name, otherwise
# name will be updated to mlp.experts[0].gate_up_proj, which
# will then be updated below in expert_params_mapping
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if
(
"mlp.experts."
in
name
)
and
name
not
in
params_dict
:
continue
name_mapped
=
name
.
replace
(
weight_name
,
param_name
)
# QKV fusion is optional, fall back to normal
# weight loading if it's not enabled
# if go with fusion option, then update name
if
((
param_name
==
"fused_qkv_a_proj"
)
and
name_mapped
not
in
params_dict
):
continue
else
:
name
=
name_mapped
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
for
mapping
in
expert_params_mapping
:
param_name
,
weight_name
,
expert_id
,
shard_id
=
mapping
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
name
,
shard_id
=
shard_id
,
expert_id
=
expert_id
,
)
break
else
:
# if PP disabled then draft will share embed with target
if
get_pp_group
().
world_size
==
1
and
\
"embed_tokens."
in
name
:
continue
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
# Remapping the name of FP8 kv-scale.
name
=
maybe_remap_kv_scale_name
(
name
,
params_dict
)
if
name
is
None
:
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
return
loaded_params
class
EagleDeepseekV3ForCausalLM
(
DeepseekV3ForCausalLM
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
nn
.
Module
.
__init__
(
self
)
self
.
config
=
vllm_config
.
\
speculative_config
.
draft_model_config
.
hf_config
quant_config
=
vllm_config
.
quant_config
target_layer_num
=
vllm_config
.
model_config
.
get_num_layers
(
vllm_config
.
parallel_config
)
self
.
model
=
DeepseekV2Model
(
vllm_config
=
vllm_config
,
prefix
=
"model"
,
start_layer_id
=
target_layer_num
)
self
.
lm_head
=
ParallelLMHead
(
self
.
config
.
vocab_size
,
self
.
config
.
hidden_size
,
quant_config
=
quant_config
)
logit_scale
=
getattr
(
self
.
config
,
"logit_scale"
,
1.0
)
self
.
logits_processor
=
LogitsProcessor
(
self
.
config
.
vocab_size
,
scale
=
logit_scale
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
inputs_embeds
is
not
None
:
raise
NotImplementedError
(
f
"
{
type
(
self
).
__name__
}
does not support multimodal inputs yet."
)
return
self
.
model
(
input_ids
,
positions
,
hidden_states
)
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
return
logits
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]]):
loader
=
AutoWeightsLoader
(
self
,
skip_prefixes
=
None
,
)
model_weights
=
{}
for
name
,
loaded_weight
in
weights
:
if
"lm_head"
not
in
name
:
name
=
"model."
+
name
model_weights
[
name
]
=
loaded_weight
loader
.
load_weights
(
model_weights
.
items
())
vllm/model_executor/models/registry.py
View file @
83e69a09
...
...
@@ -264,6 +264,7 @@ _SPECULATIVE_DECODING_MODELS = {
"Eagle3LlamaForCausalLM"
:
(
"llama_eagle3"
,
"Eagle3LlamaForCausalLM"
),
# TODO: Re-enable this once tests/models/test_initialization.py is fixed, see PR #22333 #22611 # noqa: E501
# "LlamaForCausalLMEagle3": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
"EagleDeepSeekMTPModel"
:
(
"deepseek_eagle"
,
"EagleDeepseekV3ForCausalLM"
),
"DeepSeekMTPModel"
:
(
"deepseek_mtp"
,
"DeepSeekMTP"
),
"Glm4MoeMTPModel"
:
(
"glm4_moe_mtp"
,
"Glm4MoeMTP"
),
"MedusaModel"
:
(
"medusa"
,
"Medusa"
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment