Unverified Commit 6f2f59f2 authored by Zhengkai Zhang's avatar Zhengkai Zhang Committed by GitHub
Browse files

[Misc][Spec Decode] support different load config for draft model (#34022)


Signed-off-by: default avatarzzhengkai <zzhengkai@devgpu049.ldc1.facebook.com>
Co-authored-by: default avatarzzhengkai <zzhengkai@devgpu049.ldc1.facebook.com>
parent bb2fc8b5
......@@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any, Literal, get_args
from pydantic import Field, SkipValidation, model_validator
from typing_extensions import Self
from vllm.config import LoadConfig
from vllm.config.model import ModelConfig
from vllm.config.parallel import ParallelConfig
from vllm.config.utils import config
......@@ -160,6 +161,10 @@ class SpeculativeConfig:
tokens with estimated probability (based on frequency counts) greater than
or equal to this value."""
draft_load_config: LoadConfig | None = None
"""Load config for the draft model. If not specified, will use the load
config from the target model."""
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
......
......@@ -128,8 +128,9 @@ def get_model(
vllm_config: VllmConfig,
model_config: ModelConfig | None = None,
prefix: str = "",
load_config: LoadConfig | None = None,
) -> nn.Module:
loader = get_model_loader(vllm_config.load_config)
loader = get_model_loader(load_config or vllm_config.load_config)
if model_config is None:
model_config = vllm_config.model_config
return loader.load_model(
......
......@@ -1286,6 +1286,7 @@ class SpecDecodeBaseProposer:
model = get_model(
vllm_config=self.vllm_config,
model_config=self.speculative_config.draft_model_config,
load_config=self.speculative_config.draft_load_config,
)
return model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment