Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
9a528260
Unverified
Commit
9a528260
authored
Apr 05, 2026
by
Aaron Batilo
Committed by
GitHub
Apr 05, 2026
Browse files
[Bugfix][Spec Decode] Fix extract_hidden_states for VLM models (#38987)
Signed-off-by:
Aaron Batilo
<
abatilo@coreweave.com
>
parent
968ed02a
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
173 additions
and
0 deletions
+173
-0
tests/v1/spec_decode/test_extract_hidden_states.py
tests/v1/spec_decode/test_extract_hidden_states.py
+163
-0
vllm/transformers_utils/configs/extract_hidden_states.py
vllm/transformers_utils/configs/extract_hidden_states.py
+10
-0
No files found.
tests/v1/spec_decode/test_extract_hidden_states.py
View file @
9a528260
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
json
from
unittest
import
mock
import
numpy
as
np
import
pytest
import
torch
from
transformers
import
CLIPVisionConfig
,
LlamaConfig
,
LlavaConfig
,
PretrainedConfig
from
tests.v1.attention.utils
import
(
BatchSpec
,
...
...
@@ -23,6 +25,10 @@ from vllm.config import (
)
from
vllm.config.load
import
LoadConfig
from
vllm.platforms
import
current_platform
from
vllm.transformers_utils.config
import
get_hf_text_config
from
vllm.transformers_utils.configs.extract_hidden_states
import
(
ExtractHiddenStatesConfig
,
)
from
vllm.v1.spec_decode.extract_hidden_states
import
ExtractHiddenStatesProposer
from
vllm.v1.worker.gpu_input_batch
import
CachedRequestState
,
InputBatch
...
...
@@ -323,3 +329,160 @@ def test_propose_different_layer_counts(num_hidden_layers):
assert
draft_tokens
.
shape
==
(
batch_size
,
1
)
assert
torch
.
equal
(
draft_tokens
,
sampled_token_ids
)
# ---------------------------------------------------------------------------
# VLM / composite config tests for ExtractHiddenStatesConfig
# ---------------------------------------------------------------------------
class
_DummyVLMConfig
(
PretrainedConfig
):
"""Minimal composite config that mimics VLMs like Kimi-K2.5 or LLaVA.
The text model's parameters (hidden_size, num_attention_heads, …) live
exclusively under ``text_config``; the top-level config has none of them.
"""
model_type
=
"test_vlm"
def
__init__
(
self
,
text_config
:
PretrainedConfig
,
**
kwargs
):
self
.
text_config
=
text_config
super
().
__init__
(
architectures
=
[
"LlamaForCausalLM"
],
**
kwargs
)
def
get_text_config
(
self
,
decoder
:
bool
=
False
)
->
PretrainedConfig
:
del
decoder
return
self
.
text_config
def
test_extract_hidden_states_text_only_config_regression
():
"""Text-only models (no nested text_config) must keep working."""
model_config
=
ModelConfig
(
model
=
model_dir
,
runner
=
"generate"
,
max_model_len
=
100
)
speculative_config
=
SpeculativeConfig
(
target_model_config
=
model_config
,
target_parallel_config
=
ParallelConfig
(),
method
=
"extract_hidden_states"
,
num_speculative_tokens
=
1
,
draft_model_config
=
{
"hf_config"
:
{
"eagle_aux_hidden_state_layer_ids"
:
[
1
,
2
,
3
,
4
],
}
},
)
assert
speculative_config
.
draft_model_config
is
not
None
# For text-only models, hf_text_config should be the config itself.
assert
speculative_config
.
draft_model_config
.
hf_text_config
is
(
speculative_config
.
draft_model_config
.
hf_config
)
assert
(
speculative_config
.
draft_model_config
.
hf_text_config
.
num_attention_heads
==
model_config
.
hf_text_config
.
num_attention_heads
)
def
test_extract_hidden_states_config_preserves_vlm_text_config
():
"""A real VLM config (LLaVA) with nested text_config must be preserved."""
text_config
=
LlamaConfig
(
vocab_size
=
32000
,
hidden_size
=
128
,
intermediate_size
=
256
,
num_hidden_layers
=
2
,
num_attention_heads
=
8
,
)
vlm_config
=
LlavaConfig
(
vision_config
=
CLIPVisionConfig
(),
text_config
=
text_config
,
)
# Precondition: to_dict() flattens the nested config to a plain dict.
assert
isinstance
(
vlm_config
.
to_dict
()[
"text_config"
],
dict
)
extract_config
=
ExtractHiddenStatesConfig
(
vlm_config
,
eagle_aux_hidden_state_layer_ids
=
[
1
,
2
],
)
# The fix: text_config is still a PretrainedConfig, not a dict.
assert
isinstance
(
extract_config
.
text_config
,
LlamaConfig
)
extracted
=
get_hf_text_config
(
extract_config
)
assert
extracted
is
extract_config
.
text_config
assert
extracted
.
num_attention_heads
==
text_config
.
num_attention_heads
assert
extracted
.
hidden_size
==
text_config
.
hidden_size
# Serialization must still round-trip correctly.
serialized
=
extract_config
.
to_dict
()
assert
isinstance
(
serialized
[
"text_config"
],
dict
)
assert
serialized
[
"text_config"
][
"num_attention_heads"
]
==
(
text_config
.
num_attention_heads
)
json_str
=
json
.
loads
(
extract_config
.
to_json_string
())
assert
json_str
[
"text_config"
][
"num_attention_heads"
]
==
(
text_config
.
num_attention_heads
)
def
test_extract_hidden_states_speculative_config_vlm
():
"""SpeculativeConfig with a VLM target must build without errors."""
nested_text_config
=
LlamaConfig
(
vocab_size
=
32000
,
hidden_size
=
128
,
intermediate_size
=
256
,
num_hidden_layers
=
2
,
num_attention_heads
=
8
,
)
target_model_config
=
ModelConfig
(
model
=
model_dir
,
runner
=
"generate"
,
max_model_len
=
100
,
)
# Replace the real text-only config with our composite VLM config.
target_model_config
.
hf_config
=
_DummyVLMConfig
(
text_config
=
nested_text_config
,
)
target_model_config
.
hf_text_config
=
nested_text_config
speculative_config
=
SpeculativeConfig
(
target_model_config
=
target_model_config
,
target_parallel_config
=
ParallelConfig
(),
method
=
"extract_hidden_states"
,
num_speculative_tokens
=
1
,
draft_model_config
=
{
"hf_config"
:
{
"eagle_aux_hidden_state_layer_ids"
:
[
1
,
2
],
}
},
)
assert
speculative_config
.
draft_model_config
is
not
None
assert
isinstance
(
speculative_config
.
draft_model_config
.
hf_config
.
text_config
,
LlamaConfig
,
)
assert
speculative_config
.
draft_model_config
.
hf_text_config
is
(
speculative_config
.
draft_model_config
.
hf_config
.
text_config
)
assert
(
speculative_config
.
draft_model_config
.
hf_text_config
.
num_attention_heads
==
nested_text_config
.
num_attention_heads
)
def
test_extract_hidden_states_config_invalid_text_config
():
"""A nested text_config missing required attrs must still be rejected."""
broken_text_config
=
PretrainedConfig
(
hidden_size
=
128
)
vlm_config
=
_DummyVLMConfig
(
text_config
=
broken_text_config
)
extract_config
=
ExtractHiddenStatesConfig
(
vlm_config
,
eagle_aux_hidden_state_layer_ids
=
[
1
],
)
# The object is preserved (not flattened), …
assert
extract_config
.
text_config
is
broken_text_config
# … but validation still rejects the missing attribute.
with
pytest
.
raises
(
ValueError
,
match
=
"num_attention_heads"
):
get_hf_text_config
(
extract_config
)
vllm/transformers_utils/configs/extract_hidden_states.py
View file @
9a528260
...
...
@@ -23,10 +23,14 @@ class ExtractHiddenStatesConfig(PretrainedConfig):
if
isinstance
(
model
,
dict
):
model_dict
=
model
source_text_config
=
None
elif
isinstance
(
model
,
PretrainedConfig
):
model_dict
=
model
.
to_dict
()
text_config
=
model
.
get_text_config
()
source_text_config
=
text_config
if
text_config
is
not
model
else
None
else
:
model_dict
=
{}
source_text_config
=
None
# Combine: model_dict first, then kwargs override
combined
=
{
**
model_dict
,
**
kwargs
}
...
...
@@ -35,6 +39,12 @@ class ExtractHiddenStatesConfig(PretrainedConfig):
combined
[
"architectures"
]
=
[
"ExtractHiddenStatesModel"
]
# to_dict() and kwargs both flatten text_config to a plain dict;
# downstream get_hf_text_config() needs it as a PretrainedConfig
# for attribute access. Re-insert the original object.
if
source_text_config
is
not
None
:
combined
[
"text_config"
]
=
source_text_config
super
().
__init__
(
**
combined
)
@
classmethod
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment