Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
49b26e2c
Unverified
Commit
49b26e2c
authored
Jul 03, 2023
by
Ricardo Lu
Committed by
GitHub
Jul 02, 2023
Browse files
feat: add ChatCompletion endpoint in OpenAI demo server. (#330)
parent
dafd924c
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
284 additions
and
6 deletions
+284
-6
vllm/entrypoints/openai/api_server.py
vllm/entrypoints/openai/api_server.py
+236
-3
vllm/entrypoints/openai/protocol.py
vllm/entrypoints/openai/protocol.py
+48
-3
No files found.
vllm/entrypoints/openai/api_server.py
View file @
49b26e2c
...
@@ -4,7 +4,7 @@ import argparse
...
@@ -4,7 +4,7 @@ import argparse
from
http
import
HTTPStatus
from
http
import
HTTPStatus
import
json
import
json
import
time
import
time
from
typing
import
AsyncGenerator
,
Dict
,
List
,
Optional
from
typing
import
AsyncGenerator
,
Dict
,
List
,
Optional
,
Union
,
Any
import
fastapi
import
fastapi
from
fastapi
import
BackgroundTasks
,
Request
from
fastapi
import
BackgroundTasks
,
Request
...
@@ -17,8 +17,12 @@ from vllm.engine.arg_utils import AsyncEngineArgs
...
@@ -17,8 +17,12 @@ from vllm.engine.arg_utils import AsyncEngineArgs
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
from
vllm.entrypoints.openai.protocol
import
(
from
vllm.entrypoints.openai.protocol
import
(
CompletionRequest
,
CompletionResponse
,
CompletionResponseChoice
,
CompletionRequest
,
CompletionResponse
,
CompletionResponseChoice
,
CompletionResponseStreamChoice
,
CompletionStreamResponse
,
ErrorResponse
,
CompletionResponseStreamChoice
,
CompletionStreamResponse
,
LogProbs
,
ModelCard
,
ModelList
,
ModelPermission
,
UsageInfo
)
ChatCompletionRequest
,
ChatCompletionResponse
,
ChatCompletionResponseChoice
,
ChatCompletionResponseStreamChoice
,
ChatCompletionStreamResponse
,
ChatMessage
,
DeltaMessage
,
ErrorResponse
,
LogProbs
,
ModelCard
,
ModelList
,
ModelPermission
,
UsageInfo
)
from
fastchat.conversation
import
Conversation
,
SeparatorStyle
,
get_conv_template
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.outputs
import
RequestOutput
from
vllm.outputs
import
RequestOutput
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
...
@@ -55,6 +59,70 @@ async def check_model(request) -> Optional[JSONResponse]:
...
@@ -55,6 +59,70 @@ async def check_model(request) -> Optional[JSONResponse]:
return
ret
return
ret
async
def
get_gen_prompt
(
request
)
->
str
:
conv
=
get_conv_template
(
request
.
model
)
conv
=
Conversation
(
name
=
conv
.
name
,
system
=
conv
.
system
,
roles
=
conv
.
roles
,
messages
=
list
(
conv
.
messages
),
# prevent in-place modification
offset
=
conv
.
offset
,
sep_style
=
SeparatorStyle
(
conv
.
sep_style
),
sep
=
conv
.
sep
,
sep2
=
conv
.
sep2
,
stop_str
=
conv
.
stop_str
,
stop_token_ids
=
conv
.
stop_token_ids
,
)
if
isinstance
(
request
.
messages
,
str
):
prompt
=
request
.
messages
else
:
for
message
in
request
.
messages
:
msg_role
=
message
[
"role"
]
if
msg_role
==
"system"
:
conv
.
system
=
message
[
"content"
]
elif
msg_role
==
"user"
:
conv
.
append_message
(
conv
.
roles
[
0
],
message
[
"content"
])
elif
msg_role
==
"assistant"
:
conv
.
append_message
(
conv
.
roles
[
1
],
message
[
"content"
])
else
:
raise
ValueError
(
f
"Unknown role:
{
msg_role
}
"
)
# Add a blank message for the assistant.
conv
.
append_message
(
conv
.
roles
[
1
],
None
)
prompt
=
conv
.
get_prompt
()
return
prompt
async
def
check_length
(
request
,
prompt
,
engine
):
if
hasattr
(
engine
.
engine
.
model_config
.
hf_config
,
"max_sequence_length"
):
context_len
=
engine
.
engine
.
model_config
.
hf_config
.
max_sequence_length
elif
hasattr
(
engine
.
engine
.
model_config
.
hf_config
,
"seq_length"
):
context_len
=
engine
.
engine
.
model_config
.
hf_config
.
seq_length
elif
hasattr
(
engine
.
engine
.
model_config
.
hf_config
,
"max_position_embeddings"
):
context_len
=
engine
.
engine
.
model_config
.
hf_config
.
max_position_embeddings
elif
hasattr
(
engine
.
engine
.
model_config
.
hf_config
,
"seq_length"
):
context_len
=
engine
.
engine
.
model_config
.
hf_config
.
seq_length
else
:
context_len
=
2048
input_ids
=
tokenizer
(
prompt
).
input_ids
token_num
=
len
(
input_ids
)
if
token_num
+
request
.
max_tokens
>
context_len
:
return
create_error_response
(
HTTPStatus
.
BAD_REQUEST
,
f
"This model's maximum context length is
{
context_len
}
tokens. "
f
"However, you requested
{
request
.
max_tokens
+
token_num
}
tokens "
f
"(
{
token_num
}
in the messages, "
f
"
{
request
.
max_tokens
}
in the completion). "
f
"Please reduce the length of the messages or completion."
,
)
else
:
return
None
@
app
.
get
(
"/v1/models"
)
@
app
.
get
(
"/v1/models"
)
async
def
show_available_models
():
async
def
show_available_models
():
"""Show available models. Right now we only have one model."""
"""Show available models. Right now we only have one model."""
...
@@ -85,6 +153,171 @@ def create_logprobs(token_ids: List[int],
...
@@ -85,6 +153,171 @@ def create_logprobs(token_ids: List[int],
return
logprobs
return
logprobs
@
app
.
post
(
"/v1/chat/completions"
)
async
def
create_chat_completion
(
raw_request
:
Request
):
"""Completion API similar to OpenAI's API.
See https://platform.openai.com/docs/api-reference/chat/create
for the API specification. This API mimics the OpenAI ChatCompletion API.
NOTE: Currently we do not support the following features:
- function_call (Users should implement this by themselves)
- logit_bias (to be supported by vLLM engine)
"""
request
=
ChatCompletionRequest
(
**
await
raw_request
.
json
())
logger
.
info
(
f
"Received chat completion request:
{
request
}
"
)
error_check_ret
=
await
check_model
(
request
)
if
error_check_ret
is
not
None
:
return
error_check_ret
if
request
.
logit_bias
is
not
None
:
# TODO: support logit_bias in vLLM engine.
return
create_error_response
(
HTTPStatus
.
BAD_REQUEST
,
"logit_bias is not currently supported"
)
prompt
=
await
get_gen_prompt
(
request
)
error_check_ret
=
await
check_length
(
request
,
prompt
,
engine
)
if
error_check_ret
is
not
None
:
return
error_check_ret
model_name
=
request
.
model
request_id
=
f
"cmpl-
{
random_uuid
()
}
"
created_time
=
int
(
time
.
time
())
try
:
sampling_params
=
SamplingParams
(
n
=
request
.
n
,
presence_penalty
=
request
.
presence_penalty
,
frequency_penalty
=
request
.
frequency_penalty
,
temperature
=
request
.
temperature
,
top_p
=
request
.
top_p
,
stop
=
request
.
stop
,
max_tokens
=
request
.
max_tokens
,
best_of
=
request
.
best_of
,
top_k
=
request
.
top_k
,
ignore_eos
=
request
.
ignore_eos
,
use_beam_search
=
request
.
use_beam_search
,
)
except
ValueError
as
e
:
return
create_error_response
(
HTTPStatus
.
BAD_REQUEST
,
str
(
e
))
result_generator
=
engine
.
generate
(
prompt
,
sampling_params
,
request_id
)
async
def
abort_request
()
->
None
:
await
engine
.
abort
(
request_id
)
def
create_stream_response_json
(
index
:
int
,
text
:
str
,
finish_reason
:
Optional
[
str
]
=
None
)
->
str
:
choice_data
=
ChatCompletionResponseStreamChoice
(
index
=
index
,
delta
=
DeltaMessage
(
content
=
text
),
finish_reason
=
finish_reason
,
)
response
=
ChatCompletionStreamResponse
(
id
=
request_id
,
created
=
created_time
,
model
=
model_name
,
choices
=
[
choice_data
],
)
response_json
=
response
.
json
(
ensure_ascii
=
False
)
return
response_json
async
def
completion_stream_generator
()
->
AsyncGenerator
[
str
,
None
]:
# First chunk with role
for
i
in
range
(
request
.
n
):
choice_data
=
ChatCompletionResponseStreamChoice
(
index
=
i
,
delta
=
DeltaMessage
(
role
=
"assistant"
),
finish_reason
=
None
,
)
chunk
=
ChatCompletionStreamResponse
(
id
=
request_id
,
choices
=
[
choice_data
],
model
=
model_name
)
yield
f
"data:
{
chunk
.
json
(
exclude_unset
=
True
,
ensure_ascii
=
False
)
}
\n\n
"
previous_texts
=
[
""
]
*
request
.
n
previous_num_tokens
=
[
0
]
*
request
.
n
async
for
res
in
result_generator
:
res
:
RequestOutput
for
output
in
res
.
outputs
:
i
=
output
.
index
delta_text
=
output
.
text
[
len
(
previous_texts
[
i
]):]
previous_texts
[
i
]
=
output
.
text
previous_num_tokens
[
i
]
=
len
(
output
.
token_ids
)
response_json
=
create_stream_response_json
(
index
=
i
,
text
=
delta_text
,
)
yield
f
"data:
{
response_json
}
\n\n
"
if
output
.
finish_reason
is
not
None
:
response_json
=
create_stream_response_json
(
index
=
i
,
text
=
""
,
finish_reason
=
output
.
finish_reason
,
)
yield
f
"data:
{
response_json
}
\n\n
"
yield
"data: [DONE]
\n\n
"
# Streaming response
if
request
.
stream
:
background_tasks
=
BackgroundTasks
()
# Abort the request if the client disconnects.
background_tasks
.
add_task
(
abort_request
)
return
StreamingResponse
(
completion_stream_generator
(),
media_type
=
"text/event-stream"
,
background
=
background_tasks
)
# Non-streaming response
final_res
:
RequestOutput
=
None
async
for
res
in
result_generator
:
if
await
raw_request
.
is_disconnected
():
# Abort the request if the client disconnects.
await
abort_request
()
return
create_error_response
(
HTTPStatus
.
BAD_REQUEST
,
"Client disconnected"
)
final_res
=
res
assert
final_res
is
not
None
choices
=
[]
for
output
in
final_res
.
outputs
:
choice_data
=
ChatCompletionResponseChoice
(
index
=
output
.
index
,
message
=
ChatMessage
(
role
=
"assistant"
,
content
=
output
.
text
),
finish_reason
=
output
.
finish_reason
,
)
choices
.
append
(
choice_data
)
num_prompt_tokens
=
len
(
final_res
.
prompt_token_ids
)
num_generated_tokens
=
sum
(
len
(
output
.
token_ids
)
for
output
in
final_res
.
outputs
)
usage
=
UsageInfo
(
prompt_tokens
=
num_prompt_tokens
,
completion_tokens
=
num_generated_tokens
,
total_tokens
=
num_prompt_tokens
+
num_generated_tokens
,
)
response
=
ChatCompletionResponse
(
id
=
request_id
,
created
=
created_time
,
model
=
model_name
,
choices
=
choices
,
usage
=
usage
,
)
if
request
.
stream
:
# When user requests streaming but we don't stream, we still need to
# return a streaming response with a single event.
response_json
=
response
.
json
(
ensure_ascii
=
False
)
async
def
fake_stream_generator
()
->
AsyncGenerator
[
str
,
None
]:
yield
f
"data:
{
response_json
}
\n\n
"
yield
"data: [DONE]
\n\n
"
return
StreamingResponse
(
fake_stream_generator
(),
media_type
=
"text/event-stream"
)
return
response
@
app
.
post
(
"/v1/completions"
)
@
app
.
post
(
"/v1/completions"
)
async
def
create_completion
(
raw_request
:
Request
):
async
def
create_completion
(
raw_request
:
Request
):
"""Completion API similar to OpenAI's API.
"""Completion API similar to OpenAI's API.
...
...
vllm/entrypoints/openai/protocol.py
View file @
49b26e2c
...
@@ -53,16 +53,22 @@ class UsageInfo(BaseModel):
...
@@ -53,16 +53,22 @@ class UsageInfo(BaseModel):
class
ChatCompletionRequest
(
BaseModel
):
class
ChatCompletionRequest
(
BaseModel
):
model
:
str
model
:
str
messages
:
List
[
Dict
[
str
,
str
]]
messages
:
Union
[
str
,
List
[
Dict
[
str
,
str
]]
]
temperature
:
Optional
[
float
]
=
0.7
temperature
:
Optional
[
float
]
=
0.7
top_p
:
Optional
[
float
]
=
1.0
top_p
:
Optional
[
float
]
=
1.0
n
:
Optional
[
int
]
=
1
n
:
Optional
[
int
]
=
1
max_tokens
:
Optional
[
int
]
=
None
max_tokens
:
Optional
[
int
]
=
16
stop
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
None
stop
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
Field
(
default_factory
=
list
)
stream
:
Optional
[
bool
]
=
False
stream
:
Optional
[
bool
]
=
False
presence_penalty
:
Optional
[
float
]
=
0.0
presence_penalty
:
Optional
[
float
]
=
0.0
frequency_penalty
:
Optional
[
float
]
=
0.0
frequency_penalty
:
Optional
[
float
]
=
0.0
logit_bias
:
Optional
[
Dict
[
str
,
float
]]
=
None
user
:
Optional
[
str
]
=
None
user
:
Optional
[
str
]
=
None
# Additional parameters supported by vLLM
best_of
:
Optional
[
int
]
=
None
top_k
:
Optional
[
int
]
=
-
1
ignore_eos
:
Optional
[
bool
]
=
False
use_beam_search
:
Optional
[
bool
]
=
False
class
CompletionRequest
(
BaseModel
):
class
CompletionRequest
(
BaseModel
):
...
@@ -124,3 +130,42 @@ class CompletionStreamResponse(BaseModel):
...
@@ -124,3 +130,42 @@ class CompletionStreamResponse(BaseModel):
created
:
int
=
Field
(
default_factory
=
lambda
:
int
(
time
.
time
()))
created
:
int
=
Field
(
default_factory
=
lambda
:
int
(
time
.
time
()))
model
:
str
model
:
str
choices
:
List
[
CompletionResponseStreamChoice
]
choices
:
List
[
CompletionResponseStreamChoice
]
class
ChatMessage
(
BaseModel
):
role
:
str
content
:
str
class
ChatCompletionResponseChoice
(
BaseModel
):
index
:
int
message
:
ChatMessage
finish_reason
:
Optional
[
Literal
[
"stop"
,
"length"
]]
=
None
class
ChatCompletionResponse
(
BaseModel
):
id
:
str
=
Field
(
default_factory
=
lambda
:
f
"chatcmpl-
{
random_uuid
()
}
"
)
object
:
str
=
"chat.completion"
created
:
int
=
Field
(
default_factory
=
lambda
:
int
(
time
.
time
()))
model
:
str
choices
:
List
[
ChatCompletionResponseChoice
]
usage
:
UsageInfo
class
DeltaMessage
(
BaseModel
):
role
:
Optional
[
str
]
=
None
content
:
Optional
[
str
]
=
None
class
ChatCompletionResponseStreamChoice
(
BaseModel
):
index
:
int
delta
:
DeltaMessage
finish_reason
:
Optional
[
Literal
[
"stop"
,
"length"
]]
=
None
class
ChatCompletionStreamResponse
(
BaseModel
):
id
:
str
=
Field
(
default_factory
=
lambda
:
f
"chatcmpl-
{
random_uuid
()
}
"
)
object
:
str
=
"chat.completion.chunk"
created
:
int
=
Field
(
default_factory
=
lambda
:
int
(
time
.
time
()))
model
:
str
choices
:
List
[
ChatCompletionResponseStreamChoice
]
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment