Unverified Commit 1d0fbe8e authored by Keith Stevens's avatar Keith Stevens Committed by GitHub
Browse files

[Feature] Adds basic support for image content in OpenAI chat routes (#113)

parent 97aa9b32
......@@ -357,7 +357,7 @@ python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port
- Mistral
- Mixtral
- LLaVA
- `python3 -m sglang.launch_server --model-path liuhaotian/llava-v1.5-7b --tokenizer-path llava-hf/llava-1.5-7b-hf --port 30000`
- `python3 -m sglang.launch_server --model-path liuhaotian/llava-v1.5-7b --tokenizer-path llava-hf/llava-1.5-7b-hf --chat-template vicuna_v1.1 --port 30000`
- Qwen / Qwen 2
- AWQ quantization
......
......@@ -2,7 +2,7 @@
# https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
import dataclasses
from enum import IntEnum, auto
from typing import Dict, List, Tuple, Union
from typing import Dict, List, Optional, Tuple, Union
from sglang.srt.managers.openai_protocol import ChatCompletionRequest
......@@ -52,6 +52,7 @@ class Conversation:
sep2: str = None
# Stop criteria (the default one is EOS token)
stop_str: Union[str, List[str]] = None
image_data: Optional[List[str]] = None
def get_prompt(self) -> str:
"""Get the prompt for generation."""
......@@ -251,6 +252,10 @@ class Conversation:
"""Append a new message."""
self.messages.append([role, message])
def append_image(self, image: str):
"""Append a new message."""
self.image_data.append(image)
def update_last_message(self, message: str):
"""Update the last output.
......@@ -341,18 +346,31 @@ def generate_chat_conv(
sep=conv.sep,
sep2=conv.sep2,
stop_str=conv.stop_str,
image_data=[],
)
if isinstance(request.messages, str):
raise ValueError("The messages should be a list of dict.")
for message in request.messages:
msg_role = message["role"]
msg_role = message.role
if msg_role == "system":
conv.system_message = message["content"]
conv.system_message = message.content
elif msg_role == "user":
conv.append_message(conv.roles[0], message["content"])
# Handle the various types of Chat Request content types here.
role = conv.roles[0]
if isinstance(message.content, str):
conv.append_message(conv.roles[0], message.content)
else:
real_content = ""
for content in message.content:
if content.type == "text":
real_content += content.text
elif content.type == "image_url":
real_content += "<image>"
conv.append_image(content.image_url.url)
conv.append_message(conv.roles[0], real_content)
elif msg_role == "assistant":
conv.append_message(conv.roles[1], message["content"])
conv.append_message(conv.roles[1], message.content)
else:
raise ValueError(f"Unknown role: {msg_role}")
......
import time
from typing import Dict, List, Optional, Union
from typing_extensions import Literal
from pydantic import BaseModel, Field
......@@ -68,9 +69,44 @@ class CompletionStreamResponse(BaseModel):
usage: UsageInfo
class ChatCompletionMessageGenericParam(BaseModel):
role: Literal["system", "assistant"]
content: str
class ChatCompletionMessageContentTextPart(BaseModel):
type: Literal["text"]
text: str
class ChatCompletionMessageContentImageURL(BaseModel):
url: str
detail: Optional[Literal["auto", "low", "high"]] = "auto"
class ChatCompletionMessageContentImagePart(BaseModel):
type: Literal["image_url"]
image_url: ChatCompletionMessageContentImageURL
ChatCompletionMessageContentPart = Union[
ChatCompletionMessageContentTextPart, ChatCompletionMessageContentImagePart
]
class ChatCompletionMessageUserParam(BaseModel):
role: Literal["user"]
content: Union[str, List[ChatCompletionMessageContentPart]]
ChatCompletionMessageParam = Union[
ChatCompletionMessageGenericParam, ChatCompletionMessageUserParam
]
class ChatCompletionRequest(BaseModel):
model: str
messages: Union[str, List[Dict[str, str]]]
messages: Union[str, List[ChatCompletionMessageParam]]
temperature: Optional[float] = 0.7
top_p: Optional[float] = 1.0
n: Optional[int] = 1
......
......@@ -150,12 +150,17 @@ class TokenizerManager:
if sampling_params.max_new_tokens != 0:
sampling_params.normalize(self.tokenizer)
sampling_params.verify()
if obj.image_data is None:
pixel_values, image_hash, image_size = None, None, None
else:
if isinstance(obj.image_data, list) and len(obj.image_data) > 0:
pixel_values, image_hash, image_size = await self.get_pixel_values(
obj.image_data[0]
)
elif isinstance(obj.image_data, str):
pixel_values, image_hash, image_size = await self.get_pixel_values(
obj.image_data
)
else:
pixel_values, image_hash, image_size = None, None, None
tokenized_obj = TokenizedGenerateReqInput(
rid=rid,
input_text=obj.text,
......
......@@ -16,7 +16,7 @@ import psutil
import requests
import uvicorn
import uvloop
from fastapi import FastAPI, Request
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import Response, StreamingResponse
from sglang.backend.runtime_endpoint import RuntimeEndpoint
from sglang.srt.conversation import (
......@@ -190,16 +190,31 @@ async def v1_chat_completions(raw_request: Request):
# TODO: Validate the request and return HTTPStatus.BAD_REQUEST if invalid.
assert request.n == 1
# Prep the data needed for the underlying GenerateReqInput:
# - prompt: The full prompt string.
# - stop: Custom stop tokens.
# - image_data: None or a list of image strings (URLs or base64 strings).
# None skips any image processing in GenerateReqInput.
if not isinstance(request.messages, str):
# Apply chat template and its stop strings.
if chat_template_name is None:
# This flow doesn't support the full OpenAI spec. Verify messages
# has the right type before proceeding:
for m in request.messages:
if not isinstance(m.content, str):
raise HTTPException(
status_code=503,
detail="Structured content requests not supported with HuggingFace Chat Templates. Make sure the server specifies a sglang chat template.",
)
prompt = tokenizer_manager.tokenizer.apply_chat_template(
request.messages, tokenize=False, add_generation_prompt=True
)
stop = request.stop
image_data = None
else:
conv = generate_chat_conv(request, chat_template_name)
prompt = conv.get_prompt()
image_data = conv.image_data
stop = conv.stop_str or []
if request.stop:
if isinstance(request.stop, str):
......@@ -210,9 +225,11 @@ async def v1_chat_completions(raw_request: Request):
# Use the raw prompt and stop strings if the messages is already a string.
prompt = request.messages
stop = request.stop
image_data = None
adapted_request = GenerateReqInput(
text=prompt,
image_data=image_data,
sampling_params={
"temperature": request.temperature,
"max_new_tokens": request.max_tokens,
......@@ -303,6 +320,7 @@ def launch_server(server_args, pipe_finish_writer):
# Load chat template if needed
if server_args.chat_template is not None:
print(server_args.chat_template)
if not chat_template_exists(server_args.chat_template):
if not os.path.exists(server_args.chat_template):
raise RuntimeError(
......
from sglang.srt.conversation import generate_chat_conv
from sglang.srt.managers.openai_protocol import (
ChatCompletionMessageGenericParam,
ChatCompletionMessageContentImagePart,
ChatCompletionMessageContentImageURL,
ChatCompletionMessageContentTextPart,
ChatCompletionMessageUserParam,
ChatCompletionRequest,
)
def test_chat_completion_to_conv_image():
"""Test that we can convert a chat image request to a convo"""
request = ChatCompletionRequest(
model="default",
messages=[
ChatCompletionMessageGenericParam(
role="system", content="You are a helpful AI assistant"
),
ChatCompletionMessageUserParam(
role="user",
content=[
ChatCompletionMessageContentTextPart(
type="text", text="Describe this image"
),
ChatCompletionMessageContentImagePart(
type="image_url",
image_url=ChatCompletionMessageContentImageURL(
url="https://someurl.com"
),
),
],
),
],
)
conv = generate_chat_conv(request, "vicuna_v1.1")
assert conv.messages == [
["USER", "Describe this image<image>"],
["ASSISTANT", None],
]
assert conv.system_message == "You are a helpful AI assistant"
assert conv.image_data == ["https://someurl.com"]
if __name__ == "__main__":
test_chat_completion_to_conv_image()
from sglang.srt.managers.openai_protocol import (
ChatCompletionMessageGenericParam,
ChatCompletionMessageContentImagePart,
ChatCompletionMessageContentImageURL,
ChatCompletionMessageContentTextPart,
ChatCompletionMessageUserParam,
ChatCompletionRequest,
)
def test_chat_completion_request_image():
"""Test that Chat Completion Requests with images can be converted."""
image_request = {
"model": "default",
"messages": [
{"role": "system", "content": "You are a helpful AI assistant"},
{
"role": "user",
"content": [
{"type": "text", "text": "Describe this image"},
{"type": "image_url", "image_url": {"url": "https://someurl.com"}},
],
},
],
"temperature": 0,
"max_tokens": 64,
}
request = ChatCompletionRequest(**image_request)
assert len(request.messages) == 2
assert request.messages[0] == ChatCompletionMessageGenericParam(
role="system", content="You are a helpful AI assistant"
)
assert request.messages[1] == ChatCompletionMessageUserParam(
role="user",
content=[
ChatCompletionMessageContentTextPart(
type="text", text="Describe this image"
),
ChatCompletionMessageContentImagePart(
type="image_url",
image_url=ChatCompletionMessageContentImageURL(
url="https://someurl.com"
),
),
],
)
if __name__ == "__main__":
test_chat_completion_request_image()
......@@ -71,6 +71,36 @@ def test_chat_completion(args):
assert response.usage.total_tokens > 0
def test_chat_completion_image(args):
client = openai.Client(api_key="EMPTY", base_url=args.base_url)
response = client.chat.completions.create(
model="default",
messages=[
{"role": "system", "content": "You are a helpful AI assistant"},
{
"role": "user",
"content": [
{"type": "text", "text": "Describe this image"},
{
"type": "image_url",
"image_url": {
"url": "https://raw.githubusercontent.com/sgl-project/sglang/main/assets/mixtral_8x7b.jpg"
},
},
],
},
],
temperature=0,
max_tokens=32,
)
print(response.choices[0].message.content)
assert response.id
assert response.created
assert response.usage.prompt_tokens > 0
assert response.usage.completion_tokens > 0
assert response.usage.total_tokens > 0
def test_chat_completion_stream(args):
client = openai.Client(api_key="EMPTY", base_url=args.base_url)
response = client.chat.completions.create(
......@@ -100,9 +130,14 @@ def test_chat_completion_stream(args):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--base-url", type=str, default="http://127.0.0.1:30000/v1")
parser.add_argument(
"--test-image", action="store_true", help="Enables testing image inputs"
)
args = parser.parse_args()
test_completion(args)
test_completion_stream(args)
test_chat_completion(args)
test_chat_completion_stream(args)
if args.test_image:
test_chat_completion_image(args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment