Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
9a528260
Unverified
Commit
9a528260
authored
Apr 05, 2026
by
Aaron Batilo
Committed by
GitHub
Apr 05, 2026
Browse files
[Bugfix][Spec Decode] Fix extract_hidden_states for VLM models (#38987)
Signed-off-by:
Aaron Batilo
<
abatilo@coreweave.com
>
parent
968ed02a
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
173 additions
and
0 deletions
+173
-0
tests/v1/spec_decode/test_extract_hidden_states.py
tests/v1/spec_decode/test_extract_hidden_states.py
+163
-0
vllm/transformers_utils/configs/extract_hidden_states.py
vllm/transformers_utils/configs/extract_hidden_states.py
+10
-0
No files found.
tests/v1/spec_decode/test_extract_hidden_states.py
View file @
9a528260
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
json
from
unittest
import
mock
from
unittest
import
mock
import
numpy
as
np
import
numpy
as
np
import
pytest
import
pytest
import
torch
import
torch
from
transformers
import
CLIPVisionConfig
,
LlamaConfig
,
LlavaConfig
,
PretrainedConfig
from
tests.v1.attention.utils
import
(
from
tests.v1.attention.utils
import
(
BatchSpec
,
BatchSpec
,
...
@@ -23,6 +25,10 @@ from vllm.config import (
...
@@ -23,6 +25,10 @@ from vllm.config import (
)
)
from
vllm.config.load
import
LoadConfig
from
vllm.config.load
import
LoadConfig
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.transformers_utils.config
import
get_hf_text_config
from
vllm.transformers_utils.configs.extract_hidden_states
import
(
ExtractHiddenStatesConfig
,
)
from
vllm.v1.spec_decode.extract_hidden_states
import
ExtractHiddenStatesProposer
from
vllm.v1.spec_decode.extract_hidden_states
import
ExtractHiddenStatesProposer
from
vllm.v1.worker.gpu_input_batch
import
CachedRequestState
,
InputBatch
from
vllm.v1.worker.gpu_input_batch
import
CachedRequestState
,
InputBatch
...
@@ -323,3 +329,160 @@ def test_propose_different_layer_counts(num_hidden_layers):
...
@@ -323,3 +329,160 @@ def test_propose_different_layer_counts(num_hidden_layers):
assert
draft_tokens
.
shape
==
(
batch_size
,
1
)
assert
draft_tokens
.
shape
==
(
batch_size
,
1
)
assert
torch
.
equal
(
draft_tokens
,
sampled_token_ids
)
assert
torch
.
equal
(
draft_tokens
,
sampled_token_ids
)
# ---------------------------------------------------------------------------
# VLM / composite config tests for ExtractHiddenStatesConfig
# ---------------------------------------------------------------------------
class
_DummyVLMConfig
(
PretrainedConfig
):
"""Minimal composite config that mimics VLMs like Kimi-K2.5 or LLaVA.
The text model's parameters (hidden_size, num_attention_heads, …) live
exclusively under ``text_config``; the top-level config has none of them.
"""
model_type
=
"test_vlm"
def
__init__
(
self
,
text_config
:
PretrainedConfig
,
**
kwargs
):
self
.
text_config
=
text_config
super
().
__init__
(
architectures
=
[
"LlamaForCausalLM"
],
**
kwargs
)
def
get_text_config
(
self
,
decoder
:
bool
=
False
)
->
PretrainedConfig
:
del
decoder
return
self
.
text_config
def
test_extract_hidden_states_text_only_config_regression
():
"""Text-only models (no nested text_config) must keep working."""
model_config
=
ModelConfig
(
model
=
model_dir
,
runner
=
"generate"
,
max_model_len
=
100
)
speculative_config
=
SpeculativeConfig
(
target_model_config
=
model_config
,
target_parallel_config
=
ParallelConfig
(),
method
=
"extract_hidden_states"
,
num_speculative_tokens
=
1
,
draft_model_config
=
{
"hf_config"
:
{
"eagle_aux_hidden_state_layer_ids"
:
[
1
,
2
,
3
,
4
],
}
},
)
assert
speculative_config
.
draft_model_config
is
not
None
# For text-only models, hf_text_config should be the config itself.
assert
speculative_config
.
draft_model_config
.
hf_text_config
is
(
speculative_config
.
draft_model_config
.
hf_config
)
assert
(
speculative_config
.
draft_model_config
.
hf_text_config
.
num_attention_heads
==
model_config
.
hf_text_config
.
num_attention_heads
)
def
test_extract_hidden_states_config_preserves_vlm_text_config
():
"""A real VLM config (LLaVA) with nested text_config must be preserved."""
text_config
=
LlamaConfig
(
vocab_size
=
32000
,
hidden_size
=
128
,
intermediate_size
=
256
,
num_hidden_layers
=
2
,
num_attention_heads
=
8
,
)
vlm_config
=
LlavaConfig
(
vision_config
=
CLIPVisionConfig
(),
text_config
=
text_config
,
)
# Precondition: to_dict() flattens the nested config to a plain dict.
assert
isinstance
(
vlm_config
.
to_dict
()[
"text_config"
],
dict
)
extract_config
=
ExtractHiddenStatesConfig
(
vlm_config
,
eagle_aux_hidden_state_layer_ids
=
[
1
,
2
],
)
# The fix: text_config is still a PretrainedConfig, not a dict.
assert
isinstance
(
extract_config
.
text_config
,
LlamaConfig
)
extracted
=
get_hf_text_config
(
extract_config
)
assert
extracted
is
extract_config
.
text_config
assert
extracted
.
num_attention_heads
==
text_config
.
num_attention_heads
assert
extracted
.
hidden_size
==
text_config
.
hidden_size
# Serialization must still round-trip correctly.
serialized
=
extract_config
.
to_dict
()
assert
isinstance
(
serialized
[
"text_config"
],
dict
)
assert
serialized
[
"text_config"
][
"num_attention_heads"
]
==
(
text_config
.
num_attention_heads
)
json_str
=
json
.
loads
(
extract_config
.
to_json_string
())
assert
json_str
[
"text_config"
][
"num_attention_heads"
]
==
(
text_config
.
num_attention_heads
)
def
test_extract_hidden_states_speculative_config_vlm
():
"""SpeculativeConfig with a VLM target must build without errors."""
nested_text_config
=
LlamaConfig
(
vocab_size
=
32000
,
hidden_size
=
128
,
intermediate_size
=
256
,
num_hidden_layers
=
2
,
num_attention_heads
=
8
,
)
target_model_config
=
ModelConfig
(
model
=
model_dir
,
runner
=
"generate"
,
max_model_len
=
100
,
)
# Replace the real text-only config with our composite VLM config.
target_model_config
.
hf_config
=
_DummyVLMConfig
(
text_config
=
nested_text_config
,
)
target_model_config
.
hf_text_config
=
nested_text_config
speculative_config
=
SpeculativeConfig
(
target_model_config
=
target_model_config
,
target_parallel_config
=
ParallelConfig
(),
method
=
"extract_hidden_states"
,
num_speculative_tokens
=
1
,
draft_model_config
=
{
"hf_config"
:
{
"eagle_aux_hidden_state_layer_ids"
:
[
1
,
2
],
}
},
)
assert
speculative_config
.
draft_model_config
is
not
None
assert
isinstance
(
speculative_config
.
draft_model_config
.
hf_config
.
text_config
,
LlamaConfig
,
)
assert
speculative_config
.
draft_model_config
.
hf_text_config
is
(
speculative_config
.
draft_model_config
.
hf_config
.
text_config
)
assert
(
speculative_config
.
draft_model_config
.
hf_text_config
.
num_attention_heads
==
nested_text_config
.
num_attention_heads
)
def
test_extract_hidden_states_config_invalid_text_config
():
"""A nested text_config missing required attrs must still be rejected."""
broken_text_config
=
PretrainedConfig
(
hidden_size
=
128
)
vlm_config
=
_DummyVLMConfig
(
text_config
=
broken_text_config
)
extract_config
=
ExtractHiddenStatesConfig
(
vlm_config
,
eagle_aux_hidden_state_layer_ids
=
[
1
],
)
# The object is preserved (not flattened), …
assert
extract_config
.
text_config
is
broken_text_config
# … but validation still rejects the missing attribute.
with
pytest
.
raises
(
ValueError
,
match
=
"num_attention_heads"
):
get_hf_text_config
(
extract_config
)
vllm/transformers_utils/configs/extract_hidden_states.py
View file @
9a528260
...
@@ -23,10 +23,14 @@ class ExtractHiddenStatesConfig(PretrainedConfig):
...
@@ -23,10 +23,14 @@ class ExtractHiddenStatesConfig(PretrainedConfig):
if
isinstance
(
model
,
dict
):
if
isinstance
(
model
,
dict
):
model_dict
=
model
model_dict
=
model
source_text_config
=
None
elif
isinstance
(
model
,
PretrainedConfig
):
elif
isinstance
(
model
,
PretrainedConfig
):
model_dict
=
model
.
to_dict
()
model_dict
=
model
.
to_dict
()
text_config
=
model
.
get_text_config
()
source_text_config
=
text_config
if
text_config
is
not
model
else
None
else
:
else
:
model_dict
=
{}
model_dict
=
{}
source_text_config
=
None
# Combine: model_dict first, then kwargs override
# Combine: model_dict first, then kwargs override
combined
=
{
**
model_dict
,
**
kwargs
}
combined
=
{
**
model_dict
,
**
kwargs
}
...
@@ -35,6 +39,12 @@ class ExtractHiddenStatesConfig(PretrainedConfig):
...
@@ -35,6 +39,12 @@ class ExtractHiddenStatesConfig(PretrainedConfig):
combined
[
"architectures"
]
=
[
"ExtractHiddenStatesModel"
]
combined
[
"architectures"
]
=
[
"ExtractHiddenStatesModel"
]
# to_dict() and kwargs both flatten text_config to a plain dict;
# downstream get_hf_text_config() needs it as a PretrainedConfig
# for attribute access. Re-insert the original object.
if
source_text_config
is
not
None
:
combined
[
"text_config"
]
=
source_text_config
super
().
__init__
(
**
combined
)
super
().
__init__
(
**
combined
)
@
classmethod
@
classmethod
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment