Unverified Commit 9a528260 authored by Aaron Batilo's avatar Aaron Batilo Committed by GitHub
Browse files

[Bugfix][Spec Decode] Fix extract_hidden_states for VLM models (#38987)


Signed-off-by: default avatarAaron Batilo <abatilo@coreweave.com>
parent 968ed02a
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
from unittest import mock from unittest import mock
import numpy as np import numpy as np
import pytest import pytest
import torch import torch
from transformers import CLIPVisionConfig, LlamaConfig, LlavaConfig, PretrainedConfig
from tests.v1.attention.utils import ( from tests.v1.attention.utils import (
BatchSpec, BatchSpec,
...@@ -23,6 +25,10 @@ from vllm.config import ( ...@@ -23,6 +25,10 @@ from vllm.config import (
) )
from vllm.config.load import LoadConfig from vllm.config.load import LoadConfig
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.transformers_utils.config import get_hf_text_config
from vllm.transformers_utils.configs.extract_hidden_states import (
ExtractHiddenStatesConfig,
)
from vllm.v1.spec_decode.extract_hidden_states import ExtractHiddenStatesProposer from vllm.v1.spec_decode.extract_hidden_states import ExtractHiddenStatesProposer
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
...@@ -323,3 +329,160 @@ def test_propose_different_layer_counts(num_hidden_layers): ...@@ -323,3 +329,160 @@ def test_propose_different_layer_counts(num_hidden_layers):
assert draft_tokens.shape == (batch_size, 1) assert draft_tokens.shape == (batch_size, 1)
assert torch.equal(draft_tokens, sampled_token_ids) assert torch.equal(draft_tokens, sampled_token_ids)
# ---------------------------------------------------------------------------
# VLM / composite config tests for ExtractHiddenStatesConfig
# ---------------------------------------------------------------------------
class _DummyVLMConfig(PretrainedConfig):
"""Minimal composite config that mimics VLMs like Kimi-K2.5 or LLaVA.
The text model's parameters (hidden_size, num_attention_heads, …) live
exclusively under ``text_config``; the top-level config has none of them.
"""
model_type = "test_vlm"
def __init__(self, text_config: PretrainedConfig, **kwargs):
self.text_config = text_config
super().__init__(architectures=["LlamaForCausalLM"], **kwargs)
def get_text_config(self, decoder: bool = False) -> PretrainedConfig:
del decoder
return self.text_config
def test_extract_hidden_states_text_only_config_regression():
"""Text-only models (no nested text_config) must keep working."""
model_config = ModelConfig(model=model_dir, runner="generate", max_model_len=100)
speculative_config = SpeculativeConfig(
target_model_config=model_config,
target_parallel_config=ParallelConfig(),
method="extract_hidden_states",
num_speculative_tokens=1,
draft_model_config={
"hf_config": {
"eagle_aux_hidden_state_layer_ids": [1, 2, 3, 4],
}
},
)
assert speculative_config.draft_model_config is not None
# For text-only models, hf_text_config should be the config itself.
assert speculative_config.draft_model_config.hf_text_config is (
speculative_config.draft_model_config.hf_config
)
assert (
speculative_config.draft_model_config.hf_text_config.num_attention_heads
== model_config.hf_text_config.num_attention_heads
)
def test_extract_hidden_states_config_preserves_vlm_text_config():
"""A real VLM config (LLaVA) with nested text_config must be preserved."""
text_config = LlamaConfig(
vocab_size=32000,
hidden_size=128,
intermediate_size=256,
num_hidden_layers=2,
num_attention_heads=8,
)
vlm_config = LlavaConfig(
vision_config=CLIPVisionConfig(),
text_config=text_config,
)
# Precondition: to_dict() flattens the nested config to a plain dict.
assert isinstance(vlm_config.to_dict()["text_config"], dict)
extract_config = ExtractHiddenStatesConfig(
vlm_config,
eagle_aux_hidden_state_layer_ids=[1, 2],
)
# The fix: text_config is still a PretrainedConfig, not a dict.
assert isinstance(extract_config.text_config, LlamaConfig)
extracted = get_hf_text_config(extract_config)
assert extracted is extract_config.text_config
assert extracted.num_attention_heads == text_config.num_attention_heads
assert extracted.hidden_size == text_config.hidden_size
# Serialization must still round-trip correctly.
serialized = extract_config.to_dict()
assert isinstance(serialized["text_config"], dict)
assert serialized["text_config"]["num_attention_heads"] == (
text_config.num_attention_heads
)
json_str = json.loads(extract_config.to_json_string())
assert json_str["text_config"]["num_attention_heads"] == (
text_config.num_attention_heads
)
def test_extract_hidden_states_speculative_config_vlm():
"""SpeculativeConfig with a VLM target must build without errors."""
nested_text_config = LlamaConfig(
vocab_size=32000,
hidden_size=128,
intermediate_size=256,
num_hidden_layers=2,
num_attention_heads=8,
)
target_model_config = ModelConfig(
model=model_dir,
runner="generate",
max_model_len=100,
)
# Replace the real text-only config with our composite VLM config.
target_model_config.hf_config = _DummyVLMConfig(
text_config=nested_text_config,
)
target_model_config.hf_text_config = nested_text_config
speculative_config = SpeculativeConfig(
target_model_config=target_model_config,
target_parallel_config=ParallelConfig(),
method="extract_hidden_states",
num_speculative_tokens=1,
draft_model_config={
"hf_config": {
"eagle_aux_hidden_state_layer_ids": [1, 2],
}
},
)
assert speculative_config.draft_model_config is not None
assert isinstance(
speculative_config.draft_model_config.hf_config.text_config,
LlamaConfig,
)
assert speculative_config.draft_model_config.hf_text_config is (
speculative_config.draft_model_config.hf_config.text_config
)
assert (
speculative_config.draft_model_config.hf_text_config.num_attention_heads
== nested_text_config.num_attention_heads
)
def test_extract_hidden_states_config_invalid_text_config():
"""A nested text_config missing required attrs must still be rejected."""
broken_text_config = PretrainedConfig(hidden_size=128)
vlm_config = _DummyVLMConfig(text_config=broken_text_config)
extract_config = ExtractHiddenStatesConfig(
vlm_config,
eagle_aux_hidden_state_layer_ids=[1],
)
# The object is preserved (not flattened), …
assert extract_config.text_config is broken_text_config
# … but validation still rejects the missing attribute.
with pytest.raises(ValueError, match="num_attention_heads"):
get_hf_text_config(extract_config)
...@@ -23,10 +23,14 @@ class ExtractHiddenStatesConfig(PretrainedConfig): ...@@ -23,10 +23,14 @@ class ExtractHiddenStatesConfig(PretrainedConfig):
if isinstance(model, dict): if isinstance(model, dict):
model_dict = model model_dict = model
source_text_config = None
elif isinstance(model, PretrainedConfig): elif isinstance(model, PretrainedConfig):
model_dict = model.to_dict() model_dict = model.to_dict()
text_config = model.get_text_config()
source_text_config = text_config if text_config is not model else None
else: else:
model_dict = {} model_dict = {}
source_text_config = None
# Combine: model_dict first, then kwargs override # Combine: model_dict first, then kwargs override
combined = {**model_dict, **kwargs} combined = {**model_dict, **kwargs}
...@@ -35,6 +39,12 @@ class ExtractHiddenStatesConfig(PretrainedConfig): ...@@ -35,6 +39,12 @@ class ExtractHiddenStatesConfig(PretrainedConfig):
combined["architectures"] = ["ExtractHiddenStatesModel"] combined["architectures"] = ["ExtractHiddenStatesModel"]
# to_dict() and kwargs both flatten text_config to a plain dict;
# downstream get_hf_text_config() needs it as a PretrainedConfig
# for attribute access. Re-insert the original object.
if source_text_config is not None:
combined["text_config"] = source_text_config
super().__init__(**combined) super().__init__(**combined)
@classmethod @classmethod
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment