"docs/vscode:/vscode.git/clone" did not exist on "9eaace9216e10790c76e7675741daefa92ae1b59"
Unverified Commit 51c554d8 authored by Christopher Chou's avatar Christopher Chou Committed by GitHub
Browse files

Allow more flexible assistant and system response (#1256)

parent 79ece2c5
...@@ -386,7 +386,16 @@ def generate_chat_conv( ...@@ -386,7 +386,16 @@ def generate_chat_conv(
for message in request.messages: for message in request.messages:
msg_role = message.role msg_role = message.role
if msg_role == "system": if msg_role == "system":
conv.system_message = message.content if isinstance(message.content, str):
conv.system_message = message.content
elif isinstance(message.content, list):
if (
len(message.content) != 1
or getattr(message.content[0], "type", None) != "text"
):
raise ValueError("The system message should be a single text.")
else:
conv.system_message = getattr(message.content[0], "text", "")
elif msg_role == "user": elif msg_role == "user":
# Handle the various types of Chat Request content types here. # Handle the various types of Chat Request content types here.
role = conv.roles[0] role = conv.roles[0]
...@@ -414,7 +423,20 @@ def generate_chat_conv( ...@@ -414,7 +423,20 @@ def generate_chat_conv(
conv.append_image(content.image_url.url) conv.append_image(content.image_url.url)
conv.append_message(conv.roles[0], real_content) conv.append_message(conv.roles[0], real_content)
elif msg_role == "assistant": elif msg_role == "assistant":
conv.append_message(conv.roles[1], message.content) parsed_content = ""
if isinstance(message.content, str):
parsed_content = message.content
elif isinstance(message.content, list):
if (
len(message.content) != 1
or getattr(message.content[0], "type", None) != "text"
):
raise ValueError(
"The assistant's response should be a single text."
)
else:
parsed_content = getattr(message.content[0], "text", "")
conv.append_message(conv.roles[1], parsed_content)
else: else:
raise ValueError(f"Unknown role: {msg_role}") raise ValueError(f"Unknown role: {msg_role}")
......
...@@ -844,8 +844,23 @@ def v1_chat_generate_request( ...@@ -844,8 +844,23 @@ def v1_chat_generate_request(
if not isinstance(request.messages, str): if not isinstance(request.messages, str):
# Apply chat template and its stop strings. # Apply chat template and its stop strings.
if chat_template_name is None: if chat_template_name is None:
openai_compatible_messages = []
for message in request.messages:
if isinstance(message.content, str):
openai_compatible_messages.append(
{"role": message.role, "content": message.content}
)
else:
content_list = message.dict()["content"]
for content in content_list:
if content["type"] == "text":
openai_compatible_messages.append(
{"role": message.role, "content": content["text"]}
)
prompt_ids = tokenizer_manager.tokenizer.apply_chat_template( prompt_ids = tokenizer_manager.tokenizer.apply_chat_template(
request.messages, tokenize=True, add_generation_prompt=True openai_compatible_messages,
tokenize=True,
add_generation_prompt=True,
) )
stop = request.stop stop = request.stop
image_data = None image_data = None
......
...@@ -200,11 +200,6 @@ class CompletionStreamResponse(BaseModel): ...@@ -200,11 +200,6 @@ class CompletionStreamResponse(BaseModel):
usage: Optional[UsageInfo] = None usage: Optional[UsageInfo] = None
class ChatCompletionMessageGenericParam(BaseModel):
role: Literal["system", "assistant"]
content: str
class ChatCompletionMessageContentTextPart(BaseModel): class ChatCompletionMessageContentTextPart(BaseModel):
type: Literal["text"] type: Literal["text"]
text: str text: str
...@@ -225,6 +220,11 @@ ChatCompletionMessageContentPart = Union[ ...@@ -225,6 +220,11 @@ ChatCompletionMessageContentPart = Union[
] ]
class ChatCompletionMessageGenericParam(BaseModel):
role: Literal["system", "assistant"]
content: Union[str, List[ChatCompletionMessageContentTextPart]]
class ChatCompletionMessageUserParam(BaseModel): class ChatCompletionMessageUserParam(BaseModel):
role: Literal["user"] role: Literal["user"]
content: Union[str, List[ChatCompletionMessageContentPart]] content: Union[str, List[ChatCompletionMessageContentPart]]
......
...@@ -76,6 +76,56 @@ class TestOpenAIVisionServer(unittest.TestCase): ...@@ -76,6 +76,56 @@ class TestOpenAIVisionServer(unittest.TestCase):
assert response.usage.completion_tokens > 0 assert response.usage.completion_tokens > 0
assert response.usage.total_tokens > 0 assert response.usage.total_tokens > 0
def test_multi_turn_chat_completion(self):
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
response = client.chat.completions.create(
model="default",
messages=[
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {
"url": "https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true"
},
},
{
"type": "text",
"text": "Describe this image in a very short sentence.",
},
],
},
{
"role": "assistant",
"content": [
{
"type": "text",
"text": "There is a man at the back of a yellow cab ironing his clothes.",
}
],
},
{
"role": "user",
"content": [
{"type": "text", "text": "Repeat your previous answer."}
],
},
],
temperature=0,
)
assert response.choices[0].message.role == "assistant"
text = response.choices[0].message.content
assert isinstance(text, str)
assert "man" in text or "cab" in text, text
assert response.id
assert response.created
assert response.usage.prompt_tokens > 0
assert response.usage.completion_tokens > 0
assert response.usage.total_tokens > 0
def test_mult_images_chat_completion(self): def test_mult_images_chat_completion(self):
client = openai.Client(api_key=self.api_key, base_url=self.base_url) client = openai.Client(api_key=self.api_key, base_url=self.base_url)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment