Unverified Commit 23471f9a authored by Cody Yu's avatar Cody Yu Committed by GitHub
Browse files

Support v1/chat/completions (#50)

parent 61d4c939
...@@ -248,6 +248,8 @@ In addition, the server supports an experimental OpenAI-compatible API. ...@@ -248,6 +248,8 @@ In addition, the server supports an experimental OpenAI-compatible API.
import openai import openai
client = openai.Client( client = openai.Client(
base_url="http://127.0.0.1:30000/v1", api_key="EMPTY") base_url="http://127.0.0.1:30000/v1", api_key="EMPTY")
# Text completion
response = client.completions.create( response = client.completions.create(
model="default", model="default",
prompt="The capital of France is", prompt="The capital of France is",
...@@ -255,6 +257,46 @@ response = client.completions.create( ...@@ -255,6 +257,46 @@ response = client.completions.create(
max_tokens=32, max_tokens=32,
) )
print(response) print(response)
# Chat completion
response = client.chat.completions.create(
model="default",
messages=[
{"role": "system", "content": "You are a helpful AI assistant"},
{"role": "user", "content": "List 3 countries and their capitals."},
],
temperature=0,
max_tokens=64,
)
print(response)
```
In above example, the server uses the chat template specified in the model tokenizer.
You can override the chat template if needed when launching the server:
```
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
--chat-template llama-2
```
If the chat template you are looking for is missing, you are welcome to contribute it.
Meanwhile, you can also temporary register your chat template as follows:
```json
{
"name": "my_model",
"system": "<|im_start|>system",
"user": "<|im_start|>user",
"assistant": "<|im_start|>assistant",
"sep_style": "CHATML",
"sep": "<|im_end|>",
"stop_str": ["<|im_end|>", "<|im_start|>"]
}
```
```
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
--chat-template ./my_model_template.json
``` ```
### Additional Arguments ### Additional Arguments
......
# Adapted from
# https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
from sglang.srt.managers.openai_protocol import ChatCompletionRequest
from enum import IntEnum, auto
import dataclasses
from typing import Dict, List, Tuple, Union
class SeparatorStyle(IntEnum):
"""Separator styles."""
ADD_COLON_SINGLE = auto()
ADD_COLON_TWO = auto()
ADD_COLON_SPACE_SINGLE = auto()
NO_COLON_SINGLE = auto()
NO_COLON_TWO = auto()
ADD_NEW_LINE_SINGLE = auto()
LLAMA2 = auto()
CHATGLM = auto()
CHATML = auto()
CHATINTERN = auto()
DOLLY = auto()
RWKV = auto()
PHOENIX = auto()
ROBIN = auto()
FALCON_CHAT = auto()
CHATGLM3 = auto()
DEEPSEEK_CHAT = auto()
METAMATH = auto()
@dataclasses.dataclass
class Conversation:
"""A class that manages prompt templates and keeps all conversation history."""
# The name of this template
name: str
# The template of the system prompt
system_template: str = "{system_message}"
# The system message
system_message: str = ""
# The names of two roles
roles: Tuple[str] = ("USER", "ASSISTANT")
# All messages. Each item is (role, message).
messages: List[List[str]] = ()
# The number of few shot examples
offset: int = 0
# The separator style and configurations
sep_style: SeparatorStyle = SeparatorStyle.ADD_COLON_SINGLE
sep: str = "\n"
sep2: str = None
# Stop criteria (the default one is EOS token)
stop_str: Union[str, List[str]] = None
def get_prompt(self) -> str:
"""Get the prompt for generation."""
system_prompt = self.system_template.format(system_message=self.system_message)
if self.sep_style == SeparatorStyle.ADD_COLON_SINGLE:
ret = system_prompt + self.sep
for role, message in self.messages:
if message:
ret += role + ": " + message + self.sep
else:
ret += role + ":"
return ret
elif self.sep_style == SeparatorStyle.ADD_COLON_TWO:
seps = [self.sep, self.sep2]
ret = system_prompt + seps[0]
for i, (role, message) in enumerate(self.messages):
if message:
ret += role + ": " + message + seps[i % 2]
else:
ret += role + ":"
return ret
elif self.sep_style == SeparatorStyle.ADD_COLON_SPACE_SINGLE:
ret = system_prompt + self.sep
for role, message in self.messages:
if message:
ret += role + ": " + message + self.sep
else:
ret += role + ": " # must be end with a space
return ret
elif self.sep_style == SeparatorStyle.ADD_NEW_LINE_SINGLE:
ret = "" if system_prompt == "" else system_prompt + self.sep
for role, message in self.messages:
if message:
ret += role + "\n" + message + self.sep
else:
ret += role + "\n"
return ret
elif self.sep_style == SeparatorStyle.NO_COLON_SINGLE:
ret = system_prompt
for role, message in self.messages:
if message:
ret += role + message + self.sep
else:
ret += role
return ret
elif self.sep_style == SeparatorStyle.NO_COLON_TWO:
seps = [self.sep, self.sep2]
ret = system_prompt
for i, (role, message) in enumerate(self.messages):
if message:
ret += role + message + seps[i % 2]
else:
ret += role
return ret
elif self.sep_style == SeparatorStyle.RWKV:
ret = system_prompt
for i, (role, message) in enumerate(self.messages):
if message:
ret += role + ": " + message.replace("\r\n", "\n").replace("\n\n", "\n")
ret += "\n\n"
else:
ret += role + ":"
return ret
elif self.sep_style == SeparatorStyle.LLAMA2:
seps = [self.sep, self.sep2]
if self.system_message:
ret = system_prompt
else:
ret = "[INST] "
for i, (role, message) in enumerate(self.messages):
tag = self.roles[i % 2]
if message:
if i == 0:
ret += message + " "
else:
ret += tag + " " + message + seps[i % 2]
else:
ret += tag
return ret
elif self.sep_style == SeparatorStyle.CHATGLM:
# source: https://huggingface.co/THUDM/chatglm-6b/blob/1d240ba371910e9282298d4592532d7f0f3e9f3e/modeling_chatglm.py#L1302-L1308
# source2: https://huggingface.co/THUDM/chatglm2-6b/blob/e186c891cf64310ac66ef10a87e6635fa6c2a579/modeling_chatglm.py#L926
round_add_n = 1 if self.name == "chatglm2" else 0
if system_prompt:
ret = system_prompt + self.sep
else:
ret = ""
for i, (role, message) in enumerate(self.messages):
if i % 2 == 0:
ret += f"[Round {i//2 + round_add_n}]{self.sep}"
if message:
ret += f"{role}{message}{self.sep}"
else:
ret += f"{role}:"
return ret
elif self.sep_style == SeparatorStyle.CHATML:
ret = "" if system_prompt == "" else system_prompt + self.sep + "\n"
for role, message in self.messages:
if message:
ret += role + "\n" + message + self.sep + "\n"
else:
ret += role + "\n"
return ret
elif self.sep_style == SeparatorStyle.CHATGLM3:
ret = ""
if self.system_message:
ret += system_prompt
for role, message in self.messages:
if message:
ret += role + "\n" + message
else:
ret += role
return ret
elif self.sep_style == SeparatorStyle.CHATINTERN:
# source: https://huggingface.co/internlm/internlm-chat-7b-8k/blob/bd546fa984b4b0b86958f56bf37f94aa75ab8831/modeling_internlm.py#L771
seps = [self.sep, self.sep2]
ret = system_prompt
for i, (role, message) in enumerate(self.messages):
if i % 2 == 0:
ret += "<s>"
if message:
ret += role + ":" + message + seps[i % 2] + "\n"
else:
ret += role + ":"
return ret
elif self.sep_style == SeparatorStyle.DOLLY:
seps = [self.sep, self.sep2]
ret = system_prompt
for i, (role, message) in enumerate(self.messages):
if message:
ret += role + ":\n" + message + seps[i % 2]
if i % 2 == 1:
ret += "\n\n"
else:
ret += role + ":\n"
return ret
elif self.sep_style == SeparatorStyle.PHOENIX:
ret = system_prompt
for role, message in self.messages:
if message:
ret += role + ": " + "<s>" + message + "</s>"
else:
ret += role + ": " + "<s>"
return ret
elif self.sep_style == SeparatorStyle.ROBIN:
ret = system_prompt + self.sep
for role, message in self.messages:
if message:
ret += role + ":\n" + message + self.sep
else:
ret += role + ":\n"
return ret
elif self.sep_style == SeparatorStyle.FALCON_CHAT:
ret = ""
if self.system_message:
ret += system_prompt + self.sep
for role, message in self.messages:
if message:
ret += role + ": " + message + self.sep
else:
ret += role + ":"
return ret
elif self.sep_style == SeparatorStyle.METAMATH:
ret = "" if system_prompt == "" else system_prompt + self.sep
for i, (role, message) in enumerate(self.messages):
# For MetaMath, sep2 is used to prefix the message.
starting_sep = ":\n" if i % 2 == 0 else ": " + self.sep2
ending_sep = self.sep if i % 2 == 0 else ""
if message:
ret += role + starting_sep + message + ending_sep
else:
ret += role + starting_sep
return ret
elif self.sep_style == SeparatorStyle.DEEPSEEK_CHAT:
seps = [self.sep, self.sep2]
ret = system_prompt
for i, (role, message) in enumerate(self.messages):
if message:
ret += role + ": " + message + seps[i % 2]
else:
ret += role + ":"
return ret
else:
raise ValueError(f"Invalid style: {self.sep_style}")
def set_system_message(self, system_message: str):
"""Set the system message."""
self.system_message = system_message
def append_message(self, role: str, message: str):
"""Append a new message."""
self.messages.append([role, message])
def update_last_message(self, message: str):
"""Update the last output.
The last message is typically set to be None when constructing the prompt,
so we need to update it in-place after getting the response from a model.
"""
self.messages[-1][1] = message
def to_gradio_chatbot(self):
"""Convert the conversation to gradio chatbot format."""
ret = []
for i, (role, msg) in enumerate(self.messages[self.offset :]):
if i % 2 == 0:
ret.append([msg, None])
else:
ret[-1][-1] = msg
return ret
def to_openai_api_messages(self):
"""Convert the conversation to OpenAI chat completion format."""
if self.system_message == "":
ret = []
else:
ret = [{"role": "system", "content": self.system_message}]
for i, (_, msg) in enumerate(self.messages[self.offset :]):
if i % 2 == 0:
ret.append({"role": "user", "content": msg})
else:
if msg is not None:
ret.append({"role": "assistant", "content": msg})
return ret
def copy(self):
return Conversation(
name=self.name,
system_template=self.system_template,
system_message=self.system_message,
roles=self.roles,
messages=[[x, y] for x, y in self.messages],
offset=self.offset,
sep_style=self.sep_style,
sep=self.sep,
sep2=self.sep2,
stop_str=self.stop_str,
)
def dict(self):
return {
"template_name": self.name,
"system_message": self.system_message,
"roles": self.roles,
"messages": self.messages,
"offset": self.offset,
}
# A global registry for all conversation templates
chat_templates: Dict[str, Conversation] = {}
def register_conv_template(template: Conversation, override: bool = False):
"""Register a new conversation template."""
if not override:
assert template.name not in chat_templates, f"{template.name} has been registered."
chat_templates[template.name] = template
def chat_template_exists(template_name: str) -> bool:
return template_name in chat_templates
def generate_chat_conv(request: ChatCompletionRequest, template_name: str) -> Conversation:
conv = chat_templates[template_name].copy()
conv = Conversation(
name=conv.name,
system_template=conv.system_template,
system_message=conv.system_message,
roles=conv.roles,
messages=list(conv.messages), # prevent in-place modification
offset=conv.offset,
sep_style=SeparatorStyle(conv.sep_style),
sep=conv.sep,
sep2=conv.sep2,
stop_str=conv.stop_str,
)
if isinstance(request.messages, str):
raise ValueError("The messages should be a list of dict.")
for message in request.messages:
msg_role = message["role"]
if msg_role == "system":
conv.system_message = message["content"]
elif msg_role == "user":
conv.append_message(conv.roles[0], message["content"])
elif msg_role == "assistant":
conv.append_message(conv.roles[1], message["content"])
else:
raise ValueError(f"Unknown role: {msg_role}")
# Add a blank message for the assistant.
conv.append_message(conv.roles[1], None)
return conv
# llama2 template
# reference: https://huggingface.co/blog/codellama#conversational-instructions
# reference: https://github.com/facebookresearch/llama/blob/1a240688810f8036049e8da36b073f63d2ac552c/llama/generation.py#L212
register_conv_template(
Conversation(
name="llama-2",
system_template="[INST] <<SYS>>\n{system_message}\n<</SYS>>\n\n",
roles=("[INST]", "[/INST]"),
sep_style=SeparatorStyle.LLAMA2,
sep=" ",
sep2=" </s><s>",
stop_str=["[INST]", "[/INST]", "<<SYS>>", "<</SYS>>"],
)
)
register_conv_template(
Conversation(
name="chatml",
system_template="<|im_start|>system\n{system_message}",
system_message="You are an AI assistant.",
roles=("<|im_start|>user", "<|im_start|>assistant"),
sep_style=SeparatorStyle.CHATML,
sep="<|im_end|>",
stop_str=["<|endoftext|>", "<|im_end|>"],
)
)
...@@ -65,3 +65,59 @@ class CompletionStreamResponse(BaseModel): ...@@ -65,3 +65,59 @@ class CompletionStreamResponse(BaseModel):
created: int = Field(default_factory=lambda: int(time.time())) created: int = Field(default_factory=lambda: int(time.time()))
model: str model: str
choices: List[CompletionResponseStreamChoice] choices: List[CompletionResponseStreamChoice]
usage: UsageInfo
class ChatCompletionRequest(BaseModel):
model: str
messages: Union[str, List[Dict[str, str]]]
temperature: Optional[float] = 0.7
top_p: Optional[float] = 1.0
n: Optional[int] = 1
max_tokens: Optional[int] = 16
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
stream: Optional[bool] = False
presence_penalty: Optional[float] = 0.0
frequency_penalty: Optional[float] = 0.0
logit_bias: Optional[Dict[str, float]] = None
user: Optional[str] = None
best_of: Optional[int] = None
class ChatMessage(BaseModel):
role: Optional[str] = None
content: Optional[str] = None
class ChatCompletionResponseChoice(BaseModel):
index: int
message: ChatMessage
finish_reason: Optional[str] = None
class ChatCompletionResponse(BaseModel):
id: str
object: str = "chat.completion"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[ChatCompletionResponseChoice]
usage: UsageInfo
class DeltaMessage(BaseModel):
role: Optional[str] = None
content: Optional[str] = None
class ChatCompletionResponseStreamChoice(BaseModel):
index: int
delta: DeltaMessage
finish_reason: Optional[str] = None
class ChatCompletionStreamResponse(BaseModel):
id: str
object: str = "chat.completion.chunk"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[ChatCompletionResponseStreamChoice]
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
import asyncio import asyncio
import json import json
import multiprocessing as mp import multiprocessing as mp
import os
import sys import sys
import threading import threading
import time import time
...@@ -17,15 +18,29 @@ import uvloop ...@@ -17,15 +18,29 @@ import uvloop
from fastapi import FastAPI, Request from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from sglang.backend.runtime_endpoint import RuntimeEndpoint from sglang.backend.runtime_endpoint import RuntimeEndpoint
from sglang.srt.conversation import (
Conversation,
SeparatorStyle,
chat_template_exists,
generate_chat_conv,
register_conv_template,
)
from sglang.srt.managers.detokenizer_manager import start_detokenizer_process from sglang.srt.managers.detokenizer_manager import start_detokenizer_process
from sglang.srt.managers.io_struct import GenerateReqInput from sglang.srt.managers.io_struct import GenerateReqInput
from sglang.srt.managers.openai_protocol import ( from sglang.srt.managers.openai_protocol import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseChoice,
ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse,
ChatMessage,
CompletionRequest, CompletionRequest,
CompletionResponse, CompletionResponse,
CompletionResponseChoice, CompletionResponseChoice,
CompletionResponseStreamChoice, CompletionResponseStreamChoice,
CompletionStreamResponse, CompletionStreamResponse,
UsageInfo DeltaMessage,
UsageInfo,
) )
from sglang.srt.managers.router.manager import start_router_process from sglang.srt.managers.router.manager import start_router_process
from sglang.srt.managers.tokenizer_manager import TokenizerManager from sglang.srt.managers.tokenizer_manager import TokenizerManager
...@@ -37,6 +52,7 @@ asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) ...@@ -37,6 +52,7 @@ asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
app = FastAPI() app = FastAPI()
tokenizer_manager = None tokenizer_manager = None
chat_template_name = None
@app.get("/get_model_info") @app.get("/get_model_info")
...@@ -46,6 +62,7 @@ async def get_model_info(): ...@@ -46,6 +62,7 @@ async def get_model_info():
} }
return result return result
async def stream_generator(obj): async def stream_generator(obj):
async for out in tokenizer_manager.generate_request(obj): async for out in tokenizer_manager.generate_request(obj):
yield out yield out
...@@ -61,7 +78,7 @@ async def generate_request(obj: GenerateReqInput): ...@@ -61,7 +78,7 @@ async def generate_request(obj: GenerateReqInput):
async for out in stream_generator(obj): async for out in stream_generator(obj):
yield f"data: {json.dumps(out, ensure_ascii=False)}\n\n" yield f"data: {json.dumps(out, ensure_ascii=False)}\n\n"
yield "data: [DONE]\n\n" yield "data: [DONE]\n\n"
return StreamingResponse(stream_results(), media_type="text/event-stream") return StreamingResponse(stream_results(), media_type="text/event-stream")
ret = await tokenizer_manager.generate_request(obj).__anext__() ret = await tokenizer_manager.generate_request(obj).__anext__()
...@@ -91,11 +108,15 @@ async def v1_completions(raw_request: Request): ...@@ -91,11 +108,15 @@ async def v1_completions(raw_request: Request):
adapted_request.post_init() adapted_request.post_init()
if adapted_request.stream: if adapted_request.stream:
async def gnerate_stream_resp(): async def gnerate_stream_resp():
stream_buffer = "" stream_buffer = ""
async for content in stream_generator(adapted_request): async for content in stream_generator(adapted_request):
text = content["text"] text = content["text"]
delta = text[len(stream_buffer):] prompt_tokens = content["meta_info"]["prompt_tokens"]
completion_tokens = content["meta_info"]["completion_tokens"]
delta = text[len(stream_buffer) :]
stream_buffer = text stream_buffer = text
choice_data = CompletionResponseStreamChoice( choice_data = CompletionResponseStreamChoice(
index=0, index=0,
...@@ -108,12 +129,17 @@ async def v1_completions(raw_request: Request): ...@@ -108,12 +129,17 @@ async def v1_completions(raw_request: Request):
object="text_completion", object="text_completion",
choices=[choice_data], choices=[choice_data],
model=request.model, model=request.model,
usage=UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
),
) )
yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n" yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(gnerate_stream_resp(), media_type="text/event-stream") return StreamingResponse(gnerate_stream_resp(), media_type="text/event-stream")
# Non-streaming response. # Non-streaming response.
ret = await generate_request(adapted_request) ret = await generate_request(adapted_request)
...@@ -121,7 +147,7 @@ async def v1_completions(raw_request: Request): ...@@ -121,7 +147,7 @@ async def v1_completions(raw_request: Request):
index=0, index=0,
text=ret["text"], text=ret["text"],
logprobs=None, logprobs=None,
finish_reason=None, # TODO(comaniac): Add finish reason. finish_reason=None, # TODO(comaniac): Add finish reason.
) )
prompt_tokens = ret["meta_info"]["prompt_tokens"] prompt_tokens = ret["meta_info"]["prompt_tokens"]
...@@ -139,8 +165,108 @@ async def v1_completions(raw_request: Request): ...@@ -139,8 +165,108 @@ async def v1_completions(raw_request: Request):
return response return response
@app.post("/v1/chat/completions")
async def v1_chat_completions(raw_request: Request):
request_json = await raw_request.json()
request = ChatCompletionRequest(**request_json)
# TODO: Validate the request and return HTTPStatus.BAD_REQUEST if invalid.
assert request.n == 1
if not isinstance(request.messages, str):
# Apply chat template and its stop strings.
if chat_template_name is None:
prompt = tokenizer_manager.tokenizer.apply_chat_template(
request.messages, tokenize=False, add_generation_prompt=True
)
stop = request.stop
else:
conv = generate_chat_conv(request, chat_template_name)
prompt = conv.get_prompt()
stop = conv.stop_str or []
if request.stop:
if isinstance(request.stop, str):
stop.append(request.stop)
else:
stop.extend(request.stop)
else:
# Use the raw prompt and stop strings if the messages is already a string.
prompt = request.messages
stop = request.stop
adapted_request = GenerateReqInput(
text=prompt,
sampling_params={
"temperature": request.temperature,
"max_new_tokens": request.max_tokens,
"stop": stop,
"top_p": request.top_p,
"presence_penalty": request.presence_penalty,
"frequency_penalty": request.frequency_penalty,
},
stream=request.stream,
)
adapted_request.post_init()
if adapted_request.stream:
async def gnerate_stream_resp():
is_first = True
stream_buffer = ""
async for content in stream_generator(adapted_request):
if is_first:
# First chunk with role
is_first = False
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(role="assistant"),
finish_reason=None,
)
chunk = ChatCompletionStreamResponse(
id=content["meta_info"]["id"], choices=[choice_data], model=request.model
)
yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n"
text = content["text"]
delta = text[len(stream_buffer) :]
stream_buffer = text
choice_data = ChatCompletionResponseStreamChoice(
index=0, delta=DeltaMessage(content=delta), finish_reason=None
)
chunk = ChatCompletionStreamResponse(
id=content["meta_info"]["id"], choices=[choice_data], model=request.model
)
yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(gnerate_stream_resp(), media_type="text/event-stream")
# Non-streaming response.
ret = await generate_request(adapted_request)
prompt_tokens = ret["meta_info"]["prompt_tokens"]
completion_tokens = ret["meta_info"]["completion_tokens"]
choice_data = ChatCompletionResponseChoice(
index=0,
message=ChatMessage(role="assistant", content=ret["text"]),
finish_reason=None, # TODO(comaniac): Add finish reason.
)
response = ChatCompletionResponse(
id=ret["meta_info"]["id"],
model=request.model,
choices=[choice_data],
usage=UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
),
)
return response
def launch_server(server_args, pipe_finish_writer): def launch_server(server_args, pipe_finish_writer):
global tokenizer_manager global tokenizer_manager
global chat_template_name
# Allocate ports # Allocate ports
can_use_ports = alloc_usable_network_port( can_use_ports = alloc_usable_network_port(
...@@ -154,6 +280,36 @@ def launch_server(server_args, pipe_finish_writer): ...@@ -154,6 +280,36 @@ def launch_server(server_args, pipe_finish_writer):
model_rpc_ports=can_use_ports[4:], model_rpc_ports=can_use_ports[4:],
) )
# Load chat template if needed
if server_args.chat_template is not None:
if not chat_template_exists(server_args.chat_template):
if not os.path.exists(server_args.chat_template):
raise RuntimeError(
f"Chat template {server_args.chat_template} is not a built-in template name "
"or a valid chat template file path."
)
with open(server_args.chat_template, "r") as filep:
template = json.load(filep)
try:
sep_style = SeparatorStyle[template["sep_style"]]
except KeyError:
raise ValueError(f"Unknown separator style: {template['sep_style']}") from None
register_conv_template(
Conversation(
name=template["name"],
system_template=template["system"] + "\n{system_message}",
system_message=template.get("system_message", ""),
roles=(template["user"], template["assistant"]),
sep_style=sep_style,
sep=template.get("sep", "\n"),
stop_str=template["stop_str"],
),
override=True,
)
chat_template_name = template["name"]
else:
chat_template_name = server_args.chat_template
# Launch processes # Launch processes
tokenizer_manager = TokenizerManager(server_args, port_args) tokenizer_manager = TokenizerManager(server_args, port_args)
pipe_router_reader, pipe_router_writer = mp.Pipe(duplex=False) pipe_router_reader, pipe_router_writer = mp.Pipe(duplex=False)
......
...@@ -11,6 +11,7 @@ class ServerArgs: ...@@ -11,6 +11,7 @@ class ServerArgs:
port: int = 30000 port: int = 30000
load_format: str = "auto" load_format: str = "auto"
tokenizer_mode: str = "auto" tokenizer_mode: str = "auto"
chat_template: Optional[str] = None
trust_remote_code: bool = True trust_remote_code: bool = True
mem_fraction_static: Optional[float] = None mem_fraction_static: Optional[float] = None
tp_size: int = 1 tp_size: int = 1
...@@ -77,6 +78,12 @@ class ServerArgs: ...@@ -77,6 +78,12 @@ class ServerArgs:
"tokenizer if available, and 'slow' will " "tokenizer if available, and 'slow' will "
"always use the slow tokenizer.", "always use the slow tokenizer.",
) )
parser.add_argument(
"--chat-template",
type=str,
default=ServerArgs.chat_template,
help="The buliltin chat template name or the path of the chat template file. This is only used for OpenAI-compatible API server",
)
parser.add_argument( parser.add_argument(
"--trust-remote-code", "--trust-remote-code",
action="store_true", action="store_true",
......
""" """
python3 -m sglang.launch_server --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --port 30000 First run the following command to launch the server.
Note that TinyLlama adopts different chat templates in different versions.
For v0.4, the chat template is chatml.
Output: python3 -m sglang.launch_server --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 \
The capital of France is Paris.\nThe capital of the United States is Washington, D.C.\nThe capital of Canada is Ottawa.\nThe capital of Japan is Tokyo --port 30000 --chat-template chatml
Output example:
The capital of France is Paris.
The capital of the United States is Washington, D.C.
The capital of Canada is Ottawa.
The capital of Japan is Tokyo
""" """
import argparse import argparse
...@@ -38,13 +46,57 @@ def test_completion_stream(args): ...@@ -38,13 +46,57 @@ def test_completion_stream(args):
for r in response: for r in response:
print(r.choices[0].text, end="", flush=True) print(r.choices[0].text, end="", flush=True)
assert r.id assert r.id
assert r.created
assert r.usage.prompt_tokens > 0 assert r.usage.prompt_tokens > 0
assert r.usage.completion_tokens > 0 assert r.usage.completion_tokens > 0
assert r.usage.total_tokens > 0 assert r.usage.total_tokens > 0
print() print()
def test_chat_completion(args):
client = openai.Client(api_key="EMPTY", base_url=args.base_url)
response = client.chat.completions.create(
model="default",
messages=[
{"role": "system", "content": "You are a helpful AI assistant"},
{"role": "user", "content": "What is the capital of France?"},
],
temperature=0,
max_tokens=32,
)
print(response.choices[0].message.content)
assert response.id
assert response.created
assert response.usage.prompt_tokens > 0
assert response.usage.completion_tokens > 0
assert response.usage.total_tokens > 0
def test_chat_completion_stream(args):
client = openai.Client(api_key="EMPTY", base_url=args.base_url)
response = client.chat.completions.create(
model="default",
messages=[
{"role": "system", "content": "You are a helpful AI assistant"},
{"role": "user", "content": "List 3 countries and their capitals."},
],
temperature=0,
max_tokens=64,
stream=True,
)
is_first = True
for chunk in response:
if is_first:
is_first = False
assert chunk.choices[0].delta.role == "assistant"
continue
data = chunk.choices[0].delta
if not data.content:
continue
print(data.content, end="", flush=True)
print()
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--base-url", type=str, default="http://127.0.0.1:30000/v1") parser.add_argument("--base-url", type=str, default="http://127.0.0.1:30000/v1")
...@@ -52,3 +104,5 @@ if __name__ == "__main__": ...@@ -52,3 +104,5 @@ if __name__ == "__main__":
test_completion(args) test_completion(args)
test_completion_stream(args) test_completion_stream(args)
test_chat_completion(args)
test_chat_completion_stream(args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment