Unverified Commit 49b26e2c authored by Ricardo Lu's avatar Ricardo Lu Committed by GitHub
Browse files

feat: add ChatCompletion endpoint in OpenAI demo server. (#330)

parent dafd924c
...@@ -4,7 +4,7 @@ import argparse ...@@ -4,7 +4,7 @@ import argparse
from http import HTTPStatus from http import HTTPStatus
import json import json
import time import time
from typing import AsyncGenerator, Dict, List, Optional from typing import AsyncGenerator, Dict, List, Optional, Union, Any
import fastapi import fastapi
from fastapi import BackgroundTasks, Request from fastapi import BackgroundTasks, Request
...@@ -17,8 +17,12 @@ from vllm.engine.arg_utils import AsyncEngineArgs ...@@ -17,8 +17,12 @@ from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.protocol import ( from vllm.entrypoints.openai.protocol import (
CompletionRequest, CompletionResponse, CompletionResponseChoice, CompletionRequest, CompletionResponse, CompletionResponseChoice,
CompletionResponseStreamChoice, CompletionStreamResponse, ErrorResponse, CompletionResponseStreamChoice, CompletionStreamResponse,
LogProbs, ModelCard, ModelList, ModelPermission, UsageInfo) ChatCompletionRequest, ChatCompletionResponse, ChatCompletionResponseChoice,
ChatCompletionResponseStreamChoice, ChatCompletionStreamResponse,
ChatMessage, DeltaMessage, ErrorResponse, LogProbs,
ModelCard, ModelList, ModelPermission, UsageInfo)
from fastchat.conversation import Conversation, SeparatorStyle, get_conv_template
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
...@@ -55,6 +59,70 @@ async def check_model(request) -> Optional[JSONResponse]: ...@@ -55,6 +59,70 @@ async def check_model(request) -> Optional[JSONResponse]:
return ret return ret
async def get_gen_prompt(request) -> str:
conv = get_conv_template(request.model)
conv = Conversation(
name=conv.name,
system=conv.system,
roles=conv.roles,
messages=list(conv.messages), # prevent in-place modification
offset=conv.offset,
sep_style=SeparatorStyle(conv.sep_style),
sep=conv.sep,
sep2=conv.sep2,
stop_str=conv.stop_str,
stop_token_ids=conv.stop_token_ids,
)
if isinstance(request.messages, str):
prompt = request.messages
else:
for message in request.messages:
msg_role = message["role"]
if msg_role == "system":
conv.system = message["content"]
elif msg_role == "user":
conv.append_message(conv.roles[0], message["content"])
elif msg_role == "assistant":
conv.append_message(conv.roles[1], message["content"])
else:
raise ValueError(f"Unknown role: {msg_role}")
# Add a blank message for the assistant.
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
return prompt
async def check_length(request, prompt, engine):
if hasattr(engine.engine.model_config.hf_config, "max_sequence_length"):
context_len = engine.engine.model_config.hf_config.max_sequence_length
elif hasattr(engine.engine.model_config.hf_config, "seq_length"):
context_len = engine.engine.model_config.hf_config.seq_length
elif hasattr(engine.engine.model_config.hf_config, "max_position_embeddings"):
context_len = engine.engine.model_config.hf_config.max_position_embeddings
elif hasattr(engine.engine.model_config.hf_config, "seq_length"):
context_len = engine.engine.model_config.hf_config.seq_length
else:
context_len = 2048
input_ids = tokenizer(prompt).input_ids
token_num = len(input_ids)
if token_num + request.max_tokens > context_len:
return create_error_response(
HTTPStatus.BAD_REQUEST,
f"This model's maximum context length is {context_len} tokens. "
f"However, you requested {request.max_tokens + token_num} tokens "
f"({token_num} in the messages, "
f"{request.max_tokens} in the completion). "
f"Please reduce the length of the messages or completion.",
)
else:
return None
@app.get("/v1/models") @app.get("/v1/models")
async def show_available_models(): async def show_available_models():
"""Show available models. Right now we only have one model.""" """Show available models. Right now we only have one model."""
...@@ -85,6 +153,171 @@ def create_logprobs(token_ids: List[int], ...@@ -85,6 +153,171 @@ def create_logprobs(token_ids: List[int],
return logprobs return logprobs
@app.post("/v1/chat/completions")
async def create_chat_completion(raw_request: Request):
"""Completion API similar to OpenAI's API.
See https://platform.openai.com/docs/api-reference/chat/create
for the API specification. This API mimics the OpenAI ChatCompletion API.
NOTE: Currently we do not support the following features:
- function_call (Users should implement this by themselves)
- logit_bias (to be supported by vLLM engine)
"""
request = ChatCompletionRequest(**await raw_request.json())
logger.info(f"Received chat completion request: {request}")
error_check_ret = await check_model(request)
if error_check_ret is not None:
return error_check_ret
if request.logit_bias is not None:
# TODO: support logit_bias in vLLM engine.
return create_error_response(HTTPStatus.BAD_REQUEST,
"logit_bias is not currently supported")
prompt = await get_gen_prompt(request)
error_check_ret = await check_length(request, prompt, engine)
if error_check_ret is not None:
return error_check_ret
model_name = request.model
request_id = f"cmpl-{random_uuid()}"
created_time = int(time.time())
try:
sampling_params = SamplingParams(
n=request.n,
presence_penalty=request.presence_penalty,
frequency_penalty=request.frequency_penalty,
temperature=request.temperature,
top_p=request.top_p,
stop=request.stop,
max_tokens=request.max_tokens,
best_of=request.best_of,
top_k=request.top_k,
ignore_eos=request.ignore_eos,
use_beam_search=request.use_beam_search,
)
except ValueError as e:
return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
result_generator = engine.generate(prompt, sampling_params,
request_id)
async def abort_request() -> None:
await engine.abort(request_id)
def create_stream_response_json(index: int,
text: str,
finish_reason: Optional[str] = None) -> str:
choice_data = ChatCompletionResponseStreamChoice(
index=index,
delta=DeltaMessage(content=text),
finish_reason=finish_reason,
)
response = ChatCompletionStreamResponse(
id=request_id,
created=created_time,
model=model_name,
choices=[choice_data],
)
response_json = response.json(ensure_ascii=False)
return response_json
async def completion_stream_generator() -> AsyncGenerator[str, None]:
# First chunk with role
for i in range(request.n):
choice_data = ChatCompletionResponseStreamChoice(
index=i,
delta=DeltaMessage(role="assistant"),
finish_reason=None,
)
chunk = ChatCompletionStreamResponse(
id=request_id, choices=[choice_data], model=model_name
)
yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n"
previous_texts = [""] * request.n
previous_num_tokens = [0] * request.n
async for res in result_generator:
res: RequestOutput
for output in res.outputs:
i = output.index
delta_text = output.text[len(previous_texts[i]):]
previous_texts[i] = output.text
previous_num_tokens[i] = len(output.token_ids)
response_json = create_stream_response_json(
index=i,
text=delta_text,
)
yield f"data: {response_json}\n\n"
if output.finish_reason is not None:
response_json = create_stream_response_json(
index=i,
text="",
finish_reason=output.finish_reason,
)
yield f"data: {response_json}\n\n"
yield "data: [DONE]\n\n"
# Streaming response
if request.stream:
background_tasks = BackgroundTasks()
# Abort the request if the client disconnects.
background_tasks.add_task(abort_request)
return StreamingResponse(completion_stream_generator(),
media_type="text/event-stream",
background=background_tasks)
# Non-streaming response
final_res: RequestOutput = None
async for res in result_generator:
if await raw_request.is_disconnected():
# Abort the request if the client disconnects.
await abort_request()
return create_error_response(HTTPStatus.BAD_REQUEST,
"Client disconnected")
final_res = res
assert final_res is not None
choices = []
for output in final_res.outputs:
choice_data = ChatCompletionResponseChoice(
index=output.index,
message=ChatMessage(role="assistant", content=output.text),
finish_reason=output.finish_reason,
)
choices.append(choice_data)
num_prompt_tokens = len(final_res.prompt_token_ids)
num_generated_tokens = sum(len(output.token_ids)
for output in final_res.outputs)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=num_generated_tokens,
total_tokens=num_prompt_tokens + num_generated_tokens,
)
response = ChatCompletionResponse(
id=request_id,
created=created_time,
model=model_name,
choices=choices,
usage=usage,
)
if request.stream:
# When user requests streaming but we don't stream, we still need to
# return a streaming response with a single event.
response_json = response.json(ensure_ascii=False)
async def fake_stream_generator() -> AsyncGenerator[str, None]:
yield f"data: {response_json}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(fake_stream_generator(),
media_type="text/event-stream")
return response
@app.post("/v1/completions") @app.post("/v1/completions")
async def create_completion(raw_request: Request): async def create_completion(raw_request: Request):
"""Completion API similar to OpenAI's API. """Completion API similar to OpenAI's API.
......
...@@ -53,16 +53,22 @@ class UsageInfo(BaseModel): ...@@ -53,16 +53,22 @@ class UsageInfo(BaseModel):
class ChatCompletionRequest(BaseModel): class ChatCompletionRequest(BaseModel):
model: str model: str
messages: List[Dict[str, str]] messages: Union[str, List[Dict[str, str]]]
temperature: Optional[float] = 0.7 temperature: Optional[float] = 0.7
top_p: Optional[float] = 1.0 top_p: Optional[float] = 1.0
n: Optional[int] = 1 n: Optional[int] = 1
max_tokens: Optional[int] = None max_tokens: Optional[int] = 16
stop: Optional[Union[str, List[str]]] = None stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
stream: Optional[bool] = False stream: Optional[bool] = False
presence_penalty: Optional[float] = 0.0 presence_penalty: Optional[float] = 0.0
frequency_penalty: Optional[float] = 0.0 frequency_penalty: Optional[float] = 0.0
logit_bias: Optional[Dict[str, float]] = None
user: Optional[str] = None user: Optional[str] = None
# Additional parameters supported by vLLM
best_of: Optional[int] = None
top_k: Optional[int] = -1
ignore_eos: Optional[bool] = False
use_beam_search: Optional[bool] = False
class CompletionRequest(BaseModel): class CompletionRequest(BaseModel):
...@@ -124,3 +130,42 @@ class CompletionStreamResponse(BaseModel): ...@@ -124,3 +130,42 @@ class CompletionStreamResponse(BaseModel):
created: int = Field(default_factory=lambda: int(time.time())) created: int = Field(default_factory=lambda: int(time.time()))
model: str model: str
choices: List[CompletionResponseStreamChoice] choices: List[CompletionResponseStreamChoice]
class ChatMessage(BaseModel):
role: str
content: str
class ChatCompletionResponseChoice(BaseModel):
index: int
message: ChatMessage
finish_reason: Optional[Literal["stop", "length"]] = None
class ChatCompletionResponse(BaseModel):
id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
object: str = "chat.completion"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[ChatCompletionResponseChoice]
usage: UsageInfo
class DeltaMessage(BaseModel):
role: Optional[str] = None
content: Optional[str] = None
class ChatCompletionResponseStreamChoice(BaseModel):
index: int
delta: DeltaMessage
finish_reason: Optional[Literal["stop", "length"]] = None
class ChatCompletionStreamResponse(BaseModel):
id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
object: str = "chat.completion.chunk"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[ChatCompletionResponseStreamChoice]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment