Unverified Commit b5782fcd authored by KrishnanPrash's avatar KrishnanPrash Committed by GitHub
Browse files

feat: Add --custom-jinja-template argument to pass a custom chat template for TRTLLM (#3332)


Signed-off-by: default avatarKrishnan Prashanth <kprashanth@nvidia.com>
parent 9b9536d0
...@@ -335,6 +335,7 @@ async def init(runtime: DistributedRuntime, config: Config): ...@@ -335,6 +335,7 @@ async def init(runtime: DistributedRuntime, config: Config):
kv_cache_block_size=config.kv_block_size, kv_cache_block_size=config.kv_block_size,
migration_limit=config.migration_limit, migration_limit=config.migration_limit,
runtime_config=runtime_config, runtime_config=runtime_config,
custom_template_path=config.custom_jinja_template,
) )
# Get health check payload (checks env var and falls back to TensorRT-LLM default) # Get health check payload (checks env var and falls back to TensorRT-LLM default)
......
...@@ -59,6 +59,7 @@ class Config: ...@@ -59,6 +59,7 @@ class Config:
self.max_file_size_mb: int = 50 self.max_file_size_mb: int = 50
self.reasoning_parser: Optional[str] = None self.reasoning_parser: Optional[str] = None
self.tool_call_parser: Optional[str] = None self.tool_call_parser: Optional[str] = None
self.custom_jinja_template: Optional[str] = None
def __str__(self) -> str: def __str__(self) -> str:
return ( return (
...@@ -89,7 +90,8 @@ class Config: ...@@ -89,7 +90,8 @@ class Config:
f"allowed_local_media_path={self.allowed_local_media_path}, " f"allowed_local_media_path={self.allowed_local_media_path}, "
f"max_file_size_mb={self.max_file_size_mb}, " f"max_file_size_mb={self.max_file_size_mb}, "
f"reasoning_parser={self.reasoning_parser}, " f"reasoning_parser={self.reasoning_parser}, "
f"tool_call_parser={self.tool_call_parser}" f"tool_call_parser={self.tool_call_parser}, "
f"custom_jinja_template={self.custom_jinja_template}"
) )
...@@ -296,6 +298,12 @@ def cmd_line_args(): ...@@ -296,6 +298,12 @@ def cmd_line_args():
choices=get_reasoning_parser_names(), choices=get_reasoning_parser_names(),
help="Reasoning parser name for the model. If not specified, no reasoning parsing is performed.", help="Reasoning parser name for the model. If not specified, no reasoning parsing is performed.",
) )
parser.add_argument(
"--custom-jinja-template",
type=str,
default=None,
help="Path to a custom Jinja template file to override the model's default chat template. This template will take precedence over any template found in the model repository.",
)
args = parser.parse_args() args = parser.parse_args()
...@@ -367,6 +375,15 @@ def cmd_line_args(): ...@@ -367,6 +375,15 @@ def cmd_line_args():
config.reasoning_parser = args.dyn_reasoning_parser config.reasoning_parser = args.dyn_reasoning_parser
config.tool_call_parser = args.dyn_tool_call_parser config.tool_call_parser = args.dyn_tool_call_parser
# Handle custom jinja template path expansion (environment variables and home directory)
if args.custom_jinja_template:
expanded_template_path = os.path.expandvars(
os.path.expanduser(args.custom_jinja_template)
)
config.custom_jinja_template = expanded_template_path
else:
config.custom_jinja_template = None
return config return config
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment