Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
SIYIXNI
vllm
Commits
3a7dd7e3
"...composable_kernel.git" did not exist on "b2df701822843b357b14a476cd0800098d47a888"
Unverified
Commit
3a7dd7e3
authored
Jan 24, 2024
by
Simon Mo
Committed by
GitHub
Jan 24, 2024
Browse files
Support Batch Completion in Server (#2529)
parent
223c1922
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
215 additions
and
105 deletions
+215
-105
tests/entrypoints/test_openai_server.py
tests/entrypoints/test_openai_server.py
+53
-2
vllm/entrypoints/openai/serving_completion.py
vllm/entrypoints/openai/serving_completion.py
+162
-103
No files found.
tests/entrypoints/test_openai_server.py
View file @
3a7dd7e3
import
time
import
os
import
subprocess
import
subprocess
import
time
import
sys
import
sys
import
pytest
import
pytest
...
@@ -17,8 +18,11 @@ pytestmark = pytest.mark.asyncio
...
@@ -17,8 +18,11 @@ pytestmark = pytest.mark.asyncio
class
ServerRunner
:
class
ServerRunner
:
def
__init__
(
self
,
args
):
def
__init__
(
self
,
args
):
env
=
os
.
environ
.
copy
()
env
[
"PYTHONUNBUFFERED"
]
=
"1"
self
.
proc
=
subprocess
.
Popen
(
self
.
proc
=
subprocess
.
Popen
(
[
"python3"
,
"-m"
,
"vllm.entrypoints.openai.api_server"
]
+
args
,
[
"python3"
,
"-m"
,
"vllm.entrypoints.openai.api_server"
]
+
args
,
env
=
env
,
stdout
=
sys
.
stdout
,
stdout
=
sys
.
stdout
,
stderr
=
sys
.
stderr
,
stderr
=
sys
.
stderr
,
)
)
...
@@ -58,7 +62,8 @@ def server():
...
@@ -58,7 +62,8 @@ def server():
"--dtype"
,
"--dtype"
,
"bfloat16"
,
# use half precision for speed and memory savings in CI environment
"bfloat16"
,
# use half precision for speed and memory savings in CI environment
"--max-model-len"
,
"--max-model-len"
,
"8192"
"8192"
,
"--enforce-eager"
,
])
])
ray
.
get
(
server_runner
.
ready
.
remote
())
ray
.
get
(
server_runner
.
ready
.
remote
())
yield
server_runner
yield
server_runner
...
@@ -199,5 +204,51 @@ async def test_chat_streaming(server, client: openai.AsyncOpenAI):
...
@@ -199,5 +204,51 @@ async def test_chat_streaming(server, client: openai.AsyncOpenAI):
assert
""
.
join
(
chunks
)
==
output
assert
""
.
join
(
chunks
)
==
output
async
def
test_batch_completions
(
server
,
client
:
openai
.
AsyncOpenAI
):
# test simple list
batch
=
await
client
.
completions
.
create
(
model
=
MODEL_NAME
,
prompt
=
[
"Hello, my name is"
,
"Hello, my name is"
],
max_tokens
=
5
,
temperature
=
0.0
,
)
assert
len
(
batch
.
choices
)
==
2
assert
batch
.
choices
[
0
].
text
==
batch
.
choices
[
1
].
text
# test n = 2
batch
=
await
client
.
completions
.
create
(
model
=
MODEL_NAME
,
prompt
=
[
"Hello, my name is"
,
"Hello, my name is"
],
n
=
2
,
max_tokens
=
5
,
temperature
=
0.0
,
extra_body
=
dict
(
# NOTE: this has to be true for n > 1 in vLLM, but not necessary for official client.
use_beam_search
=
True
),
)
assert
len
(
batch
.
choices
)
==
4
assert
batch
.
choices
[
0
].
text
!=
batch
.
choices
[
1
].
text
,
"beam search should be different"
assert
batch
.
choices
[
0
].
text
==
batch
.
choices
[
2
].
text
,
"two copies of the same prompt should be the same"
assert
batch
.
choices
[
1
].
text
==
batch
.
choices
[
3
].
text
,
"two copies of the same prompt should be the same"
# test streaming
batch
=
await
client
.
completions
.
create
(
model
=
MODEL_NAME
,
prompt
=
[
"Hello, my name is"
,
"Hello, my name is"
],
max_tokens
=
5
,
temperature
=
0.0
,
stream
=
True
,
)
texts
=
[
""
]
*
2
async
for
chunk
in
batch
:
assert
len
(
chunk
.
choices
)
==
1
choice
=
chunk
.
choices
[
0
]
texts
[
choice
.
index
]
+=
choice
.
text
assert
texts
[
0
]
==
texts
[
1
]
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
pytest
.
main
([
__file__
])
pytest
.
main
([
__file__
])
vllm/entrypoints/openai/serving_completion.py
View file @
3a7dd7e3
import
asyncio
import
time
import
time
from
fastapi
import
Request
from
fastapi
import
Request
from
typing
import
AsyncGenerator
,
AsyncIterator
from
typing
import
AsyncGenerator
,
AsyncIterator
,
Callable
,
List
,
Optional
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.utils
import
random_uuid
from
vllm.utils
import
random_uuid
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
...
@@ -18,48 +19,68 @@ from vllm.entrypoints.openai.serving_engine import OpenAIServing
...
@@ -18,48 +19,68 @@ from vllm.entrypoints.openai.serving_engine import OpenAIServing
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
TypeTokenIDs
=
list
[
int
]
TypeTopLogProbs
=
List
[
Optional
[
dict
[
int
,
float
]]]
TypeCreateLogProbsFn
=
Callable
[
[
TypeTokenIDs
,
TypeTopLogProbs
,
Optional
[
int
],
int
],
LogProbs
]
async
def
completion_stream_generator
(
async
def
completion_stream_generator
(
request
:
CompletionRequest
,
request
:
CompletionRequest
,
result_generator
:
AsyncIterator
[
RequestOutput
],
raw_request
:
Request
,
echo_without_generation
,
create_logprobs_fn
,
request_id
,
created_time
,
on_abort
,
model_name
)
->
AsyncGenerator
[
str
,
None
]:
result_generator
:
AsyncIterator
[
tuple
[
int
,
RequestOutput
]],
previous_texts
=
[
""
]
*
request
.
n
create_logprobs_fn
:
TypeCreateLogProbsFn
,
previous_num_tokens
=
[
0
]
*
request
.
n
request_id
:
str
,
has_echoed
=
[
False
]
*
request
.
n
created_time
:
int
,
model_name
:
str
,
async
for
res
in
result_generator
:
num_prompts
:
int
,
# TODO: handle client disconnect for streaming
)
->
AsyncGenerator
[
str
,
None
]:
previous_texts
=
[
""
]
*
request
.
n
*
num_prompts
previous_num_tokens
=
[
0
]
*
request
.
n
*
num_prompts
has_echoed
=
[
False
]
*
request
.
n
*
num_prompts
async
for
prompt_idx
,
res
in
result_generator
:
# Abort the request if the client disconnects.
if
await
raw_request
.
is_disconnected
():
await
on_abort
(
f
"
{
request_id
}
-
{
prompt_idx
}
"
)
raise
StopAsyncIteration
()
for
output
in
res
.
outputs
:
for
output
in
res
.
outputs
:
i
=
output
.
index
i
=
output
.
index
+
prompt_idx
*
request
.
n
delta_text
=
output
.
text
[
len
(
previous_texts
[
i
]):]
# TODO(simon): optimize the performance by avoiding full text O(n^2) sending.
token_ids
=
output
.
token_ids
[
previous_num_tokens
[
i
]:]
if
request
.
logprobs
is
not
None
:
if
request
.
echo
and
request
.
max_tokens
==
0
:
top_logprobs
=
output
.
logprobs
[
previous_num_tokens
[
i
]:]
# only return the prompt
else
:
delta_text
=
res
.
prompt
top_logprobs
=
None
delta_token_ids
=
res
.
prompt_token_ids
offsets
=
len
(
previous_texts
[
i
])
top_logprobs
=
res
.
prompt_logprobs
if
request
.
echo
and
not
has_echoed
[
i
]:
has_echoed
[
i
]
=
True
if
not
echo_without_generation
:
elif
request
.
echo
and
request
.
max_tokens
>
0
and
not
has_echoed
[
i
]:
delta_text
=
res
.
prompt
+
delta_text
# echo the prompt and first token
token_ids
=
res
.
prompt_token_ids
+
token_ids
delta_text
=
res
.
prompt
+
output
.
text
if
top_logprobs
:
delta_token_ids
=
res
.
prompt_token_ids
+
output
.
token_ids
top_logprobs
=
res
.
prompt_logprobs
+
top_logprobs
top_logprobs
=
res
.
prompt_logprobs
+
(
output
.
logprobs
or
[])
else
:
# only just return the prompt
delta_text
=
res
.
prompt
token_ids
=
res
.
prompt_token_ids
if
top_logprobs
:
top_logprobs
=
res
.
prompt_logprobs
has_echoed
[
i
]
=
True
has_echoed
[
i
]
=
True
else
:
# return just the delta
delta_text
=
output
.
text
[
len
(
previous_texts
[
i
]):]
delta_token_ids
=
output
.
token_ids
[
previous_num_tokens
[
i
]:]
top_logprobs
=
output
.
logprobs
[
previous_num_tokens
[
i
]:]
if
output
.
logprobs
else
None
if
request
.
logprobs
is
not
None
:
if
request
.
logprobs
is
not
None
:
assert
top_logprobs
is
not
None
,
"top_logprobs must be provided when logprobs is requested"
logprobs
=
create_logprobs_fn
(
logprobs
=
create_logprobs_fn
(
token_ids
=
token_ids
,
token_ids
=
delta_
token_ids
,
top_logprobs
=
top_logprobs
,
top_logprobs
=
top_logprobs
,
num_output_top_logprobs
=
request
.
logprobs
,
num_output_top_logprobs
=
request
.
logprobs
,
initial_text_offset
=
offsets
,
initial_text_offset
=
len
(
previous_texts
[
i
])
,
)
)
else
:
else
:
logprobs
=
None
logprobs
=
None
previous_texts
[
i
]
=
output
.
text
previous_texts
[
i
]
=
output
.
text
previous_num_tokens
[
i
]
=
len
(
output
.
token_ids
)
previous_num_tokens
[
i
]
=
len
(
output
.
token_ids
)
finish_reason
=
output
.
finish_reason
finish_reason
=
output
.
finish_reason
...
@@ -77,7 +98,7 @@ async def completion_stream_generator(
...
@@ -77,7 +98,7 @@ async def completion_stream_generator(
]).
model_dump_json
(
exclude_unset
=
True
)
]).
model_dump_json
(
exclude_unset
=
True
)
yield
f
"data:
{
response_json
}
\n\n
"
yield
f
"data:
{
response_json
}
\n\n
"
if
output
.
finish_reason
is
not
None
:
if
output
.
finish_reason
is
not
None
:
# return final usage
logprobs
=
LogProbs
()
if
request
.
logprobs
is
not
None
else
None
logprobs
=
LogProbs
()
if
request
.
logprobs
is
not
None
else
None
prompt_tokens
=
len
(
res
.
prompt_token_ids
)
prompt_tokens
=
len
(
res
.
prompt_token_ids
)
completion_tokens
=
len
(
output
.
token_ids
)
completion_tokens
=
len
(
output
.
token_ids
)
...
@@ -129,51 +150,58 @@ def parse_prompt_format(prompt) -> tuple[bool, list]:
...
@@ -129,51 +150,58 @@ def parse_prompt_format(prompt) -> tuple[bool, list]:
return
prompt_is_tokens
,
prompts
return
prompt_is_tokens
,
prompts
def
request_output_to_completion_response
(
final_res
:
RequestOutput
,
request
,
def
request_output_to_completion_response
(
echo_without_generation
,
final_res_batch
:
list
[
RequestOutput
],
create_logprobs_fn
,
request_id
,
request
:
CompletionRequest
,
created_time
,
create_logprobs_fn
:
TypeCreateLogProbsFn
,
model_name
)
->
CompletionResponse
:
request_id
:
str
,
assert
final_res
is
not
None
created_time
:
int
,
model_name
:
str
,
)
->
CompletionResponse
:
choices
=
[]
choices
=
[]
prompt_token_ids
=
final_res
.
prompt_token_ids
num_prompt_tokens
=
0
prompt_logprobs
=
final_res
.
prompt_logprobs
num_generated_tokens
=
0
prompt_text
=
final_res
.
prompt
for
final_res
in
final_res_batch
:
for
output
in
final_res
.
outputs
:
assert
final_res
is
not
None
if
request
.
logprobs
is
not
None
:
prompt_token_ids
=
final_res
.
prompt_token_ids
if
not
echo_without_generation
:
prompt_logprobs
=
final_res
.
prompt_logprobs
prompt_text
=
final_res
.
prompt
for
output
in
final_res
.
outputs
:
if
request
.
echo
and
request
.
max_tokens
==
0
:
token_ids
=
prompt_token_ids
top_logprobs
=
prompt_logprobs
output_text
=
prompt_text
elif
request
.
echo
and
request
.
max_tokens
>
0
:
token_ids
=
prompt_token_ids
+
output
.
token_ids
top_logprobs
=
prompt_logprobs
+
output
.
logprobs
output_text
=
prompt_text
+
output
.
text
else
:
token_ids
=
output
.
token_ids
token_ids
=
output
.
token_ids
top_logprobs
=
output
.
logprobs
top_logprobs
=
output
.
logprobs
if
request
.
echo
:
output_text
=
output
.
text
token_ids
=
prompt_token_ids
+
token_ids
top_logprobs
=
prompt_logprobs
+
top_logprobs
if
request
.
logprobs
is
not
None
:
logprobs
=
create_logprobs_fn
(
token_ids
=
token_ids
,
top_logprobs
=
top_logprobs
,
num_output_top_logprobs
=
request
.
logprobs
,
)
else
:
else
:
token_ids
=
prompt_token_ids
logprobs
=
None
top_logprobs
=
prompt_logprobs
logprobs
=
create_logprobs_fn
(
choice_data
=
CompletionResponseChoice
(
token_ids
=
token_ids
,
index
=
len
(
choices
),
top_logprobs
=
top_logprobs
,
text
=
output_text
,
num_output_top_logprobs
=
request
.
logprobs
,
logprobs
=
logprobs
,
finish_reason
=
output
.
finish_reason
,
)
)
else
:
choices
.
append
(
choice_data
)
logprobs
=
None
if
not
echo_without_generation
:
num_prompt_tokens
+=
len
(
prompt_token_ids
)
output_text
=
output
.
text
num_generated_tokens
+=
sum
(
if
request
.
echo
:
len
(
output
.
token_ids
)
for
output
in
final_res
.
outputs
)
output_text
=
prompt_text
+
output_text
else
:
output_text
=
prompt_text
choice_data
=
CompletionResponseChoice
(
index
=
output
.
index
,
text
=
output_text
,
logprobs
=
logprobs
,
finish_reason
=
output
.
finish_reason
,
)
choices
.
append
(
choice_data
)
num_prompt_tokens
=
len
(
final_res
.
prompt_token_ids
)
num_generated_tokens
=
sum
(
len
(
output
.
token_ids
)
for
output
in
final_res
.
outputs
)
usage
=
UsageInfo
(
usage
=
UsageInfo
(
prompt_tokens
=
num_prompt_tokens
,
prompt_tokens
=
num_prompt_tokens
,
completion_tokens
=
num_generated_tokens
,
completion_tokens
=
num_generated_tokens
,
...
@@ -189,6 +217,36 @@ def request_output_to_completion_response(final_res: RequestOutput, request,
...
@@ -189,6 +217,36 @@ def request_output_to_completion_response(final_res: RequestOutput, request,
)
)
def
merge_async_iterators
(
*
iterators
):
"""Merge multiple asynchronous iterators into a single iterator.
This method handle the case where some iterators finish before others.
When it yields, it yields a tuple (i, item) where i is the index of the
iterator that yields the item.
"""
queue
=
asyncio
.
Queue
()
finished
=
[
False
]
*
len
(
iterators
)
async
def
producer
(
i
,
iterator
):
async
for
item
in
iterator
:
await
queue
.
put
((
i
,
item
))
finished
[
i
]
=
True
_tasks
=
[
asyncio
.
create_task
(
producer
(
i
,
iterator
))
for
i
,
iterator
in
enumerate
(
iterators
)
]
async
def
consumer
():
while
not
all
(
finished
)
or
not
queue
.
empty
():
item
=
await
queue
.
get
()
yield
item
await
asyncio
.
gather
(
*
_tasks
)
return
consumer
()
class
OpenAIServingCompletion
(
OpenAIServing
):
class
OpenAIServingCompletion
(
OpenAIServing
):
def
__init__
(
self
,
engine
:
AsyncLLMEngine
,
served_model
:
str
):
def
__init__
(
self
,
engine
:
AsyncLLMEngine
,
served_model
:
str
):
...
@@ -210,9 +268,6 @@ class OpenAIServingCompletion(OpenAIServing):
...
@@ -210,9 +268,6 @@ class OpenAIServingCompletion(OpenAIServing):
if
error_check_ret
is
not
None
:
if
error_check_ret
is
not
None
:
return
error_check_ret
return
error_check_ret
# OpenAI API supports echoing the prompt when max_tokens is 0.
echo_without_generation
=
request
.
echo
and
request
.
max_tokens
==
0
# Return error for unsupported features.
# Return error for unsupported features.
if
request
.
suffix
is
not
None
:
if
request
.
suffix
is
not
None
:
return
self
.
create_error_response
(
return
self
.
create_error_response
(
...
@@ -226,30 +281,30 @@ class OpenAIServingCompletion(OpenAIServing):
...
@@ -226,30 +281,30 @@ class OpenAIServingCompletion(OpenAIServing):
created_time
=
int
(
time
.
monotonic
())
created_time
=
int
(
time
.
monotonic
())
# Schedule the request and get the result generator.
# Schedule the request and get the result generator.
generators
=
[]
try
:
try
:
sampling_params
=
request
.
to_sampling_params
()
sampling_params
=
request
.
to_sampling_params
()
prompt_is_tokens
,
prompts
=
parse_prompt_format
(
request
.
prompt
)
prompt_is_tokens
,
prompts
=
parse_prompt_format
(
request
.
prompt
)
if
len
(
prompts
)
>
1
:
for
i
,
prompt
in
enumerate
(
prompts
):
raise
ValueError
(
if
prompt_is_tokens
:
"Batching in completion API is not supported."
)
input_ids
=
self
.
_validate_prompt_and_tokenize
(
prompt
=
prompts
[
0
]
request
,
prompt_ids
=
prompt
)
else
:
if
prompt_is_tokens
:
input_ids
=
self
.
_validate_prompt_and_tokenize
(
input_ids
=
self
.
_validate_prompt_and_tokenize
(
request
,
prompt
=
prompt
)
request
,
prompt_ids
=
prompt
)
else
:
generators
.
append
(
input_ids
=
self
.
_validate_prompt_and_tokenize
(
request
,
self
.
engine
.
generate
(
None
,
prompt
=
prompt
)
sampling_params
,
f
"
{
request_id
}
-
{
i
}
"
,
result_generator
=
self
.
engine
.
generate
(
None
,
prompt_token_ids
=
input_ids
))
sampling_params
,
request_id
,
prompt_token_ids
=
input_ids
)
except
ValueError
as
e
:
except
ValueError
as
e
:
return
self
.
create_error_response
(
str
(
e
))
return
self
.
create_error_response
(
str
(
e
))
result_generator
:
AsyncIterator
[
tuple
[
int
,
RequestOutput
]]
=
merge_async_iterators
(
*
generators
)
# Similar to the OpenAI API, when n != best_of, we do not stream the
# Similar to the OpenAI API, when n != best_of, we do not stream the
# results. In addition, we do not stream the results when use beam search.
# results. In addition, we do not stream the results when use beam search.
stream
=
(
request
.
stream
stream
=
(
request
.
stream
...
@@ -258,23 +313,27 @@ class OpenAIServingCompletion(OpenAIServing):
...
@@ -258,23 +313,27 @@ class OpenAIServingCompletion(OpenAIServing):
# Streaming response
# Streaming response
if
stream
:
if
stream
:
return
completion_stream_generator
(
request
,
result_generator
,
return
completion_stream_generator
(
request
,
echo_without_generation
,
raw_request
,
self
.
engine
.
abort
,
result_generator
,
self
.
_create_logprobs
,
self
.
_create_logprobs
,
request_id
,
created_time
,
request_id
,
model_name
)
created_time
,
model_name
,
num_prompts
=
len
(
prompts
))
# Non-streaming response
# Non-streaming response
final_res
:
RequestOutput
=
None
final_res
_batch
:
RequestOutput
=
[
None
]
*
len
(
prompts
)
async
for
res
in
result_generator
:
async
for
i
,
res
in
result_generator
:
if
await
raw_request
.
is_disconnected
():
if
await
raw_request
.
is_disconnected
():
# Abort the request if the client disconnects.
# Abort the request if the client disconnects.
await
self
.
engine
.
abort
(
request_id
)
await
self
.
engine
.
abort
(
f
"
{
request_id
}
-
{
i
}
"
)
return
self
.
create_error_response
(
"Client disconnected"
)
return
self
.
create_error_response
(
"Client disconnected"
)
final_res
=
res
final_res
_batch
[
i
]
=
res
response
=
request_output_to_completion_response
(
response
=
request_output_to_completion_response
(
final_res
,
request
,
echo_without_generation
,
self
.
_create_logprobs
,
final_res
_batch
,
request
,
self
.
_create_logprobs
,
request_id
,
request_id
,
created_time
,
model_name
)
created_time
,
model_name
)
# When user requests streaming but we don't stream, we still need to
# When user requests streaming but we don't stream, we still need to
# return a streaming response with a single event.
# return a streaming response with a single event.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment