Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
1b875a0e
Unverified
Commit
1b875a0e
authored
Dec 27, 2024
by
Robert Shaw
Committed by
GitHub
Dec 26, 2024
Browse files
[V1][3/N] API Server: Reduce Task Switching + Handle Abort Properly (#11534)
parent
eb881ed0
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
63 additions
and
153 deletions
+63
-153
vllm/v1/engine/async_llm.py
vllm/v1/engine/async_llm.py
+62
-97
vllm/v1/engine/async_stream.py
vllm/v1/engine/async_stream.py
+0
-55
vllm/v1/engine/core.py
vllm/v1/engine/core.py
+1
-1
No files found.
vllm/v1/engine/async_llm.py
View file @
1b875a0e
...
@@ -9,14 +9,13 @@ from vllm.inputs import INPUT_REGISTRY, InputRegistry, PromptType
...
@@ -9,14 +9,13 @@ from vllm.inputs import INPUT_REGISTRY, InputRegistry, PromptType
from
vllm.inputs.preprocess
import
InputPreprocessor
from
vllm.inputs.preprocess
import
InputPreprocessor
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.outputs
import
PoolingRequestOutput
,
RequestOutput
from
vllm.outputs
import
RequestOutput
from
vllm.pooling_params
import
PoolingParams
from
vllm.pooling_params
import
PoolingParams
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
vllm.transformers_utils.tokenizer_group
import
init_tokenizer_from_configs
from
vllm.transformers_utils.tokenizer_group
import
init_tokenizer_from_configs
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.v1.engine.async_stream
import
AsyncStream
from
vllm.v1.engine.core_client
import
EngineCoreClient
from
vllm.v1.engine.core_client
import
EngineCoreClient
from
vllm.v1.engine.detokenizer
import
Detokenizer
from
vllm.v1.engine.detokenizer
import
Detokenizer
from
vllm.v1.engine.processor
import
Processor
from
vllm.v1.engine.processor
import
Processor
...
@@ -54,10 +53,8 @@ class AsyncLLM(EngineClient):
...
@@ -54,10 +53,8 @@ class AsyncLLM(EngineClient):
lora_config
=
vllm_config
.
lora_config
)
lora_config
=
vllm_config
.
lora_config
)
self
.
tokenizer
.
ping
()
self
.
tokenizer
.
ping
()
# Request streams (map of request_id -> AsyncStream).
# Request streams (map of request_id -> queue).
self
.
request_streams
:
Dict
[
str
,
AsyncStream
]
=
{}
self
.
rid_to_queue
:
Dict
[
str
,
asyncio
.
Queue
]
=
{}
# List of cancelled request ids to be aborted.
self
.
client_aborted_requests
:
List
[
str
]
=
[]
# Processor (converts Inputs --> EngineCoreRequests).
# Processor (converts Inputs --> EngineCoreRequests).
self
.
processor
=
Processor
(
self
.
processor
=
Processor
(
...
@@ -153,14 +150,13 @@ class AsyncLLM(EngineClient):
...
@@ -153,14 +150,13 @@ class AsyncLLM(EngineClient):
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
priority
:
int
=
0
,
priority
:
int
=
0
,
)
->
A
sync
Generator
[
Union
[
RequestOutput
,
Pooling
RequestOutput
]
,
None
]
:
)
->
a
sync
io
.
Queue
[
RequestOutput
]:
"""Add new request to the AsyncLLM."""
"""Add new request to the AsyncLLM."""
if
self
.
detokenizer
.
is_request_active
(
request_id
):
# 1) Create a new output queue for the request.
raise
ValueError
(
f
"Request
{
request_id
}
already exists."
)
if
request_id
in
self
.
rid_to_queue
:
raise
ValueError
(
f
"Request id
{
request_id
}
already running."
)
# 1) Create a new AsyncStream for the request.
self
.
rid_to_queue
[
request_id
]
=
asyncio
.
Queue
()
stream
=
self
.
_add_request_to_streams
(
request_id
)
# 2) Convert input --> DetokenizerRequest / EngineCoreRequest.
# 2) Convert input --> DetokenizerRequest / EngineCoreRequest.
detokenizer_req
,
engine_core_req
=
self
.
processor
.
process_inputs
(
detokenizer_req
,
engine_core_req
=
self
.
processor
.
process_inputs
(
...
@@ -173,8 +169,10 @@ class AsyncLLM(EngineClient):
...
@@ -173,8 +169,10 @@ class AsyncLLM(EngineClient):
# 4) Add the EngineCoreRequest to EngineCore (separate process).
# 4) Add the EngineCoreRequest to EngineCore (separate process).
await
self
.
engine_core
.
add_request_async
(
engine_core_req
)
await
self
.
engine_core
.
add_request_async
(
engine_core_req
)
# 5) Return the generator.
if
self
.
log_requests
:
return
stream
.
generator
()
logger
.
info
(
"Added request %s."
,
request_id
)
return
self
.
rid_to_queue
[
request_id
]
# TODO: we should support multiple prompts in one call, as you
# TODO: we should support multiple prompts in one call, as you
# can do with LLM.generate. So that for multi-prompt completion
# can do with LLM.generate. So that for multi-prompt completion
...
@@ -194,7 +192,7 @@ class AsyncLLM(EngineClient):
...
@@ -194,7 +192,7 @@ class AsyncLLM(EngineClient):
"""
"""
Main function called by the API server to kick off a request
Main function called by the API server to kick off a request
* 1) Making an AsyncStream corresponding to the Request.
* 1) Making an AsyncStream corresponding to the Request.
#
2) Processing the Input.
*
2) Processing the Input.
* 3) Adding the Request to the Detokenizer.
* 3) Adding the Request to the Detokenizer.
* 4) Adding the Request to the EngineCore (separate process).
* 4) Adding the Request to the EngineCore (separate process).
...
@@ -206,14 +204,15 @@ class AsyncLLM(EngineClient):
...
@@ -206,14 +204,15 @@ class AsyncLLM(EngineClient):
returning the RequestOutput back to the caller.
returning the RequestOutput back to the caller.
"""
"""
# We start the output_handler on the first call to generate() so that
try
:
# we can call __init__ before the event loop starts, which enables us
# We start the output_handler on the first call to generate() so
# to handle startup failure gracefully in the OpenAI server.
# we can call __init__ before the event loop, which enables us
if
self
.
output_handler
is
None
:
# to handle startup failure gracefully in the OpenAI server.
self
.
output_handler
=
asyncio
.
create_task
(
if
self
.
output_handler
is
None
:
self
.
_run_output_handler
())
self
.
output_handler
=
asyncio
.
create_task
(
self
.
_run_output_handler
())
async
for
output
in
await
self
.
add_request
(
q
=
await
self
.
add_request
(
request_id
,
request_id
,
prompt
,
prompt
,
sampling_params
,
sampling_params
,
...
@@ -221,79 +220,42 @@ class AsyncLLM(EngineClient):
...
@@ -221,79 +220,42 @@ class AsyncLLM(EngineClient):
trace_headers
=
trace_headers
,
trace_headers
=
trace_headers
,
prompt_adapter_request
=
prompt_adapter_request
,
prompt_adapter_request
=
prompt_adapter_request
,
priority
=
priority
,
priority
=
priority
,
):
)
yield
output
def
_finish_stream
(
self
,
request_id
:
str
):
stream
=
self
.
request_streams
.
pop
(
request_id
,
None
)
if
stream
is
not
None
:
stream
.
finish
()
def
_add_request_to_streams
(
self
,
request_id
:
str
,
)
->
AsyncStream
:
if
request_id
in
self
.
request_streams
:
raise
ValueError
(
f
"Request id
{
request_id
}
already running."
)
# Avoid streams having circular ref to parent AsyncLLM object.
aborted_reqs
=
self
.
client_aborted_requests
stream
=
AsyncStream
(
request_id
,
aborted_reqs
.
append
)
self
.
request_streams
[
request_id
]
=
stream
if
self
.
log_requests
:
logger
.
info
(
"Added request %s."
,
request_id
)
return
stream
# The output_handler task pushes items into the queue.
# This task pulls from the queue and yields to caller.
async
def
_process_cancellations
(
self
)
->
None
:
while
True
:
"""
# Note: drain queue without await if possible (avoids
Process requests cancelled from user disconnecting.
# task switching under load which helps performance).
out
=
q
.
get_nowait
()
if
q
.
qsize
()
>
0
else
await
q
.
get
()
When a client disconnects, AsyncStream._cancel() is called.
We passed a callback to AsyncStream(), which appends to
# Note: both Detokenizer and EngineCore handle their
self.client_aborted_requests.
# own request cleanup based on finished.
if
out
.
finished
:
As a result, if any requests are canceled from the user side
del
self
.
rid_to_queue
[
request_id
]
the request_id will show up in self.client_aborted_requests.
yield
out
"""
break
# Avoid streams having circular ref to parent AsyncLLM object.
yield
out
if
not
self
.
client_aborted_requests
:
return
# If the request is disconnected by the client, the
reqs_to_abort
=
self
.
client_aborted_requests
.
copy
()
# generate() task will be canceled. So, we abort the
self
.
client_aborted_requests
.
clear
()
# request if we end up here.
except
asyncio
.
CancelledError
:
# Remove from Detokenizer.
await
self
.
abort
(
request_id
)
self
.
detokenizer
.
abort_requests
(
reqs_to_abort
)
raise
# Remove from RequestStreams.
for
request_id
in
reqs_to_abort
:
if
self
.
log_requests
:
logger
.
info
(
"User-cancelled request %s."
,
request_id
)
self
.
_finish_stream
(
request_id
)
# Remove from EngineCore.
await
self
.
engine_core
.
abort_requests_async
(
reqs_to_abort
)
def
_process_request_outputs
(
self
,
request_outputs
:
List
[
RequestOutput
]):
def
_process_request_outputs
(
self
,
request_outputs
:
List
[
RequestOutput
]):
"""Process outputs by putting them into per-request
AsyncStream
s."""
"""Process outputs by putting them into per-request
queue
s."""
for
request_output
in
request_outputs
:
for
request_output
in
request_outputs
:
request_id
=
request_output
.
request_id
request_id
=
request_output
.
request_id
assert
request_id
in
self
.
request_streams
# Each request in the API server pulls from the per-request stream.
# Note: it is possible a request was aborted and removed from
stream
=
self
.
request_streams
.
get
(
request_id
)
# the state due to client cancellations, so if we encounter a
if
stream
is
not
None
:
# request id not in the state, we skip.
stream
.
put
(
request_output
)
if
request_id
in
self
.
rid_to_queue
:
self
.
rid_to_queue
[
request_id
].
put_nowait
(
request_output
)
# If finished, remove from the tracker.
if
request_output
.
finished
:
if
self
.
log_requests
:
logger
.
info
(
"Finished request %s."
,
request_id
)
self
.
_finish_stream
(
request_id
)
async
def
_run_output_handler
(
self
):
async
def
_run_output_handler
(
self
):
"""Background loop: pulls from EngineCore and pushes to AsyncStreams."""
"""Background loop: pulls from EngineCore and pushes to AsyncStreams."""
...
@@ -306,24 +268,27 @@ class AsyncLLM(EngineClient):
...
@@ -306,24 +268,27 @@ class AsyncLLM(EngineClient):
# 2) Detokenize based on the output.
# 2) Detokenize based on the output.
request_outputs
,
reqs_to_abort
=
self
.
detokenizer
.
step
(
outputs
)
request_outputs
,
reqs_to_abort
=
self
.
detokenizer
.
step
(
outputs
)
# 3) Put the RequestOutputs into the per-request
AsyncStream
s.
# 3) Put the RequestOutputs into the per-request
queue
s.
self
.
_process_request_outputs
(
request_outputs
)
self
.
_process_request_outputs
(
request_outputs
)
# 4) Abort any requests that finished due to stop strings.
# 4) Abort any requests that finished due to stop strings.
await
self
.
engine_core
.
abort_requests_async
(
reqs_to_abort
)
await
self
.
engine_core
.
abort_requests_async
(
reqs_to_abort
)
# 5) Abort any requests due to client cancellations.
await
self
.
_process_cancellations
()
except
BaseException
as
e
:
except
BaseException
as
e
:
logger
.
error
(
e
)
logger
.
error
(
e
)
raise
e
raise
e
# TODO: can we eliminate these?
async
def
abort
(
self
,
request_id
:
str
)
->
None
:
async
def
abort
(
self
,
request_id
:
str
)
->
None
:
# Note: Who Calls this? I dont think this is actually used.
"""Abort RequestId in self, detokenizer, and engine core."""
raise
ValueError
(
"Not Supported on V1 yet."
)
request_ids
=
[
request_id
]
await
self
.
engine_core
.
abort_requests_async
(
request_ids
)
self
.
detokenizer
.
abort_requests
(
request_ids
)
# If a request finishes while we await then the request_id
# will be removed from the tracked queues before we get here.
if
request_id
in
self
.
rid_to_queue
:
del
self
.
rid_to_queue
[
request_id
]
def
encode
(
def
encode
(
self
,
self
,
...
...
vllm/v1/engine/async_stream.py
deleted
100644 → 0
View file @
eb881ed0
import
asyncio
from
typing
import
Any
,
AsyncGenerator
,
Callable
,
Optional
,
Type
,
Union
from
vllm.outputs
import
PoolingRequestOutput
,
RequestOutput
class
AsyncStream
:
"""A stream of RequestOutputs or PoolingRequestOutputs for a request
that can be iterated over asynchronously via an async generator."""
STOP_ITERATION
=
Exception
()
# Sentinel
def
__init__
(
self
,
request_id
:
str
,
cancel
:
Callable
[[
str
],
None
])
->
None
:
self
.
request_id
=
request_id
self
.
_cancel
=
cancel
self
.
_queue
:
asyncio
.
Queue
=
asyncio
.
Queue
()
self
.
_finished
=
False
def
put
(
self
,
item
:
Union
[
RequestOutput
,
PoolingRequestOutput
,
Exception
])
->
None
:
if
not
self
.
_finished
:
self
.
_queue
.
put_nowait
(
item
)
def
finish
(
self
,
exception
:
Optional
[
Union
[
BaseException
,
Type
[
BaseException
]]]
=
None
,
)
->
None
:
if
not
self
.
_finished
:
self
.
_finished
=
True
self
.
_queue
.
put_nowait
(
exception
if
self
.
_is_raisable
(
exception
)
else
AsyncStream
.
STOP_ITERATION
)
async
def
generator
(
self
)
->
AsyncGenerator
[
Union
[
RequestOutput
,
PoolingRequestOutput
],
None
]:
finished
=
False
try
:
while
True
:
result
=
await
self
.
_queue
.
get
()
if
self
.
_is_raisable
(
result
):
finished
=
True
if
result
==
AsyncStream
.
STOP_ITERATION
:
return
raise
result
yield
result
finally
:
self
.
_finished
=
True
if
not
finished
:
self
.
_cancel
(
self
.
request_id
)
@
staticmethod
def
_is_raisable
(
value
:
Any
):
return
isinstance
(
value
,
BaseException
)
or
\
(
isinstance
(
value
,
type
)
and
\
issubclass
(
value
,
BaseException
))
vllm/v1/engine/core.py
View file @
1b875a0e
...
@@ -32,7 +32,7 @@ logger = init_logger(__name__)
...
@@ -32,7 +32,7 @@ logger = init_logger(__name__)
POLLING_TIMEOUT_MS
=
5000
POLLING_TIMEOUT_MS
=
5000
POLLING_TIMEOUT_S
=
POLLING_TIMEOUT_MS
//
1000
POLLING_TIMEOUT_S
=
POLLING_TIMEOUT_MS
//
1000
LOGGING_TIME_S
=
5000
LOGGING_TIME_S
=
POLLING_TIMEOUT_S
class
EngineCore
:
class
EngineCore
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment