Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
3a7dd7e3
Unverified
Commit
3a7dd7e3
authored
Jan 24, 2024
by
Simon Mo
Committed by
GitHub
Jan 24, 2024
Browse files
Support Batch Completion in Server (#2529)
parent
223c1922
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
215 additions
and
105 deletions
+215
-105
tests/entrypoints/test_openai_server.py
tests/entrypoints/test_openai_server.py
+53
-2
vllm/entrypoints/openai/serving_completion.py
vllm/entrypoints/openai/serving_completion.py
+162
-103
No files found.
tests/entrypoints/test_openai_server.py
View file @
3a7dd7e3
import
time
import
os
import
subprocess
import
time
import
sys
import
pytest
...
...
@@ -17,8 +18,11 @@ pytestmark = pytest.mark.asyncio
class
ServerRunner
:
def
__init__
(
self
,
args
):
env
=
os
.
environ
.
copy
()
env
[
"PYTHONUNBUFFERED"
]
=
"1"
self
.
proc
=
subprocess
.
Popen
(
[
"python3"
,
"-m"
,
"vllm.entrypoints.openai.api_server"
]
+
args
,
env
=
env
,
stdout
=
sys
.
stdout
,
stderr
=
sys
.
stderr
,
)
...
...
@@ -58,7 +62,8 @@ def server():
"--dtype"
,
"bfloat16"
,
# use half precision for speed and memory savings in CI environment
"--max-model-len"
,
"8192"
"8192"
,
"--enforce-eager"
,
])
ray
.
get
(
server_runner
.
ready
.
remote
())
yield
server_runner
...
...
@@ -199,5 +204,51 @@ async def test_chat_streaming(server, client: openai.AsyncOpenAI):
assert
""
.
join
(
chunks
)
==
output
async
def
test_batch_completions
(
server
,
client
:
openai
.
AsyncOpenAI
):
# test simple list
batch
=
await
client
.
completions
.
create
(
model
=
MODEL_NAME
,
prompt
=
[
"Hello, my name is"
,
"Hello, my name is"
],
max_tokens
=
5
,
temperature
=
0.0
,
)
assert
len
(
batch
.
choices
)
==
2
assert
batch
.
choices
[
0
].
text
==
batch
.
choices
[
1
].
text
# test n = 2
batch
=
await
client
.
completions
.
create
(
model
=
MODEL_NAME
,
prompt
=
[
"Hello, my name is"
,
"Hello, my name is"
],
n
=
2
,
max_tokens
=
5
,
temperature
=
0.0
,
extra_body
=
dict
(
# NOTE: this has to be true for n > 1 in vLLM, but not necessary for official client.
use_beam_search
=
True
),
)
assert
len
(
batch
.
choices
)
==
4
assert
batch
.
choices
[
0
].
text
!=
batch
.
choices
[
1
].
text
,
"beam search should be different"
assert
batch
.
choices
[
0
].
text
==
batch
.
choices
[
2
].
text
,
"two copies of the same prompt should be the same"
assert
batch
.
choices
[
1
].
text
==
batch
.
choices
[
3
].
text
,
"two copies of the same prompt should be the same"
# test streaming
batch
=
await
client
.
completions
.
create
(
model
=
MODEL_NAME
,
prompt
=
[
"Hello, my name is"
,
"Hello, my name is"
],
max_tokens
=
5
,
temperature
=
0.0
,
stream
=
True
,
)
texts
=
[
""
]
*
2
async
for
chunk
in
batch
:
assert
len
(
chunk
.
choices
)
==
1
choice
=
chunk
.
choices
[
0
]
texts
[
choice
.
index
]
+=
choice
.
text
assert
texts
[
0
]
==
texts
[
1
]
if
__name__
==
"__main__"
:
pytest
.
main
([
__file__
])
vllm/entrypoints/openai/serving_completion.py
View file @
3a7dd7e3
import
asyncio
import
time
from
fastapi
import
Request
from
typing
import
AsyncGenerator
,
AsyncIterator
from
typing
import
AsyncGenerator
,
AsyncIterator
,
Callable
,
List
,
Optional
from
vllm.logger
import
init_logger
from
vllm.utils
import
random_uuid
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
...
...
@@ -18,48 +19,68 @@ from vllm.entrypoints.openai.serving_engine import OpenAIServing
logger
=
init_logger
(
__name__
)
TypeTokenIDs
=
list
[
int
]
TypeTopLogProbs
=
List
[
Optional
[
dict
[
int
,
float
]]]
TypeCreateLogProbsFn
=
Callable
[
[
TypeTokenIDs
,
TypeTopLogProbs
,
Optional
[
int
],
int
],
LogProbs
]
async
def
completion_stream_generator
(
request
:
CompletionRequest
,
result_generator
:
AsyncIterator
[
RequestOutput
],
echo_without_generation
,
create_logprobs_fn
,
request_id
,
created_time
,
model_name
)
->
AsyncGenerator
[
str
,
None
]:
previous_texts
=
[
""
]
*
request
.
n
previous_num_tokens
=
[
0
]
*
request
.
n
has_echoed
=
[
False
]
*
request
.
n
async
for
res
in
result_generator
:
# TODO: handle client disconnect for streaming
raw_request
:
Request
,
on_abort
,
result_generator
:
AsyncIterator
[
tuple
[
int
,
RequestOutput
]],
create_logprobs_fn
:
TypeCreateLogProbsFn
,
request_id
:
str
,
created_time
:
int
,
model_name
:
str
,
num_prompts
:
int
,
)
->
AsyncGenerator
[
str
,
None
]:
previous_texts
=
[
""
]
*
request
.
n
*
num_prompts
previous_num_tokens
=
[
0
]
*
request
.
n
*
num_prompts
has_echoed
=
[
False
]
*
request
.
n
*
num_prompts
async
for
prompt_idx
,
res
in
result_generator
:
# Abort the request if the client disconnects.
if
await
raw_request
.
is_disconnected
():
await
on_abort
(
f
"
{
request_id
}
-
{
prompt_idx
}
"
)
raise
StopAsyncIteration
()
for
output
in
res
.
outputs
:
i
=
output
.
index
delta_text
=
output
.
text
[
len
(
previous_texts
[
i
]):]
token_ids
=
output
.
token_ids
[
previous_num_tokens
[
i
]:]
if
request
.
logprobs
is
not
None
:
top_logprobs
=
output
.
logprobs
[
previous_num_tokens
[
i
]:]
else
:
top_logprobs
=
None
offsets
=
len
(
previous_texts
[
i
])
if
request
.
echo
and
not
has_echoed
[
i
]:
if
not
echo_without_generation
:
delta_text
=
res
.
prompt
+
delta_text
token_ids
=
res
.
prompt_token_ids
+
token_ids
if
top_logprobs
:
top_logprobs
=
res
.
prompt_logprobs
+
top_logprobs
else
:
# only just return the prompt
i
=
output
.
index
+
prompt_idx
*
request
.
n
# TODO(simon): optimize the performance by avoiding full text O(n^2) sending.
if
request
.
echo
and
request
.
max_tokens
==
0
:
# only return the prompt
delta_text
=
res
.
prompt
token_ids
=
res
.
prompt_token_ids
if
top_logprobs
:
delta_token_ids
=
res
.
prompt_token_ids
top_logprobs
=
res
.
prompt_logprobs
has_echoed
[
i
]
=
True
elif
request
.
echo
and
request
.
max_tokens
>
0
and
not
has_echoed
[
i
]:
# echo the prompt and first token
delta_text
=
res
.
prompt
+
output
.
text
delta_token_ids
=
res
.
prompt_token_ids
+
output
.
token_ids
top_logprobs
=
res
.
prompt_logprobs
+
(
output
.
logprobs
or
[])
has_echoed
[
i
]
=
True
else
:
# return just the delta
delta_text
=
output
.
text
[
len
(
previous_texts
[
i
]):]
delta_token_ids
=
output
.
token_ids
[
previous_num_tokens
[
i
]:]
top_logprobs
=
output
.
logprobs
[
previous_num_tokens
[
i
]:]
if
output
.
logprobs
else
None
if
request
.
logprobs
is
not
None
:
assert
top_logprobs
is
not
None
,
"top_logprobs must be provided when logprobs is requested"
logprobs
=
create_logprobs_fn
(
token_ids
=
token_ids
,
token_ids
=
delta_
token_ids
,
top_logprobs
=
top_logprobs
,
num_output_top_logprobs
=
request
.
logprobs
,
initial_text_offset
=
offsets
,
initial_text_offset
=
len
(
previous_texts
[
i
])
,
)
else
:
logprobs
=
None
previous_texts
[
i
]
=
output
.
text
previous_num_tokens
[
i
]
=
len
(
output
.
token_ids
)
finish_reason
=
output
.
finish_reason
...
...
@@ -77,7 +98,7 @@ async def completion_stream_generator(
]).
model_dump_json
(
exclude_unset
=
True
)
yield
f
"data:
{
response_json
}
\n\n
"
if
output
.
finish_reason
is
not
None
:
if
output
.
finish_reason
is
not
None
:
# return final usage
logprobs
=
LogProbs
()
if
request
.
logprobs
is
not
None
else
None
prompt_tokens
=
len
(
res
.
prompt_token_ids
)
completion_tokens
=
len
(
output
.
token_ids
)
...
...
@@ -129,27 +150,38 @@ def parse_prompt_format(prompt) -> tuple[bool, list]:
return
prompt_is_tokens
,
prompts
def
request_output_to_completion_response
(
final_res
:
RequestOutput
,
request
,
echo_without_generation
,
create_logprobs_fn
,
request_id
,
created_time
,
model_name
)
->
CompletionResponse
:
assert
final_res
is
not
None
def
request_output_to_completion_response
(
final_res_batch
:
list
[
RequestOutput
],
request
:
CompletionRequest
,
create_logprobs_fn
:
TypeCreateLogProbsFn
,
request_id
:
str
,
created_time
:
int
,
model_name
:
str
,
)
->
CompletionResponse
:
choices
=
[]
num_prompt_tokens
=
0
num_generated_tokens
=
0
for
final_res
in
final_res_batch
:
assert
final_res
is
not
None
prompt_token_ids
=
final_res
.
prompt_token_ids
prompt_logprobs
=
final_res
.
prompt_logprobs
prompt_text
=
final_res
.
prompt
for
output
in
final_res
.
outputs
:
if
request
.
logprobs
is
not
None
:
if
not
echo_without_generation
:
token_ids
=
output
.
token_ids
top_logprobs
=
output
.
logprobs
if
request
.
echo
:
token_ids
=
prompt_token_ids
+
token_ids
top_logprobs
=
prompt_logprobs
+
top_logprobs
else
:
if
request
.
echo
and
request
.
max_tokens
==
0
:
token_ids
=
prompt_token_ids
top_logprobs
=
prompt_logprobs
output_text
=
prompt_text
elif
request
.
echo
and
request
.
max_tokens
>
0
:
token_ids
=
prompt_token_ids
+
output
.
token_ids
top_logprobs
=
prompt_logprobs
+
output
.
logprobs
output_text
=
prompt_text
+
output
.
text
else
:
token_ids
=
output
.
token_ids
top_logprobs
=
output
.
logprobs
output_text
=
output
.
text
if
request
.
logprobs
is
not
None
:
logprobs
=
create_logprobs_fn
(
token_ids
=
token_ids
,
top_logprobs
=
top_logprobs
,
...
...
@@ -157,23 +189,19 @@ def request_output_to_completion_response(final_res: RequestOutput, request,
)
else
:
logprobs
=
None
if
not
echo_without_generation
:
output_text
=
output
.
text
if
request
.
echo
:
output_text
=
prompt_text
+
output_text
else
:
output_text
=
prompt_text
choice_data
=
CompletionResponseChoice
(
index
=
output
.
index
,
index
=
len
(
choices
)
,
text
=
output_text
,
logprobs
=
logprobs
,
finish_reason
=
output
.
finish_reason
,
)
choices
.
append
(
choice_data
)
num_prompt_tokens
=
len
(
final_res
.
prompt_token_ids
)
num_generated_tokens
=
sum
(
num_prompt_tokens
+
=
len
(
prompt_token_ids
)
num_generated_tokens
+
=
sum
(
len
(
output
.
token_ids
)
for
output
in
final_res
.
outputs
)
usage
=
UsageInfo
(
prompt_tokens
=
num_prompt_tokens
,
completion_tokens
=
num_generated_tokens
,
...
...
@@ -189,6 +217,36 @@ def request_output_to_completion_response(final_res: RequestOutput, request,
)
def
merge_async_iterators
(
*
iterators
):
"""Merge multiple asynchronous iterators into a single iterator.
This method handle the case where some iterators finish before others.
When it yields, it yields a tuple (i, item) where i is the index of the
iterator that yields the item.
"""
queue
=
asyncio
.
Queue
()
finished
=
[
False
]
*
len
(
iterators
)
async
def
producer
(
i
,
iterator
):
async
for
item
in
iterator
:
await
queue
.
put
((
i
,
item
))
finished
[
i
]
=
True
_tasks
=
[
asyncio
.
create_task
(
producer
(
i
,
iterator
))
for
i
,
iterator
in
enumerate
(
iterators
)
]
async
def
consumer
():
while
not
all
(
finished
)
or
not
queue
.
empty
():
item
=
await
queue
.
get
()
yield
item
await
asyncio
.
gather
(
*
_tasks
)
return
consumer
()
class
OpenAIServingCompletion
(
OpenAIServing
):
def
__init__
(
self
,
engine
:
AsyncLLMEngine
,
served_model
:
str
):
...
...
@@ -210,9 +268,6 @@ class OpenAIServingCompletion(OpenAIServing):
if
error_check_ret
is
not
None
:
return
error_check_ret
# OpenAI API supports echoing the prompt when max_tokens is 0.
echo_without_generation
=
request
.
echo
and
request
.
max_tokens
==
0
# Return error for unsupported features.
if
request
.
suffix
is
not
None
:
return
self
.
create_error_response
(
...
...
@@ -226,30 +281,30 @@ class OpenAIServingCompletion(OpenAIServing):
created_time
=
int
(
time
.
monotonic
())
# Schedule the request and get the result generator.
generators
=
[]
try
:
sampling_params
=
request
.
to_sampling_params
()
prompt_is_tokens
,
prompts
=
parse_prompt_format
(
request
.
prompt
)
if
len
(
prompts
)
>
1
:
raise
ValueError
(
"Batching in completion API is not supported."
)
prompt
=
prompts
[
0
]
for
i
,
prompt
in
enumerate
(
prompts
):
if
prompt_is_tokens
:
input_ids
=
self
.
_validate_prompt_and_tokenize
(
request
,
prompt_ids
=
prompt
)
else
:
input_ids
=
self
.
_validate_prompt_and_tokenize
(
request
,
prompt
=
prompt
)
input_ids
=
self
.
_validate_prompt_and_tokenize
(
request
,
prompt
=
prompt
)
result_generator
=
self
.
engine
.
generate
(
None
,
generators
.
append
(
self
.
engine
.
generate
(
None
,
sampling_params
,
request_id
,
prompt_token_ids
=
input_ids
)
f
"
{
request_id
}
-
{
i
}
"
,
prompt_token_ids
=
input_ids
)
)
except
ValueError
as
e
:
return
self
.
create_error_response
(
str
(
e
))
result_generator
:
AsyncIterator
[
tuple
[
int
,
RequestOutput
]]
=
merge_async_iterators
(
*
generators
)
# Similar to the OpenAI API, when n != best_of, we do not stream the
# results. In addition, we do not stream the results when use beam search.
stream
=
(
request
.
stream
...
...
@@ -258,23 +313,27 @@ class OpenAIServingCompletion(OpenAIServing):
# Streaming response
if
stream
:
return
completion_stream_generator
(
request
,
result_generator
,
echo_without_generation
,
return
completion_stream_generator
(
request
,
raw_request
,
self
.
engine
.
abort
,
result_generator
,
self
.
_create_logprobs
,
request_id
,
created_time
,
model_name
)
request_id
,
created_time
,
model_name
,
num_prompts
=
len
(
prompts
))
# Non-streaming response
final_res
:
RequestOutput
=
None
async
for
res
in
result_generator
:
final_res
_batch
:
RequestOutput
=
[
None
]
*
len
(
prompts
)
async
for
i
,
res
in
result_generator
:
if
await
raw_request
.
is_disconnected
():
# Abort the request if the client disconnects.
await
self
.
engine
.
abort
(
request_id
)
await
self
.
engine
.
abort
(
f
"
{
request_id
}
-
{
i
}
"
)
return
self
.
create_error_response
(
"Client disconnected"
)
final_res
=
res
final_res
_batch
[
i
]
=
res
response
=
request_output_to_completion_response
(
final_res
,
request
,
echo_without_generation
,
self
.
_create_logprobs
,
request_id
,
created_time
,
model_name
)
final_res
_batch
,
request
,
self
.
_create_logprobs
,
request_id
,
created_time
,
model_name
)
# When user requests streaming but we don't stream, we still need to
# return a streaming response with a single event.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment