Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
cde9183b
Unverified
Commit
cde9183b
authored
Aug 21, 2024
by
Joe Runde
Committed by
GitHub
Aug 22, 2024
Browse files
[Bug][Frontend] Improve ZMQ client robustness (#7443)
Signed-off-by:
Joe Runde
<
Joseph.Runde@ibm.com
>
parent
df1a2113
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
176 additions
and
28 deletions
+176
-28
tests/entrypoints/openai/rpc/__init__.py
tests/entrypoints/openai/rpc/__init__.py
+0
-0
tests/entrypoints/openai/rpc/test_zmq_client.py
tests/entrypoints/openai/rpc/test_zmq_client.py
+119
-0
vllm/entrypoints/openai/api_server.py
vllm/entrypoints/openai/api_server.py
+3
-2
vllm/entrypoints/openai/rpc/__init__.py
vllm/entrypoints/openai/rpc/__init__.py
+0
-4
vllm/entrypoints/openai/rpc/client.py
vllm/entrypoints/openai/rpc/client.py
+48
-22
vllm/envs.py
vllm/envs.py
+6
-0
No files found.
tests/entrypoints/openai/rpc/__init__.py
0 → 100644
View file @
cde9183b
tests/entrypoints/openai/rpc/test_zmq_client.py
0 → 100644
View file @
cde9183b
import
asyncio
import
tempfile
import
unittest
import
unittest.mock
import
uuid
import
pytest
import
pytest_asyncio
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
from
vllm.entrypoints.openai.rpc.client
import
(
AsyncEngineRPCClient
,
RPCClientClosedError
)
from
vllm.entrypoints.openai.rpc.server
import
AsyncEngineRPCServer
@
pytest
.
fixture
(
scope
=
"function"
)
def
tmp_socket
():
with
tempfile
.
TemporaryDirectory
()
as
td
:
yield
f
"ipc://
{
td
}
/
{
uuid
.
uuid4
()
}
"
@
pytest_asyncio
.
fixture
(
scope
=
"function"
)
async
def
dummy_server
(
tmp_socket
,
monkeypatch
):
dummy_engine
=
unittest
.
mock
.
AsyncMock
()
def
dummy_engine_builder
(
*
args
,
**
kwargs
):
return
dummy_engine
with
monkeypatch
.
context
()
as
m
:
m
.
setattr
(
AsyncLLMEngine
,
"from_engine_args"
,
dummy_engine_builder
)
server
=
AsyncEngineRPCServer
(
None
,
None
,
rpc_path
=
tmp_socket
)
loop
=
asyncio
.
get_running_loop
()
server_task
=
loop
.
create_task
(
server
.
run_server_loop
())
try
:
yield
server
finally
:
server_task
.
cancel
()
server
.
cleanup
()
@
pytest_asyncio
.
fixture
(
scope
=
"function"
)
async
def
client
(
tmp_socket
):
client
=
AsyncEngineRPCClient
(
rpc_path
=
tmp_socket
)
# Sanity check: the server is connected
await
client
.
_wait_for_server_rpc
()
try
:
yield
client
finally
:
client
.
close
()
@
pytest
.
mark
.
asyncio
async
def
test_client_data_methods_use_timeouts
(
monkeypatch
,
dummy_server
,
client
:
AsyncEngineRPCClient
):
with
monkeypatch
.
context
()
as
m
:
# Make the server _not_ reply with a model config
m
.
setattr
(
dummy_server
,
"get_config"
,
lambda
x
:
None
)
m
.
setattr
(
client
,
"_data_timeout"
,
10
)
# And ensure the task completes anyway
# (client.setup() invokes server.get_config())
client_task
=
asyncio
.
get_running_loop
().
create_task
(
client
.
setup
())
with
pytest
.
raises
(
TimeoutError
,
match
=
"Server didn't reply within"
):
await
asyncio
.
wait_for
(
client_task
,
timeout
=
0.05
)
@
pytest
.
mark
.
asyncio
async
def
test_client_aborts_use_timeouts
(
monkeypatch
,
dummy_server
,
client
:
AsyncEngineRPCClient
):
with
monkeypatch
.
context
()
as
m
:
# Hang all abort requests
m
.
setattr
(
dummy_server
,
"abort"
,
lambda
x
:
None
)
m
.
setattr
(
client
,
"_data_timeout"
,
10
)
# Ensure the client doesn't hang
client_task
=
asyncio
.
get_running_loop
().
create_task
(
client
.
abort
(
"test request id"
))
with
pytest
.
raises
(
TimeoutError
,
match
=
"Server didn't reply within"
):
await
asyncio
.
wait_for
(
client_task
,
timeout
=
0.05
)
@
pytest
.
mark
.
asyncio
async
def
test_client_data_methods_reraise_exceptions
(
monkeypatch
,
dummy_server
,
client
:
AsyncEngineRPCClient
):
with
monkeypatch
.
context
()
as
m
:
# Make the server raise some random exception
exception
=
RuntimeError
(
"Client test exception"
)
def
raiser
():
raise
exception
m
.
setattr
(
dummy_server
.
engine
,
"get_model_config"
,
raiser
)
m
.
setattr
(
client
,
"_data_timeout"
,
10
)
client_task
=
asyncio
.
get_running_loop
().
create_task
(
client
.
setup
())
# And ensure the task completes, raising the exception
with
pytest
.
raises
(
RuntimeError
,
match
=
str
(
exception
)):
await
asyncio
.
wait_for
(
client_task
,
timeout
=
0.05
)
@
pytest
.
mark
.
asyncio
async
def
test_client_errors_after_closing
(
monkeypatch
,
dummy_server
,
client
:
AsyncEngineRPCClient
):
client
.
close
()
# Healthchecks and generate requests will fail with explicit errors
with
pytest
.
raises
(
RPCClientClosedError
):
await
client
.
check_health
()
with
pytest
.
raises
(
RPCClientClosedError
):
async
for
_
in
client
.
generate
(
None
,
None
,
None
):
pass
# But no-ops like aborting will pass
await
client
.
abort
(
"test-request-id"
)
await
client
.
do_log_stats
()
vllm/entrypoints/openai/api_server.py
View file @
cde9183b
...
...
@@ -6,7 +6,7 @@ import os
import
re
import
tempfile
from
argparse
import
Namespace
from
contextlib
import
asynccontextmanager
from
contextlib
import
asynccontextmanager
,
suppress
from
http
import
HTTPStatus
from
typing
import
AsyncIterator
,
Optional
,
Set
...
...
@@ -83,6 +83,7 @@ async def lifespan(app: FastAPI):
async
def
_force_log
():
while
True
:
await
asyncio
.
sleep
(
10
)
with
suppress
(
Exception
):
await
async_engine_client
.
do_log_stats
()
if
not
engine_args
.
disable_log_stats
:
...
...
vllm/entrypoints/openai/rpc/__init__.py
View file @
cde9183b
...
...
@@ -10,10 +10,6 @@ from vllm.sampling_params import SamplingParams
# Success string used for RPC instructions.
VLLM_RPC_SUCCESS_STR
=
"SUCCESS"
# Timeouts.
VLLM_RPC_SERVER_START_TIMEOUT_MS
=
1000
VLLM_RPC_HEALTH_TIMEOUT_MS
=
10000
# Minimum value of ZMQ.SOCKET_LIMIT to run mp.
VLLM_RPC_SOCKET_LIMIT_CUTOFF
=
2000
...
...
vllm/entrypoints/openai/rpc/client.py
View file @
cde9183b
import
asyncio
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
,
suppress
from
typing
import
Any
,
AsyncGenerator
,
Mapping
,
Optional
from
uuid
import
uuid4
...
...
@@ -11,13 +11,12 @@ from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
ParallelConfig
,
SchedulerConfig
)
# yapf: disable
from
vllm.entrypoints.openai.rpc
import
(
RPC_REQUEST_TYPE
,
VLLM_RPC_HEALTH_TIMEOUT_MS
,
VLLM_RPC_SERVER_START_TIMEOUT_MS
,
VLLM_RPC_SOCKET_LIMIT_CUTOFF
,
VLLM_RPC_SUCCESS_STR
,
VLLM_RPC_ZMQ_HWM
,
RPCAbortRequest
,
RPCGenerateRequest
,
RPCUtilityRequest
)
# yapf: enable
from
vllm.envs
import
VLLM_RPC_GET_DATA_TIMEOUT_MS
from
vllm.inputs
import
PromptInputs
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
...
...
@@ -32,6 +31,17 @@ logger = init_logger(__name__)
INPROC_PROXY_PATH
=
f
"inproc://
{
uuid4
()
}
"
class
RPCClientClosedError
(
Exception
):
"""Exception class raised when the client is used post-close.
The client can be closed, which closes the ZMQ context. This normally
happens on server shutdown. In some cases, methods like abort and
do_log_stats will still be called and then try to open a socket, which
causes a ZMQError and creates a huge stack trace.
So, we throw this error such that we can suppress it.
"""
class
AsyncEngineRPCClient
:
"""
RPCClient that connects to the RPCServer wrapping AsyncLLMEngine.
...
...
@@ -85,6 +95,8 @@ class AsyncEngineRPCClient:
def
__init__
(
self
,
rpc_path
:
str
):
self
.
context
=
zmq
.
asyncio
.
Context
()
self
.
_data_timeout
=
VLLM_RPC_GET_DATA_TIMEOUT_MS
self
.
_errored
=
False
# Maximum number of sockets that can be opened (typically 65536).
# ZMQ_SOCKET_LIMIT (http://api.zeromq.org/4-2:zmq-ctx-get)
...
...
@@ -143,7 +155,6 @@ class AsyncEngineRPCClient:
# Wait until server is ready.
await
self
.
_wait_for_server_rpc
()
self
.
_errored
=
False
# Get the configs.
self
.
model_config
=
await
self
.
_get_model_config_rpc
()
...
...
@@ -170,6 +181,15 @@ class AsyncEngineRPCClient:
@
contextmanager
def
to_proxy_socket
(
self
):
# Connect to the RPCServer via the proxy.
# Raise a sensible error if the client was already closed.
# This can happen if a server shutdown is triggered but some coroutines
# are still running requests.
# There should not be a race condition with this check because we don't
# yield to the event loop between here and opening the socket.
if
self
.
context
.
closed
:
raise
RPCClientClosedError
(
"The ZMQ client has already shut down"
)
# Note that we use DEALER to enable asynchronous communication
# to enable streaming.
socket
=
self
.
context
.
socket
(
zmq
.
constants
.
DEALER
)
...
...
@@ -189,9 +209,18 @@ class AsyncEngineRPCClient:
# Ping RPCServer with a request.
await
socket
.
send_multipart
([
cloudpickle
.
dumps
(
request
)])
# Make sure the server responds
if
await
socket
.
poll
(
timeout
=
self
.
_data_timeout
)
==
0
:
raise
TimeoutError
(
"Server didn't reply within "
f
"
{
self
.
_data_timeout
}
ms"
)
# Await the data from the Server.
data
=
cloudpickle
.
loads
(
await
socket
.
recv
())
if
isinstance
(
data
,
Exception
):
# Re-raise exceptions returned by the server
raise
data
if
not
isinstance
(
data
,
expected_type
):
# LoRAConfig can be None.
if
expected_type
==
LoRAConfig
and
data
is
None
:
...
...
@@ -208,29 +237,28 @@ class AsyncEngineRPCClient:
self
,
request
:
RPC_REQUEST_TYPE
,
error_message
:
str
,
timeout
:
Optional
[
int
]
=
None
,
socket
:
Optional
[
zmq
.
asyncio
.
Socket
]
=
None
):
"""Send one-way RPC request to trigger an action."""
async
def
do_rpc_call
(
socket
:
zmq
.
asyncio
.
Socket
,
request
:
RPC_REQUEST_TYPE
,
timeout
=
None
):
request
:
RPC_REQUEST_TYPE
):
await
socket
.
send_multipart
([
cloudpickle
.
dumps
(
request
)])
if
timeout
is
not
None
and
await
socket
.
poll
(
timeout
=
timeout
)
==
0
:
raise
TimeoutError
(
f
"Server didn't reply within
{
timeout
}
ms"
)
if
await
socket
.
poll
(
timeout
=
self
.
_data_timeout
)
==
0
:
raise
TimeoutError
(
"Server didn't reply within "
f
"
{
self
.
_data_timeout
}
ms"
)
return
cloudpickle
.
loads
(
await
socket
.
recv
())
# Make a new socket connection.
if
socket
is
None
:
with
self
.
to_proxy_socket
()
as
socket
:
response
=
await
do_rpc_call
(
socket
,
request
,
timeout
)
response
=
await
do_rpc_call
(
socket
,
request
)
# Use existing socket connection.
else
:
response
=
await
do_rpc_call
(
socket
,
request
,
timeout
)
response
=
await
do_rpc_call
(
socket
,
request
)
if
not
isinstance
(
response
,
str
)
or
response
!=
VLLM_RPC_SUCCESS_STR
:
if
isinstance
(
response
,
Exception
):
...
...
@@ -255,8 +283,7 @@ class AsyncEngineRPCClient:
await
self
.
_send_one_way_rpc_request
(
request
=
RPCUtilityRequest
.
IS_SERVER_READY
,
error_message
=
"Unable to start RPC Server"
,
timeout
=
VLLM_RPC_SERVER_START_TIMEOUT_MS
)
error_message
=
"Unable to start RPC Server"
)
async
def
_get_model_config_rpc
(
self
)
->
ModelConfig
:
"""Get the ModelConfig object from the RPC Server"""
...
...
@@ -308,14 +335,14 @@ class AsyncEngineRPCClient:
async
def
abort
(
self
,
request_id
:
str
):
"""Send an ABORT_REQUEST signal to the RPC Server"""
with
suppress
(
RPCClientClosedError
):
await
self
.
_send_one_way_rpc_request
(
request
=
RPCAbortRequest
(
request_id
),
error_message
=
f
"RPCAbortRequest
{
request_id
}
failed"
)
async
def
do_log_stats
(
self
):
"""Send a DO_LOG_STATS signal to the RPC Server"""
with
suppress
(
RPCClientClosedError
):
await
self
.
_send_one_way_rpc_request
(
request
=
RPCUtilityRequest
.
DO_LOG_STATS
,
error_message
=
"RPCRequest DO_LOG_STATS failed."
)
...
...
@@ -393,7 +420,6 @@ class AsyncEngineRPCClient:
await
self
.
_send_one_way_rpc_request
(
request
=
RPCUtilityRequest
.
IS_SERVER_HEALTHY
,
error_message
=
"Got Unhealthy response from RPC Server"
,
timeout
=
VLLM_RPC_HEALTH_TIMEOUT_MS
,
socket
=
socket
)
async
def
encode
(
self
,
*
args
,
...
...
vllm/envs.py
View file @
cde9183b
...
...
@@ -56,6 +56,7 @@ if TYPE_CHECKING:
VERBOSE
:
bool
=
False
VLLM_ALLOW_LONG_MAX_MODEL_LEN
:
bool
=
False
VLLM_TEST_FORCE_FP8_MARLIN
:
bool
=
False
VLLM_RPC_GET_DATA_TIMEOUT_MS
:
int
=
5000
VLLM_ALLOW_ENGINE_USE_RAY
:
bool
=
False
VLLM_PLUGINS
:
Optional
[
List
[
str
]]
=
None
VLLM_TORCH_PROFILER_DIR
:
Optional
[
str
]
=
None
...
...
@@ -374,6 +375,11 @@ environment_variables: Dict[str, Callable[[], Any]] = {
(
os
.
environ
.
get
(
"VLLM_TEST_FORCE_FP8_MARLIN"
,
"0"
).
strip
().
lower
()
in
(
"1"
,
"true"
)),
# Time in ms for the zmq client to wait for a response from the backend
# server for simple data operations
"VLLM_RPC_GET_DATA_TIMEOUT_MS"
:
lambda
:
int
(
os
.
getenv
(
"VLLM_RPC_GET_DATA_TIMEOUT_MS"
,
"5000"
)),
# If set, allow running the engine as a separate ray actor,
# which is a deprecated feature soon to be removed.
# See https://github.com/vllm-project/vllm/issues/7045
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment