Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
4289cad3
Unverified
Commit
4289cad3
authored
Aug 28, 2024
by
Nick Hill
Committed by
GitHub
Aug 28, 2024
Browse files
[Frontend] Minor optimizations to zmq decoupled front-end (#7957)
Co-authored-by:
Robert Shaw
<
rshaw@neuralmagic
>
parent
af59df0a
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
64 additions
and
65 deletions
+64
-65
vllm/entrypoints/openai/rpc/client.py
vllm/entrypoints/openai/rpc/client.py
+37
-44
vllm/entrypoints/openai/rpc/server.py
vllm/entrypoints/openai/rpc/server.py
+27
-21
No files found.
vllm/entrypoints/openai/rpc/client.py
View file @
4289cad3
import
asyncio
import
pickle
from
contextlib
import
contextmanager
,
suppress
from
typing
import
Any
,
AsyncGenerator
,
Mapping
,
Optional
from
typing
import
Any
,
AsyncGenerator
,
Iterator
,
Mapping
,
Optional
from
uuid
import
uuid4
import
cloudpickle
import
zmq
import
zmq.asyncio
from
zmq.asyncio
import
Socket
from
vllm.config
import
(
DecodingConfig
,
LoRAConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
)
...
...
@@ -115,18 +117,21 @@ class AsyncEngineRPCClient:
self
.
context
.
set
(
zmq
.
constants
.
MAX_SOCKETS
,
socket_limit
)
# IPC connection to RPC Server (uses unix sockets).
self
.
to_rpc_server
=
self
.
context
.
socket
(
zmq
.
constants
.
DEALER
)
self
.
to_rpc_server
:
Socket
=
self
.
context
.
socket
(
zmq
.
constants
.
DEALER
)
self
.
to_rpc_server
.
set_hwm
(
VLLM_RPC_ZMQ_HWM
)
self
.
to_rpc_server
.
bind
(
rpc_path
)
# In process proxy to RPC Server (uses memory-based messaging).
self
.
from_api_server
=
self
.
context
.
socket
(
zmq
.
constants
.
ROUTER
)
self
.
from_api_server
:
Socket
=
self
.
context
.
socket
(
zmq
.
constants
.
ROUTER
)
self
.
from_api_server
.
set_hwm
(
VLLM_RPC_ZMQ_HWM
)
self
.
from_api_server
.
bind
(
INPROC_PROXY_PATH
)
# Asyncio background task for the proxy.
self
.
proxy_task
=
asyncio
.
create_task
(
self
.
proxy_
in_
task
=
asyncio
.
create_task
(
self
.
run_proxy
(
self
.
from_api_server
,
self
.
to_rpc_server
))
self
.
proxy_out_task
=
asyncio
.
create_task
(
self
.
run_proxy
(
self
.
to_rpc_server
,
self
.
from_api_server
))
# Since we open 1 inproc socket per request, we have a hard cap on
# the number of requests that can run in vLLM w. frontend
...
...
@@ -136,20 +141,11 @@ class AsyncEngineRPCClient:
# 1 for generate(), 1 for abort(), do_log_stats(), check_health()
self
.
limit_concurrency
=
socket_limit
//
2
-
2
async
def
run_proxy
(
self
,
socket_from
,
socket_to
):
async
def
run_proxy
(
self
,
socket_from
:
Socket
,
socket_to
:
Socket
):
"""Background task that runs a proxy"""
poller
=
zmq
.
asyncio
.
Poller
()
poller
.
register
(
socket_from
,
zmq
.
constants
.
POLLIN
)
poller
.
register
(
socket_to
,
zmq
.
constants
.
POLLIN
)
while
True
:
events_lst
=
await
poller
.
poll
()
events
=
dict
(
events_lst
)
if
socket_from
in
events
:
identity
,
msg
=
await
socket_from
.
recv_multipart
()
await
socket_to
.
send_multipart
([
identity
,
msg
])
if
socket_to
in
events
:
identity
,
msg
=
await
socket_to
.
recv_multipart
()
await
socket_from
.
send_multipart
([
identity
,
msg
])
frames
=
await
socket_from
.
recv_multipart
(
copy
=
False
)
await
socket_to
.
send_multipart
(
frames
,
copy
=
False
)
async
def
setup
(
self
):
"""Setup the client before it starts sending server requests."""
...
...
@@ -180,7 +176,7 @@ class AsyncEngineRPCClient:
self
.
context
.
destroy
()
@
contextmanager
def
to_proxy_socket
(
self
):
def
to_proxy_socket
(
self
)
->
Iterator
[
Socket
]
:
# Connect to the RPCServer via the proxy.
# Raise a sensible error if the client was already closed.
...
...
@@ -208,7 +204,8 @@ class AsyncEngineRPCClient:
with
self
.
to_proxy_socket
()
as
socket
:
# Ping RPCServer with a request.
await
socket
.
send_multipart
([
cloudpickle
.
dumps
(
request
)])
await
socket
.
send_multipart
((
cloudpickle
.
dumps
(
request
),
),
copy
=
False
)
# Make sure the server responds
if
await
socket
.
poll
(
timeout
=
self
.
_data_timeout
)
==
0
:
...
...
@@ -216,7 +213,8 @@ class AsyncEngineRPCClient:
f
"
{
self
.
_data_timeout
}
ms"
)
# Await the data from the Server.
data
=
cloudpickle
.
loads
(
await
socket
.
recv
())
frame
=
await
socket
.
recv
(
copy
=
False
)
data
=
pickle
.
loads
(
frame
.
buffer
)
if
isinstance
(
data
,
Exception
):
# Re-raise exceptions returned by the server
...
...
@@ -234,23 +232,22 @@ class AsyncEngineRPCClient:
return
data
async
def
_send_one_way_rpc_request
(
self
,
request
:
RPC_REQUEST_TYPE
,
error_message
:
str
,
socket
:
Optional
[
zmq
.
asyncio
.
Socket
]
=
None
):
async
def
_send_one_way_rpc_request
(
self
,
request
:
RPC_REQUEST_TYPE
,
error_message
:
str
,
socket
:
Optional
[
Socket
]
=
None
):
"""Send one-way RPC request to trigger an action."""
async
def
do_rpc_call
(
socket
:
zmq
.
asyncio
.
Socket
,
request
:
RPC_REQUEST_TYPE
):
async
def
do_rpc_call
(
socket
:
Socket
,
request
:
RPC_REQUEST_TYPE
):
await
socket
.
send_multipart
(
[
cloudpickle
.
dumps
(
request
)
]
)
await
socket
.
send_multipart
(
(
cloudpickle
.
dumps
(
request
)
,
)
)
if
await
socket
.
poll
(
timeout
=
self
.
_data_timeout
)
==
0
:
raise
TimeoutError
(
"Server didn't reply within "
f
"
{
self
.
_data_timeout
}
ms"
)
return
cloudpickle
.
loads
(
await
socket
.
recv
())
frame
=
await
socket
.
recv
(
copy
=
False
)
return
pickle
.
loads
(
frame
.
buffer
)
# Make a new socket connection.
if
socket
is
None
:
...
...
@@ -386,21 +383,19 @@ class AsyncEngineRPCClient:
try
:
with
self
.
to_proxy_socket
()
as
socket
:
# Send RPCGenerateRequest to the RPCServer.
await
socket
.
send_multipart
([
cloudpickle
.
dumps
(
RPCGenerateRequest
(
inputs
=
inputs
,
sampling_params
=
sampling_params
,
request_id
=
request_id
,
lora_request
=
lora_request
,
trace_headers
=
trace_headers
,
prompt_adapter_request
=
prompt_adapter_request
))
])
await
socket
.
send_multipart
((
cloudpickle
.
dumps
(
RPCGenerateRequest
(
inputs
=
inputs
,
sampling_params
=
sampling_params
,
request_id
=
request_id
,
lora_request
=
lora_request
,
trace_headers
=
trace_headers
,
prompt_adapter_request
=
prompt_adapter_request
)),
))
# Stream back the results from the RPC Server.
while
not
finished
:
message
=
await
socket
.
recv
()
request_output
=
cloud
pickle
.
loads
(
message
)
message
=
await
socket
.
recv
(
copy
=
False
)
request_output
=
pickle
.
loads
(
message
.
buffer
)
if
isinstance
(
request_output
,
Exception
):
# On exception, check if the server is still healthy
...
...
@@ -424,9 +419,7 @@ class AsyncEngineRPCClient:
if
not
finished
and
not
self
.
_errored
:
await
self
.
abort
(
request_id
)
async
def
check_health
(
self
,
socket
:
Optional
[
zmq
.
asyncio
.
Socket
]
=
None
)
->
None
:
async
def
check_health
(
self
,
socket
:
Optional
[
Socket
]
=
None
)
->
None
:
"""Raise if unhealthy"""
await
self
.
_send_one_way_rpc_request
(
...
...
@@ -451,4 +444,4 @@ class AsyncEngineRPCClient:
await
self
.
_send_one_way_rpc_request
(
request
=
RPCUtilityRequest
.
STOP_PROFILE
,
error_message
=
"RPCRequest STOP_PROFILE failed."
)
\ No newline at end of file
error_message
=
"RPCRequest STOP_PROFILE failed."
)
vllm/entrypoints/openai/rpc/server.py
View file @
4289cad3
import
asyncio
import
pickle
import
signal
from
typing
import
Any
,
Coroutine
,
Union
...
...
@@ -7,6 +8,8 @@ import uvloop
import
zmq
import
zmq.asyncio
from
typing_extensions
import
Never
from
zmq
import
Frame
# type: ignore[attr-defined]
from
zmq.asyncio
import
Socket
from
vllm
import
AsyncEngineArgs
,
AsyncLLMEngine
from
vllm.config
import
(
DecodingConfig
,
LoRAConfig
,
ModelConfig
,
...
...
@@ -35,7 +38,7 @@ class AsyncEngineRPCServer:
self
.
context
=
zmq
.
asyncio
.
Context
()
# Init socket.
self
.
socket
=
self
.
context
.
socket
(
zmq
.
constants
.
DEALER
)
self
.
socket
:
Socket
=
self
.
context
.
socket
(
zmq
.
constants
.
DEALER
)
self
.
socket
.
set_hwm
(
VLLM_RPC_ZMQ_HWM
)
self
.
socket
.
connect
(
rpc_path
)
...
...
@@ -63,30 +66,31 @@ class AsyncEngineRPCServer:
else
:
raise
ValueError
(
"Unknown Config Request: %s"
,
request
)
await
self
.
socket
.
send_multipart
(
[
identity
,
cloudpickle
.
dumps
(
config
)]
)
await
self
.
socket
.
send_multipart
(
(
identity
,
pickle
.
dumps
(
config
)),
copy
=
False
)
except
Exception
as
e
:
await
self
.
socket
.
send_multipart
([
identity
,
cloudpickle
.
dumps
(
e
)])
await
self
.
socket
.
send_multipart
((
identity
,
pickle
.
dumps
(
e
)),
copy
=
False
)
async
def
is_tracing_enabled
(
self
,
identity
):
"""Send the is_tracing_enabled flag"""
tracing_flag
=
await
self
.
engine
.
is_tracing_enabled
()
await
self
.
socket
.
send_multipart
(
[
identity
,
cloud
pickle
.
dumps
(
tracing_flag
)
]
)
(
identity
,
pickle
.
dumps
(
tracing_flag
)
)
)
async
def
do_log_stats
(
self
,
identity
):
"""Log stats and confirm success."""
await
self
.
engine
.
do_log_stats
()
await
self
.
socket
.
send_multipart
(
[
identity
,
cloud
pickle
.
dumps
(
VLLM_RPC_SUCCESS_STR
)
]
)
(
identity
,
pickle
.
dumps
(
VLLM_RPC_SUCCESS_STR
)
)
)
async
def
is_server_ready
(
self
,
identity
):
"""Notify the client that we are ready."""
await
self
.
socket
.
send_multipart
(
[
identity
,
cloud
pickle
.
dumps
(
VLLM_RPC_SUCCESS_STR
)
]
)
(
identity
,
pickle
.
dumps
(
VLLM_RPC_SUCCESS_STR
)
)
)
async
def
abort
(
self
,
identity
,
request
:
RPCAbortRequest
):
"""Abort request and notify the client of success."""
...
...
@@ -96,7 +100,7 @@ class AsyncEngineRPCServer:
result
:
Union
[
str
,
Exception
]
=
VLLM_RPC_SUCCESS_STR
except
Exception
as
e
:
result
=
e
await
self
.
socket
.
send_multipart
(
[
identity
,
cloud
pickle
.
dumps
(
result
)
]
)
await
self
.
socket
.
send_multipart
(
(
identity
,
pickle
.
dumps
(
result
)
)
)
async
def
generate
(
self
,
identity
,
generate_request
:
RPCGenerateRequest
):
try
:
...
...
@@ -110,45 +114,47 @@ class AsyncEngineRPCServer:
async
for
request_output
in
results_generator
:
await
self
.
socket
.
send_multipart
(
[
identity
,
cloud
pickle
.
dumps
(
request_output
)
]
)
(
identity
,
pickle
.
dumps
(
request_output
)
),
copy
=
False
)
except
Exception
as
e
:
await
self
.
socket
.
send_multipart
([
identity
,
cloudpickle
.
dumps
(
e
)])
await
self
.
socket
.
send_multipart
((
identity
,
pickle
.
dumps
(
e
)),
copy
=
False
)
async
def
check_health
(
self
,
identity
):
try
:
await
self
.
engine
.
check_health
()
await
self
.
socket
.
send_multipart
(
[
identity
,
cloud
pickle
.
dumps
(
VLLM_RPC_SUCCESS_STR
)
]
)
(
identity
,
pickle
.
dumps
(
VLLM_RPC_SUCCESS_STR
)
)
)
except
Exception
as
e
:
await
self
.
socket
.
send_multipart
([
identity
,
cloudpickle
.
dumps
(
e
)])
await
self
.
socket
.
send_multipart
((
identity
,
pickle
.
dumps
(
e
)),
copy
=
False
)
async
def
start_profile
(
self
,
identity
):
logger
.
info
(
"Starting profiler..."
)
await
self
.
engine
.
start_profile
()
logger
.
info
(
"Profiler started."
)
await
self
.
socket
.
send_multipart
(
[
await
self
.
socket
.
send_multipart
(
(
identity
,
cloud
pickle
.
dumps
(
VLLM_RPC_SUCCESS_STR
),
]
)
pickle
.
dumps
(
VLLM_RPC_SUCCESS_STR
),
)
)
async
def
stop_profile
(
self
,
identity
):
logger
.
info
(
"Stopping profiler..."
)
await
self
.
engine
.
stop_profile
()
logger
.
info
(
"Profiler stopped."
)
await
self
.
socket
.
send_multipart
(
[
await
self
.
socket
.
send_multipart
(
(
identity
,
cloud
pickle
.
dumps
(
VLLM_RPC_SUCCESS_STR
),
]
)
pickle
.
dumps
(
VLLM_RPC_SUCCESS_STR
),
)
)
def
_make_handler_coro
(
self
,
identity
,
message
)
->
Coroutine
[
Any
,
Any
,
Never
]:
message
:
Frame
)
->
Coroutine
[
Any
,
Any
,
Never
]:
"""Route the zmq message to the handler coroutine."""
request
=
cloudpickle
.
loads
(
message
)
request
=
cloudpickle
.
loads
(
message
.
buffer
)
if
isinstance
(
request
,
RPCGenerateRequest
):
return
self
.
generate
(
identity
,
request
)
...
...
@@ -189,7 +195,7 @@ class AsyncEngineRPCServer:
running_tasks
=
set
()
while
True
:
# Wait for a request.
identity
,
message
=
await
self
.
socket
.
recv_multipart
()
identity
,
message
=
await
self
.
socket
.
recv_multipart
(
copy
=
False
)
# Process the request async.
task
=
asyncio
.
create_task
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment