Unverified Commit 1d0fbe8e authored by Keith Stevens's avatar Keith Stevens Committed by GitHub
Browse files

[Feature] Adds basic support for image content in OpenAI chat routes (#113)

parent 97aa9b32
...@@ -357,7 +357,7 @@ python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port ...@@ -357,7 +357,7 @@ python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port
- Mistral - Mistral
- Mixtral - Mixtral
- LLaVA - LLaVA
- `python3 -m sglang.launch_server --model-path liuhaotian/llava-v1.5-7b --tokenizer-path llava-hf/llava-1.5-7b-hf --port 30000` - `python3 -m sglang.launch_server --model-path liuhaotian/llava-v1.5-7b --tokenizer-path llava-hf/llava-1.5-7b-hf --chat-template vicuna_v1.1 --port 30000`
- Qwen / Qwen 2 - Qwen / Qwen 2
- AWQ quantization - AWQ quantization
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py # https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
import dataclasses import dataclasses
from enum import IntEnum, auto from enum import IntEnum, auto
from typing import Dict, List, Tuple, Union from typing import Dict, List, Optional, Tuple, Union
from sglang.srt.managers.openai_protocol import ChatCompletionRequest from sglang.srt.managers.openai_protocol import ChatCompletionRequest
...@@ -52,6 +52,7 @@ class Conversation: ...@@ -52,6 +52,7 @@ class Conversation:
sep2: str = None sep2: str = None
# Stop criteria (the default one is EOS token) # Stop criteria (the default one is EOS token)
stop_str: Union[str, List[str]] = None stop_str: Union[str, List[str]] = None
image_data: Optional[List[str]] = None
def get_prompt(self) -> str: def get_prompt(self) -> str:
"""Get the prompt for generation.""" """Get the prompt for generation."""
...@@ -251,6 +252,10 @@ class Conversation: ...@@ -251,6 +252,10 @@ class Conversation:
"""Append a new message.""" """Append a new message."""
self.messages.append([role, message]) self.messages.append([role, message])
def append_image(self, image: str):
"""Append a new message."""
self.image_data.append(image)
def update_last_message(self, message: str): def update_last_message(self, message: str):
"""Update the last output. """Update the last output.
...@@ -341,18 +346,31 @@ def generate_chat_conv( ...@@ -341,18 +346,31 @@ def generate_chat_conv(
sep=conv.sep, sep=conv.sep,
sep2=conv.sep2, sep2=conv.sep2,
stop_str=conv.stop_str, stop_str=conv.stop_str,
image_data=[],
) )
if isinstance(request.messages, str): if isinstance(request.messages, str):
raise ValueError("The messages should be a list of dict.") raise ValueError("The messages should be a list of dict.")
for message in request.messages: for message in request.messages:
msg_role = message["role"] msg_role = message.role
if msg_role == "system": if msg_role == "system":
conv.system_message = message["content"] conv.system_message = message.content
elif msg_role == "user": elif msg_role == "user":
conv.append_message(conv.roles[0], message["content"]) # Handle the various types of Chat Request content types here.
role = conv.roles[0]
if isinstance(message.content, str):
conv.append_message(conv.roles[0], message.content)
else:
real_content = ""
for content in message.content:
if content.type == "text":
real_content += content.text
elif content.type == "image_url":
real_content += "<image>"
conv.append_image(content.image_url.url)
conv.append_message(conv.roles[0], real_content)
elif msg_role == "assistant": elif msg_role == "assistant":
conv.append_message(conv.roles[1], message["content"]) conv.append_message(conv.roles[1], message.content)
else: else:
raise ValueError(f"Unknown role: {msg_role}") raise ValueError(f"Unknown role: {msg_role}")
......
import time import time
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union
from typing_extensions import Literal
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
...@@ -68,9 +69,44 @@ class CompletionStreamResponse(BaseModel): ...@@ -68,9 +69,44 @@ class CompletionStreamResponse(BaseModel):
usage: UsageInfo usage: UsageInfo
class ChatCompletionMessageGenericParam(BaseModel):
role: Literal["system", "assistant"]
content: str
class ChatCompletionMessageContentTextPart(BaseModel):
type: Literal["text"]
text: str
class ChatCompletionMessageContentImageURL(BaseModel):
url: str
detail: Optional[Literal["auto", "low", "high"]] = "auto"
class ChatCompletionMessageContentImagePart(BaseModel):
type: Literal["image_url"]
image_url: ChatCompletionMessageContentImageURL
ChatCompletionMessageContentPart = Union[
ChatCompletionMessageContentTextPart, ChatCompletionMessageContentImagePart
]
class ChatCompletionMessageUserParam(BaseModel):
role: Literal["user"]
content: Union[str, List[ChatCompletionMessageContentPart]]
ChatCompletionMessageParam = Union[
ChatCompletionMessageGenericParam, ChatCompletionMessageUserParam
]
class ChatCompletionRequest(BaseModel): class ChatCompletionRequest(BaseModel):
model: str model: str
messages: Union[str, List[Dict[str, str]]] messages: Union[str, List[ChatCompletionMessageParam]]
temperature: Optional[float] = 0.7 temperature: Optional[float] = 0.7
top_p: Optional[float] = 1.0 top_p: Optional[float] = 1.0
n: Optional[int] = 1 n: Optional[int] = 1
......
...@@ -150,12 +150,17 @@ class TokenizerManager: ...@@ -150,12 +150,17 @@ class TokenizerManager:
if sampling_params.max_new_tokens != 0: if sampling_params.max_new_tokens != 0:
sampling_params.normalize(self.tokenizer) sampling_params.normalize(self.tokenizer)
sampling_params.verify() sampling_params.verify()
if obj.image_data is None:
pixel_values, image_hash, image_size = None, None, None if isinstance(obj.image_data, list) and len(obj.image_data) > 0:
else: pixel_values, image_hash, image_size = await self.get_pixel_values(
obj.image_data[0]
)
elif isinstance(obj.image_data, str):
pixel_values, image_hash, image_size = await self.get_pixel_values( pixel_values, image_hash, image_size = await self.get_pixel_values(
obj.image_data obj.image_data
) )
else:
pixel_values, image_hash, image_size = None, None, None
tokenized_obj = TokenizedGenerateReqInput( tokenized_obj = TokenizedGenerateReqInput(
rid=rid, rid=rid,
input_text=obj.text, input_text=obj.text,
......
...@@ -16,7 +16,7 @@ import psutil ...@@ -16,7 +16,7 @@ import psutil
import requests import requests
import uvicorn import uvicorn
import uvloop import uvloop
from fastapi import FastAPI, Request from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import Response, StreamingResponse from fastapi.responses import Response, StreamingResponse
from sglang.backend.runtime_endpoint import RuntimeEndpoint from sglang.backend.runtime_endpoint import RuntimeEndpoint
from sglang.srt.conversation import ( from sglang.srt.conversation import (
...@@ -190,16 +190,31 @@ async def v1_chat_completions(raw_request: Request): ...@@ -190,16 +190,31 @@ async def v1_chat_completions(raw_request: Request):
# TODO: Validate the request and return HTTPStatus.BAD_REQUEST if invalid. # TODO: Validate the request and return HTTPStatus.BAD_REQUEST if invalid.
assert request.n == 1 assert request.n == 1
# Prep the data needed for the underlying GenerateReqInput:
# - prompt: The full prompt string.
# - stop: Custom stop tokens.
# - image_data: None or a list of image strings (URLs or base64 strings).
# None skips any image processing in GenerateReqInput.
if not isinstance(request.messages, str): if not isinstance(request.messages, str):
# Apply chat template and its stop strings. # Apply chat template and its stop strings.
if chat_template_name is None: if chat_template_name is None:
# This flow doesn't support the full OpenAI spec. Verify messages
# has the right type before proceeding:
for m in request.messages:
if not isinstance(m.content, str):
raise HTTPException(
status_code=503,
detail="Structured content requests not supported with HuggingFace Chat Templates. Make sure the server specifies a sglang chat template.",
)
prompt = tokenizer_manager.tokenizer.apply_chat_template( prompt = tokenizer_manager.tokenizer.apply_chat_template(
request.messages, tokenize=False, add_generation_prompt=True request.messages, tokenize=False, add_generation_prompt=True
) )
stop = request.stop stop = request.stop
image_data = None
else: else:
conv = generate_chat_conv(request, chat_template_name) conv = generate_chat_conv(request, chat_template_name)
prompt = conv.get_prompt() prompt = conv.get_prompt()
image_data = conv.image_data
stop = conv.stop_str or [] stop = conv.stop_str or []
if request.stop: if request.stop:
if isinstance(request.stop, str): if isinstance(request.stop, str):
...@@ -210,9 +225,11 @@ async def v1_chat_completions(raw_request: Request): ...@@ -210,9 +225,11 @@ async def v1_chat_completions(raw_request: Request):
# Use the raw prompt and stop strings if the messages is already a string. # Use the raw prompt and stop strings if the messages is already a string.
prompt = request.messages prompt = request.messages
stop = request.stop stop = request.stop
image_data = None
adapted_request = GenerateReqInput( adapted_request = GenerateReqInput(
text=prompt, text=prompt,
image_data=image_data,
sampling_params={ sampling_params={
"temperature": request.temperature, "temperature": request.temperature,
"max_new_tokens": request.max_tokens, "max_new_tokens": request.max_tokens,
...@@ -303,6 +320,7 @@ def launch_server(server_args, pipe_finish_writer): ...@@ -303,6 +320,7 @@ def launch_server(server_args, pipe_finish_writer):
# Load chat template if needed # Load chat template if needed
if server_args.chat_template is not None: if server_args.chat_template is not None:
print(server_args.chat_template)
if not chat_template_exists(server_args.chat_template): if not chat_template_exists(server_args.chat_template):
if not os.path.exists(server_args.chat_template): if not os.path.exists(server_args.chat_template):
raise RuntimeError( raise RuntimeError(
......
from sglang.srt.conversation import generate_chat_conv
from sglang.srt.managers.openai_protocol import (
ChatCompletionMessageGenericParam,
ChatCompletionMessageContentImagePart,
ChatCompletionMessageContentImageURL,
ChatCompletionMessageContentTextPart,
ChatCompletionMessageUserParam,
ChatCompletionRequest,
)
def test_chat_completion_to_conv_image():
"""Test that we can convert a chat image request to a convo"""
request = ChatCompletionRequest(
model="default",
messages=[
ChatCompletionMessageGenericParam(
role="system", content="You are a helpful AI assistant"
),
ChatCompletionMessageUserParam(
role="user",
content=[
ChatCompletionMessageContentTextPart(
type="text", text="Describe this image"
),
ChatCompletionMessageContentImagePart(
type="image_url",
image_url=ChatCompletionMessageContentImageURL(
url="https://someurl.com"
),
),
],
),
],
)
conv = generate_chat_conv(request, "vicuna_v1.1")
assert conv.messages == [
["USER", "Describe this image<image>"],
["ASSISTANT", None],
]
assert conv.system_message == "You are a helpful AI assistant"
assert conv.image_data == ["https://someurl.com"]
if __name__ == "__main__":
test_chat_completion_to_conv_image()
from sglang.srt.managers.openai_protocol import (
ChatCompletionMessageGenericParam,
ChatCompletionMessageContentImagePart,
ChatCompletionMessageContentImageURL,
ChatCompletionMessageContentTextPart,
ChatCompletionMessageUserParam,
ChatCompletionRequest,
)
def test_chat_completion_request_image():
"""Test that Chat Completion Requests with images can be converted."""
image_request = {
"model": "default",
"messages": [
{"role": "system", "content": "You are a helpful AI assistant"},
{
"role": "user",
"content": [
{"type": "text", "text": "Describe this image"},
{"type": "image_url", "image_url": {"url": "https://someurl.com"}},
],
},
],
"temperature": 0,
"max_tokens": 64,
}
request = ChatCompletionRequest(**image_request)
assert len(request.messages) == 2
assert request.messages[0] == ChatCompletionMessageGenericParam(
role="system", content="You are a helpful AI assistant"
)
assert request.messages[1] == ChatCompletionMessageUserParam(
role="user",
content=[
ChatCompletionMessageContentTextPart(
type="text", text="Describe this image"
),
ChatCompletionMessageContentImagePart(
type="image_url",
image_url=ChatCompletionMessageContentImageURL(
url="https://someurl.com"
),
),
],
)
if __name__ == "__main__":
test_chat_completion_request_image()
...@@ -71,6 +71,36 @@ def test_chat_completion(args): ...@@ -71,6 +71,36 @@ def test_chat_completion(args):
assert response.usage.total_tokens > 0 assert response.usage.total_tokens > 0
def test_chat_completion_image(args):
client = openai.Client(api_key="EMPTY", base_url=args.base_url)
response = client.chat.completions.create(
model="default",
messages=[
{"role": "system", "content": "You are a helpful AI assistant"},
{
"role": "user",
"content": [
{"type": "text", "text": "Describe this image"},
{
"type": "image_url",
"image_url": {
"url": "https://raw.githubusercontent.com/sgl-project/sglang/main/assets/mixtral_8x7b.jpg"
},
},
],
},
],
temperature=0,
max_tokens=32,
)
print(response.choices[0].message.content)
assert response.id
assert response.created
assert response.usage.prompt_tokens > 0
assert response.usage.completion_tokens > 0
assert response.usage.total_tokens > 0
def test_chat_completion_stream(args): def test_chat_completion_stream(args):
client = openai.Client(api_key="EMPTY", base_url=args.base_url) client = openai.Client(api_key="EMPTY", base_url=args.base_url)
response = client.chat.completions.create( response = client.chat.completions.create(
...@@ -100,9 +130,14 @@ def test_chat_completion_stream(args): ...@@ -100,9 +130,14 @@ def test_chat_completion_stream(args):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--base-url", type=str, default="http://127.0.0.1:30000/v1") parser.add_argument("--base-url", type=str, default="http://127.0.0.1:30000/v1")
parser.add_argument(
"--test-image", action="store_true", help="Enables testing image inputs"
)
args = parser.parse_args() args = parser.parse_args()
test_completion(args) test_completion(args)
test_completion_stream(args) test_completion_stream(args)
test_chat_completion(args) test_chat_completion(args)
test_chat_completion_stream(args) test_chat_completion_stream(args)
if args.test_image:
test_chat_completion_image(args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment