Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
c70ac4b8
Unverified
Commit
c70ac4b8
authored
Sep 26, 2025
by
qizixi
Committed by
GitHub
Sep 26, 2025
Browse files
[spec decode] Consolidate speculative decode method name for MTP (#25232)
Signed-off-by:
zixi-qi
<
qizixi@meta.com
>
parent
cf892028
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
287 additions
and
40 deletions
+287
-40
examples/offline_inference/spec_decode.py
examples/offline_inference/spec_decode.py
+3
-2
tests/v1/e2e/test_spec_decode.py
tests/v1/e2e/test_spec_decode.py
+65
-0
tests/v1/spec_decode/test_mtp.py
tests/v1/spec_decode/test_mtp.py
+195
-0
vllm/config/speculative.py
vllm/config/speculative.py
+19
-31
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+1
-1
vllm/v1/spec_decode/eagle.py
vllm/v1/spec_decode/eagle.py
+4
-6
No files found.
examples/offline_inference/spec_decode.py
View file @
c70ac4b8
...
...
@@ -54,6 +54,7 @@ def parse_args():
"--method"
,
type
=
str
,
default
=
"eagle"
,
choices
=
[
"ngram"
,
"eagle"
,
"eagle3"
,
"mtp"
],
)
parser
.
add_argument
(
"--num-spec-tokens"
,
type
=
int
,
default
=
2
)
parser
.
add_argument
(
"--prompt-lookup-max"
,
type
=
int
,
default
=
5
)
...
...
@@ -118,9 +119,9 @@ def main(args):
"prompt_lookup_max"
:
args
.
prompt_lookup_max
,
"prompt_lookup_min"
:
args
.
prompt_lookup_min
,
}
elif
args
.
method
.
endswith
(
"mtp"
)
:
elif
args
.
method
==
"mtp"
:
speculative_config
=
{
"method"
:
args
.
method
,
"method"
:
"mtp"
,
"num_speculative_tokens"
:
args
.
num_spec_tokens
,
}
else
:
...
...
tests/v1/e2e/test_spec_decode.py
View file @
c70ac4b8
...
...
@@ -15,6 +15,8 @@ from vllm.assets.image import VLM_IMAGES_DIR
from
vllm.distributed
import
cleanup_dist_env_and_memory
from
vllm.platforms
import
current_platform
MTP_SIMILARITY_RATE
=
0.8
def
get_test_prompts
(
mm_enabled
:
bool
):
prompt_types
=
[
"repeat"
,
"sentence"
]
...
...
@@ -222,3 +224,66 @@ def test_eagle_correctness(
del
spec_llm
torch
.
cuda
.
empty_cache
()
cleanup_dist_env_and_memory
()
@
pytest
.
mark
.
parametrize
([
"model_setup"
,
"mm_enabled"
],
[
((
"mtp"
,
"XiaomiMiMo/MiMo-7B-Base"
,
1
),
False
),
((
"mtp"
,
"ZixiQi/DeepSeek-V3-4layers-MTP-FP8"
,
1
),
False
),
],
ids
=
[
"mimo"
,
"deepseek"
])
def
test_mtp_correctness
(
monkeypatch
:
pytest
.
MonkeyPatch
,
sampling_config
:
SamplingParams
,
model_setup
:
tuple
[
str
,
str
,
int
],
mm_enabled
:
bool
,
):
# Generate test prompts inside the function instead of using fixture
test_prompts
=
get_test_prompts
(
mm_enabled
)
'''
Compare the outputs of a original LLM and a speculative LLM
should be the same when using MTP speculative decoding.
model_setup: (method, model_name, tp_size)
'''
with
monkeypatch
.
context
()
as
m
:
m
.
setenv
(
"VLLM_USE_V1"
,
"1"
)
m
.
setenv
(
"VLLM_MLA_DISABLE"
,
"1"
)
method
,
model_name
,
tp_size
=
model_setup
ref_llm
=
LLM
(
model
=
model_name
,
max_model_len
=
2048
,
tensor_parallel_size
=
tp_size
,
trust_remote_code
=
True
)
ref_outputs
=
ref_llm
.
chat
(
test_prompts
,
sampling_config
)
del
ref_llm
torch
.
cuda
.
empty_cache
()
cleanup_dist_env_and_memory
()
spec_llm
=
LLM
(
model
=
model_name
,
trust_remote_code
=
True
,
tensor_parallel_size
=
tp_size
,
speculative_config
=
{
"method"
:
method
,
"num_speculative_tokens"
:
1
,
"max_model_len"
:
2048
,
},
max_model_len
=
2048
,
)
spec_outputs
=
spec_llm
.
chat
(
test_prompts
,
sampling_config
)
matches
=
0
misses
=
0
for
ref_output
,
spec_output
in
zip
(
ref_outputs
,
spec_outputs
):
if
ref_output
.
outputs
[
0
].
text
==
spec_output
.
outputs
[
0
].
text
:
matches
+=
1
else
:
misses
+=
1
print
(
f
"ref_output:
{
ref_output
.
outputs
[
0
].
text
}
"
)
print
(
f
"spec_output:
{
spec_output
.
outputs
[
0
].
text
}
"
)
# Heuristic: expect at least 80% of the prompts to match exactly
# Upon failure, inspect the outputs to check for inaccuracy.
assert
matches
>
int
(
MTP_SIMILARITY_RATE
*
len
(
ref_outputs
))
del
spec_llm
torch
.
cuda
.
empty_cache
()
cleanup_dist_env_and_memory
()
tests/v1/spec_decode/test_mtp.py
0 → 100644
View file @
c70ac4b8
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
unittest
import
mock
import
pytest
import
torch
from
tests.v1.attention.utils
import
(
BatchSpec
,
_Backend
,
create_common_attn_metadata
,
create_standard_kv_cache_spec
,
get_attention_backend
)
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
,
SpeculativeConfig
,
VllmConfig
)
from
vllm.config.load
import
LoadConfig
from
vllm.model_executor.models.llama
import
LlamaForCausalLM
from
vllm.platforms
import
current_platform
from
vllm.v1.spec_decode.eagle
import
EagleProposer
mimo_7b_dir
=
"XiaomiMiMo/MiMo-7B-Base"
def
_create_mtp_proposer
(
num_speculative_tokens
:
int
)
->
EagleProposer
:
"""Create an MTP proposer with unified model configuration."""
model_config
=
ModelConfig
(
model
=
mimo_7b_dir
,
runner
=
"generate"
,
max_model_len
=
100
,
trust_remote_code
=
True
)
speculative_config
=
SpeculativeConfig
(
target_model_config
=
model_config
,
target_parallel_config
=
ParallelConfig
(),
model
=
mimo_7b_dir
,
method
=
"mtp"
,
num_speculative_tokens
=
num_speculative_tokens
,
)
vllm_config
=
VllmConfig
(
model_config
=
model_config
,
cache_config
=
CacheConfig
(),
speculative_config
=
speculative_config
,
device_config
=
DeviceConfig
(
device
=
current_platform
.
device_type
),
parallel_config
=
ParallelConfig
(),
load_config
=
LoadConfig
(),
scheduler_config
=
SchedulerConfig
())
return
EagleProposer
(
vllm_config
=
vllm_config
,
device
=
current_platform
.
device_type
)
@
mock
.
patch
(
'vllm.v1.spec_decode.eagle.get_pp_group'
)
@
mock
.
patch
(
'vllm.v1.spec_decode.eagle.get_layers_from_vllm_config'
)
@
mock
.
patch
(
'vllm.v1.spec_decode.eagle.get_model'
)
def
test_mtp_load_model_unified
(
mock_get_model
,
mock_get_layers
,
mock_get_pp_group
):
"""Test MTP-specific model loading with unified model approach."""
# Setup mocks
mock_model
=
mock
.
MagicMock
()
mock_model
.
model
.
embed_tokens
.
weight
.
shape
=
(
131072
,
4096
)
mock_get_model
.
return_value
=
mock_model
target_attn_layers
=
{
"target_attn_1"
:
mock
.
MagicMock
()}
all_attn_layers
=
{
**
target_attn_layers
,
"draft_attn_1"
:
mock
.
MagicMock
()}
mock_get_layers
.
side_effect
=
[
target_attn_layers
,
all_attn_layers
]
mock_pp_group
=
mock
.
MagicMock
()
mock_pp_group
.
world_size
=
1
mock_get_pp_group
.
return_value
=
mock_pp_group
# Create target model
class
_TargetModelStub
(
LlamaForCausalLM
):
model
:
mock
.
MagicMock
lm_head
:
mock
.
MagicMock
target_model
=
mock
.
create_autospec
(
_TargetModelStub
,
instance
=
True
)
target_model
.
model
=
mock
.
MagicMock
()
target_model
.
model
.
embed_tokens
.
weight
.
shape
=
(
131072
,
4096
)
target_model
.
lm_head
=
mock
.
MagicMock
()
# Create MTP proposer
proposer
=
_create_mtp_proposer
(
num_speculative_tokens
=
4
)
proposer
.
load_model
(
target_model
)
# Verify MTP-specific behavior:
# Model is loaded
mock_get_model
.
assert_called_once
()
# MTP shares lm_head with target model
assert
proposer
.
model
.
lm_head
==
target_model
.
lm_head
# MTP shares embed_tokens with target model
assert
proposer
.
model
.
model
.
embed_tokens
==
target_model
.
model
.
embed_tokens
@
pytest
.
mark
.
parametrize
(
"num_speculative_tokens"
,
[
1
])
def
test_mtp_propose
(
num_speculative_tokens
,
monkeypatch
):
"""Test that MTP's forward method returns hidden states directly"""
device
=
torch
.
device
(
current_platform
.
device_type
)
batch_size
=
2
seq_lens
=
[
5
,
3
]
total_tokens
=
sum
(
seq_lens
)
vocab_size
=
100
proposer
=
_create_mtp_proposer
(
num_speculative_tokens
)
hidden_size
=
proposer
.
hidden_size
# Mock the MTP model to verify it returns hidden states directly
model_mock
=
mock
.
MagicMock
()
# MTP returns hidden states directly
if
num_speculative_tokens
==
1
:
model_mock
.
return_value
=
torch
.
zeros
(
total_tokens
,
hidden_size
,
device
=
device
)
else
:
# Multiple forward passes for multi-token speculation
forward_returns
=
[]
for
i
in
range
(
num_speculative_tokens
):
if
i
==
0
:
h_states
=
torch
.
zeros
(
total_tokens
,
hidden_size
,
device
=
device
)
else
:
h_states
=
torch
.
zeros
(
batch_size
,
hidden_size
,
device
=
device
)
forward_returns
.
append
(
h_states
)
model_mock
.
side_effect
=
forward_returns
# Mock compute_logits
def
create_deterministic_logits
(
batch_size
,
vocab_size
,
token_offset
):
logits
=
torch
.
full
((
batch_size
,
vocab_size
),
-
100.0
,
device
=
device
)
logits
[:,
token_offset
]
=
100.0
return
logits
if
num_speculative_tokens
==
1
:
model_mock
.
compute_logits
.
return_value
=
create_deterministic_logits
(
batch_size
,
vocab_size
,
42
)
else
:
logits_returns
=
[
create_deterministic_logits
(
batch_size
,
vocab_size
,
42
+
i
)
for
i
in
range
(
num_speculative_tokens
)
]
model_mock
.
compute_logits
.
side_effect
=
logits_returns
proposer
.
model
=
model_mock
proposer
.
attn_layer_names
=
[
"layer.0"
]
# Prepare inputs
batch_spec
=
BatchSpec
(
seq_lens
=
seq_lens
,
query_lens
=
seq_lens
)
common_attn_metadata
=
create_common_attn_metadata
(
batch_spec
,
block_size
=
16
,
device
=
device
)
target_token_ids
=
torch
.
randint
(
0
,
vocab_size
,
(
total_tokens
,
),
device
=
device
)
target_positions
=
torch
.
cat
([
torch
.
arange
(
seq_lens
[
0
],
device
=
device
),
torch
.
arange
(
seq_lens
[
1
],
device
=
device
)
])
target_hidden_states
=
torch
.
randn
(
total_tokens
,
hidden_size
,
device
=
device
)
next_token_ids
=
torch
.
randint
(
0
,
vocab_size
,
(
batch_size
,
),
dtype
=
torch
.
int32
,
device
=
device
)
sampling_metadata
=
mock
.
MagicMock
()
# Setup attention metadata
attn_metadata_builder_cls
,
_
=
get_attention_backend
(
_Backend
.
FLASH_ATTN
)
attn_metadata_builder
=
attn_metadata_builder_cls
(
kv_cache_spec
=
create_standard_kv_cache_spec
(
proposer
.
vllm_config
),
layer_names
=
proposer
.
attn_layer_names
,
vllm_config
=
proposer
.
vllm_config
,
device
=
device
,
)
proposer
.
runner
=
mock
.
MagicMock
()
proposer
.
attn_metadata_builder
=
attn_metadata_builder
# Run propose
result
=
proposer
.
propose
(
target_token_ids
=
target_token_ids
,
target_positions
=
target_positions
,
target_hidden_states
=
target_hidden_states
,
next_token_ids
=
next_token_ids
,
last_token_indices
=
None
,
common_attn_metadata
=
common_attn_metadata
,
sampling_metadata
=
sampling_metadata
)
# Verify the model was called correctly
assert
model_mock
.
called
# Verify output shape
assert
result
.
shape
==
(
batch_size
,
num_speculative_tokens
)
vllm/config/speculative.py
View file @
c70ac4b8
...
...
@@ -32,7 +32,9 @@ logger = init_logger(__name__)
SpeculativeMethod
=
Literal
[
"ngram"
,
"eagle"
,
"eagle3"
,
"medusa"
,
"mlp_speculator"
,
"draft_model"
,
"deepseek_mtp"
,
"ernie_mtp"
,
"qwen3_next_mtp"
,
"mimo_mtp"
,
"longcat_flash_mtp"
]
"longcat_flash_mtp"
,
"mtp"
]
MTP_MODEL_TYPES
=
(
"deepseek_mtp"
,
"mimo_mtp"
,
"glm4_moe_mtp"
,
"ernie_mtp"
,
"qwen3_next_mtp"
,
"longcat_flash_mtp"
)
@
config
...
...
@@ -207,11 +209,16 @@ class SpeculativeConfig:
# can not be detected, it will be considered as the "draft_model" by
# default.
if
self
.
method
in
MTP_MODEL_TYPES
:
logger
.
warning
(
"method `%s` is deprecated and replaced with mtp."
,
self
.
method
)
self
.
method
=
"mtp"
if
self
.
model
is
None
and
self
.
num_speculative_tokens
is
not
None
:
# TODO(Shangming): Refactor mtp configuration logic when supporting
if
(
self
.
target_model_config
and
self
.
target_model_config
.
hf_text_config
.
model_type
i
n
(
"deepseek_v3"
,
"mimo"
,
"ernie4_5_moe"
,
"qwen3_next"
)):
if
self
.
method
==
"mtp"
:
assert
(
self
.
target_model_config
i
s
not
None
),
"target_model_config must be present for mtp"
# use the draft model from the same model:
self
.
model
=
self
.
target_model_config
.
model
# Align the quantization of draft model for cases such as
...
...
@@ -312,31 +319,13 @@ class SpeculativeConfig:
"mlp_speculator"
):
self
.
method
=
"mlp_speculator"
elif
(
self
.
draft_model_config
.
hf_config
.
model_type
in
(
"deepseek_mtp"
,
"mimo_mtp"
,
"glm4_moe_mtp"
)
):
self
.
method
=
"
deepseek_
mtp"
in
MTP_MODEL_TYPES
):
self
.
method
=
"mtp"
if
self
.
num_speculative_tokens
>
1
:
logger
.
warning
(
"All Deepseek MTP models only have "
\
"one layer. Might need some code changes "
\
"to support multiple layers."
)
elif
(
self
.
draft_model_config
.
hf_config
.
model_type
==
"ernie_mtp"
):
self
.
method
=
"ernie_mtp"
if
self
.
num_speculative_tokens
>
1
:
logger
.
warning
(
"All Ernie MTP models only have "
\
"one layer. Might need some code changes "
\
"to support multiple layers."
)
elif
(
self
.
draft_model_config
.
hf_config
.
model_type
==
"qwen3_next_mtp"
):
self
.
method
=
"qwen3_next_mtp"
if
self
.
num_speculative_tokens
>
1
:
logger
.
warning
(
"All Qwen3Next MTP models only have "
\
"one layer. Might need some code changes "
\
"to support multiple layers."
"Enabling num_speculative_tokens > 1 will run"
\
"multiple times of forward on same MTP layer"
\
",which may result in lower acceptance rate"
\
)
elif
(
self
.
draft_model_config
.
hf_config
.
model_type
in
(
"longcat_flash_mtp"
)):
...
...
@@ -353,7 +342,7 @@ class SpeculativeConfig:
"Speculative decoding with draft model is not "
"supported yet. Please consider using other "
"speculative decoding methods such as ngram, medusa, "
"eagle, or
deepseek_
mtp."
)
"eagle, or mtp."
)
# Replace hf_config for EAGLE draft_model
if
self
.
method
in
(
"eagle"
,
"eagle3"
):
...
...
@@ -562,8 +551,7 @@ class SpeculativeConfig:
return
self
.
num_speculative_tokens
def
use_eagle
(
self
)
->
bool
:
return
self
.
method
in
(
"eagle"
,
"eagle3"
,
"deepseek_mtp"
,
"ernie_mtp"
,
"qwen3_next_mtp"
,
"longcat_flash_mtp"
)
return
self
.
method
in
(
"eagle"
,
"eagle3"
,
"mtp"
)
def
__repr__
(
self
)
->
str
:
method
=
self
.
method
...
...
vllm/engine/arg_utils.py
View file @
c70ac4b8
...
...
@@ -1481,7 +1481,7 @@ class EngineArgs:
raise
NotImplementedError
(
"Draft model speculative decoding is not supported yet. "
"Please consider using other speculative decoding methods "
"such as ngram, medusa, eagle, or
deepseek_
mtp."
)
"such as ngram, medusa, eagle, or mtp."
)
V1_BACKENDS
=
[
"FLASH_ATTN"
,
...
...
vllm/v1/spec_decode/eagle.py
View file @
c70ac4b8
...
...
@@ -222,8 +222,7 @@ class EagleProposer:
hidden_states
=
self
.
hidden_states
[:
num_input_tokens
],
inputs_embeds
=
inputs_embeds
,
)
if
self
.
method
in
(
"deepseek_mtp"
,
"ernie_mtp"
,
"qwen3_next_mtp"
,
"longcat_flash_mtp"
):
if
self
.
method
==
"mtp"
:
last_hidden_states
=
ret_hidden_states
hidden_states
=
last_hidden_states
else
:
...
...
@@ -352,8 +351,7 @@ class EagleProposer:
hidden_states
=
self
.
hidden_states
[:
input_batch_size
],
inputs_embeds
=
inputs_embeds
,
)
if
self
.
method
in
(
"deepseek_mtp"
,
"ernie_mtp"
,
"qwen3_next_mtp"
,
"longcat_flash_mtp"
):
if
self
.
method
==
"mtp"
:
last_hidden_states
=
ret_hidden_states
hidden_states
=
ret_hidden_states
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment