Unverified Commit 5c9a2d49 authored by richardhuo-nv's avatar richardhuo-nv Committed by GitHub
Browse files

fix: add speculative decoding config to dynamo serve + trtllm (#1356)

parent b8dc0150
......@@ -22,6 +22,7 @@ from typing import Any, Dict, Tuple
import yaml
from tensorrt_llm._torch.pyexecutor.config import PyTorchConfig
from tensorrt_llm.llmapi import KvCacheConfig
from tensorrt_llm.llmapi.llm_args import DecodingBaseConfig
@dataclass
......@@ -32,12 +33,14 @@ class LLMAPIConfig:
model_path: str | None = None,
pytorch_backend_config: PyTorchConfig | None = None,
kv_cache_config: KvCacheConfig | None = None,
speculative_config: DecodingBaseConfig | None = None,
**kwargs,
):
self.model_name = model_name
self.model_path = model_path
self.pytorch_backend_config = pytorch_backend_config
self.kv_cache_config = kv_cache_config
self.speculative_config = speculative_config
self.extra_args = kwargs
# Hardcoded to skip tokenizer init for now.
......@@ -51,6 +54,7 @@ class LLMAPIConfig:
data = {
"pytorch_backend_config": self.pytorch_backend_config,
"kv_cache_config": self.kv_cache_config,
"speculative_config": self.speculative_config,
"skip_tokenizer_init": self.skip_tokenizer_init,
}
if self.extra_args:
......@@ -68,6 +72,12 @@ class LLMAPIConfig:
self.kv_cache_config = KvCacheConfig(**other_config["kv_cache_config"])
self.extra_args.pop("kv_cache_config", None)
if "speculative_config" in other_config:
self.speculative_config = DecodingBaseConfig.from_dict(
other_config["speculative_config"]
)
self.extra_args.pop("speculative_config", None)
def _get_llm_args(engine_config):
# Only do model validation checks and leave other checks to LLMAPI
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment