Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
2c3ea294
Unverified
Commit
2c3ea294
authored
Apr 29, 2025
by
woodx
Committed by
GitHub
Apr 28, 2025
Browse files
[Feature] support auto chat template (#4949)
parent
5bb0accb
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
112 additions
and
31 deletions
+112
-31
python/sglang/srt/conversation.py
python/sglang/srt/conversation.py
+97
-1
python/sglang/srt/entrypoints/engine.py
python/sglang/srt/entrypoints/engine.py
+6
-1
python/sglang/srt/openai_api/adapter.py
python/sglang/srt/openai_api/adapter.py
+9
-4
test/srt/test_vision_openai_server.py
test/srt/test_vision_openai_server.py
+0
-25
No files found.
python/sglang/srt/conversation.py
View file @
2c3ea294
...
@@ -17,7 +17,7 @@
...
@@ -17,7 +17,7 @@
# https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
# https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
import
dataclasses
import
dataclasses
from
enum
import
IntEnum
,
auto
from
enum
import
IntEnum
,
auto
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Callable
,
Dict
,
List
,
Optional
,
Tuple
,
Union
from
sglang.srt.openai_api.protocol
import
ChatCompletionRequest
from
sglang.srt.openai_api.protocol
import
ChatCompletionRequest
...
@@ -407,6 +407,7 @@ class Conversation:
...
@@ -407,6 +407,7 @@ class Conversation:
# A global registry for all conversation templates
# A global registry for all conversation templates
chat_templates
:
Dict
[
str
,
Conversation
]
=
{}
chat_templates
:
Dict
[
str
,
Conversation
]
=
{}
matching_function_registry
:
List
[
Callable
]
=
[]
def
register_conv_template
(
template
:
Conversation
,
override
:
bool
=
False
):
def
register_conv_template
(
template
:
Conversation
,
override
:
bool
=
False
):
...
@@ -419,6 +420,18 @@ def register_conv_template(template: Conversation, override: bool = False):
...
@@ -419,6 +420,18 @@ def register_conv_template(template: Conversation, override: bool = False):
chat_templates
[
template
.
name
]
=
template
chat_templates
[
template
.
name
]
=
template
def
register_conv_template_matching_function
(
func
):
matching_function_registry
.
append
(
func
)
def
get_conv_template_by_model_path
(
model_path
):
for
matching_func
in
matching_function_registry
:
conv_name
=
matching_func
(
model_path
)
if
conv_name
is
not
None
:
return
conv_name
return
None
def
chat_template_exists
(
template_name
:
str
)
->
bool
:
def
chat_template_exists
(
template_name
:
str
)
->
bool
:
return
template_name
in
chat_templates
return
template_name
in
chat_templates
...
@@ -792,3 +805,86 @@ register_conv_template(
...
@@ -792,3 +805,86 @@ register_conv_template(
audio_token
=
"(<audio>./</audio>)"
,
audio_token
=
"(<audio>./</audio>)"
,
)
)
)
)
@
register_conv_template_matching_function
def
match_deepseek_janus_pro
(
model_path
:
str
):
if
(
"llama"
in
model_path
.
lower
()
and
"3.2"
in
model_path
.
lower
()
and
"vision"
in
model_path
.
lower
()
):
return
"llama_3_vision"
@
register_conv_template_matching_function
def
match_deepseek_janus_pro
(
model_path
:
str
):
if
"janus"
in
model_path
.
lower
():
return
"janus-pro"
@
register_conv_template_matching_function
def
match_vicuna
(
model_path
:
str
):
if
"vicuna"
in
model_path
.
lower
():
return
"vicuna_v1.1"
if
"llava-v1.5"
in
model_path
.
lower
():
return
"vicuna_v1.1"
if
"llava-next-video-7b"
in
model_path
.
lower
():
return
"vicuna_v1.1"
@
register_conv_template_matching_function
def
match_llama2_chat
(
model_path
:
str
):
model_path
=
model_path
.
lower
()
if
"llama-2"
in
model_path
and
"chat"
in
model_path
:
return
"llama-2"
if
(
"mistral"
in
model_path
or
"mixtral"
in
model_path
)
and
"instruct"
in
model_path
:
return
"llama-2"
if
"codellama"
in
model_path
and
"instruct"
in
model_path
:
return
"llama-2"
@
register_conv_template_matching_function
def
match_deepseek_vl
(
model_path
:
str
):
model_path
=
model_path
.
lower
()
if
"deepseek"
in
model_path
and
"vl2"
in
model_path
:
return
"deepseek-vl2"
@
register_conv_template_matching_function
def
match_chat_ml
(
model_path
:
str
):
# import pdb;pdb.set_trace()
model_path
=
model_path
.
lower
()
# Now the suffix for qwen2 chat model is "instruct"
if
"gme"
in
model_path
and
"qwen"
in
model_path
and
"vl"
in
model_path
:
return
"gme-qwen2-vl"
if
"qwen"
in
model_path
and
"vl"
in
model_path
:
return
"qwen2-vl"
if
(
"llava-v1.6-34b"
in
model_path
or
"llava-v1.6-yi-34b"
in
model_path
or
"llava-next-video-34b"
in
model_path
or
"llava-onevision-qwen2"
in
model_path
):
return
"chatml-llava"
@
register_conv_template_matching_function
def
match_gemma_it
(
model_path
:
str
):
model_path
=
model_path
.
lower
()
if
"gemma"
in
model_path
and
"it"
in
model_path
:
return
"gemma-it"
if
"gemma-3"
in
model_path
and
"1b"
not
in
model_path
:
# gemma-3-1b-it is completion model
return
"gemma-it"
@
register_conv_template_matching_function
def
match_openbmb_minicpm
(
model_path
:
str
):
model_path
=
model_path
.
lower
()
if
"minicpm-v"
in
model_path
:
return
"minicpmv"
elif
"minicpm-o"
in
model_path
:
return
"minicpmo"
python/sglang/srt/entrypoints/engine.py
View file @
2c3ea294
...
@@ -58,7 +58,10 @@ from sglang.srt.managers.io_struct import (
...
@@ -58,7 +58,10 @@ from sglang.srt.managers.io_struct import (
)
)
from
sglang.srt.managers.scheduler
import
run_scheduler_process
from
sglang.srt.managers.scheduler
import
run_scheduler_process
from
sglang.srt.managers.tokenizer_manager
import
TokenizerManager
from
sglang.srt.managers.tokenizer_manager
import
TokenizerManager
from
sglang.srt.openai_api.adapter
import
load_chat_template_for_openai_api
from
sglang.srt.openai_api.adapter
import
(
guess_chat_template_name_from_model_path
,
load_chat_template_for_openai_api
,
)
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.srt.torch_memory_saver_adapter
import
TorchMemorySaverAdapter
from
sglang.srt.torch_memory_saver_adapter
import
TorchMemorySaverAdapter
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
...
@@ -584,6 +587,8 @@ def _launch_subprocesses(
...
@@ -584,6 +587,8 @@ def _launch_subprocesses(
load_chat_template_for_openai_api
(
load_chat_template_for_openai_api
(
tokenizer_manager
,
server_args
.
chat_template
,
server_args
.
model_path
tokenizer_manager
,
server_args
.
chat_template
,
server_args
.
model_path
)
)
else
:
guess_chat_template_name_from_model_path
(
server_args
.
model_path
)
if
server_args
.
completion_template
:
if
server_args
.
completion_template
:
load_completion_template_for_openai_api
(
server_args
.
completion_template
)
load_completion_template_for_openai_api
(
server_args
.
completion_template
)
...
...
python/sglang/srt/openai_api/adapter.py
View file @
2c3ea294
...
@@ -36,6 +36,7 @@ from sglang.srt.conversation import (
...
@@ -36,6 +36,7 @@ from sglang.srt.conversation import (
chat_template_exists
,
chat_template_exists
,
generate_chat_conv
,
generate_chat_conv
,
generate_embedding_convs
,
generate_embedding_convs
,
get_conv_template_by_model_path
,
register_conv_template
,
register_conv_template
,
)
)
from
sglang.srt.function_call_parser
import
FunctionCallParser
from
sglang.srt.function_call_parser
import
FunctionCallParser
...
@@ -163,10 +164,14 @@ def load_chat_template_for_openai_api(tokenizer_manager, chat_template_arg, mode
...
@@ -163,10 +164,14 @@ def load_chat_template_for_openai_api(tokenizer_manager, chat_template_arg, mode
else
:
else
:
chat_template_name
=
chat_template_arg
chat_template_name
=
chat_template_arg
# Check chat-template
# TODO:
def
guess_chat_template_name_from_model_path
(
model_path
):
# 1. Do not import any code from sglang.lang
global
chat_template_name
# 2. For VLM, when chat_template_arg is None, set it automatically by guessing from model_path.
chat_template_name
=
get_conv_template_by_model_path
(
model_path
)
if
chat_template_name
is
not
None
:
logger
.
info
(
f
"Infer the chat template name from the model path and obtain the result:
{
chat_template_name
}
."
)
async
def
v1_files_create
(
async
def
v1_files_create
(
...
...
test/srt/test_vision_openai_server.py
View file @
2c3ea294
...
@@ -47,11 +47,6 @@ class TestOpenAIVisionServer(CustomTestCase):
...
@@ -47,11 +47,6 @@ class TestOpenAIVisionServer(CustomTestCase):
cls
.
base_url
,
cls
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
api_key
=
cls
.
api_key
,
api_key
=
cls
.
api_key
,
other_args
=
[
"--chat-template"
,
"chatml-llava"
,
# "--log-requests",
],
)
)
cls
.
base_url
+=
"/v1"
cls
.
base_url
+=
"/v1"
...
@@ -475,8 +470,6 @@ class TestQwen2VLServer(TestOpenAIVisionServer):
...
@@ -475,8 +470,6 @@ class TestQwen2VLServer(TestOpenAIVisionServer):
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
api_key
=
cls
.
api_key
,
api_key
=
cls
.
api_key
,
other_args
=
[
other_args
=
[
"--chat-template"
,
"qwen2-vl"
,
"--mem-fraction-static"
,
"--mem-fraction-static"
,
"0.4"
,
"0.4"
,
],
],
...
@@ -496,8 +489,6 @@ class TestQwen2_5_VLServer(TestOpenAIVisionServer):
...
@@ -496,8 +489,6 @@ class TestQwen2_5_VLServer(TestOpenAIVisionServer):
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
api_key
=
cls
.
api_key
,
api_key
=
cls
.
api_key
,
other_args
=
[
other_args
=
[
"--chat-template"
,
"qwen2-vl"
,
"--mem-fraction-static"
,
"--mem-fraction-static"
,
"0.4"
,
"0.4"
,
],
],
...
@@ -517,8 +508,6 @@ class TestVLMContextLengthIssue(CustomTestCase):
...
@@ -517,8 +508,6 @@ class TestVLMContextLengthIssue(CustomTestCase):
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
api_key
=
cls
.
api_key
,
api_key
=
cls
.
api_key
,
other_args
=
[
other_args
=
[
"--chat-template"
,
"qwen2-vl"
,
"--context-length"
,
"--context-length"
,
"300"
,
"300"
,
"--mem-fraction-static=0.80"
,
"--mem-fraction-static=0.80"
,
...
@@ -573,10 +562,6 @@ class TestMllamaServer(TestOpenAIVisionServer):
...
@@ -573,10 +562,6 @@ class TestMllamaServer(TestOpenAIVisionServer):
cls
.
base_url
,
cls
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
api_key
=
cls
.
api_key
,
api_key
=
cls
.
api_key
,
other_args
=
[
"--chat-template"
,
"llama_3_vision"
,
],
)
)
cls
.
base_url
+=
"/v1"
cls
.
base_url
+=
"/v1"
...
@@ -596,8 +581,6 @@ class TestMinicpmvServer(TestOpenAIVisionServer):
...
@@ -596,8 +581,6 @@ class TestMinicpmvServer(TestOpenAIVisionServer):
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
[
other_args
=
[
"--trust-remote-code"
,
"--trust-remote-code"
,
"--chat-template"
,
"minicpmv"
,
"--mem-fraction-static"
,
"--mem-fraction-static"
,
"0.4"
,
"0.4"
,
],
],
...
@@ -617,8 +600,6 @@ class TestMinicpmoServer(TestOpenAIVisionServer):
...
@@ -617,8 +600,6 @@ class TestMinicpmoServer(TestOpenAIVisionServer):
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
[
other_args
=
[
"--trust-remote-code"
,
"--trust-remote-code"
,
"--chat-template"
,
"minicpmo"
,
"--mem-fraction-static"
,
"--mem-fraction-static"
,
"0.7"
,
"0.7"
,
],
],
...
@@ -642,8 +623,6 @@ class TestDeepseekVL2Server(TestOpenAIVisionServer):
...
@@ -642,8 +623,6 @@ class TestDeepseekVL2Server(TestOpenAIVisionServer):
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
[
other_args
=
[
"--trust-remote-code"
,
"--trust-remote-code"
,
"--chat-template"
,
"deepseek-vl2"
,
"--context-length"
,
"--context-length"
,
"4096"
,
"4096"
,
],
],
...
@@ -690,8 +669,6 @@ class TestJanusProServer(TestOpenAIVisionServer):
...
@@ -690,8 +669,6 @@ class TestJanusProServer(TestOpenAIVisionServer):
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
[
other_args
=
[
"--trust-remote-code"
,
"--trust-remote-code"
,
"--chat-template"
,
"janus-pro"
,
"--mem-fraction-static"
,
"--mem-fraction-static"
,
"0.4"
,
"0.4"
,
],
],
...
@@ -744,8 +721,6 @@ class TestGemma3itServer(TestOpenAIVisionServer):
...
@@ -744,8 +721,6 @@ class TestGemma3itServer(TestOpenAIVisionServer):
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
[
other_args
=
[
"--trust-remote-code"
,
"--trust-remote-code"
,
"--chat-template"
,
"gemma-it"
,
"--mem-fraction-static"
,
"--mem-fraction-static"
,
"0.75"
,
"0.75"
,
"--enable-multimodal"
,
"--enable-multimodal"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment