Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
5886aa49
Unverified
Commit
5886aa49
authored
Dec 30, 2024
by
Robert Shaw
Committed by
GitHub
Dec 30, 2024
Browse files
[V1] [6/N] API Server: Better Shutdown (#11586)
parent
8d9b6721
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
40 additions
and
45 deletions
+40
-45
vllm/entrypoints/openai/api_server.py
vllm/entrypoints/openai/api_server.py
+12
-32
vllm/v1/engine/async_llm.py
vllm/v1/engine/async_llm.py
+22
-3
vllm/v1/engine/core_client.py
vllm/v1/engine/core_client.py
+6
-10
No files found.
vllm/entrypoints/openai/api_server.py
View file @
5886aa49
...
...
@@ -68,7 +68,7 @@ from vllm.entrypoints.utils import with_cancellation
from
vllm.logger
import
init_logger
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.utils
import
(
FlexibleArgumentParser
,
get_open_zmq_ipc_path
,
is_valid_ipv6_address
,
kill_process_tree
,
set_ulimit
)
is_valid_ipv6_address
,
set_ulimit
)
from
vllm.version
import
__version__
as
VLLM_VERSION
TIMEOUT_KEEP_ALIVE
=
5
# seconds
...
...
@@ -133,32 +133,21 @@ async def build_async_engine_client_from_engine_args(
Returns the Client or None if the creation failed.
"""
# Fall back
# TODO: fill out feature matrix.
# AsyncLLMEngine.
if
(
MQLLMEngineClient
.
is_unsupported_config
(
engine_args
)
or
envs
.
VLLM_USE_V1
or
disable_frontend_multiprocessing
):
engine_config
=
engine_args
.
create_engine_config
(
UsageContext
.
OPENAI_API_SERVER
)
uses_ray
=
getattr
(
AsyncLLMEngine
.
_get_executor_cls
(
engine_config
),
"uses_ray"
,
False
)
build_engine
=
partial
(
AsyncLLMEngine
.
from_engine_args
,
engine_client
:
Optional
[
EngineClient
]
=
None
try
:
engine_client
=
AsyncLLMEngine
.
from_engine_args
(
engine_args
=
engine_args
,
engine_config
=
engine_config
,
usage_context
=
UsageContext
.
OPENAI_API_SERVER
)
if
uses_ray
:
# Must run in main thread with ray for its signal handlers to work
engine_client
=
build_engine
()
else
:
engine_client
=
await
asyncio
.
get_running_loop
().
run_in_executor
(
None
,
build_engine
)
yield
engine_client
if
hasattr
(
engine_client
,
"shutdown"
):
finally
:
if
engine_client
and
hasattr
(
engine_client
,
"shutdown"
):
engine_client
.
shutdown
()
return
#
Otherwise, use the multiprocessing Async
LLMEngine.
#
MQ
LLMEngine.
else
:
if
"PROMETHEUS_MULTIPROC_DIR"
not
in
os
.
environ
:
# Make TemporaryDirectory for prometheus multiprocessing
...
...
@@ -737,15 +726,6 @@ async def run_server(args, **uvicorn_kwargs) -> None:
signal
.
signal
(
signal
.
SIGTERM
,
signal_handler
)
# The child processes will send SIGQUIT to this process when
# any error happens. This process then clean up the whole tree.
# TODO(rob): move this into AsyncLLM.__init__ once we remove
# the context manager below.
def
sigquit_handler
(
signum
,
frame
):
kill_process_tree
(
os
.
getpid
())
signal
.
signal
(
signal
.
SIGQUIT
,
sigquit_handler
)
async
with
build_async_engine_client
(
args
)
as
engine_client
:
app
=
build_app
(
args
)
...
...
vllm/v1/engine/async_llm.py
View file @
5886aa49
import
asyncio
import
os
import
signal
from
typing
import
AsyncGenerator
,
Dict
,
List
,
Mapping
,
Optional
,
Type
,
Union
from
vllm.config
import
ModelConfig
,
VllmConfig
...
...
@@ -16,6 +18,7 @@ from vllm.sampling_params import SamplingParams
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
vllm.transformers_utils.tokenizer_group
import
init_tokenizer_from_configs
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.utils
import
kill_process_tree
from
vllm.v1.engine.core_client
import
EngineCoreClient
from
vllm.v1.engine.detokenizer
import
Detokenizer
from
vllm.v1.engine.processor
import
Processor
...
...
@@ -38,6 +41,22 @@ class AsyncLLM(EngineClient):
log_requests
:
bool
=
True
,
start_engine_loop
:
bool
=
True
,
)
->
None
:
# The child processes will send SIGQUIT when unrecoverable
# errors happen. We kill the process tree here so that the
# stack trace is very evident.
# TODO: rather than killing the main process, we should
# figure out how to raise an AsyncEngineDeadError and
# handle at the API server level so we can return a better
# error code to the clients calling VLLM.
def
sigquit_handler
(
signum
,
frame
):
logger
.
fatal
(
"AsyncLLM got SIGQUIT from worker processes, shutting "
"down. See stack trace above for root cause issue."
)
kill_process_tree
(
os
.
getpid
())
signal
.
signal
(
signal
.
SIGQUIT
,
sigquit_handler
)
assert
start_engine_loop
self
.
log_requests
=
log_requests
...
...
@@ -276,9 +295,9 @@ class AsyncLLM(EngineClient):
# 4) Abort any requests that finished due to stop strings.
await
self
.
engine_core
.
abort_requests_async
(
reqs_to_abort
)
except
Base
Exception
as
e
:
logger
.
e
rror
(
e
)
raise
e
except
Exception
as
e
:
logger
.
e
xception
(
"EngineCore output handler hit an error: %s"
,
e
)
kill_process_tree
(
os
.
getpid
())
async
def
abort
(
self
,
request_id
:
str
)
->
None
:
"""Abort RequestId in self, detokenizer, and engine core."""
...
...
vllm/v1/engine/core_client.py
View file @
5886aa49
...
...
@@ -6,7 +6,7 @@ import zmq.asyncio
from
vllm.config
import
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.utils
import
get_open_zmq_ipc_path
from
vllm.utils
import
get_open_zmq_ipc_path
,
make_zmq_socket
from
vllm.v1.engine
import
(
EngineCoreOutput
,
EngineCoreOutputs
,
EngineCoreProfile
,
EngineCoreRequest
,
EngineCoreRequestType
,
EngineCoreRequestUnion
)
...
...
@@ -144,17 +144,13 @@ class MPClient(EngineCoreClient):
else
:
self
.
ctx
=
zmq
.
Context
()
# type: ignore[attr-defined]
# Path for IPC.
# Path
s and sockets
for IPC.
output_path
=
get_open_zmq_ipc_path
()
input_path
=
get_open_zmq_ipc_path
()
# Get output (EngineCoreOutput) from EngineCore.
self
.
output_socket
=
self
.
ctx
.
socket
(
zmq
.
constants
.
PULL
)
self
.
output_socket
.
connect
(
output_path
)
# Send input (EngineCoreRequest) to EngineCore.
self
.
input_socket
=
self
.
ctx
.
socket
(
zmq
.
constants
.
PUSH
)
self
.
input_socket
.
bind
(
input_path
)
self
.
output_socket
=
make_zmq_socket
(
self
.
ctx
,
output_path
,
zmq
.
constants
.
PULL
)
self
.
input_socket
=
make_zmq_socket
(
self
.
ctx
,
input_path
,
zmq
.
constants
.
PUSH
)
# Start EngineCore in background process.
self
.
proc_handle
:
Optional
[
BackgroundProcHandle
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment