Unverified Commit 2c3ea294 authored by woodx's avatar woodx Committed by GitHub
Browse files

[Feature] support auto chat template (#4949)

parent 5bb0accb
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
# https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py # https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
import dataclasses import dataclasses
from enum import IntEnum, auto from enum import IntEnum, auto
from typing import Dict, List, Optional, Tuple, Union from typing import Callable, Dict, List, Optional, Tuple, Union
from sglang.srt.openai_api.protocol import ChatCompletionRequest from sglang.srt.openai_api.protocol import ChatCompletionRequest
...@@ -407,6 +407,7 @@ class Conversation: ...@@ -407,6 +407,7 @@ class Conversation:
# A global registry for all conversation templates # A global registry for all conversation templates
chat_templates: Dict[str, Conversation] = {} chat_templates: Dict[str, Conversation] = {}
matching_function_registry: List[Callable] = []
def register_conv_template(template: Conversation, override: bool = False): def register_conv_template(template: Conversation, override: bool = False):
...@@ -419,6 +420,18 @@ def register_conv_template(template: Conversation, override: bool = False): ...@@ -419,6 +420,18 @@ def register_conv_template(template: Conversation, override: bool = False):
chat_templates[template.name] = template chat_templates[template.name] = template
def register_conv_template_matching_function(func):
matching_function_registry.append(func)
def get_conv_template_by_model_path(model_path):
for matching_func in matching_function_registry:
conv_name = matching_func(model_path)
if conv_name is not None:
return conv_name
return None
def chat_template_exists(template_name: str) -> bool: def chat_template_exists(template_name: str) -> bool:
return template_name in chat_templates return template_name in chat_templates
...@@ -792,3 +805,86 @@ register_conv_template( ...@@ -792,3 +805,86 @@ register_conv_template(
audio_token="(<audio>./</audio>)", audio_token="(<audio>./</audio>)",
) )
) )
@register_conv_template_matching_function
def match_deepseek_janus_pro(model_path: str):
if (
"llama" in model_path.lower()
and "3.2" in model_path.lower()
and "vision" in model_path.lower()
):
return "llama_3_vision"
@register_conv_template_matching_function
def match_deepseek_janus_pro(model_path: str):
if "janus" in model_path.lower():
return "janus-pro"
@register_conv_template_matching_function
def match_vicuna(model_path: str):
if "vicuna" in model_path.lower():
return "vicuna_v1.1"
if "llava-v1.5" in model_path.lower():
return "vicuna_v1.1"
if "llava-next-video-7b" in model_path.lower():
return "vicuna_v1.1"
@register_conv_template_matching_function
def match_llama2_chat(model_path: str):
model_path = model_path.lower()
if "llama-2" in model_path and "chat" in model_path:
return "llama-2"
if (
"mistral" in model_path or "mixtral" in model_path
) and "instruct" in model_path:
return "llama-2"
if "codellama" in model_path and "instruct" in model_path:
return "llama-2"
@register_conv_template_matching_function
def match_deepseek_vl(model_path: str):
model_path = model_path.lower()
if "deepseek" in model_path and "vl2" in model_path:
return "deepseek-vl2"
@register_conv_template_matching_function
def match_chat_ml(model_path: str):
# import pdb;pdb.set_trace()
model_path = model_path.lower()
# Now the suffix for qwen2 chat model is "instruct"
if "gme" in model_path and "qwen" in model_path and "vl" in model_path:
return "gme-qwen2-vl"
if "qwen" in model_path and "vl" in model_path:
return "qwen2-vl"
if (
"llava-v1.6-34b" in model_path
or "llava-v1.6-yi-34b" in model_path
or "llava-next-video-34b" in model_path
or "llava-onevision-qwen2" in model_path
):
return "chatml-llava"
@register_conv_template_matching_function
def match_gemma_it(model_path: str):
model_path = model_path.lower()
if "gemma" in model_path and "it" in model_path:
return "gemma-it"
if "gemma-3" in model_path and "1b" not in model_path:
# gemma-3-1b-it is completion model
return "gemma-it"
@register_conv_template_matching_function
def match_openbmb_minicpm(model_path: str):
model_path = model_path.lower()
if "minicpm-v" in model_path:
return "minicpmv"
elif "minicpm-o" in model_path:
return "minicpmo"
...@@ -58,7 +58,10 @@ from sglang.srt.managers.io_struct import ( ...@@ -58,7 +58,10 @@ from sglang.srt.managers.io_struct import (
) )
from sglang.srt.managers.scheduler import run_scheduler_process from sglang.srt.managers.scheduler import run_scheduler_process
from sglang.srt.managers.tokenizer_manager import TokenizerManager from sglang.srt.managers.tokenizer_manager import TokenizerManager
from sglang.srt.openai_api.adapter import load_chat_template_for_openai_api from sglang.srt.openai_api.adapter import (
guess_chat_template_name_from_model_path,
load_chat_template_for_openai_api,
)
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
from sglang.srt.utils import ( from sglang.srt.utils import (
...@@ -584,6 +587,8 @@ def _launch_subprocesses( ...@@ -584,6 +587,8 @@ def _launch_subprocesses(
load_chat_template_for_openai_api( load_chat_template_for_openai_api(
tokenizer_manager, server_args.chat_template, server_args.model_path tokenizer_manager, server_args.chat_template, server_args.model_path
) )
else:
guess_chat_template_name_from_model_path(server_args.model_path)
if server_args.completion_template: if server_args.completion_template:
load_completion_template_for_openai_api(server_args.completion_template) load_completion_template_for_openai_api(server_args.completion_template)
......
...@@ -36,6 +36,7 @@ from sglang.srt.conversation import ( ...@@ -36,6 +36,7 @@ from sglang.srt.conversation import (
chat_template_exists, chat_template_exists,
generate_chat_conv, generate_chat_conv,
generate_embedding_convs, generate_embedding_convs,
get_conv_template_by_model_path,
register_conv_template, register_conv_template,
) )
from sglang.srt.function_call_parser import FunctionCallParser from sglang.srt.function_call_parser import FunctionCallParser
...@@ -163,10 +164,14 @@ def load_chat_template_for_openai_api(tokenizer_manager, chat_template_arg, mode ...@@ -163,10 +164,14 @@ def load_chat_template_for_openai_api(tokenizer_manager, chat_template_arg, mode
else: else:
chat_template_name = chat_template_arg chat_template_name = chat_template_arg
# Check chat-template
# TODO: def guess_chat_template_name_from_model_path(model_path):
# 1. Do not import any code from sglang.lang global chat_template_name
# 2. For VLM, when chat_template_arg is None, set it automatically by guessing from model_path. chat_template_name = get_conv_template_by_model_path(model_path)
if chat_template_name is not None:
logger.info(
f"Infer the chat template name from the model path and obtain the result: {chat_template_name}."
)
async def v1_files_create( async def v1_files_create(
......
...@@ -47,11 +47,6 @@ class TestOpenAIVisionServer(CustomTestCase): ...@@ -47,11 +47,6 @@ class TestOpenAIVisionServer(CustomTestCase):
cls.base_url, cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
api_key=cls.api_key, api_key=cls.api_key,
other_args=[
"--chat-template",
"chatml-llava",
# "--log-requests",
],
) )
cls.base_url += "/v1" cls.base_url += "/v1"
...@@ -475,8 +470,6 @@ class TestQwen2VLServer(TestOpenAIVisionServer): ...@@ -475,8 +470,6 @@ class TestQwen2VLServer(TestOpenAIVisionServer):
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
api_key=cls.api_key, api_key=cls.api_key,
other_args=[ other_args=[
"--chat-template",
"qwen2-vl",
"--mem-fraction-static", "--mem-fraction-static",
"0.4", "0.4",
], ],
...@@ -496,8 +489,6 @@ class TestQwen2_5_VLServer(TestOpenAIVisionServer): ...@@ -496,8 +489,6 @@ class TestQwen2_5_VLServer(TestOpenAIVisionServer):
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
api_key=cls.api_key, api_key=cls.api_key,
other_args=[ other_args=[
"--chat-template",
"qwen2-vl",
"--mem-fraction-static", "--mem-fraction-static",
"0.4", "0.4",
], ],
...@@ -517,8 +508,6 @@ class TestVLMContextLengthIssue(CustomTestCase): ...@@ -517,8 +508,6 @@ class TestVLMContextLengthIssue(CustomTestCase):
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
api_key=cls.api_key, api_key=cls.api_key,
other_args=[ other_args=[
"--chat-template",
"qwen2-vl",
"--context-length", "--context-length",
"300", "300",
"--mem-fraction-static=0.80", "--mem-fraction-static=0.80",
...@@ -573,10 +562,6 @@ class TestMllamaServer(TestOpenAIVisionServer): ...@@ -573,10 +562,6 @@ class TestMllamaServer(TestOpenAIVisionServer):
cls.base_url, cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
api_key=cls.api_key, api_key=cls.api_key,
other_args=[
"--chat-template",
"llama_3_vision",
],
) )
cls.base_url += "/v1" cls.base_url += "/v1"
...@@ -596,8 +581,6 @@ class TestMinicpmvServer(TestOpenAIVisionServer): ...@@ -596,8 +581,6 @@ class TestMinicpmvServer(TestOpenAIVisionServer):
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[ other_args=[
"--trust-remote-code", "--trust-remote-code",
"--chat-template",
"minicpmv",
"--mem-fraction-static", "--mem-fraction-static",
"0.4", "0.4",
], ],
...@@ -617,8 +600,6 @@ class TestMinicpmoServer(TestOpenAIVisionServer): ...@@ -617,8 +600,6 @@ class TestMinicpmoServer(TestOpenAIVisionServer):
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[ other_args=[
"--trust-remote-code", "--trust-remote-code",
"--chat-template",
"minicpmo",
"--mem-fraction-static", "--mem-fraction-static",
"0.7", "0.7",
], ],
...@@ -642,8 +623,6 @@ class TestDeepseekVL2Server(TestOpenAIVisionServer): ...@@ -642,8 +623,6 @@ class TestDeepseekVL2Server(TestOpenAIVisionServer):
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[ other_args=[
"--trust-remote-code", "--trust-remote-code",
"--chat-template",
"deepseek-vl2",
"--context-length", "--context-length",
"4096", "4096",
], ],
...@@ -690,8 +669,6 @@ class TestJanusProServer(TestOpenAIVisionServer): ...@@ -690,8 +669,6 @@ class TestJanusProServer(TestOpenAIVisionServer):
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[ other_args=[
"--trust-remote-code", "--trust-remote-code",
"--chat-template",
"janus-pro",
"--mem-fraction-static", "--mem-fraction-static",
"0.4", "0.4",
], ],
...@@ -744,8 +721,6 @@ class TestGemma3itServer(TestOpenAIVisionServer): ...@@ -744,8 +721,6 @@ class TestGemma3itServer(TestOpenAIVisionServer):
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[ other_args=[
"--trust-remote-code", "--trust-remote-code",
"--chat-template",
"gemma-it",
"--mem-fraction-static", "--mem-fraction-static",
"0.75", "0.75",
"--enable-multimodal", "--enable-multimodal",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment