Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
9a3f49ae
Unverified
Commit
9a3f49ae
authored
Aug 06, 2024
by
Nick Hill
Committed by
GitHub
Aug 07, 2024
Browse files
[BugFix] Overhaul async request cancellation (#7111)
parent
f9a56006
Changes
11
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
226 additions
and
226 deletions
+226
-226
tests/async_engine/api_server_async_engine.py
tests/async_engine/api_server_async_engine.py
+5
-4
tests/async_engine/test_request_tracker.py
tests/async_engine/test_request_tracker.py
+12
-13
tests/test_utils.py
tests/test_utils.py
+3
-2
vllm/engine/async_llm_engine.py
vllm/engine/async_llm_engine.py
+67
-79
vllm/engine/protocol.py
vllm/engine/protocol.py
+5
-5
vllm/entrypoints/api_server.py
vllm/entrypoints/api_server.py
+9
-7
vllm/entrypoints/openai/rpc/client.py
vllm/entrypoints/openai/rpc/client.py
+32
-30
vllm/entrypoints/openai/serving_chat.py
vllm/entrypoints/openai/serving_chat.py
+18
-16
vllm/entrypoints/openai/serving_completion.py
vllm/entrypoints/openai/serving_completion.py
+6
-14
vllm/entrypoints/openai/serving_embedding.py
vllm/entrypoints/openai/serving_embedding.py
+7
-7
vllm/utils.py
vllm/utils.py
+62
-49
No files found.
tests/async_engine/api_server_async_engine.py
View file @
9a3f49ae
"""vllm.entrypoints.api_server with some extra logging for testing."""
"""vllm.entrypoints.api_server with some extra logging for testing."""
from
typing
import
Any
,
Dict
from
typing
import
Any
,
Dict
,
Iterable
import
uvicorn
import
uvicorn
from
fastapi.responses
import
JSONResponse
,
Response
from
fastapi.responses
import
JSONResponse
,
Response
...
@@ -18,9 +18,10 @@ class AsyncLLMEngineWithStats(AsyncLLMEngine):
...
@@ -18,9 +18,10 @@ class AsyncLLMEngineWithStats(AsyncLLMEngine):
super
().
__init__
(
*
args
,
**
kwargs
)
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
_num_aborts
=
0
self
.
_num_aborts
=
0
async
def
abort
(
self
,
request_id
:
str
)
->
None
:
async
def
_engine_abort
(
self
,
request_ids
:
Iterable
[
str
]):
await
super
().
abort
(
request_id
)
ids
=
list
(
request_ids
)
self
.
_num_aborts
+=
1
self
.
_num_aborts
+=
len
(
ids
)
await
super
().
_engine_abort
(
ids
)
def
testing_stats
(
self
)
->
Dict
[
str
,
Any
]:
def
testing_stats
(
self
)
->
Dict
[
str
,
Any
]:
return
{
"num_aborted_requests"
:
self
.
_num_aborts
}
return
{
"num_aborted_requests"
:
self
.
_num_aborts
}
...
...
tests/async_engine/test_request_tracker.py
View file @
9a3f49ae
...
@@ -10,23 +10,23 @@ async def test_request_tracker():
...
@@ -10,23 +10,23 @@ async def test_request_tracker():
stream_1
=
tracker
.
add_request
(
"1"
)
stream_1
=
tracker
.
add_request
(
"1"
)
assert
tracker
.
new_requests_event
.
is_set
()
assert
tracker
.
new_requests_event
.
is_set
()
await
tracker
.
wait_for_new_requests
()
await
tracker
.
wait_for_new_requests
()
new
,
finish
ed
=
tracker
.
get_new_and_
finish
ed_requests
()
new
,
abort
ed
=
tracker
.
get_new_and_
abort
ed_requests
()
assert
not
tracker
.
new_requests_event
.
is_set
()
assert
not
tracker
.
new_requests_event
.
is_set
()
assert
len
(
new
)
==
1
assert
len
(
new
)
==
1
assert
new
[
0
][
"request_id"
]
==
"1"
assert
new
[
0
][
"request_id"
]
==
"1"
assert
not
finish
ed
assert
not
abort
ed
assert
not
stream_1
.
finished
assert
not
stream_1
.
finished
stream_2
=
tracker
.
add_request
(
"2"
)
stream_2
=
tracker
.
add_request
(
"2"
)
stream_3
=
tracker
.
add_request
(
"3"
)
stream_3
=
tracker
.
add_request
(
"3"
)
assert
tracker
.
new_requests_event
.
is_set
()
assert
tracker
.
new_requests_event
.
is_set
()
await
tracker
.
wait_for_new_requests
()
await
tracker
.
wait_for_new_requests
()
new
,
finish
ed
=
tracker
.
get_new_and_
finish
ed_requests
()
new
,
abort
ed
=
tracker
.
get_new_and_
abort
ed_requests
()
assert
not
tracker
.
new_requests_event
.
is_set
()
assert
not
tracker
.
new_requests_event
.
is_set
()
assert
len
(
new
)
==
2
assert
len
(
new
)
==
2
assert
new
[
0
][
"request_id"
]
==
"2"
assert
new
[
0
][
"request_id"
]
==
"2"
assert
new
[
1
][
"request_id"
]
==
"3"
assert
new
[
1
][
"request_id"
]
==
"3"
assert
not
finish
ed
assert
not
abort
ed
assert
not
stream_2
.
finished
assert
not
stream_2
.
finished
assert
not
stream_3
.
finished
assert
not
stream_3
.
finished
...
@@ -36,9 +36,9 @@ async def test_request_tracker():
...
@@ -36,9 +36,9 @@ async def test_request_tracker():
assert
not
tracker
.
new_requests_event
.
is_set
()
assert
not
tracker
.
new_requests_event
.
is_set
()
tracker
.
abort_request
(
"1"
)
tracker
.
abort_request
(
"1"
)
new
,
finish
ed
=
tracker
.
get_new_and_
finish
ed_requests
()
new
,
abort
ed
=
tracker
.
get_new_and_
abort
ed_requests
()
assert
len
(
finish
ed
)
==
1
assert
len
(
abort
ed
)
==
1
assert
"1"
in
finish
ed
assert
"1"
in
abort
ed
assert
not
new
assert
not
new
assert
stream_1
.
finished
assert
stream_1
.
finished
...
@@ -46,9 +46,9 @@ async def test_request_tracker():
...
@@ -46,9 +46,9 @@ async def test_request_tracker():
tracker
.
abort_request
(
"4"
)
tracker
.
abort_request
(
"4"
)
assert
tracker
.
new_requests_event
.
is_set
()
assert
tracker
.
new_requests_event
.
is_set
()
await
tracker
.
wait_for_new_requests
()
await
tracker
.
wait_for_new_requests
()
new
,
finish
ed
=
tracker
.
get_new_and_
finish
ed_requests
()
new
,
abort
ed
=
tracker
.
get_new_and_
abort
ed_requests
()
assert
len
(
finish
ed
)
==
1
assert
len
(
abort
ed
)
==
1
assert
"4"
in
finish
ed
assert
"4"
in
abort
ed
assert
not
new
assert
not
new
assert
stream_4
.
finished
assert
stream_4
.
finished
...
@@ -57,10 +57,9 @@ async def test_request_tracker():
...
@@ -57,10 +57,9 @@ async def test_request_tracker():
tracker
.
process_request_output
(
tracker
.
process_request_output
(
RequestOutput
(
"2"
,
"output"
,
[],
[],
[],
finished
=
True
))
RequestOutput
(
"2"
,
"output"
,
[],
[],
[],
finished
=
True
))
await
tracker
.
wait_for_new_requests
()
await
tracker
.
wait_for_new_requests
()
new
,
finish
ed
=
tracker
.
get_new_and_
finish
ed_requests
()
new
,
abort
ed
=
tracker
.
get_new_and_
abort
ed_requests
()
assert
not
tracker
.
new_requests_event
.
is_set
()
assert
not
tracker
.
new_requests_event
.
is_set
()
assert
len
(
finished
)
==
1
assert
not
aborted
assert
"2"
in
finished
assert
len
(
new
)
==
1
assert
len
(
new
)
==
1
assert
new
[
0
][
"request_id"
]
==
"5"
assert
new
[
0
][
"request_id"
]
==
"5"
assert
stream_2
.
finished
assert
stream_2
.
finished
...
...
tests/test_utils.py
View file @
9a3f49ae
...
@@ -2,6 +2,7 @@ import asyncio
...
@@ -2,6 +2,7 @@ import asyncio
import
os
import
os
import
socket
import
socket
import
sys
import
sys
from
functools
import
partial
from
typing
import
(
TYPE_CHECKING
,
Any
,
AsyncIterator
,
Awaitable
,
Protocol
,
from
typing
import
(
TYPE_CHECKING
,
Any
,
AsyncIterator
,
Awaitable
,
Protocol
,
Tuple
,
TypeVar
)
Tuple
,
TypeVar
)
...
@@ -37,11 +38,11 @@ async def test_merge_async_iterators():
...
@@ -37,11 +38,11 @@ async def test_merge_async_iterators():
yield
f
"item from iterator
{
idx
}
"
yield
f
"item from iterator
{
idx
}
"
await
asyncio
.
sleep
(
0.1
)
await
asyncio
.
sleep
(
0.1
)
except
asyncio
.
CancelledError
:
except
asyncio
.
CancelledError
:
p
ass
p
rint
(
f
"iterator
{
idx
}
cancelled"
)
iterators
=
[
mock_async_iterator
(
i
)
for
i
in
range
(
3
)]
iterators
=
[
mock_async_iterator
(
i
)
for
i
in
range
(
3
)]
merged_iterator
:
AsyncIterator
[
Tuple
[
int
,
str
]]
=
merge_async_iterators
(
merged_iterator
:
AsyncIterator
[
Tuple
[
int
,
str
]]
=
merge_async_iterators
(
*
iterators
)
*
iterators
,
is_cancelled
=
partial
(
asyncio
.
sleep
,
0
,
result
=
False
)
)
async
def
stream_output
(
generator
:
AsyncIterator
[
Tuple
[
int
,
str
]]):
async
def
stream_output
(
generator
:
AsyncIterator
[
Tuple
[
int
,
str
]]):
async
for
idx
,
output
in
generator
:
async
for
idx
,
output
in
generator
:
...
...
vllm/engine/async_llm_engine.py
View file @
9a3f49ae
import
asyncio
import
asyncio
import
time
import
time
from
functools
import
partial
from
functools
import
partial
from
typing
import
(
Async
It
erator
,
Callable
,
Dict
,
Iterable
,
List
,
Mapping
,
from
typing
import
(
Async
Gen
erator
,
Callable
,
Dict
,
Iterable
,
List
,
Mapping
,
Optional
,
Set
,
Tuple
,
Type
,
Union
)
Optional
,
Set
,
Tuple
,
Type
,
Union
)
from
transformers
import
PreTrainedTokenizer
from
transformers
import
PreTrainedTokenizer
...
@@ -62,12 +62,16 @@ def _log_task_completion(task: asyncio.Task,
...
@@ -62,12 +62,16 @@ def _log_task_completion(task: asyncio.Task,
"actual cause."
)
from
e
"actual cause."
)
from
e
STOP_ITERATION
=
Exception
()
# Sentinel
class
AsyncStream
:
class
AsyncStream
:
"""A stream of RequestOutputs or EmbeddingRequestOutputs for a request
"""A stream of RequestOutputs or EmbeddingRequestOutputs for a request
that can be iterated over asynchronously."""
that can be iterated over asynchronously
via an async generator
."""
def
__init__
(
self
,
request_id
:
str
)
->
None
:
def
__init__
(
self
,
request_id
:
str
,
cancel
:
Callable
[[
str
],
None
]
)
->
None
:
self
.
request_id
=
request_id
self
.
request_id
=
request_id
self
.
_cancel
=
cancel
self
.
_queue
:
asyncio
.
Queue
=
asyncio
.
Queue
()
self
.
_queue
:
asyncio
.
Queue
=
asyncio
.
Queue
()
self
.
_finished
=
False
self
.
_finished
=
False
...
@@ -77,22 +81,30 @@ class AsyncStream:
...
@@ -77,22 +81,30 @@ class AsyncStream:
return
return
self
.
_queue
.
put_nowait
(
item
)
self
.
_queue
.
put_nowait
(
item
)
def
finish
(
self
)
->
None
:
def
finish
(
self
,
cancelled
:
bool
=
False
)
->
None
:
self
.
_queue
.
put_nowait
(
StopAsyncIteration
())
if
not
self
.
_finished
:
self
.
_finished
=
True
self
.
_finished
=
True
self
.
_queue
.
put_nowait
(
asyncio
.
CancelledError
if
cancelled
else
STOP_ITERATION
)
@
property
@
property
def
finished
(
self
)
->
bool
:
def
finished
(
self
)
->
bool
:
return
self
.
_finished
return
self
.
_finished
def
__aiter__
(
self
):
async
def
generator
(
return
self
self
)
->
AsyncGenerator
[
Union
[
RequestOutput
,
EmbeddingRequestOutput
],
None
]:
async
def
__anext__
(
self
)
->
Union
[
RequestOutput
,
EmbeddingRequestOutput
]:
try
:
while
not
self
.
_finished
:
result
=
await
self
.
_queue
.
get
()
result
=
await
self
.
_queue
.
get
()
if
isinstance
(
result
,
Exception
):
if
isinstance
(
result
,
Exception
):
if
result
==
STOP_ITERATION
:
return
raise
result
raise
result
return
result
yield
result
except
GeneratorExit
:
self
.
_cancel
(
self
.
request_id
)
raise
asyncio
.
CancelledError
from
None
class
RequestTracker
:
class
RequestTracker
:
...
@@ -100,7 +112,7 @@ class RequestTracker:
...
@@ -100,7 +112,7 @@ class RequestTracker:
def
__init__
(
self
)
->
None
:
def
__init__
(
self
)
->
None
:
self
.
_request_streams
:
Dict
[
str
,
AsyncStream
]
=
{}
self
.
_request_streams
:
Dict
[
str
,
AsyncStream
]
=
{}
self
.
_
finish
ed_requests
:
asyncio
.
Queue
[
str
]
=
asyncio
.
Queue
()
self
.
_
abort
ed_requests
:
asyncio
.
Queue
[
str
]
=
asyncio
.
Queue
()
self
.
_new_requests
:
asyncio
.
Queue
[
Tuple
[
AsyncStream
,
self
.
_new_requests
:
asyncio
.
Queue
[
Tuple
[
AsyncStream
,
dict
]]
=
asyncio
.
Queue
()
dict
]]
=
asyncio
.
Queue
()
self
.
new_requests_event
=
asyncio
.
Event
()
self
.
new_requests_event
=
asyncio
.
Event
()
...
@@ -131,15 +143,21 @@ class RequestTracker:
...
@@ -131,15 +143,21 @@ class RequestTracker:
verbose
:
bool
=
False
)
->
None
:
verbose
:
bool
=
False
)
->
None
:
"""Process a request output from the engine."""
"""Process a request output from the engine."""
request_id
=
request_output
.
request_id
request_id
=
request_output
.
request_id
finished
=
request_output
.
finished
if
finished
:
stream
=
self
.
_request_streams
.
pop
(
request_id
,
None
)
else
:
stream
=
self
.
_request_streams
.
get
(
request_id
)
# Guard against a KeyError which can occur if the request was aborted
# Guard against a KeyError which can occur if the request was aborted
# while the output was generated
# while the output was generated
if
(
stream
:
=
self
.
_request_streams
.
get
(
request_id
))
is
not
None
:
if
stream
is
not
None
:
stream
.
put
(
request_output
)
stream
.
put
(
request_output
)
if
request_output
.
finished
:
if
finished
:
if
verbose
:
stream
.
finish
()
if
verbose
and
finished
:
logger
.
info
(
"Finished request %s."
,
request_id
)
logger
.
info
(
"Finished request %s."
,
request_id
)
self
.
abort_request
(
request_id
)
def
process_exception
(
self
,
def
process_exception
(
self
,
request_id
:
str
,
request_id
:
str
,
...
@@ -162,7 +180,8 @@ class RequestTracker:
...
@@ -162,7 +180,8 @@ class RequestTracker:
if
request_id
in
self
.
_request_streams
:
if
request_id
in
self
.
_request_streams
:
raise
KeyError
(
f
"Request
{
request_id
}
already exists."
)
raise
KeyError
(
f
"Request
{
request_id
}
already exists."
)
stream
=
AsyncStream
(
request_id
)
abort_request
=
partial
(
self
.
abort_request
,
verbose
=
verbose
)
stream
=
AsyncStream
(
request_id
,
abort_request
)
self
.
_new_requests
.
put_nowait
((
stream
,
{
self
.
_new_requests
.
put_nowait
((
stream
,
{
"request_id"
:
request_id
,
"request_id"
:
request_id
,
**
engine_add_request_kwargs
**
engine_add_request_kwargs
...
@@ -175,36 +194,36 @@ class RequestTracker:
...
@@ -175,36 +194,36 @@ class RequestTracker:
return
stream
return
stream
def
abort_request
(
self
,
request_id
:
str
,
*
,
verbose
:
bool
=
False
)
->
None
:
def
abort_request
(
self
,
request_id
:
str
,
*
,
cancelled
:
bool
=
False
,
verbose
:
bool
=
False
)
->
None
:
"""Abort a request during next background loop iteration."""
"""Abort a request during next background loop iteration."""
if
verbose
:
if
verbose
:
logger
.
info
(
"Aborted request %s."
,
request_id
)
logger
.
info
(
"Aborted request %s."
,
request_id
)
self
.
_
finish
ed_requests
.
put_nowait
(
request_id
)
self
.
_
abort
ed_requests
.
put_nowait
(
request_id
)
if
request_id
not
in
self
.
_request_streams
or
self
.
_request_streams
[
stream
=
self
.
_request_streams
.
pop
(
request_id
,
None
)
request_id
].
finished
:
if
stream
is
not
None
:
# The request has already finished or been aborted.
stream
.
finish
(
cancelled
=
cancelled
)
return
self
.
_request_streams
[
request_id
].
finish
()
def
get_new_and_aborted_requests
(
self
)
->
Tuple
[
List
[
Dict
],
Set
[
str
]]:
def
get_new_and_finished_requests
(
self
)
->
Tuple
[
List
[
Dict
],
Set
[
str
]]:
"""Get the new requests and finished requests to be
"""Get the new requests and finished requests to be
sent to the engine."""
sent to the engine."""
new_requests
:
List
[
Dict
]
=
[]
new_requests
:
List
[
Dict
]
=
[]
finished_requests
:
Set
[
str
]
=
set
()
finished_requests
:
Set
[
str
]
=
set
()
while
not
self
.
_
finish
ed_requests
.
empty
():
while
not
self
.
_
abort
ed_requests
.
empty
():
request_id
=
self
.
_
finish
ed_requests
.
get_nowait
()
request_id
=
self
.
_
abort
ed_requests
.
get_nowait
()
finished_requests
.
add
(
request_id
)
finished_requests
.
add
(
request_id
)
self
.
_request_streams
.
pop
(
request_id
,
None
)
while
not
self
.
_new_requests
.
empty
():
while
not
self
.
_new_requests
.
empty
():
stream
,
new_request
=
self
.
_new_requests
.
get_nowait
()
stream
,
new_request
=
self
.
_new_requests
.
get_nowait
()
if
stream
.
request_id
in
finished_requests
:
if
stream
.
request_id
in
finished_requests
:
# The request has already been aborted.
# The request has already been aborted.
stream
.
finish
()
stream
.
finish
(
cancelled
=
True
)
continue
continue
self
.
_request_streams
[
stream
.
request_id
]
=
stream
self
.
_request_streams
[
stream
.
request_id
]
=
stream
new_requests
.
append
(
new_request
)
new_requests
.
append
(
new_request
)
...
@@ -556,8 +575,8 @@ class AsyncLLMEngine:
...
@@ -556,8 +575,8 @@ class AsyncLLMEngine:
Returns True if there are in-progress requests."""
Returns True if there are in-progress requests."""
new_requests
,
finish
ed_requests
=
(
new_requests
,
abort
ed_requests
=
(
self
.
_request_tracker
.
get_new_and_
finish
ed_requests
())
self
.
_request_tracker
.
get_new_and_
abort
ed_requests
())
for
new_request
in
new_requests
:
for
new_request
in
new_requests
:
# Add the request into the vLLM engine's waiting queue.
# Add the request into the vLLM engine's waiting queue.
...
@@ -576,8 +595,8 @@ class AsyncLLMEngine:
...
@@ -576,8 +595,8 @@ class AsyncLLMEngine:
verbose
=
self
.
log_requests
,
verbose
=
self
.
log_requests
,
)
)
if
finish
ed_requests
:
if
abort
ed_requests
:
await
self
.
_engine_abort
(
finish
ed_requests
)
await
self
.
_engine_abort
(
abort
ed_requests
)
if
self
.
engine_use_ray
:
if
self
.
engine_use_ray
:
request_outputs
=
await
self
.
engine
.
step
.
remote
()
# type: ignore
request_outputs
=
await
self
.
engine
.
step
.
remote
()
# type: ignore
...
@@ -666,6 +685,8 @@ class AsyncLLMEngine:
...
@@ -666,6 +685,8 @@ class AsyncLLMEngine:
raise
raise
await
asyncio
.
sleep
(
0
)
await
asyncio
.
sleep
(
0
)
# This method does not need to be async, but kept that way
# for backwards compatibility.
async
def
add_request
(
async
def
add_request
(
self
,
self
,
request_id
:
str
,
request_id
:
str
,
...
@@ -675,7 +696,7 @@ class AsyncLLMEngine:
...
@@ -675,7 +696,7 @@ class AsyncLLMEngine:
lora_request
:
Optional
[
LoRARequest
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
)
->
Async
Stream
:
)
->
Async
Generator
[
Union
[
RequestOutput
,
EmbeddingRequestOutput
],
None
]
:
if
not
self
.
is_running
:
if
not
self
.
is_running
:
if
self
.
start_engine_loop
:
if
self
.
start_engine_loop
:
self
.
start_background_loop
()
self
.
start_background_loop
()
...
@@ -686,20 +707,17 @@ class AsyncLLMEngine:
...
@@ -686,20 +707,17 @@ class AsyncLLMEngine:
"error that caused the background loop to stop "
"error that caused the background loop to stop "
"(AsyncEngineDeadError)."
)
"(AsyncEngineDeadError)."
)
if
arrival_time
is
None
:
arrival_time
=
time
.
time
()
stream
=
self
.
_request_tracker
.
add_request
(
stream
=
self
.
_request_tracker
.
add_request
(
request_id
,
request_id
,
verbose
=
self
.
log_requests
,
verbose
=
self
.
log_requests
,
inputs
=
inputs
,
inputs
=
inputs
,
params
=
params
,
params
=
params
,
arrival_time
=
arrival_time
,
arrival_time
=
arrival_time
or
time
.
time
()
,
lora_request
=
lora_request
,
lora_request
=
lora_request
,
trace_headers
=
trace_headers
,
trace_headers
=
trace_headers
,
prompt_adapter_request
=
prompt_adapter_request
)
prompt_adapter_request
=
prompt_adapter_request
)
return
stream
return
stream
.
generator
()
async
def
generate
(
async
def
generate
(
self
,
self
,
...
@@ -709,7 +727,7 @@ class AsyncLLMEngine:
...
@@ -709,7 +727,7 @@ class AsyncLLMEngine:
lora_request
:
Optional
[
LoRARequest
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
)
->
Async
It
erator
[
RequestOutput
]:
)
->
Async
Gen
erator
[
RequestOutput
,
None
]:
"""Generate outputs for a request.
"""Generate outputs for a request.
Generate outputs for a request. This method is a coroutine. It adds the
Generate outputs for a request. This method is a coroutine. It adds the
...
@@ -774,7 +792,7 @@ class AsyncLLMEngine:
...
@@ -774,7 +792,7 @@ class AsyncLLMEngine:
>>> # Process and return the final output
>>> # Process and return the final output
>>> ...
>>> ...
"""
"""
async
for
output
in
self
.
_process
_request
(
async
for
output
in
await
self
.
add
_request
(
request_id
,
request_id
,
inputs
,
inputs
,
sampling_params
,
sampling_params
,
...
@@ -791,7 +809,7 @@ class AsyncLLMEngine:
...
@@ -791,7 +809,7 @@ class AsyncLLMEngine:
request_id
:
str
,
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
)
->
Async
It
erator
[
EmbeddingRequestOutput
]:
)
->
Async
Gen
erator
[
EmbeddingRequestOutput
,
None
]:
"""Generate outputs for a request from an embedding model.
"""Generate outputs for a request from an embedding model.
Generate outputs for a request. This method is a coroutine. It adds the
Generate outputs for a request. This method is a coroutine. It adds the
...
@@ -852,7 +870,7 @@ class AsyncLLMEngine:
...
@@ -852,7 +870,7 @@ class AsyncLLMEngine:
>>> # Process and return the final output
>>> # Process and return the final output
>>> ...
>>> ...
"""
"""
async
for
output
in
self
.
_process
_request
(
async
for
output
in
await
self
.
add
_request
(
request_id
,
request_id
,
inputs
,
inputs
,
pooling_params
,
pooling_params
,
...
@@ -861,37 +879,6 @@ class AsyncLLMEngine:
...
@@ -861,37 +879,6 @@ class AsyncLLMEngine:
):
):
yield
LLMEngine
.
validate_output
(
output
,
EmbeddingRequestOutput
)
yield
LLMEngine
.
validate_output
(
output
,
EmbeddingRequestOutput
)
async
def
_process_request
(
self
,
request_id
:
str
,
inputs
:
PromptInputs
,
params
:
Union
[
SamplingParams
,
PoolingParams
],
*
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
)
->
AsyncIterator
[
Union
[
RequestOutput
,
EmbeddingRequestOutput
]]:
"""Common logic to process requests with SamplingParams or
PoolingParams."""
arrival_time
=
time
.
time
()
stream
=
await
self
.
add_request
(
request_id
,
inputs
,
params
,
arrival_time
=
arrival_time
,
lora_request
=
lora_request
,
trace_headers
=
trace_headers
,
prompt_adapter_request
=
prompt_adapter_request
,
)
try
:
async
for
request_output
in
stream
:
yield
request_output
except
(
Exception
,
asyncio
.
CancelledError
)
as
e
:
self
.
_abort
(
request_id
)
raise
e
async
def
abort
(
self
,
request_id
:
str
)
->
None
:
async
def
abort
(
self
,
request_id
:
str
)
->
None
:
"""Abort a request.
"""Abort a request.
...
@@ -920,6 +907,7 @@ class AsyncLLMEngine:
...
@@ -920,6 +907,7 @@ class AsyncLLMEngine:
request_id: The unique id of the request.
request_id: The unique id of the request.
"""
"""
self
.
_request_tracker
.
abort_request
(
request_id
,
self
.
_request_tracker
.
abort_request
(
request_id
,
cancelled
=
True
,
verbose
=
self
.
log_requests
)
verbose
=
self
.
log_requests
)
async
def
get_model_config
(
self
)
->
ModelConfig
:
async
def
get_model_config
(
self
)
->
ModelConfig
:
...
...
vllm/engine/protocol.py
View file @
9a3f49ae
from
typing
import
(
Async
It
erator
,
List
,
Mapping
,
Optional
,
Protocol
,
from
typing
import
(
Async
Gen
erator
,
List
,
Mapping
,
Optional
,
Protocol
,
runtime_checkable
)
runtime_checkable
)
from
transformers
import
PreTrainedTokenizer
from
transformers
import
PreTrainedTokenizer
...
@@ -30,7 +30,7 @@ class AsyncEngineClient(Protocol):
...
@@ -30,7 +30,7 @@ class AsyncEngineClient(Protocol):
def
errored
(
self
)
->
bool
:
def
errored
(
self
)
->
bool
:
...
...
async
def
generate
(
def
generate
(
self
,
self
,
inputs
:
PromptInputs
,
inputs
:
PromptInputs
,
sampling_params
:
SamplingParams
,
sampling_params
:
SamplingParams
,
...
@@ -38,17 +38,17 @@ class AsyncEngineClient(Protocol):
...
@@ -38,17 +38,17 @@ class AsyncEngineClient(Protocol):
lora_request
:
Optional
[
LoRARequest
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
)
->
Async
It
erator
[
RequestOutput
]:
)
->
Async
Gen
erator
[
RequestOutput
,
None
]:
"""Generates outputs for a request"""
"""Generates outputs for a request"""
async
def
encode
(
def
encode
(
self
,
self
,
inputs
:
PromptInputs
,
inputs
:
PromptInputs
,
pooling_params
:
PoolingParams
,
pooling_params
:
PoolingParams
,
request_id
:
str
,
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
)
->
Async
It
erator
[
EmbeddingRequestOutput
]:
)
->
Async
Gen
erator
[
EmbeddingRequestOutput
,
None
]:
"""Generate outputs for a request from an embedding model."""
"""Generate outputs for a request from an embedding model."""
async
def
abort
(
self
,
request_id
:
str
)
->
None
:
async
def
abort
(
self
,
request_id
:
str
)
->
None
:
...
...
vllm/entrypoints/api_server.py
View file @
9a3f49ae
...
@@ -20,7 +20,8 @@ from vllm.entrypoints.launcher import serve_http
...
@@ -20,7 +20,8 @@ from vllm.entrypoints.launcher import serve_http
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.utils
import
FlexibleArgumentParser
,
random_uuid
from
vllm.utils
import
(
FlexibleArgumentParser
,
iterate_with_cancellation
,
random_uuid
)
from
vllm.version
import
__version__
as
VLLM_VERSION
from
vllm.version
import
__version__
as
VLLM_VERSION
logger
=
init_logger
(
"vllm.entrypoints.api_server"
)
logger
=
init_logger
(
"vllm.entrypoints.api_server"
)
...
@@ -53,6 +54,8 @@ async def generate(request: Request) -> Response:
...
@@ -53,6 +54,8 @@ async def generate(request: Request) -> Response:
assert
engine
is
not
None
assert
engine
is
not
None
results_generator
=
engine
.
generate
(
prompt
,
sampling_params
,
request_id
)
results_generator
=
engine
.
generate
(
prompt
,
sampling_params
,
request_id
)
results_generator
=
iterate_with_cancellation
(
results_generator
,
is_cancelled
=
request
.
is_disconnected
)
# Streaming case
# Streaming case
async
def
stream_results
()
->
AsyncGenerator
[
bytes
,
None
]:
async
def
stream_results
()
->
AsyncGenerator
[
bytes
,
None
]:
...
@@ -69,12 +72,11 @@ async def generate(request: Request) -> Response:
...
@@ -69,12 +72,11 @@ async def generate(request: Request) -> Response:
# Non-streaming case
# Non-streaming case
final_output
=
None
final_output
=
None
try
:
async
for
request_output
in
results_generator
:
async
for
request_output
in
results_generator
:
if
await
request
.
is_disconnected
():
# Abort the request if the client disconnects.
await
engine
.
abort
(
request_id
)
return
Response
(
status_code
=
499
)
final_output
=
request_output
final_output
=
request_output
except
asyncio
.
CancelledError
:
return
Response
(
status_code
=
499
)
assert
final_output
is
not
None
assert
final_output
is
not
None
prompt
=
final_output
.
prompt
prompt
=
final_output
.
prompt
...
...
vllm/entrypoints/openai/rpc/client.py
View file @
9a3f49ae
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
typing
import
Any
,
Async
It
erator
,
Mapping
,
Optional
from
typing
import
Any
,
Async
Gen
erator
,
Mapping
,
Optional
import
cloudpickle
import
cloudpickle
import
zmq
import
zmq
...
@@ -190,9 +190,11 @@ class AsyncEngineRPCClient:
...
@@ -190,9 +190,11 @@ class AsyncEngineRPCClient:
lora_request
:
Optional
[
LoRARequest
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
)
->
Async
It
erator
[
RequestOutput
]:
)
->
Async
Gen
erator
[
RequestOutput
,
None
]:
"""Send an RPCGenerateRequest to the RPCServer and stream responses."""
"""Send an RPCGenerateRequest to the RPCServer and stream responses."""
finished
=
False
try
:
with
self
.
socket
()
as
socket
:
with
self
.
socket
()
as
socket
:
# Send RPCGenerateRequest to the RPCServer.
# Send RPCGenerateRequest to the RPCServer.
...
@@ -208,18 +210,18 @@ class AsyncEngineRPCClient:
...
@@ -208,18 +210,18 @@ class AsyncEngineRPCClient:
])
])
# Stream back the results from the RPC Server.
# Stream back the results from the RPC Server.
while
True
:
while
not
finished
:
message
=
await
socket
.
recv
()
message
=
await
socket
.
recv
()
request_output
=
cloudpickle
.
loads
(
message
)
request_output
=
cloudpickle
.
loads
(
message
)
if
isinstance
(
request_output
,
Exception
):
if
isinstance
(
request_output
,
Exception
):
raise
request_output
raise
request_output
if
request_output
.
finished
:
finished
=
request_output
.
finished
break
yield
request_output
yield
request_output
yield
request_output
finally
:
if
not
finished
:
await
self
.
abort
(
request_id
)
async
def
check_health
(
self
)
->
None
:
async
def
check_health
(
self
)
->
None
:
"""Raise if unhealthy"""
"""Raise if unhealthy"""
...
@@ -243,6 +245,6 @@ class AsyncEngineRPCClient:
...
@@ -243,6 +245,6 @@ class AsyncEngineRPCClient:
"f{health_message}"
)
"f{health_message}"
)
async
def
encode
(
self
,
*
args
,
async
def
encode
(
self
,
*
args
,
**
kwargs
)
->
Async
It
erator
[
EmbeddingRequestOutput
]:
**
kwargs
)
->
Async
Gen
erator
[
EmbeddingRequestOutput
,
None
]:
raise
NotImplementedError
(
raise
NotImplementedError
(
"Embeddings not supported with multiprocessing backend"
)
"Embeddings not supported with multiprocessing backend"
)
vllm/entrypoints/openai/serving_chat.py
View file @
9a3f49ae
import
asyncio
import
time
import
time
from
typing
import
AsyncGenerator
,
AsyncIterator
,
Dict
,
List
,
Optional
from
typing
import
AsyncGenerator
,
AsyncIterator
,
Dict
,
List
,
Optional
from
typing
import
Sequence
as
GenericSequence
from
typing
import
Sequence
as
GenericSequence
...
@@ -29,7 +30,7 @@ from vllm.outputs import RequestOutput
...
@@ -29,7 +30,7 @@ from vllm.outputs import RequestOutput
from
vllm.sequence
import
Logprob
from
vllm.sequence
import
Logprob
from
vllm.tracing
import
(
contains_trace_headers
,
extract_trace_headers
,
from
vllm.tracing
import
(
contains_trace_headers
,
extract_trace_headers
,
log_tracing_disabled_warning
)
log_tracing_disabled_warning
)
from
vllm.utils
import
random_uuid
from
vllm.utils
import
iterate_with_cancellation
,
random_uuid
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -176,15 +177,17 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -176,15 +177,17 @@ class OpenAIServingChat(OpenAIServing):
# TODO: Use a vllm-specific Validation Error
# TODO: Use a vllm-specific Validation Error
return
self
.
create_error_response
(
str
(
e
))
return
self
.
create_error_response
(
str
(
e
))
if
raw_request
:
result_generator
=
iterate_with_cancellation
(
result_generator
,
raw_request
.
is_disconnected
)
# Streaming response
# Streaming response
if
request
.
stream
:
if
request
.
stream
:
return
self
.
chat_completion_stream_generator
(
return
self
.
chat_completion_stream_generator
(
request
,
result_generator
,
request_id
,
conversation
,
tokenizer
)
request
,
result_generator
,
request_id
,
conversation
,
tokenizer
)
else
:
try
:
try
:
return
await
self
.
chat_completion_full_generator
(
return
await
self
.
chat_completion_full_generator
(
request
,
raw_request
,
result_generator
,
request_id
,
request
,
result_generator
,
request_id
,
conversation
,
tokenizer
)
conversation
,
tokenizer
)
except
ValueError
as
e
:
except
ValueError
as
e
:
# TODO: Use a vllm-specific Validation Error
# TODO: Use a vllm-specific Validation Error
return
self
.
create_error_response
(
str
(
e
))
return
self
.
create_error_response
(
str
(
e
))
...
@@ -422,7 +425,6 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -422,7 +425,6 @@ class OpenAIServingChat(OpenAIServing):
async
def
chat_completion_full_generator
(
async
def
chat_completion_full_generator
(
self
,
self
,
request
:
ChatCompletionRequest
,
request
:
ChatCompletionRequest
,
raw_request
:
Optional
[
Request
],
result_generator
:
AsyncIterator
[
RequestOutput
],
result_generator
:
AsyncIterator
[
RequestOutput
],
request_id
:
str
,
request_id
:
str
,
conversation
:
List
[
ConversationMessage
],
conversation
:
List
[
ConversationMessage
],
...
@@ -433,12 +435,12 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -433,12 +435,12 @@ class OpenAIServingChat(OpenAIServing):
created_time
=
int
(
time
.
time
())
created_time
=
int
(
time
.
time
())
final_res
:
Optional
[
RequestOutput
]
=
None
final_res
:
Optional
[
RequestOutput
]
=
None
try
:
async
for
res
in
result_generator
:
async
for
res
in
result_generator
:
if
raw_request
is
not
None
and
await
raw_request
.
is_disconnected
():
# Abort the request if the client disconnects.
await
self
.
async_engine_client
.
abort
(
request_id
)
return
self
.
create_error_response
(
"Client disconnected"
)
final_res
=
res
final_res
=
res
except
asyncio
.
CancelledError
:
return
self
.
create_error_response
(
"Client disconnected"
)
assert
final_res
is
not
None
assert
final_res
is
not
None
choices
:
List
[
ChatCompletionResponseChoice
]
=
[]
choices
:
List
[
ChatCompletionResponseChoice
]
=
[]
...
...
vllm/entrypoints/openai/serving_completion.py
View file @
9a3f49ae
import
asyncio
import
time
import
time
from
typing
import
(
AsyncGenerator
,
AsyncIterator
,
Callable
,
Dict
,
List
,
from
typing
import
(
AsyncGenerator
,
AsyncIterator
,
Callable
,
Dict
,
List
,
Optional
)
Optional
)
...
@@ -84,7 +85,7 @@ class OpenAIServingCompletion(OpenAIServing):
...
@@ -84,7 +85,7 @@ class OpenAIServingCompletion(OpenAIServing):
created_time
=
int
(
time
.
time
())
created_time
=
int
(
time
.
time
())
# Schedule the request and get the result generator.
# Schedule the request and get the result generator.
generators
:
List
[
Async
It
erator
[
RequestOutput
]]
=
[]
generators
:
List
[
Async
Gen
erator
[
RequestOutput
,
None
]]
=
[]
try
:
try
:
(
(
lora_request
,
lora_request
,
...
@@ -144,7 +145,8 @@ class OpenAIServingCompletion(OpenAIServing):
...
@@ -144,7 +145,8 @@ class OpenAIServingCompletion(OpenAIServing):
return
self
.
create_error_response
(
str
(
e
))
return
self
.
create_error_response
(
str
(
e
))
result_generator
:
AsyncIterator
[
Tuple
[
result_generator
:
AsyncIterator
[
Tuple
[
int
,
RequestOutput
]]
=
merge_async_iterators
(
*
generators
)
int
,
RequestOutput
]]
=
merge_async_iterators
(
*
generators
,
is_cancelled
=
raw_request
.
is_disconnected
)
# Similar to the OpenAI API, when n != best_of, we do not stream the
# Similar to the OpenAI API, when n != best_of, we do not stream the
# results. In addition, we do not stream the results when use
# results. In addition, we do not stream the results when use
...
@@ -156,7 +158,6 @@ class OpenAIServingCompletion(OpenAIServing):
...
@@ -156,7 +158,6 @@ class OpenAIServingCompletion(OpenAIServing):
# Streaming response
# Streaming response
if
stream
:
if
stream
:
return
self
.
completion_stream_generator
(
request
,
return
self
.
completion_stream_generator
(
request
,
raw_request
,
result_generator
,
result_generator
,
request_id
,
request_id
,
created_time
,
created_time
,
...
@@ -168,10 +169,6 @@ class OpenAIServingCompletion(OpenAIServing):
...
@@ -168,10 +169,6 @@ class OpenAIServingCompletion(OpenAIServing):
final_res_batch
:
List
[
Optional
[
RequestOutput
]]
=
[
None
]
*
len
(
prompts
)
final_res_batch
:
List
[
Optional
[
RequestOutput
]]
=
[
None
]
*
len
(
prompts
)
try
:
try
:
async
for
i
,
res
in
result_generator
:
async
for
i
,
res
in
result_generator
:
if
await
raw_request
.
is_disconnected
():
# Abort the request if the client disconnects.
await
self
.
async_engine_client
.
abort
(
f
"
{
request_id
}
-
{
i
}
"
)
return
self
.
create_error_response
(
"Client disconnected"
)
final_res_batch
[
i
]
=
res
final_res_batch
[
i
]
=
res
for
i
,
final_res
in
enumerate
(
final_res_batch
):
for
i
,
final_res
in
enumerate
(
final_res_batch
):
...
@@ -194,6 +191,8 @@ class OpenAIServingCompletion(OpenAIServing):
...
@@ -194,6 +191,8 @@ class OpenAIServingCompletion(OpenAIServing):
model_name
,
model_name
,
tokenizer
,
tokenizer
,
)
)
except
asyncio
.
CancelledError
:
return
self
.
create_error_response
(
"Client disconnected"
)
except
ValueError
as
e
:
except
ValueError
as
e
:
# TODO: Use a vllm-specific Validation Error
# TODO: Use a vllm-specific Validation Error
return
self
.
create_error_response
(
str
(
e
))
return
self
.
create_error_response
(
str
(
e
))
...
@@ -214,7 +213,6 @@ class OpenAIServingCompletion(OpenAIServing):
...
@@ -214,7 +213,6 @@ class OpenAIServingCompletion(OpenAIServing):
async
def
completion_stream_generator
(
async
def
completion_stream_generator
(
self
,
self
,
request
:
CompletionRequest
,
request
:
CompletionRequest
,
raw_request
:
Request
,
result_generator
:
AsyncIterator
[
Tuple
[
int
,
RequestOutput
]],
result_generator
:
AsyncIterator
[
Tuple
[
int
,
RequestOutput
]],
request_id
:
str
,
request_id
:
str
,
created_time
:
int
,
created_time
:
int
,
...
@@ -230,12 +228,6 @@ class OpenAIServingCompletion(OpenAIServing):
...
@@ -230,12 +228,6 @@ class OpenAIServingCompletion(OpenAIServing):
try
:
try
:
async
for
prompt_idx
,
res
in
result_generator
:
async
for
prompt_idx
,
res
in
result_generator
:
# Abort the request if the client disconnects.
if
await
raw_request
.
is_disconnected
():
await
self
.
async_engine_client
.
abort
(
f
"
{
request_id
}
-
{
prompt_idx
}
"
)
raise
StopAsyncIteration
()
for
output
in
res
.
outputs
:
for
output
in
res
.
outputs
:
i
=
output
.
index
+
prompt_idx
*
num_choices
i
=
output
.
index
+
prompt_idx
*
num_choices
# TODO(simon): optimize the performance by avoiding full
# TODO(simon): optimize the performance by avoiding full
...
...
vllm/entrypoints/openai/serving_embedding.py
View file @
9a3f49ae
import
asyncio
import
base64
import
base64
import
time
import
time
from
typing
import
AsyncIterator
,
List
,
Optional
,
Tuple
,
cast
from
typing
import
AsyncGenerator
,
AsyncIterator
,
List
,
Optional
,
Tuple
,
cast
import
numpy
as
np
import
numpy
as
np
from
fastapi
import
Request
from
fastapi
import
Request
...
@@ -92,7 +93,7 @@ class OpenAIServingEmbedding(OpenAIServing):
...
@@ -92,7 +93,7 @@ class OpenAIServingEmbedding(OpenAIServing):
created_time
=
int
(
time
.
monotonic
())
created_time
=
int
(
time
.
monotonic
())
# Schedule the request and get the result generator.
# Schedule the request and get the result generator.
generators
:
List
[
Async
It
erator
[
EmbeddingRequestOutput
]]
=
[]
generators
:
List
[
Async
Gen
erator
[
EmbeddingRequestOutput
,
None
]]
=
[]
try
:
try
:
(
(
lora_request
,
lora_request
,
...
@@ -138,17 +139,14 @@ class OpenAIServingEmbedding(OpenAIServing):
...
@@ -138,17 +139,14 @@ class OpenAIServingEmbedding(OpenAIServing):
return
self
.
create_error_response
(
str
(
e
))
return
self
.
create_error_response
(
str
(
e
))
result_generator
:
AsyncIterator
[
Tuple
[
result_generator
:
AsyncIterator
[
Tuple
[
int
,
EmbeddingRequestOutput
]]
=
merge_async_iterators
(
*
generators
)
int
,
EmbeddingRequestOutput
]]
=
merge_async_iterators
(
*
generators
,
is_cancelled
=
raw_request
.
is_disconnected
)
# Non-streaming response
# Non-streaming response
final_res_batch
:
List
[
Optional
[
EmbeddingRequestOutput
]]
final_res_batch
:
List
[
Optional
[
EmbeddingRequestOutput
]]
final_res_batch
=
[
None
]
*
len
(
prompts
)
final_res_batch
=
[
None
]
*
len
(
prompts
)
try
:
try
:
async
for
i
,
res
in
result_generator
:
async
for
i
,
res
in
result_generator
:
if
await
raw_request
.
is_disconnected
():
# Abort the request if the client disconnects.
await
self
.
async_engine_client
.
abort
(
f
"
{
request_id
}
-
{
i
}
"
)
return
self
.
create_error_response
(
"Client disconnected"
)
final_res_batch
[
i
]
=
res
final_res_batch
[
i
]
=
res
for
final_res
in
final_res_batch
:
for
final_res
in
final_res_batch
:
...
@@ -160,6 +158,8 @@ class OpenAIServingEmbedding(OpenAIServing):
...
@@ -160,6 +158,8 @@ class OpenAIServingEmbedding(OpenAIServing):
response
=
request_output_to_embedding_response
(
response
=
request_output_to_embedding_response
(
final_res_batch_checked
,
request_id
,
created_time
,
model_name
,
final_res_batch_checked
,
request_id
,
created_time
,
model_name
,
encoding_format
)
encoding_format
)
except
asyncio
.
CancelledError
:
return
self
.
create_error_response
(
"Client disconnected"
)
except
ValueError
as
e
:
except
ValueError
as
e
:
# TODO: Use a vllm-specific Validation Error
# TODO: Use a vllm-specific Validation Error
return
self
.
create_error_response
(
str
(
e
))
return
self
.
create_error_response
(
str
(
e
))
...
...
vllm/utils.py
View file @
9a3f49ae
import
argparse
import
argparse
import
asyncio
import
asyncio
import
contextlib
import
datetime
import
datetime
import
enum
import
enum
import
gc
import
gc
...
@@ -11,10 +12,11 @@ import tempfile
...
@@ -11,10 +12,11 @@ import tempfile
import
threading
import
threading
import
uuid
import
uuid
import
warnings
import
warnings
from
asyncio
import
FIRST_COMPLETED
,
ensure_future
from
collections
import
defaultdict
from
collections
import
defaultdict
from
functools
import
lru_cache
,
partial
,
wraps
from
functools
import
lru_cache
,
partial
,
wraps
from
platform
import
uname
from
platform
import
uname
from
typing
import
(
Any
,
Async
It
erator
,
Awaitable
,
Callable
,
Dict
,
Generic
,
from
typing
import
(
Any
,
Async
Gen
erator
,
Awaitable
,
Callable
,
Dict
,
Generic
,
Hashable
,
List
,
Optional
,
OrderedDict
,
Set
,
Tuple
,
TypeVar
,
Hashable
,
List
,
Optional
,
OrderedDict
,
Set
,
Tuple
,
TypeVar
,
Union
,
overload
)
Union
,
overload
)
...
@@ -373,63 +375,74 @@ def make_async(func: Callable[P, T]) -> Callable[P, Awaitable[T]]:
...
@@ -373,63 +375,74 @@ def make_async(func: Callable[P, T]) -> Callable[P, Awaitable[T]]:
return
_async_wrapper
return
_async_wrapper
class
ProducerFinished
:
async
def
iterate_with_cancellation
(
pass
iterator
:
AsyncGenerator
[
T
,
None
],
is_cancelled
:
Callable
[[],
Awaitable
[
bool
]],
)
->
AsyncGenerator
[
T
,
None
]:
"""Convert async iterator into one that polls the provided function
at least once per second to check for client cancellation.
"""
# Can use anext() in python >= 3.10
awaits
=
[
ensure_future
(
iterator
.
__anext__
())]
while
True
:
done
,
pending
=
await
asyncio
.
wait
(
awaits
,
timeout
=
1
)
if
await
is_cancelled
():
with
contextlib
.
suppress
(
BaseException
):
awaits
[
0
].
cancel
()
await
iterator
.
aclose
()
raise
asyncio
.
CancelledError
(
"client cancelled"
)
if
done
:
try
:
item
=
await
awaits
[
0
]
awaits
[
0
]
=
ensure_future
(
iterator
.
__anext__
())
yield
item
except
StopAsyncIteration
:
# we are done
return
def
merge_async_iterators
(
async
def
merge_async_iterators
(
*
iterators
:
AsyncIterator
[
T
])
->
AsyncIterator
[
Tuple
[
int
,
T
]]:
*
iterators
:
AsyncGenerator
[
T
,
None
],
is_cancelled
:
Callable
[[],
Awaitable
[
bool
]],
)
->
AsyncGenerator
[
Tuple
[
int
,
T
],
None
]:
"""Merge multiple asynchronous iterators into a single iterator.
"""Merge multiple asynchronous iterators into a single iterator.
This method handle the case where some iterators finish before others.
This method handle the case where some iterators finish before others.
When it yields, it yields a tuple (i, item) where i is the index of the
When it yields, it yields a tuple (i, item) where i is the index of the
iterator that yields the item.
iterator that yields the item.
"""
queue
:
asyncio
.
Queue
[
Union
[
Tuple
[
int
,
T
],
ProducerFinished
,
Exception
]]
=
asyncio
.
Queue
()
producers
=
len
(
iterators
)
It also polls the provided function at least once per second to check
for client cancellation.
"""
async
def
producer
(
i
:
int
,
iterator
:
AsyncIterator
[
T
]):
# Can use anext() in python >= 3.10
awaits
=
{
ensure_future
(
pair
[
1
].
__anext__
()):
pair
for
pair
in
enumerate
(
iterators
)
}
try
:
try
:
async
for
item
in
iterator
:
while
awaits
:
await
queue
.
put
((
i
,
item
))
done
,
pending
=
await
asyncio
.
wait
(
awaits
.
keys
(),
except
Exception
as
e
:
return_when
=
FIRST_COMPLETED
,
await
queue
.
put
(
e
)
timeout
=
1
)
# Signal to the consumer that we've finished
if
await
is_cancelled
():
await
queue
.
put
(
ProducerFinished
())
raise
asyncio
.
CancelledError
(
"client cancelled"
)
for
d
in
done
:
_tasks
=
[
pair
=
awaits
.
pop
(
d
)
asyncio
.
create_task
(
producer
(
i
,
iterator
))
for
i
,
iterator
in
enumerate
(
iterators
)
]
async
def
consumer
():
remaining
=
producers
try
:
try
:
while
remaining
or
not
queue
.
empty
():
item
=
await
d
# we think there is a race condition here
i
,
it
=
pair
item
=
await
queue
.
get
()
awaits
[
ensure_future
(
it
.
__anext__
())]
=
pair
yield
i
,
item
if
isinstance
(
item
,
ProducerFinished
):
except
StopAsyncIteration
:
# Signal that a producer finished- not a real item
pass
remaining
-=
1
finally
:
continue
# Cancel any remaining iterators
for
f
,
(
_
,
it
)
in
awaits
.
items
():
if
isinstance
(
item
,
Exception
):
with
contextlib
.
suppress
(
BaseException
):
raise
item
f
.
cancel
()
yield
item
await
it
.
aclose
()
except
(
Exception
,
asyncio
.
CancelledError
)
as
e
:
for
task
in
_tasks
:
if
sys
.
version_info
>=
(
3
,
9
):
# msg parameter only supported in Python 3.9+
task
.
cancel
(
e
)
else
:
task
.
cancel
()
raise
e
await
asyncio
.
gather
(
*
_tasks
)
return
consumer
()
def
get_ip
()
->
str
:
def
get_ip
()
->
str
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment