Unverified Commit 2c3ea294 authored by woodx's avatar woodx Committed by GitHub
Browse files

[Feature] support auto chat template (#4949)

parent 5bb0accb
......@@ -17,7 +17,7 @@
# https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
import dataclasses
from enum import IntEnum, auto
from typing import Dict, List, Optional, Tuple, Union
from typing import Callable, Dict, List, Optional, Tuple, Union
from sglang.srt.openai_api.protocol import ChatCompletionRequest
......@@ -407,6 +407,7 @@ class Conversation:
# A global registry for all conversation templates
chat_templates: Dict[str, Conversation] = {}
matching_function_registry: List[Callable] = []
def register_conv_template(template: Conversation, override: bool = False):
......@@ -419,6 +420,18 @@ def register_conv_template(template: Conversation, override: bool = False):
chat_templates[template.name] = template
def register_conv_template_matching_function(func):
matching_function_registry.append(func)
def get_conv_template_by_model_path(model_path):
for matching_func in matching_function_registry:
conv_name = matching_func(model_path)
if conv_name is not None:
return conv_name
return None
def chat_template_exists(template_name: str) -> bool:
return template_name in chat_templates
......@@ -792,3 +805,86 @@ register_conv_template(
audio_token="(<audio>./</audio>)",
)
)
@register_conv_template_matching_function
def match_deepseek_janus_pro(model_path: str):
if (
"llama" in model_path.lower()
and "3.2" in model_path.lower()
and "vision" in model_path.lower()
):
return "llama_3_vision"
@register_conv_template_matching_function
def match_deepseek_janus_pro(model_path: str):
if "janus" in model_path.lower():
return "janus-pro"
@register_conv_template_matching_function
def match_vicuna(model_path: str):
if "vicuna" in model_path.lower():
return "vicuna_v1.1"
if "llava-v1.5" in model_path.lower():
return "vicuna_v1.1"
if "llava-next-video-7b" in model_path.lower():
return "vicuna_v1.1"
@register_conv_template_matching_function
def match_llama2_chat(model_path: str):
model_path = model_path.lower()
if "llama-2" in model_path and "chat" in model_path:
return "llama-2"
if (
"mistral" in model_path or "mixtral" in model_path
) and "instruct" in model_path:
return "llama-2"
if "codellama" in model_path and "instruct" in model_path:
return "llama-2"
@register_conv_template_matching_function
def match_deepseek_vl(model_path: str):
model_path = model_path.lower()
if "deepseek" in model_path and "vl2" in model_path:
return "deepseek-vl2"
@register_conv_template_matching_function
def match_chat_ml(model_path: str):
# import pdb;pdb.set_trace()
model_path = model_path.lower()
# Now the suffix for qwen2 chat model is "instruct"
if "gme" in model_path and "qwen" in model_path and "vl" in model_path:
return "gme-qwen2-vl"
if "qwen" in model_path and "vl" in model_path:
return "qwen2-vl"
if (
"llava-v1.6-34b" in model_path
or "llava-v1.6-yi-34b" in model_path
or "llava-next-video-34b" in model_path
or "llava-onevision-qwen2" in model_path
):
return "chatml-llava"
@register_conv_template_matching_function
def match_gemma_it(model_path: str):
model_path = model_path.lower()
if "gemma" in model_path and "it" in model_path:
return "gemma-it"
if "gemma-3" in model_path and "1b" not in model_path:
# gemma-3-1b-it is completion model
return "gemma-it"
@register_conv_template_matching_function
def match_openbmb_minicpm(model_path: str):
model_path = model_path.lower()
if "minicpm-v" in model_path:
return "minicpmv"
elif "minicpm-o" in model_path:
return "minicpmo"
......@@ -58,7 +58,10 @@ from sglang.srt.managers.io_struct import (
)
from sglang.srt.managers.scheduler import run_scheduler_process
from sglang.srt.managers.tokenizer_manager import TokenizerManager
from sglang.srt.openai_api.adapter import load_chat_template_for_openai_api
from sglang.srt.openai_api.adapter import (
guess_chat_template_name_from_model_path,
load_chat_template_for_openai_api,
)
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
from sglang.srt.utils import (
......@@ -584,6 +587,8 @@ def _launch_subprocesses(
load_chat_template_for_openai_api(
tokenizer_manager, server_args.chat_template, server_args.model_path
)
else:
guess_chat_template_name_from_model_path(server_args.model_path)
if server_args.completion_template:
load_completion_template_for_openai_api(server_args.completion_template)
......
......@@ -36,6 +36,7 @@ from sglang.srt.conversation import (
chat_template_exists,
generate_chat_conv,
generate_embedding_convs,
get_conv_template_by_model_path,
register_conv_template,
)
from sglang.srt.function_call_parser import FunctionCallParser
......@@ -163,10 +164,14 @@ def load_chat_template_for_openai_api(tokenizer_manager, chat_template_arg, mode
else:
chat_template_name = chat_template_arg
# Check chat-template
# TODO:
# 1. Do not import any code from sglang.lang
# 2. For VLM, when chat_template_arg is None, set it automatically by guessing from model_path.
def guess_chat_template_name_from_model_path(model_path):
global chat_template_name
chat_template_name = get_conv_template_by_model_path(model_path)
if chat_template_name is not None:
logger.info(
f"Infer the chat template name from the model path and obtain the result: {chat_template_name}."
)
async def v1_files_create(
......
......@@ -47,11 +47,6 @@ class TestOpenAIVisionServer(CustomTestCase):
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
api_key=cls.api_key,
other_args=[
"--chat-template",
"chatml-llava",
# "--log-requests",
],
)
cls.base_url += "/v1"
......@@ -475,8 +470,6 @@ class TestQwen2VLServer(TestOpenAIVisionServer):
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
api_key=cls.api_key,
other_args=[
"--chat-template",
"qwen2-vl",
"--mem-fraction-static",
"0.4",
],
......@@ -496,8 +489,6 @@ class TestQwen2_5_VLServer(TestOpenAIVisionServer):
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
api_key=cls.api_key,
other_args=[
"--chat-template",
"qwen2-vl",
"--mem-fraction-static",
"0.4",
],
......@@ -517,8 +508,6 @@ class TestVLMContextLengthIssue(CustomTestCase):
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
api_key=cls.api_key,
other_args=[
"--chat-template",
"qwen2-vl",
"--context-length",
"300",
"--mem-fraction-static=0.80",
......@@ -573,10 +562,6 @@ class TestMllamaServer(TestOpenAIVisionServer):
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
api_key=cls.api_key,
other_args=[
"--chat-template",
"llama_3_vision",
],
)
cls.base_url += "/v1"
......@@ -596,8 +581,6 @@ class TestMinicpmvServer(TestOpenAIVisionServer):
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--trust-remote-code",
"--chat-template",
"minicpmv",
"--mem-fraction-static",
"0.4",
],
......@@ -617,8 +600,6 @@ class TestMinicpmoServer(TestOpenAIVisionServer):
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--trust-remote-code",
"--chat-template",
"minicpmo",
"--mem-fraction-static",
"0.7",
],
......@@ -642,8 +623,6 @@ class TestDeepseekVL2Server(TestOpenAIVisionServer):
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--trust-remote-code",
"--chat-template",
"deepseek-vl2",
"--context-length",
"4096",
],
......@@ -690,8 +669,6 @@ class TestJanusProServer(TestOpenAIVisionServer):
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--trust-remote-code",
"--chat-template",
"janus-pro",
"--mem-fraction-static",
"0.4",
],
......@@ -744,8 +721,6 @@ class TestGemma3itServer(TestOpenAIVisionServer):
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--trust-remote-code",
"--chat-template",
"gemma-it",
"--mem-fraction-static",
"0.75",
"--enable-multimodal",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment