Unverified Commit 80221e18 authored by Xingyu Liu's avatar Xingyu Liu Committed by GitHub
Browse files

[BugFix]Fix eagle draft_model_config and add tests (#31753)


Signed-off-by: default avatarXingyu Liu <charlotteliu12x@gmail.com>
parent 5e714f7f
...@@ -38,7 +38,7 @@ ...@@ -38,7 +38,7 @@
"EagleDeepSeekMTPModel" "EagleDeepSeekMTPModel"
], ],
"model_type": "eagle", "model_type": "eagle",
"text_model_type": "deepseek_mtp", "text_model_type": "eagle",
"hidden_size": 2560, "hidden_size": 2560,
"total_num_hidden_layers": 1, "total_num_hidden_layers": 1,
"total_num_attention_heads": 32, "total_num_attention_heads": 32,
...@@ -55,7 +55,7 @@ ...@@ -55,7 +55,7 @@
"EagleLlamaForCausalLM" "EagleLlamaForCausalLM"
], ],
"model_type": "eagle", "model_type": "eagle",
"text_model_type": "llama", "text_model_type": "eagle",
"hidden_size": 4096, "hidden_size": 4096,
"total_num_hidden_layers": 1, "total_num_hidden_layers": 1,
"total_num_attention_heads": 32, "total_num_attention_heads": 32,
...@@ -72,7 +72,7 @@ ...@@ -72,7 +72,7 @@
"Eagle3LlamaForCausalLM" "Eagle3LlamaForCausalLM"
], ],
"model_type": "eagle", "model_type": "eagle",
"text_model_type": "llama", "text_model_type": "eagle",
"hidden_size": 4096, "hidden_size": 4096,
"total_num_hidden_layers": 1, "total_num_hidden_layers": 1,
"total_num_attention_heads": 32, "total_num_attention_heads": 32,
......
...@@ -13,8 +13,10 @@ from vllm.compilation.backends import VllmBackend ...@@ -13,8 +13,10 @@ from vllm.compilation.backends import VllmBackend
from vllm.config import ( from vllm.config import (
CompilationConfig, CompilationConfig,
ModelConfig, ModelConfig,
ParallelConfig,
PoolerConfig, PoolerConfig,
SchedulerConfig, SchedulerConfig,
SpeculativeConfig,
VllmConfig, VllmConfig,
update_config, update_config,
) )
...@@ -1105,3 +1107,23 @@ def test_needs_dp_coordination( ...@@ -1105,3 +1107,23 @@ def test_needs_dp_coordination(
vllm_config = VllmConfig(model_config=model_config, parallel_config=parallel_config) vllm_config = VllmConfig(model_config=model_config, parallel_config=parallel_config)
assert vllm_config.needs_dp_coordinator == expected_needs_coordinator assert vllm_config.needs_dp_coordinator == expected_needs_coordinator
def test_eagle_draft_model_config():
"""Test that EagleDraft model config is correctly set."""
target_model_config = ModelConfig(
"meta-llama/Meta-Llama-3-8B-Instruct", trust_remote_code=True
)
speculative_config = SpeculativeConfig(
model="yuhuili/EAGLE-LLaMA3-Instruct-8B",
num_speculative_tokens=1,
target_model_config=target_model_config,
target_parallel_config=ParallelConfig(),
)
draft_model_config = speculative_config.draft_model_config
assert draft_model_config.hf_config.architectures == ["EagleLlamaForCausalLM"]
assert draft_model_config.hf_text_config.architectures == ["EagleLlamaForCausalLM"]
assert draft_model_config.hf_config.model_type == "eagle"
assert draft_model_config.hf_text_config.model_type == "eagle"
assert draft_model_config.architectures == ["EagleLlamaForCausalLM"]
assert draft_model_config.architecture == "EagleLlamaForCausalLM"
...@@ -12,6 +12,7 @@ from vllm.config.model import ModelConfig ...@@ -12,6 +12,7 @@ from vllm.config.model import ModelConfig
from vllm.config.parallel import ParallelConfig from vllm.config.parallel import ParallelConfig
from vllm.config.utils import config from vllm.config.utils import config
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.config import get_hf_text_config
from vllm.utils.hashing import safe_hash from vllm.utils.hashing import safe_hash
from vllm.utils.import_utils import LazyLoader, has_arctic_inference from vllm.utils.import_utils import LazyLoader, has_arctic_inference
...@@ -409,10 +410,23 @@ class SpeculativeConfig: ...@@ -409,10 +410,23 @@ class SpeculativeConfig:
method=self.method, method=self.method,
model_type="eagle", model_type="eagle",
) )
# EAGLEConfig primarily updates architectures, so update
# all architectures-related fields in draft_model_config
self.draft_model_config.hf_config = eagle_config self.draft_model_config.hf_config = eagle_config
self.draft_model_config.hf_text_config = get_hf_text_config(
self.draft_model_config.hf_config
)
self.draft_model_config.model_arch_config = ( self.draft_model_config.model_arch_config = (
self.draft_model_config.get_model_arch_config() self.draft_model_config.get_model_arch_config()
) )
model_info, arch = (
self.draft_model_config.registry.inspect_model_cls(
self.draft_model_config.architectures,
self.draft_model_config,
)
)
self.draft_model_config._model_info = model_info
self.draft_model_config._architecture = arch
if self.num_speculative_tokens is not None and hasattr( if self.num_speculative_tokens is not None and hasattr(
self.draft_model_config.hf_config, "num_lookahead_tokens" self.draft_model_config.hf_config, "num_lookahead_tokens"
......
...@@ -201,7 +201,7 @@ class ModelArchConfigConvertorBase: ...@@ -201,7 +201,7 @@ class ModelArchConfigConvertorBase:
# underlying architecture # underlying architecture
return ( return (
self.hf_text_config.model.model_type self.hf_text_config.model.model_type
in ("deepseek_v2", "deepseek_v3", "deepseek_v32") in ("deepseek_v2", "deepseek_v3", "deepseek_v32", "deepseek_mtp")
and self.hf_text_config.kv_lora_rank is not None and self.hf_text_config.kv_lora_rank is not None
) )
return False return False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment