Unverified Commit 6f2f59f2 authored by Zhengkai Zhang's avatar Zhengkai Zhang Committed by GitHub
Browse files

[Misc][Spec Decode] support different load config for draft model (#34022)


Signed-off-by: default avatarzzhengkai <zzhengkai@devgpu049.ldc1.facebook.com>
Co-authored-by: default avatarzzhengkai <zzhengkai@devgpu049.ldc1.facebook.com>
parent bb2fc8b5
...@@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any, Literal, get_args ...@@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any, Literal, get_args
from pydantic import Field, SkipValidation, model_validator from pydantic import Field, SkipValidation, model_validator
from typing_extensions import Self from typing_extensions import Self
from vllm.config import LoadConfig
from vllm.config.model import ModelConfig from vllm.config.model import ModelConfig
from vllm.config.parallel import ParallelConfig from vllm.config.parallel import ParallelConfig
from vllm.config.utils import config from vllm.config.utils import config
...@@ -160,6 +161,10 @@ class SpeculativeConfig: ...@@ -160,6 +161,10 @@ class SpeculativeConfig:
tokens with estimated probability (based on frequency counts) greater than tokens with estimated probability (based on frequency counts) greater than
or equal to this value.""" or equal to this value."""
draft_load_config: LoadConfig | None = None
"""Load config for the draft model. If not specified, will use the load
config from the target model."""
def compute_hash(self) -> str: def compute_hash(self) -> str:
""" """
WARNING: Whenever a new field is added to this config, WARNING: Whenever a new field is added to this config,
......
...@@ -128,8 +128,9 @@ def get_model( ...@@ -128,8 +128,9 @@ def get_model(
vllm_config: VllmConfig, vllm_config: VllmConfig,
model_config: ModelConfig | None = None, model_config: ModelConfig | None = None,
prefix: str = "", prefix: str = "",
load_config: LoadConfig | None = None,
) -> nn.Module: ) -> nn.Module:
loader = get_model_loader(vllm_config.load_config) loader = get_model_loader(load_config or vllm_config.load_config)
if model_config is None: if model_config is None:
model_config = vllm_config.model_config model_config = vllm_config.model_config
return loader.load_model( return loader.load_model(
......
...@@ -1286,6 +1286,7 @@ class SpecDecodeBaseProposer: ...@@ -1286,6 +1286,7 @@ class SpecDecodeBaseProposer:
model = get_model( model = get_model(
vllm_config=self.vllm_config, vllm_config=self.vllm_config,
model_config=self.speculative_config.draft_model_config, model_config=self.speculative_config.draft_model_config,
load_config=self.speculative_config.draft_load_config,
) )
return model return model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment