Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
2d1b9baa
Unverified
Commit
2d1b9baa
authored
Dec 17, 2024
by
Joe Runde
Committed by
GitHub
Dec 17, 2024
Browse files
[Bugfix] Fix request cancellation without polling (#11190)
parent
f9ecbb18
Changes
12
Show whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
164 additions
and
103 deletions
+164
-103
tests/entrypoints/openai/test_basic.py
tests/entrypoints/openai/test_basic.py
+51
-0
tests/test_utils.py
tests/test_utils.py
+1
-5
tests/utils.py
tests/utils.py
+5
-6
vllm/engine/async_llm_engine.py
vllm/engine/async_llm_engine.py
+27
-19
vllm/entrypoints/api_server.py
vllm/entrypoints/api_server.py
+7
-4
vllm/entrypoints/openai/api_server.py
vllm/entrypoints/openai/api_server.py
+8
-0
vllm/entrypoints/openai/serving_chat.py
vllm/entrypoints/openai/serving_chat.py
+0
-5
vllm/entrypoints/openai/serving_completion.py
vllm/entrypoints/openai/serving_completion.py
+1
-2
vllm/entrypoints/openai/serving_embedding.py
vllm/entrypoints/openai/serving_embedding.py
+1
-4
vllm/entrypoints/openai/serving_score.py
vllm/entrypoints/openai/serving_score.py
+1
-4
vllm/entrypoints/utils.py
vllm/entrypoints/utils.py
+57
-0
vllm/utils.py
vllm/utils.py
+5
-54
No files found.
tests/entrypoints/openai/test_basic.py
View file @
2d1b9baa
import
asyncio
from
http
import
HTTPStatus
from
typing
import
List
import
openai
import
pytest
import
pytest_asyncio
import
requests
...
...
@@ -103,3 +105,52 @@ async def test_check_health(server: RemoteOpenAIServer):
response
=
requests
.
get
(
server
.
url_for
(
"health"
))
assert
response
.
status_code
==
HTTPStatus
.
OK
@
pytest
.
mark
.
parametrize
(
"server_args"
,
[
pytest
.
param
([
"--max-model-len"
,
"10100"
],
id
=
"default-frontend-multiprocessing"
),
pytest
.
param
(
[
"--disable-frontend-multiprocessing"
,
"--max-model-len"
,
"10100"
],
id
=
"disable-frontend-multiprocessing"
)
],
indirect
=
True
,
)
@
pytest
.
mark
.
asyncio
async
def
test_request_cancellation
(
server
:
RemoteOpenAIServer
):
# clunky test: send an ungodly amount of load in with short timeouts
# then ensure that it still responds quickly afterwards
chat_input
=
[{
"role"
:
"user"
,
"content"
:
"Write a long story"
}]
client
=
server
.
get_async_client
(
timeout
=
0.5
)
tasks
=
[]
# Request about 2 million tokens
for
_
in
range
(
200
):
task
=
asyncio
.
create_task
(
client
.
chat
.
completions
.
create
(
messages
=
chat_input
,
model
=
MODEL_NAME
,
max_tokens
=
10000
,
extra_body
=
{
"min_tokens"
:
10000
}))
tasks
.
append
(
task
)
done
,
pending
=
await
asyncio
.
wait
(
tasks
,
return_when
=
asyncio
.
ALL_COMPLETED
)
# Make sure all requests were sent to the server and timed out
# (We don't want to hide other errors like 400s that would invalidate this
# test)
assert
len
(
pending
)
==
0
for
d
in
done
:
with
pytest
.
raises
(
openai
.
APITimeoutError
):
d
.
result
()
# If the server had not cancelled all the other requests, then it would not
# be able to respond to this one within the timeout
client
=
server
.
get_async_client
(
timeout
=
5
)
response
=
await
client
.
chat
.
completions
.
create
(
messages
=
chat_input
,
model
=
MODEL_NAME
,
max_tokens
=
10
)
assert
len
(
response
.
choices
)
==
1
tests/test_utils.py
View file @
2d1b9baa
import
asyncio
import
os
import
socket
from
functools
import
partial
from
typing
import
AsyncIterator
,
Tuple
import
pytest
...
...
@@ -26,10 +25,7 @@ async def test_merge_async_iterators():
print
(
f
"iterator
{
idx
}
cancelled"
)
iterators
=
[
mock_async_iterator
(
i
)
for
i
in
range
(
3
)]
merged_iterator
=
merge_async_iterators
(
*
iterators
,
is_cancelled
=
partial
(
asyncio
.
sleep
,
0
,
result
=
False
))
merged_iterator
=
merge_async_iterators
(
*
iterators
)
async
def
stream_output
(
generator
:
AsyncIterator
[
Tuple
[
int
,
str
]]):
async
for
idx
,
output
in
generator
:
...
...
tests/utils.py
View file @
2d1b9baa
...
...
@@ -163,12 +163,11 @@ class RemoteOpenAIServer:
api_key
=
self
.
DUMMY_API_KEY
,
)
def
get_async_client
(
self
):
return
openai
.
AsyncOpenAI
(
base_url
=
self
.
url_for
(
"v1"
),
def
get_async_client
(
self
,
**
kwargs
):
return
openai
.
AsyncOpenAI
(
base_url
=
self
.
url_for
(
"v1"
),
api_key
=
self
.
DUMMY_API_KEY
,
max_retries
=
0
,
)
**
kwargs
)
def
_test_completion
(
...
...
vllm/engine/async_llm_engine.py
View file @
2d1b9baa
...
...
@@ -1065,6 +1065,7 @@ class AsyncLLMEngine(EngineClient):
>>> # Process and return the final output
>>> ...
"""
try
:
async
for
output
in
await
self
.
add_request
(
request_id
,
prompt
,
...
...
@@ -1075,6 +1076,9 @@ class AsyncLLMEngine(EngineClient):
priority
=
priority
,
):
yield
LLMEngine
.
validate_output
(
output
,
RequestOutput
)
except
asyncio
.
CancelledError
:
await
self
.
abort
(
request_id
)
raise
async
def
encode
(
self
,
...
...
@@ -1147,6 +1151,7 @@ class AsyncLLMEngine(EngineClient):
>>> # Process and return the final output
>>> ...
"""
try
:
async
for
output
in
await
self
.
add_request
(
request_id
,
prompt
,
...
...
@@ -1156,6 +1161,9 @@ class AsyncLLMEngine(EngineClient):
priority
=
priority
,
):
yield
LLMEngine
.
validate_output
(
output
,
PoolingRequestOutput
)
except
asyncio
.
CancelledError
:
await
self
.
abort
(
request_id
)
raise
async
def
abort
(
self
,
request_id
:
str
)
->
None
:
"""Abort a request.
...
...
vllm/entrypoints/api_server.py
View file @
2d1b9baa
...
...
@@ -17,11 +17,11 @@ from fastapi.responses import JSONResponse, Response, StreamingResponse
from
vllm.engine.arg_utils
import
AsyncEngineArgs
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
from
vllm.entrypoints.launcher
import
serve_http
from
vllm.entrypoints.utils
import
with_cancellation
from
vllm.logger
import
init_logger
from
vllm.sampling_params
import
SamplingParams
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.utils
import
(
FlexibleArgumentParser
,
iterate_with_cancellation
,
random_uuid
)
from
vllm.utils
import
FlexibleArgumentParser
,
random_uuid
from
vllm.version
import
__version__
as
VLLM_VERSION
logger
=
init_logger
(
"vllm.entrypoints.api_server"
)
...
...
@@ -47,6 +47,11 @@ async def generate(request: Request) -> Response:
- other fields: the sampling parameters (See `SamplingParams` for details).
"""
request_dict
=
await
request
.
json
()
return
await
_generate
(
request_dict
,
raw_request
=
request
)
@
with_cancellation
async
def
_generate
(
request_dict
:
dict
,
raw_request
:
Request
)
->
Response
:
prompt
=
request_dict
.
pop
(
"prompt"
)
stream
=
request_dict
.
pop
(
"stream"
,
False
)
sampling_params
=
SamplingParams
(
**
request_dict
)
...
...
@@ -54,8 +59,6 @@ async def generate(request: Request) -> Response:
assert
engine
is
not
None
results_generator
=
engine
.
generate
(
prompt
,
sampling_params
,
request_id
)
results_generator
=
iterate_with_cancellation
(
results_generator
,
is_cancelled
=
request
.
is_disconnected
)
# Streaming case
async
def
stream_results
()
->
AsyncGenerator
[
bytes
,
None
]:
...
...
vllm/entrypoints/openai/api_server.py
View file @
2d1b9baa
...
...
@@ -59,6 +59,7 @@ from vllm.entrypoints.openai.serving_score import OpenAIServingScores
from
vllm.entrypoints.openai.serving_tokenization
import
(
OpenAIServingTokenization
)
from
vllm.entrypoints.openai.tool_parsers
import
ToolParserManager
from
vllm.entrypoints.utils
import
with_cancellation
from
vllm.logger
import
init_logger
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.utils
import
(
FlexibleArgumentParser
,
get_open_zmq_ipc_path
,
...
...
@@ -311,6 +312,7 @@ async def health(raw_request: Request) -> Response:
@
router
.
post
(
"/tokenize"
)
@
with_cancellation
async
def
tokenize
(
request
:
TokenizeRequest
,
raw_request
:
Request
):
handler
=
tokenization
(
raw_request
)
...
...
@@ -325,6 +327,7 @@ async def tokenize(request: TokenizeRequest, raw_request: Request):
@
router
.
post
(
"/detokenize"
)
@
with_cancellation
async
def
detokenize
(
request
:
DetokenizeRequest
,
raw_request
:
Request
):
handler
=
tokenization
(
raw_request
)
...
...
@@ -353,6 +356,7 @@ async def show_version():
@
router
.
post
(
"/v1/chat/completions"
)
@
with_cancellation
async
def
create_chat_completion
(
request
:
ChatCompletionRequest
,
raw_request
:
Request
):
handler
=
chat
(
raw_request
)
...
...
@@ -373,6 +377,7 @@ async def create_chat_completion(request: ChatCompletionRequest,
@
router
.
post
(
"/v1/completions"
)
@
with_cancellation
async
def
create_completion
(
request
:
CompletionRequest
,
raw_request
:
Request
):
handler
=
completion
(
raw_request
)
if
handler
is
None
:
...
...
@@ -390,6 +395,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
@
router
.
post
(
"/v1/embeddings"
)
@
with_cancellation
async
def
create_embedding
(
request
:
EmbeddingRequest
,
raw_request
:
Request
):
handler
=
embedding
(
raw_request
)
if
handler
is
None
:
...
...
@@ -407,6 +413,7 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request):
@
router
.
post
(
"/score"
)
@
with_cancellation
async
def
create_score
(
request
:
ScoreRequest
,
raw_request
:
Request
):
handler
=
score
(
raw_request
)
if
handler
is
None
:
...
...
@@ -424,6 +431,7 @@ async def create_score(request: ScoreRequest, raw_request: Request):
@
router
.
post
(
"/v1/score"
)
@
with_cancellation
async
def
create_score_v1
(
request
:
ScoreRequest
,
raw_request
:
Request
):
logger
.
warning
(
"To indicate that Score API is not part of standard OpenAI API, we "
...
...
vllm/entrypoints/openai/serving_chat.py
View file @
2d1b9baa
...
...
@@ -32,7 +32,6 @@ from vllm.sampling_params import BeamSearchParams, SamplingParams
from
vllm.sequence
import
Logprob
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
,
MistralTokenizer
from
vllm.transformers_utils.tokenizers
import
maybe_serialize_tool_calls
from
vllm.utils
import
iterate_with_cancellation
logger
=
init_logger
(
__name__
)
...
...
@@ -234,10 +233,6 @@ class OpenAIServingChat(OpenAIServing):
assert
len
(
generators
)
==
1
result_generator
,
=
generators
if
raw_request
:
result_generator
=
iterate_with_cancellation
(
result_generator
,
raw_request
.
is_disconnected
)
# Streaming response
if
request
.
stream
:
return
self
.
chat_completion_stream_generator
(
...
...
vllm/entrypoints/openai/serving_completion.py
View file @
2d1b9baa
...
...
@@ -159,8 +159,7 @@ class OpenAIServingCompletion(OpenAIServing):
# TODO: Use a vllm-specific Validation Error
return
self
.
create_error_response
(
str
(
e
))
result_generator
=
merge_async_iterators
(
*
generators
,
is_cancelled
=
raw_request
.
is_disconnected
)
result_generator
=
merge_async_iterators
(
*
generators
)
model_name
=
self
.
_get_model_name
(
lora_request
)
num_prompts
=
len
(
engine_prompts
)
...
...
vllm/entrypoints/openai/serving_embedding.py
View file @
2d1b9baa
...
...
@@ -202,10 +202,7 @@ class OpenAIServingEmbedding(OpenAIServing):
# TODO: Use a vllm-specific Validation Error
return
self
.
create_error_response
(
str
(
e
))
result_generator
=
merge_async_iterators
(
*
generators
,
is_cancelled
=
raw_request
.
is_disconnected
if
raw_request
else
None
,
)
result_generator
=
merge_async_iterators
(
*
generators
)
num_prompts
=
len
(
engine_prompts
)
...
...
vllm/entrypoints/openai/serving_score.py
View file @
2d1b9baa
...
...
@@ -186,10 +186,7 @@ class OpenAIServingScores(OpenAIServing):
# TODO: Use a vllm-specific Validation Error
return
self
.
create_error_response
(
str
(
e
))
result_generator
=
merge_async_iterators
(
*
generators
,
is_cancelled
=
raw_request
.
is_disconnected
if
raw_request
else
None
,
)
result_generator
=
merge_async_iterators
(
*
generators
)
num_prompts
=
len
(
engine_prompts
)
...
...
vllm/entrypoints/utils.py
0 → 100644
View file @
2d1b9baa
import
asyncio
import
functools
from
fastapi
import
Request
async
def
listen_for_disconnect
(
request
:
Request
)
->
None
:
"""Returns if a disconnect message is received"""
while
True
:
message
=
await
request
.
receive
()
if
message
[
"type"
]
==
"http.disconnect"
:
break
def
with_cancellation
(
handler_func
):
"""Decorator that allows a route handler to be cancelled by client
disconnections.
This does _not_ use request.is_disconnected, which does not work with
middleware. Instead this follows the pattern from
starlette.StreamingResponse, which simultaneously awaits on two tasks- one
to wait for an http disconnect message, and the other to do the work that we
want done. When the first task finishes, the other is cancelled.
A core assumption of this method is that the body of the request has already
been read. This is a safe assumption to make for fastapi handlers that have
already parsed the body of the request into a pydantic model for us.
This decorator is unsafe to use elsewhere, as it will consume and throw away
all incoming messages for the request while it looks for a disconnect
message.
In the case where a `StreamingResponse` is returned by the handler, this
wrapper will stop listening for disconnects and instead the response object
will start listening for disconnects.
"""
# Functools.wraps is required for this wrapper to appear to fastapi as a
# normal route handler, with the correct request type hinting.
@
functools
.
wraps
(
handler_func
)
async
def
wrapper
(
*
args
,
**
kwargs
):
# The request is either the second positional arg or `raw_request`
request
=
args
[
1
]
if
len
(
args
)
>
1
else
kwargs
[
"raw_request"
]
handler_task
=
asyncio
.
create_task
(
handler_func
(
*
args
,
**
kwargs
))
cancellation_task
=
asyncio
.
create_task
(
listen_for_disconnect
(
request
))
done
,
pending
=
await
asyncio
.
wait
([
handler_task
,
cancellation_task
],
return_when
=
asyncio
.
FIRST_COMPLETED
)
for
task
in
pending
:
task
.
cancel
()
if
handler_task
in
done
:
return
handler_task
.
result
()
return
None
return
wrapper
vllm/utils.py
View file @
2d1b9baa
...
...
@@ -20,7 +20,7 @@ import time
import
uuid
import
warnings
import
weakref
from
asyncio
import
FIRST_COMPLETED
,
AbstractEventLoop
,
Future
,
Task
from
asyncio
import
FIRST_COMPLETED
,
AbstractEventLoop
,
Task
from
collections
import
UserDict
,
defaultdict
from
collections.abc
import
Iterable
,
Mapping
from
dataclasses
import
dataclass
,
field
...
...
@@ -370,72 +370,23 @@ def _next_task(iterator: AsyncGenerator[T, None],
return
loop
.
create_task
(
iterator
.
__anext__
())
# type: ignore[arg-type]
async
def
iterate_with_cancellation
(
iterator
:
AsyncGenerator
[
T
,
None
],
is_cancelled
:
Callable
[[],
Awaitable
[
bool
]],
)
->
AsyncGenerator
[
T
,
None
]:
"""Convert async iterator into one that polls the provided function
at least once per second to check for client cancellation.
"""
loop
=
asyncio
.
get_running_loop
()
awaits
:
List
[
Future
[
T
]]
=
[
_next_task
(
iterator
,
loop
)]
next_cancel_check
:
float
=
0
while
True
:
done
,
pending
=
await
asyncio
.
wait
(
awaits
,
timeout
=
1.5
)
# Check for cancellation at most once per second
time_now
=
time
.
time
()
if
time_now
>=
next_cancel_check
:
if
await
is_cancelled
():
with
contextlib
.
suppress
(
BaseException
):
awaits
[
0
].
cancel
()
await
iterator
.
aclose
()
raise
asyncio
.
CancelledError
(
"client cancelled"
)
next_cancel_check
=
time_now
+
1
if
done
:
try
:
item
=
await
awaits
[
0
]
awaits
[
0
]
=
_next_task
(
iterator
,
loop
)
yield
item
except
StopAsyncIteration
:
# we are done
return
async
def
merge_async_iterators
(
*
iterators
:
AsyncGenerator
[
T
,
None
],
is_cancelled
:
Optional
[
Callable
[[],
Awaitable
[
bool
]]]
=
None
,
)
->
AsyncGenerator
[
Tuple
[
int
,
T
],
None
]:
*
iterators
:
AsyncGenerator
[
T
,
None
],
)
->
AsyncGenerator
[
Tuple
[
int
,
T
],
None
]:
"""Merge multiple asynchronous iterators into a single iterator.
This method handle the case where some iterators finish before others.
When it yields, it yields a tuple (i, item) where i is the index of the
iterator that yields the item.
It also optionally polls a provided function at least once per second
to check for client cancellation.
"""
loop
=
asyncio
.
get_running_loop
()
awaits
=
{
_next_task
(
pair
[
1
],
loop
):
pair
for
pair
in
enumerate
(
iterators
)}
timeout
=
None
if
is_cancelled
is
None
else
1.5
next_cancel_check
:
float
=
0
try
:
while
awaits
:
done
,
pending
=
await
asyncio
.
wait
(
awaits
.
keys
(),
return_when
=
FIRST_COMPLETED
,
timeout
=
timeout
)
if
is_cancelled
is
not
None
:
# Check for cancellation at most once per second
time_now
=
time
.
time
()
if
time_now
>=
next_cancel_check
:
if
await
is_cancelled
():
raise
asyncio
.
CancelledError
(
"client cancelled"
)
next_cancel_check
=
time_now
+
1
done
,
_
=
await
asyncio
.
wait
(
awaits
.
keys
(),
return_when
=
FIRST_COMPLETED
)
for
d
in
done
:
pair
=
awaits
.
pop
(
d
)
try
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment