Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
726d8972
Unverified
Commit
726d8972
authored
Jan 29, 2026
by
Lucas Kabela
Committed by
GitHub
Jan 30, 2026
Browse files
[CI] Enable mypy import following for `vllm/spec_decode` (#33282)
Signed-off-by:
Lucas Kabela
<
lucaskabela@meta.com
>
parent
d334dd26
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
38 additions
and
20 deletions
+38
-20
tools/pre_commit/mypy.py
tools/pre_commit/mypy.py
+0
-1
vllm/v1/spec_decode/draft_model.py
vllm/v1/spec_decode/draft_model.py
+7
-6
vllm/v1/spec_decode/eagle.py
vllm/v1/spec_decode/eagle.py
+24
-9
vllm/v1/spec_decode/medusa.py
vllm/v1/spec_decode/medusa.py
+6
-4
vllm/v1/spec_decode/suffix_decoding.py
vllm/v1/spec_decode/suffix_decoding.py
+1
-0
No files found.
tools/pre_commit/mypy.py
View file @
726d8972
...
@@ -32,7 +32,6 @@ SEPARATE_GROUPS = [
...
@@ -32,7 +32,6 @@ SEPARATE_GROUPS = [
"vllm/model_executor"
,
"vllm/model_executor"
,
# v1 related
# v1 related
"vllm/v1/kv_offload"
,
"vllm/v1/kv_offload"
,
"vllm/v1/spec_decode"
,
]
]
# TODO(woosuk): Include the code from Megatron and HuggingFace.
# TODO(woosuk): Include the code from Megatron and HuggingFace.
...
...
vllm/v1/spec_decode/draft_model.py
View file @
726d8972
...
@@ -5,7 +5,6 @@ from typing import Any
...
@@ -5,7 +5,6 @@ from typing import Any
import
torch
import
torch
from
vllm.config
import
VllmConfig
,
get_layers_from_vllm_config
from
vllm.config
import
VllmConfig
,
get_layers_from_vllm_config
from
vllm.config.speculative
import
SpeculativeConfig
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.attention
import
Attention
from
vllm.model_executor.layers.attention
import
Attention
from
vllm.model_executor.model_loader
import
get_model
from
vllm.model_executor.model_loader
import
get_model
...
@@ -56,7 +55,7 @@ class DraftModelProposer(SpecDecodeBaseProposer):
...
@@ -56,7 +55,7 @@ class DraftModelProposer(SpecDecodeBaseProposer):
)
)
def
_raise_if_padded_drafter_batch_disabled
(
self
):
def
_raise_if_padded_drafter_batch_disabled
(
self
):
if
self
.
vllm_config
.
speculative_config
.
disable_padded_drafter_batch
:
if
self
.
speculative_config
.
disable_padded_drafter_batch
:
raise
NotImplementedError
(
raise
NotImplementedError
(
"Speculative Decoding with draft models only supports "
"Speculative Decoding with draft models only supports "
"padded drafter batch. Please don't pass --disable-padded-drafter-batch"
"padded drafter batch. Please don't pass --disable-padded-drafter-batch"
...
@@ -64,7 +63,7 @@ class DraftModelProposer(SpecDecodeBaseProposer):
...
@@ -64,7 +63,7 @@ class DraftModelProposer(SpecDecodeBaseProposer):
)
)
def
_raise_if_vocab_size_mismatch
(
self
):
def
_raise_if_vocab_size_mismatch
(
self
):
self
.
vllm_config
.
speculative_config
.
verify_equal_vocab_size_if_draft_model
()
self
.
speculative_config
.
verify_equal_vocab_size_if_draft_model
()
def
_raise_if_draft_tp_mismatch
(
self
):
def
_raise_if_draft_tp_mismatch
(
self
):
# Note(Tomas Ruiz) If we run the target model with TP > 1 and
# Note(Tomas Ruiz) If we run the target model with TP > 1 and
...
@@ -73,7 +72,7 @@ class DraftModelProposer(SpecDecodeBaseProposer):
...
@@ -73,7 +72,7 @@ class DraftModelProposer(SpecDecodeBaseProposer):
# (because TP=1), then the torch compile cache is overwritten and corrupted.
# (because TP=1), then the torch compile cache is overwritten and corrupted.
# We need a mechanism like this: https://github.com/vllm-project/vllm/pull/5414
# We need a mechanism like this: https://github.com/vllm-project/vllm/pull/5414
# To prevent this error, we assert that both TP sizes must be the same.
# To prevent this error, we assert that both TP sizes must be the same.
spec_cfg
:
SpeculativeConfig
=
self
.
vllm_config
.
speculative_config
spec_cfg
=
self
.
speculative_config
tgt_tp
=
spec_cfg
.
target_parallel_config
.
tensor_parallel_size
tgt_tp
=
spec_cfg
.
target_parallel_config
.
tensor_parallel_size
draft_tp
=
spec_cfg
.
draft_parallel_config
.
tensor_parallel_size
draft_tp
=
spec_cfg
.
draft_parallel_config
.
tensor_parallel_size
if
draft_tp
!=
tgt_tp
:
if
draft_tp
!=
tgt_tp
:
...
@@ -190,12 +189,14 @@ def create_vllm_config_for_draft_model(
...
@@ -190,12 +189,14 @@ def create_vllm_config_for_draft_model(
The vllm_config is useful when loading the draft model with get_model().
The vllm_config is useful when loading the draft model with get_model().
"""
"""
old
=
target_model_vllm_config
old
=
target_model_vllm_config
new_parallel_config
=
old
.
speculative_config
.
draft_parallel_config
.
replace
(
assert
old
.
speculative_config
is
not
None
,
"speculative_config is not set"
old_spec_config
=
old
.
speculative_config
new_parallel_config
=
old_spec_config
.
draft_parallel_config
.
replace
(
rank
=
old
.
parallel_config
.
rank
rank
=
old
.
parallel_config
.
rank
)
)
new
:
VllmConfig
=
old
.
replace
(
new
:
VllmConfig
=
old
.
replace
(
quant_config
=
None
,
# quant_config is recomputed in __init__()
quant_config
=
None
,
# quant_config is recomputed in __init__()
model_config
=
old
.
spec
ulative
_config
.
draft_model_config
,
model_config
=
old
_
spec_config
.
draft_model_config
,
parallel_config
=
new_parallel_config
,
parallel_config
=
new_parallel_config
,
)
)
return
new
return
new
...
...
vllm/v1/spec_decode/eagle.py
View file @
726d8972
...
@@ -3,6 +3,7 @@
...
@@ -3,6 +3,7 @@
import
ast
import
ast
from
dataclasses
import
replace
from
dataclasses
import
replace
from
importlib.util
import
find_spec
from
importlib.util
import
find_spec
from
typing
import
cast
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -20,6 +21,7 @@ from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
...
@@ -20,6 +21,7 @@ from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from
vllm.model_executor.model_loader
import
get_model
from
vllm.model_executor.model_loader
import
get_model
from
vllm.model_executor.models
import
supports_multimodal
from
vllm.model_executor.models
import
supports_multimodal
from
vllm.model_executor.models.deepseek_v2
import
DeepseekV32IndexerCache
from
vllm.model_executor.models.deepseek_v2
import
DeepseekV32IndexerCache
from
vllm.model_executor.models.interfaces
import
SupportsMultiModal
from
vllm.model_executor.models.llama_eagle3
import
Eagle3LlamaForCausalLM
from
vllm.model_executor.models.llama_eagle3
import
Eagle3LlamaForCausalLM
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
...
@@ -62,8 +64,8 @@ class SpecDecodeBaseProposer:
...
@@ -62,8 +64,8 @@ class SpecDecodeBaseProposer:
runner
=
None
,
runner
=
None
,
):
):
self
.
vllm_config
=
vllm_config
self
.
vllm_config
=
vllm_config
assert
vllm_config
.
speculative_config
is
not
None
self
.
speculative_config
=
vllm_config
.
speculative_config
self
.
speculative_config
=
vllm_config
.
speculative_config
assert
self
.
speculative_config
is
not
None
self
.
draft_model_config
=
self
.
speculative_config
.
draft_model_config
self
.
draft_model_config
=
self
.
speculative_config
.
draft_model_config
self
.
method
=
self
.
speculative_config
.
method
self
.
method
=
self
.
speculative_config
.
method
self
.
pass_hidden_states_to_model
=
pass_hidden_states_to_model
self
.
pass_hidden_states_to_model
=
pass_hidden_states_to_model
...
@@ -206,6 +208,7 @@ class SpecDecodeBaseProposer:
...
@@ -206,6 +208,7 @@ class SpecDecodeBaseProposer:
# Parse the speculative token tree.
# Parse the speculative token tree.
spec_token_tree
=
self
.
speculative_config
.
speculative_token_tree
spec_token_tree
=
self
.
speculative_config
.
speculative_token_tree
assert
spec_token_tree
is
not
None
self
.
tree_choices
:
list
[
tuple
[
int
,
...]]
=
ast
.
literal_eval
(
spec_token_tree
)
self
.
tree_choices
:
list
[
tuple
[
int
,
...]]
=
ast
.
literal_eval
(
spec_token_tree
)
tree_depth
=
len
(
self
.
tree_choices
[
-
1
])
tree_depth
=
len
(
self
.
tree_choices
[
-
1
])
# Precompute per-level properties of the tree.
# Precompute per-level properties of the tree.
...
@@ -1077,9 +1080,12 @@ class SpecDecodeBaseProposer:
...
@@ -1077,9 +1080,12 @@ class SpecDecodeBaseProposer:
return
model
.
__class__
.
__name__
return
model
.
__class__
.
__name__
def
load_model
(
self
,
target_model
:
nn
.
Module
)
->
None
:
def
load_model
(
self
,
target_model
:
nn
.
Module
)
->
None
:
draft_model_config
=
self
.
vllm_config
.
speculative_config
.
draft_model_config
draft_model_config
=
self
.
speculative_config
.
draft_model_config
target_attn_layer_names
=
set
(
target_attn_layer_names
=
set
(
get_layers_from_vllm_config
(
self
.
vllm_config
,
AttentionLayerBase
).
keys
()
get_layers_from_vllm_config
(
self
.
vllm_config
,
AttentionLayerBase
,
# type: ignore[type-abstract]
).
keys
()
)
)
# FIXME: support hybrid kv for draft model
# FIXME: support hybrid kv for draft model
target_indexer_layer_names
=
set
(
target_indexer_layer_names
=
set
(
...
@@ -1096,7 +1102,10 @@ class SpecDecodeBaseProposer:
...
@@ -1096,7 +1102,10 @@ class SpecDecodeBaseProposer:
)
)
draft_attn_layer_names
=
(
draft_attn_layer_names
=
(
get_layers_from_vllm_config
(
self
.
vllm_config
,
AttentionLayerBase
).
keys
()
get_layers_from_vllm_config
(
self
.
vllm_config
,
AttentionLayerBase
,
# type: ignore[type-abstract]
).
keys
()
-
target_attn_layer_names
-
target_attn_layer_names
)
)
indexer_layers
=
get_layers_from_vllm_config
(
indexer_layers
=
get_layers_from_vllm_config
(
...
@@ -1136,6 +1145,7 @@ class SpecDecodeBaseProposer:
...
@@ -1136,6 +1145,7 @@ class SpecDecodeBaseProposer:
if
supports_multimodal
(
target_model
):
if
supports_multimodal
(
target_model
):
# handle multimodality
# handle multimodality
assert
hasattr
(
target_model
,
"config"
)
if
self
.
get_model_name
(
target_model
)
in
[
if
self
.
get_model_name
(
target_model
)
in
[
"Qwen2_5_VLForConditionalGeneration"
,
"Qwen2_5_VLForConditionalGeneration"
,
"Qwen3VLForConditionalGeneration"
,
"Qwen3VLForConditionalGeneration"
,
...
@@ -1152,16 +1162,21 @@ class SpecDecodeBaseProposer:
...
@@ -1152,16 +1162,21 @@ class SpecDecodeBaseProposer:
self
.
model
.
config
.
image_token_index
=
(
self
.
model
.
config
.
image_token_index
=
(
target_model
.
config
.
image_token_index
target_model
.
config
.
image_token_index
)
)
target_language_model
=
target_model
.
get_language_model
()
target_language_model
=
cast
(
SupportsMultiModal
,
target_model
).
get_language_model
()
else
:
else
:
target_language_model
=
target_model
target_language_model
=
target_model
# share embed_tokens with the target model if needed
# share embed_tokens with the target model if needed
if
get_pp_group
().
world_size
==
1
:
if
get_pp_group
().
world_size
==
1
:
if
hasattr
(
target_language_model
.
model
,
"embed_tokens"
):
inner_model
=
getattr
(
target_language_model
,
"model"
,
None
)
target_embed_tokens
=
target_language_model
.
model
.
embed_tokens
if
inner_model
is
None
:
elif
hasattr
(
target_language_model
.
model
,
"embedding"
):
raise
AttributeError
(
"Target model does not have 'model' attribute"
)
target_embed_tokens
=
target_language_model
.
model
.
embedding
if
hasattr
(
inner_model
,
"embed_tokens"
):
target_embed_tokens
=
inner_model
.
embed_tokens
elif
hasattr
(
inner_model
,
"embedding"
):
target_embed_tokens
=
inner_model
.
embedding
else
:
else
:
raise
AttributeError
(
raise
AttributeError
(
"Target model does not have 'embed_tokens' or 'embedding' attribute"
"Target model does not have 'embed_tokens' or 'embedding' attribute"
...
...
vllm/v1/spec_decode/medusa.py
View file @
726d8972
...
@@ -27,11 +27,13 @@ class MedusaProposer:
...
@@ -27,11 +27,13 @@ class MedusaProposer:
):
):
# Save config parameters
# Save config parameters
self
.
vllm_config
=
vllm_config
self
.
vllm_config
=
vllm_config
assert
vllm_config
.
speculative_config
is
not
None
,
(
"Speculative config must be set"
)
self
.
spec_config
=
vllm_config
.
speculative_config
self
.
device
=
device
self
.
device
=
device
self
.
max_num_tokens
=
vllm_config
.
scheduler_config
.
max_num_batched_tokens
self
.
max_num_tokens
=
vllm_config
.
scheduler_config
.
max_num_batched_tokens
self
.
hidden_size
=
(
self
.
hidden_size
=
self
.
spec_config
.
draft_model_config
.
get_hidden_size
()
vllm_config
.
speculative_config
.
draft_model_config
.
get_hidden_size
()
)
self
.
dtype
=
vllm_config
.
model_config
.
dtype
self
.
dtype
=
vllm_config
.
model_config
.
dtype
def
propose
(
def
propose
(
...
@@ -58,7 +60,7 @@ class MedusaProposer:
...
@@ -58,7 +60,7 @@ class MedusaProposer:
with
set_model_tag
(
"medusa_head"
):
with
set_model_tag
(
"medusa_head"
):
self
.
model
=
get_model
(
self
.
model
=
get_model
(
vllm_config
=
self
.
vllm_config
,
vllm_config
=
self
.
vllm_config
,
model_config
=
self
.
vllm_config
.
speculative
_config
.
draft_model_config
,
model_config
=
self
.
spec
_config
.
draft_model_config
,
)
)
assert
not
(
assert
not
(
is_mixture_of_experts
(
self
.
model
)
is_mixture_of_experts
(
self
.
model
)
...
...
vllm/v1/spec_decode/suffix_decoding.py
View file @
726d8972
...
@@ -15,6 +15,7 @@ class SuffixDecodingProposer:
...
@@ -15,6 +15,7 @@ class SuffixDecodingProposer:
def
__init__
(
self
,
vllm_config
:
VllmConfig
):
def
__init__
(
self
,
vllm_config
:
VllmConfig
):
config
=
vllm_config
.
speculative_config
config
=
vllm_config
.
speculative_config
assert
config
is
not
None
,
"Speculative config must be set"
self
.
num_speculative_tokens
=
config
.
num_speculative_tokens
self
.
num_speculative_tokens
=
config
.
num_speculative_tokens
self
.
max_tree_depth
=
config
.
suffix_decoding_max_tree_depth
self
.
max_tree_depth
=
config
.
suffix_decoding_max_tree_depth
self
.
max_spec_factor
=
config
.
suffix_decoding_max_spec_factor
self
.
max_spec_factor
=
config
.
suffix_decoding_max_spec_factor
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment