Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f7e3b0c5
Unverified
Commit
f7e3b0c5
authored
Aug 21, 2024
by
Robert Shaw
Committed by
GitHub
Aug 21, 2024
Browse files
[Bugfix][Frontend] Fix Issues Under High Load With `zeromq` Frontend (#7394)
Co-authored-by:
Nick Hill
<
nickhill@us.ibm.com
>
parent
d3c002ea
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
322 additions
and
141 deletions
+322
-141
.buildkite/test-pipeline.yaml
.buildkite/test-pipeline.yaml
+1
-0
tests/entrypoints/openai/test_accuracy.py
tests/entrypoints/openai/test_accuracy.py
+55
-0
vllm/engine/async_llm_engine.py
vllm/engine/async_llm_engine.py
+5
-0
vllm/engine/protocol.py
vllm/engine/protocol.py
+4
-0
vllm/entrypoints/launcher.py
vllm/entrypoints/launcher.py
+9
-0
vllm/entrypoints/openai/api_server.py
vllm/entrypoints/openai/api_server.py
+6
-5
vllm/entrypoints/openai/rpc/__init__.py
vllm/entrypoints/openai/rpc/__init__.py
+12
-2
vllm/entrypoints/openai/rpc/client.py
vllm/entrypoints/openai/rpc/client.py
+180
-68
vllm/entrypoints/openai/rpc/server.py
vllm/entrypoints/openai/rpc/server.py
+50
-66
No files found.
.buildkite/test-pipeline.yaml
View file @
f7e3b0c5
...
@@ -86,6 +86,7 @@ steps:
...
@@ -86,6 +86,7 @@ steps:
-
vllm/
-
vllm/
commands
:
commands
:
-
pip install -e ./plugins/vllm_add_dummy_model
-
pip install -e ./plugins/vllm_add_dummy_model
-
pip install git+https://github.com/EleutherAI/lm-evaluation-harness.git@a4987bba6e9e9b3f22bd3a6c1ecf0abd04fd5622#egg=lm_eval[api]
-
pytest -v -s entrypoints/llm
-
pytest -v -s entrypoints/llm
-
pytest -v -s entrypoints/openai
-
pytest -v -s entrypoints/openai
...
...
tests/entrypoints/openai/test_accuracy.py
0 → 100644
View file @
f7e3b0c5
"""
This file test accuracy of the vLLM server via LMEval.
It uses local-completions, which interacts with vLLM
through the OAI API with N concurrent connections.
This simulates real work usage of the API and makes
sure that the zmq frontend mp RPC message passing and
AsyncLLMEngine are working correctly.
"""
import
lm_eval
import
pytest
from
...utils
import
RemoteOpenAIServer
MODEL_NAME
=
"Qwen/Qwen2-1.5B-Instruct"
NUM_CONCURRENT
=
500
TASK
=
"gsm8k"
FILTER
=
"exact_match,strict-match"
RTOL
=
0.03
EXPECTED_VALUE
=
0.58
@
pytest
.
fixture
(
scope
=
"module"
)
def
server
():
args
=
[
"--max-model-len"
,
"4096"
,
"--enable-chunked-prefill"
,
"--disable-log-requests"
,
"--enforce-eager"
]
with
RemoteOpenAIServer
(
MODEL_NAME
,
args
)
as
remote_server
:
yield
remote_server
@
pytest
.
fixture
(
scope
=
"module"
)
def
server_data
(
server
):
return
{
"url"
:
f
"
{
server
.
url_for
(
'v1'
)
}
/completions"
,
}
def
test_lm_eval_accuracy
(
server_data
):
model_args
=
(
f
"model=
{
MODEL_NAME
}
,"
f
"base_url=
{
server_data
[
'url'
]
}
,"
f
"num_concurrent=
{
NUM_CONCURRENT
}
,tokenized_requests=False"
)
results
=
lm_eval
.
simple_evaluate
(
model
=
"local-completions"
,
model_args
=
model_args
,
tasks
=
TASK
,
)
measured_value
=
results
[
"results"
][
TASK
][
FILTER
]
assert
(
measured_value
-
RTOL
<
EXPECTED_VALUE
and
measured_value
+
RTOL
>
EXPECTED_VALUE
),
f
"Expected:
{
EXPECTED_VALUE
}
| Measured:
{
measured_value
}
"
vllm/engine/async_llm_engine.py
View file @
f7e3b0c5
...
@@ -766,6 +766,11 @@ class AsyncLLMEngine:
...
@@ -766,6 +766,11 @@ class AsyncLLMEngine:
def
errored
(
self
)
->
bool
:
def
errored
(
self
)
->
bool
:
return
self
.
_errored_with
is
not
None
return
self
.
_errored_with
is
not
None
@
property
def
limit_concurrency
(
self
)
->
Optional
[
int
]:
"""Maximum number of concurrently running requests."""
return
None
def
set_errored
(
self
,
exc
:
Exception
)
->
None
:
def
set_errored
(
self
,
exc
:
Exception
)
->
None
:
self
.
_errored_with
=
exc
self
.
_errored_with
=
exc
...
...
vllm/engine/protocol.py
View file @
f7e3b0c5
...
@@ -29,6 +29,10 @@ class AsyncEngineClient(Protocol):
...
@@ -29,6 +29,10 @@ class AsyncEngineClient(Protocol):
def
errored
(
self
)
->
bool
:
def
errored
(
self
)
->
bool
:
...
...
@
property
def
limit_concurrency
(
self
)
->
Optional
[
int
]:
"""Maximum number of concurrently running requests."""
def
generate
(
def
generate
(
self
,
self
,
inputs
:
PromptInputs
,
inputs
:
PromptInputs
,
...
...
vllm/entrypoints/launcher.py
View file @
f7e3b0c5
...
@@ -27,6 +27,15 @@ async def serve_http(app: FastAPI, engine: AsyncEngineClient,
...
@@ -27,6 +27,15 @@ async def serve_http(app: FastAPI, engine: AsyncEngineClient,
logger
.
info
(
"Route: %s, Methods: %s"
,
path
,
', '
.
join
(
methods
))
logger
.
info
(
"Route: %s, Methods: %s"
,
path
,
', '
.
join
(
methods
))
# Set concurrency limits in uvicorn if running in multiprocessing mode
# since zmq has maximum socket limit of zmq.constants.SOCKET_LIMIT (65536).
if
engine
.
limit_concurrency
is
not
None
:
logger
.
info
(
"Launching Uvicorn with --limit_concurrency %s. To avoid this "
"limit at the expense of performance run with "
"--disable-frontend-multiprocessing"
,
engine
.
limit_concurrency
)
uvicorn_kwargs
[
"limit_concurrency"
]
=
engine
.
limit_concurrency
config
=
uvicorn
.
Config
(
app
,
**
uvicorn_kwargs
)
config
=
uvicorn
.
Config
(
app
,
**
uvicorn_kwargs
)
server
=
uvicorn
.
Server
(
config
)
server
=
uvicorn
.
Server
(
config
)
_add_shutdown_handlers
(
app
,
server
,
engine
)
_add_shutdown_handlers
(
app
,
server
,
engine
)
...
...
vllm/entrypoints/openai/api_server.py
View file @
f7e3b0c5
...
@@ -135,6 +135,12 @@ async def build_async_engine_client(
...
@@ -135,6 +135,12 @@ async def build_async_engine_client(
logger
.
info
(
"Multiprocessing frontend to use %s for RPC Path."
,
logger
.
info
(
"Multiprocessing frontend to use %s for RPC Path."
,
rpc_path
)
rpc_path
)
# Build RPCClient, which conforms to AsyncEngineClient Protocol.
# NOTE: Actually, this is not true yet. We still need to support
# embedding models via RPC (see TODO above)
rpc_client
=
AsyncEngineRPCClient
(
rpc_path
)
async_engine_client
=
rpc_client
# type: ignore
# Start RPCServer in separate process (holds the AsyncLLMEngine).
# Start RPCServer in separate process (holds the AsyncLLMEngine).
context
=
multiprocessing
.
get_context
(
"spawn"
)
context
=
multiprocessing
.
get_context
(
"spawn"
)
# the current process might have CUDA context,
# the current process might have CUDA context,
...
@@ -145,11 +151,6 @@ async def build_async_engine_client(
...
@@ -145,11 +151,6 @@ async def build_async_engine_client(
rpc_server_process
.
start
()
rpc_server_process
.
start
()
logger
.
info
(
"Started engine process with PID %d"
,
logger
.
info
(
"Started engine process with PID %d"
,
rpc_server_process
.
pid
)
rpc_server_process
.
pid
)
# Build RPCClient, which conforms to AsyncEngineClient Protocol.
# NOTE: Actually, this is not true yet. We still need to support
# embedding models via RPC (see TODO above)
rpc_client
=
AsyncEngineRPCClient
(
rpc_path
)
async_engine_client
=
rpc_client
# type: ignore
try
:
try
:
while
True
:
while
True
:
...
...
vllm/entrypoints/openai/rpc/__init__.py
View file @
f7e3b0c5
...
@@ -7,8 +7,18 @@ from vllm.lora.request import LoRARequest
...
@@ -7,8 +7,18 @@ from vllm.lora.request import LoRARequest
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
# Success string used for RPC instructions.
VLLM_RPC_SUCCESS_STR
=
"SUCCESS"
VLLM_RPC_SUCCESS_STR
=
"SUCCESS"
VLLM_RPC_HEALTHY_STR
=
"HEALTHY"
# Timeouts.
VLLM_RPC_SERVER_START_TIMEOUT_MS
=
1000
VLLM_RPC_HEALTH_TIMEOUT_MS
=
10000
# Minimum value of ZMQ.SOCKET_LIMIT to run mp.
VLLM_RPC_SOCKET_LIMIT_CUTOFF
=
2000
# HWM is set to Infinity.
VLLM_RPC_ZMQ_HWM
=
0
@
dataclass
@
dataclass
...
@@ -34,7 +44,7 @@ class RPCUtilityRequest(Enum):
...
@@ -34,7 +44,7 @@ class RPCUtilityRequest(Enum):
GET_SCHEDULER_CONFIG
=
5
GET_SCHEDULER_CONFIG
=
5
GET_LORA_CONFIG
=
6
GET_LORA_CONFIG
=
6
DO_LOG_STATS
=
7
DO_LOG_STATS
=
7
CHECK
_HEALTH
=
8
IS_SERVER
_HEALTH
Y
=
8
IS_TRACING_ENABLED
=
9
IS_TRACING_ENABLED
=
9
...
...
vllm/entrypoints/openai/rpc/client.py
View file @
f7e3b0c5
import
asyncio
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
typing
import
Any
,
AsyncGenerator
,
Mapping
,
Optional
from
typing
import
Any
,
AsyncGenerator
,
Mapping
,
Optional
from
uuid
import
uuid4
import
cloudpickle
import
cloudpickle
import
zmq
import
zmq
...
@@ -7,32 +9,140 @@ import zmq.asyncio
...
@@ -7,32 +9,140 @@ import zmq.asyncio
from
vllm.config
import
(
DecodingConfig
,
LoRAConfig
,
ModelConfig
,
from
vllm.config
import
(
DecodingConfig
,
LoRAConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
)
ParallelConfig
,
SchedulerConfig
)
# yapf: disable
from
vllm.entrypoints.openai.rpc
import
(
RPC_REQUEST_TYPE
,
from
vllm.entrypoints.openai.rpc
import
(
RPC_REQUEST_TYPE
,
VLLM_RPC_HEALTHY_STR
,
VLLM_RPC_HEALTH_TIMEOUT_MS
,
VLLM_RPC_SUCCESS_STR
,
RPCAbortRequest
,
VLLM_RPC_SERVER_START_TIMEOUT_MS
,
VLLM_RPC_SOCKET_LIMIT_CUTOFF
,
VLLM_RPC_SUCCESS_STR
,
VLLM_RPC_ZMQ_HWM
,
RPCAbortRequest
,
RPCGenerateRequest
,
RPCUtilityRequest
)
RPCGenerateRequest
,
RPCUtilityRequest
)
# yapf: enable
from
vllm.inputs
import
PromptInputs
from
vllm.inputs
import
PromptInputs
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.outputs
import
EmbeddingRequestOutput
,
RequestOutput
from
vllm.outputs
import
EmbeddingRequestOutput
,
RequestOutput
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.transformers_utils.tokenizer_group
import
init_tokenizer_from_configs
from
vllm.transformers_utils.tokenizer_group
import
init_tokenizer_from_configs
# Time to wait before checking it the server process is alive.
logger
=
init_logger
(
__name__
)
SERVER_START_TIMEOUT_MS
=
1000
# Path used for inprocess proxy.
INPROC_PROXY_PATH
=
f
"inproc://
{
uuid4
()
}
"
class
AsyncEngineRPCClient
:
class
AsyncEngineRPCClient
:
"""
RPCClient that connects to the RPCServer wrapping AsyncLLMEngine.
The overall design mirrors the Asynchronous Client Server Pattern
https://zguide.zeromq.org/docs/chapter3/#The-Asynchronous-Client-Server-Pattern
On startup, the RPCClient:
- makes DEALER socket (to_rpc_server) that connects to the RPCServer
via ipc, which uses unix sockets under the hood
(https://libzmq.readthedocs.io/en/zeromq4-1/zmq_ipc.html)
- makes ROUTER socket (from_api_server) that binds to a random
inproc address, which uses memory under the hood
(https://libzmq.readthedocs.io/en/zeromq3-x/zmq_inproc.html)
- runs a proxy in a background asyncio task between
from_api_server (ROUTER, inproc) and to_rpc_server (DEALER ipc, )
Each request handled by the asyncio api_server calls generate():
- make a DEALER socket that connects to from_api_server via inproc
- send a RCPGenerateRequest to the inproc socket
- background proxy forwards the request from inproc -> ipc
- RPCServer responds to the request one token at a time over ipc
- background proxy forwards the response from ipc -> inproc
The connection looks like this:
DEALER <- inproc -> [ ROUTER | DEALER ] <- ipc -> DEALER
Message routing is performed via identities that are managed by the
ROUTER socket. ROUTER sockets track every connection it has and
tells the caller about these. The way it tells the caller is to stick
the connection identity in front of each message received. When we
send the message via a ROUTER, we first send an identity frame.
See https://zguide.zeromq.org/docs/chapter3/#The-Extended-Reply-Envelope
for more details on connection identities.
This proxy design enables us to use a single unix socket, which
improves performance by avoiding syscalls (~5%) and avoids resource limits
such as ulimit, which defaults to 1024 on ubuntu.
Note: we run set_hwm(0) on each socket, which sets the HWM to inf,
which is required to avoid dropping messages under high load.
This is generally not advisable. However, since we are in control
of both sides of the connection + failure on either side is
catastrophic to the overall system health and memory profiling
suggests limited memory overhead relative to asyncio, we will
proceed for now.
See https://zguide.zeromq.org/docs/chapter2/#High-Water-Marks
for more details on high water marks.
"""
def
__init__
(
self
,
rpc_path
:
str
):
def
__init__
(
self
,
rpc_path
:
str
):
self
.
context
=
zmq
.
asyncio
.
Context
()
self
.
context
=
zmq
.
asyncio
.
Context
()
self
.
rpc_path
=
rpc_path
# Maximum number of sockets that can be opened (typically 65536).
# ZMQ_SOCKET_LIMIT (http://api.zeromq.org/4-2:zmq-ctx-get)
socket_limit
=
self
.
context
.
get
(
zmq
.
constants
.
SOCKET_LIMIT
)
if
socket_limit
<
VLLM_RPC_SOCKET_LIMIT_CUTOFF
:
raise
ValueError
(
f
"Found zmq.constants.SOCKET_LIMIT=
{
socket_limit
}
, which caps "
"the number of concurrent requests vLLM can process. Launch "
"vLLM with --disable-frontend-multiprocessing and open a "
"GitHub issue so we can investigate."
)
# We only have 1 ipc connection that uses unix sockets, so
# safe to set MAX_SOCKETS to the zmq SOCKET_LIMIT (i.e. will
# not run into ulimit issues)
self
.
context
.
set
(
zmq
.
constants
.
MAX_SOCKETS
,
socket_limit
)
# IPC connection to RPC Server (uses unix sockets).
self
.
to_rpc_server
=
self
.
context
.
socket
(
zmq
.
constants
.
DEALER
)
self
.
to_rpc_server
.
set_hwm
(
VLLM_RPC_ZMQ_HWM
)
self
.
to_rpc_server
.
bind
(
rpc_path
)
# In process proxy to RPC Server (uses memory-based messaging).
self
.
from_api_server
=
self
.
context
.
socket
(
zmq
.
constants
.
ROUTER
)
self
.
from_api_server
.
set_hwm
(
VLLM_RPC_ZMQ_HWM
)
self
.
from_api_server
.
bind
(
INPROC_PROXY_PATH
)
# Asyncio background task for the proxy.
self
.
proxy_task
=
asyncio
.
create_task
(
self
.
run_proxy
(
self
.
from_api_server
,
self
.
to_rpc_server
))
# Since we open 1 inproc socket per request, we have a hard cap on
# the number of requests that can run in vLLM w. frontend
# mulitprocessing. This value is used uvicorn to launch
# with --limit-concurrency to return 503 when server is overloaded.
# We need 2 sockets per request - 2:
# 1 for generate(), 1 for abort(), do_log_stats(), check_health()
self
.
limit_concurrency
=
socket_limit
//
2
-
2
async
def
run_proxy
(
self
,
socket_from
,
socket_to
):
"""Background task that runs a proxy"""
poller
=
zmq
.
asyncio
.
Poller
()
poller
.
register
(
socket_from
,
zmq
.
constants
.
POLLIN
)
poller
.
register
(
socket_to
,
zmq
.
constants
.
POLLIN
)
while
True
:
events
=
await
poller
.
poll
()
events
=
dict
(
events
)
if
socket_from
in
events
:
identity
,
msg
=
await
socket_from
.
recv_multipart
()
await
socket_to
.
send_multipart
([
identity
,
msg
])
if
socket_to
in
events
:
identity
,
msg
=
await
socket_to
.
recv_multipart
()
await
socket_from
.
send_multipart
([
identity
,
msg
])
async
def
setup
(
self
):
async
def
setup
(
self
):
"""Setup the client before it starts sending server requests."""
"""Setup the client before it starts sending server requests."""
# Wait until server is ready.
# Wait until server is ready.
await
self
.
wait_for_server
()
await
self
.
_
wait_for_server
_rpc
()
self
.
_errored
=
False
self
.
_errored
=
False
# Get the configs.
# Get the configs.
...
@@ -51,29 +161,23 @@ class AsyncEngineRPCClient:
...
@@ -51,29 +161,23 @@ class AsyncEngineRPCClient:
def
close
(
self
):
def
close
(
self
):
"""Destroy the ZeroMQ Context."""
"""Destroy the ZeroMQ Context."""
# Close all sockets associated with this context and
# then terminate the context.
self
.
from_api_server
.
close
()
self
.
to_rpc_server
.
close
()
self
.
context
.
destroy
()
self
.
context
.
destroy
()
@
contextmanager
@
contextmanager
def
socket
(
self
):
def
to_proxy_socket
(
self
):
# Ensure client sockets are always closed after use
# Connect to the RPCServer via the proxy.
# Connect to RPC socket for Request-Reply pattern,
# Note that we use DEALER to enable asynchronous communication
# Note that we use DEALER to enable asynchronous communication
# to enable streaming.
# to enable streaming.
socket
=
self
.
context
.
socket
(
zmq
.
constants
.
DEALER
)
socket
=
self
.
context
.
socket
(
zmq
.
constants
.
DEALER
)
socket
.
set_hwm
(
VLLM_RPC_ZMQ_HWM
)
try
:
try
:
socket
.
connect
(
self
.
rpc_path
)
socket
.
connect
(
INPROC_PROXY_PATH
)
yield
socket
yield
socket
finally
:
finally
:
# linger == 0 means discard unsent messages
# when the socket is closed. This is necessary
# because otherwise self.context.destroy() will
# wait for 30 seconds until unsent messages are
# received, which is impossible if the server
# crashed. In the absence of a server crash we
# always expect a response before closing the
# socket anyway.
# Reference: http://api.zeromq.org/4-2:zmq-setsockopt#toc24
socket
.
close
(
linger
=
0
)
socket
.
close
(
linger
=
0
)
async
def
_send_get_data_rpc_request
(
self
,
request
:
RPCUtilityRequest
,
async
def
_send_get_data_rpc_request
(
self
,
request
:
RPCUtilityRequest
,
...
@@ -81,10 +185,9 @@ class AsyncEngineRPCClient:
...
@@ -81,10 +185,9 @@ class AsyncEngineRPCClient:
error_message
:
str
)
->
Any
:
error_message
:
str
)
->
Any
:
"""Send an RPC request that is expecting data back."""
"""Send an RPC request that is expecting data back."""
with
self
.
socket
()
as
socket
:
with
self
.
to_proxy_socket
()
as
socket
:
# Ping RPCServer with a request.
# Ping RPCServer with a request.
await
socket
.
send
(
cloudpickle
.
dumps
(
request
))
await
socket
.
send
_multipart
([
cloudpickle
.
dumps
(
request
)
]
)
# Await the data from the Server.
# Await the data from the Server.
data
=
cloudpickle
.
loads
(
await
socket
.
recv
())
data
=
cloudpickle
.
loads
(
await
socket
.
recv
())
...
@@ -93,31 +196,48 @@ class AsyncEngineRPCClient:
...
@@ -93,31 +196,48 @@ class AsyncEngineRPCClient:
# LoRAConfig can be None.
# LoRAConfig can be None.
if
expected_type
==
LoRAConfig
and
data
is
None
:
if
expected_type
==
LoRAConfig
and
data
is
None
:
pass
pass
elif
isinstance
(
data
,
Exception
):
logger
.
error
(
error_message
)
raise
data
else
:
else
:
raise
ValueError
(
error_message
)
raise
ValueError
(
error_message
)
return
data
return
data
async
def
_send_one_way_rpc_request
(
self
,
async
def
_send_one_way_rpc_request
(
request
:
RPC_REQUEST_TYPE
,
self
,
error_message
:
str
,
request
:
RPC_REQUEST_TYPE
,
timeout
:
Optional
[
int
]
=
None
):
error_message
:
str
,
timeout
:
Optional
[
int
]
=
None
,
socket
:
Optional
[
zmq
.
asyncio
.
Socket
]
=
None
):
"""Send one-way RPC request to trigger an action."""
"""Send one-way RPC request to trigger an action."""
with
self
.
socket
()
as
socket
:
# Ping RPC Server with request.
await
socket
.
send
(
cloudpickle
.
dumps
(
request
))
# Await acknowledgement from RPCServer.
async
def
do_rpc_call
(
socket
:
zmq
.
asyncio
.
Socket
,
request
:
RPC_REQUEST_TYPE
,
timeout
=
None
):
await
socket
.
send_multipart
([
cloudpickle
.
dumps
(
request
)])
if
timeout
is
not
None
and
await
socket
.
poll
(
timeout
=
timeout
)
==
0
:
if
timeout
is
not
None
and
await
socket
.
poll
(
timeout
=
timeout
)
==
0
:
raise
TimeoutError
(
f
"server didn't reply within
{
timeout
}
ms"
)
raise
TimeoutError
(
f
"Server didn't reply within
{
timeout
}
ms"
)
return
cloudpickle
.
loads
(
await
socket
.
recv
())
response
=
cloudpickle
.
loads
(
await
socket
.
recv
())
# Make a new socket connection.
if
socket
is
None
:
with
self
.
to_proxy_socket
()
as
socket
:
response
=
await
do_rpc_call
(
socket
,
request
,
timeout
)
# Use existing socket connection.
else
:
response
=
await
do_rpc_call
(
socket
,
request
,
timeout
)
if
not
isinstance
(
response
,
str
)
or
response
!=
VLLM_RPC_SUCCESS_STR
:
if
not
isinstance
(
response
,
str
)
or
response
!=
VLLM_RPC_SUCCESS_STR
:
if
isinstance
(
response
,
Exception
):
logger
.
error
(
error_message
)
raise
response
raise
ValueError
(
error_message
)
raise
ValueError
(
error_message
)
return
response
async
def
get_tokenizer
(
self
,
lora_request
:
LoRARequest
):
async
def
get_tokenizer
(
self
,
lora_request
:
LoRARequest
):
return
await
self
.
tokenizer
.
get_lora_tokenizer_async
(
lora_request
)
return
await
self
.
tokenizer
.
get_lora_tokenizer_async
(
lora_request
)
...
@@ -130,13 +250,13 @@ class AsyncEngineRPCClient:
...
@@ -130,13 +250,13 @@ class AsyncEngineRPCClient:
async
def
is_tracing_enabled
(
self
)
->
bool
:
async
def
is_tracing_enabled
(
self
)
->
bool
:
return
self
.
tracing_flag
return
self
.
tracing_flag
async
def
wait_for_server
(
self
):
async
def
_
wait_for_server
_rpc
(
self
):
"""Wait for the RPCServer to start up."""
"""Wait for the RPCServer to start up."""
await
self
.
_send_one_way_rpc_request
(
await
self
.
_send_one_way_rpc_request
(
request
=
RPCUtilityRequest
.
IS_SERVER_READY
,
request
=
RPCUtilityRequest
.
IS_SERVER_READY
,
error_message
=
"Unable to start RPC Server
.
"
,
error_message
=
"Unable to start RPC Server"
,
timeout
=
SERVER_START_TIMEOUT_MS
)
timeout
=
VLLM_RPC_
SERVER_START_TIMEOUT_MS
)
async
def
_get_model_config_rpc
(
self
)
->
ModelConfig
:
async
def
_get_model_config_rpc
(
self
)
->
ModelConfig
:
"""Get the ModelConfig object from the RPC Server"""
"""Get the ModelConfig object from the RPC Server"""
...
@@ -184,8 +304,7 @@ class AsyncEngineRPCClient:
...
@@ -184,8 +304,7 @@ class AsyncEngineRPCClient:
return
await
self
.
_send_get_data_rpc_request
(
return
await
self
.
_send_get_data_rpc_request
(
RPCUtilityRequest
.
IS_TRACING_ENABLED
,
RPCUtilityRequest
.
IS_TRACING_ENABLED
,
expected_type
=
bool
,
expected_type
=
bool
,
error_message
=
"Could not get is_tracing_enabled flag from RPC "
error_message
=
"Could not get is_tracing_enabled from RPC Server"
)
"Server"
)
async
def
abort
(
self
,
request_id
:
str
):
async
def
abort
(
self
,
request_id
:
str
):
"""Send an ABORT_REQUEST signal to the RPC Server"""
"""Send an ABORT_REQUEST signal to the RPC Server"""
...
@@ -226,8 +345,7 @@ class AsyncEngineRPCClient:
...
@@ -226,8 +345,7 @@ class AsyncEngineRPCClient:
finished
=
False
finished
=
False
try
:
try
:
with
self
.
socket
()
as
socket
:
with
self
.
to_proxy_socket
()
as
socket
:
# Send RPCGenerateRequest to the RPCServer.
# Send RPCGenerateRequest to the RPCServer.
await
socket
.
send_multipart
([
await
socket
.
send_multipart
([
cloudpickle
.
dumps
(
cloudpickle
.
dumps
(
...
@@ -246,43 +364,37 @@ class AsyncEngineRPCClient:
...
@@ -246,43 +364,37 @@ class AsyncEngineRPCClient:
request_output
=
cloudpickle
.
loads
(
message
)
request_output
=
cloudpickle
.
loads
(
message
)
if
isinstance
(
request_output
,
Exception
):
if
isinstance
(
request_output
,
Exception
):
# On exception, check if the server is still healthy.
# On exception, check if the server is still healthy
# Use this to set the sync `is_running` and `errored`
# possibly setting the `errored` property.
# properties.
if
not
self
.
_errored
:
try
:
try
:
await
self
.
check_health
()
await
self
.
check_health
(
socket
=
socket
)
except
Exception
:
except
Exception
as
e
:
self
.
_errored
=
True
self
.
_errored
=
True
logger
.
exception
(
repr
(
e
))
# NB: do before raising here so that the flag is set
# NB: do before raising here so that the flag is set
# by the time the caller receives this exception
# by the time the caller receives this exception
raise
request_output
raise
request_output
finished
=
request_output
.
finished
finished
=
request_output
.
finished
yield
request_output
yield
request_output
finally
:
finally
:
if
not
finished
:
# Request was canceled by the client.
if
not
finished
and
not
self
.
_errored
:
await
self
.
abort
(
request_id
)
await
self
.
abort
(
request_id
)
async
def
check_health
(
self
)
->
None
:
async
def
check_health
(
self
,
socket
:
Optional
[
zmq
.
asyncio
.
Socket
]
=
None
)
->
None
:
"""Raise if unhealthy"""
"""Raise if unhealthy"""
with
self
.
socket
()
as
socket
:
await
self
.
_send_one_way_rpc_request
(
request
=
RPCUtilityRequest
.
IS_SERVER_HEALTHY
,
# Ping RPCServer with CHECK_HEALTH request.
error_message
=
"Got Unhealthy response from RPC Server"
,
await
socket
.
send
(
cloudpickle
.
dumps
(
RPCUtilityRequest
.
CHECK_HEALTH
)
timeout
=
VLLM_RPC_HEALTH_TIMEOUT_MS
,
)
socket
=
socket
)
# Await the reply from the server.
# TODO: do we need an internal timeout here?
# Or do we expect the external probe to timeout and let this chill?
health_message
=
cloudpickle
.
loads
(
await
socket
.
recv
())
if
isinstance
(
health_message
,
Exception
):
raise
health_message
if
health_message
!=
VLLM_RPC_HEALTHY_STR
:
raise
ValueError
(
"Expected healthy response from backend but got "
"f{health_message}"
)
async
def
encode
(
self
,
*
args
,
async
def
encode
(
self
,
*
args
,
**
kwargs
)
->
AsyncGenerator
[
EmbeddingRequestOutput
,
None
]:
**
kwargs
)
->
AsyncGenerator
[
EmbeddingRequestOutput
,
None
]:
...
...
vllm/entrypoints/openai/rpc/server.py
View file @
f7e3b0c5
import
asyncio
import
asyncio
import
signal
import
signal
from
typing
import
Any
,
Coroutine
from
typing
import
Any
,
Coroutine
,
Union
import
cloudpickle
import
cloudpickle
import
uvloop
import
uvloop
...
@@ -9,14 +9,19 @@ import zmq.asyncio
...
@@ -9,14 +9,19 @@ import zmq.asyncio
from
typing_extensions
import
Never
from
typing_extensions
import
Never
from
vllm
import
AsyncEngineArgs
,
AsyncLLMEngine
from
vllm
import
AsyncEngineArgs
,
AsyncLLMEngine
from
vllm.entrypoints.openai.rpc
import
(
VLLM_RPC_HEALTHY_STR
,
from
vllm.config
import
(
DecodingConfig
,
LoRAConfig
,
ModelConfig
,
VLLM_RPC_SUCCESS_STR
,
RPCAbortRequest
,
ParallelConfig
,
SchedulerConfig
)
from
vllm.entrypoints.openai.rpc
import
(
VLLM_RPC_SUCCESS_STR
,
VLLM_RPC_ZMQ_HWM
,
RPCAbortRequest
,
RPCGenerateRequest
,
RPCUtilityRequest
)
RPCGenerateRequest
,
RPCUtilityRequest
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.usage.usage_lib
import
UsageContext
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
CONFIG_TYPE
=
Union
[
ModelConfig
,
DecodingConfig
,
ParallelConfig
,
SchedulerConfig
,
LoRAConfig
]
class
AsyncEngineRPCServer
:
class
AsyncEngineRPCServer
:
...
@@ -29,9 +34,10 @@ class AsyncEngineRPCServer:
...
@@ -29,9 +34,10 @@ class AsyncEngineRPCServer:
# Initialize context.
# Initialize context.
self
.
context
=
zmq
.
asyncio
.
Context
()
self
.
context
=
zmq
.
asyncio
.
Context
()
# Init socket for readiness state.
# Init socket.
self
.
socket
=
self
.
context
.
socket
(
zmq
.
constants
.
ROUTER
)
self
.
socket
=
self
.
context
.
socket
(
zmq
.
constants
.
DEALER
)
self
.
socket
.
bind
(
rpc_path
)
self
.
socket
.
set_hwm
(
VLLM_RPC_ZMQ_HWM
)
self
.
socket
.
connect
(
rpc_path
)
def
cleanup
(
self
):
def
cleanup
(
self
):
"""Cleanup all resources."""
"""Cleanup all resources."""
...
@@ -41,39 +47,27 @@ class AsyncEngineRPCServer:
...
@@ -41,39 +47,27 @@ class AsyncEngineRPCServer:
# Clear the engine reference so that it can be GC'ed.
# Clear the engine reference so that it can be GC'ed.
del
self
.
engine
del
self
.
engine
async
def
get_model_config
(
self
,
identity
):
async
def
get_config
(
self
,
identity
,
request
):
"""Send the ModelConfig"""
try
:
model_config
=
await
self
.
engine
.
get_model_config
()
config
:
CONFIG_TYPE
if
request
==
RPCUtilityRequest
.
GET_MODEL_CONFIG
:
await
self
.
socket
.
send_multipart
(
config
=
await
self
.
engine
.
get_model_config
()
[
identity
,
cloudpickle
.
dumps
(
model_config
)])
elif
request
==
RPCUtilityRequest
.
GET_DECODING_CONFIG
:
config
=
await
self
.
engine
.
get_decoding_config
()
async
def
get_decoding_config
(
self
,
identity
):
elif
request
==
RPCUtilityRequest
.
GET_LORA_CONFIG
:
"""Send the DecodingConfig"""
config
=
await
self
.
engine
.
get_lora_config
()
decoding_config
=
await
self
.
engine
.
get_decoding_config
()
elif
request
==
RPCUtilityRequest
.
GET_SCHEDULER_CONFIG
:
config
=
await
self
.
engine
.
get_scheduler_config
()
await
self
.
socket
.
send_multipart
(
elif
request
==
RPCUtilityRequest
.
GET_PARALLEL_CONFIG
:
[
identity
,
cloudpickle
.
dumps
(
decoding_config
)])
config
=
await
self
.
engine
.
get_parallel_config
()
else
:
async
def
get_lora_config
(
self
,
identity
):
raise
ValueError
(
"Unknown Config Request: %s"
,
request
)
lora_config
=
await
self
.
engine
.
get_lora_config
()
await
self
.
socket
.
send_multipart
(
[
identity
,
cloudpickle
.
dumps
(
lora_config
)])
async
def
get_scheduler_config
(
self
,
identity
):
"""Send the SchedulerConfig"""
parallel_config
=
await
self
.
engine
.
get_scheduler_config
()
await
self
.
socket
.
send_multipart
(
[
identity
,
cloudpickle
.
dumps
(
parallel_config
)])
async
def
get_parallel_config
(
self
,
identity
):
await
self
.
socket
.
send_multipart
(
"""Send the ParallelConfig"""
[
identity
,
cloudpickle
.
dumps
(
config
)])
parallel_config
=
await
self
.
engine
.
get_parallel_config
()
await
self
.
socket
.
send_multipart
(
except
Exception
as
e
:
[
identity
,
cloudpickle
.
dumps
(
parallel_config
)])
await
self
.
socket
.
send_multipart
(
[
identity
,
cloudpickle
.
dumps
(
e
)])
async
def
is_tracing_enabled
(
self
,
identity
):
async
def
is_tracing_enabled
(
self
,
identity
):
"""Send the is_tracing_enabled flag"""
"""Send the is_tracing_enabled flag"""
...
@@ -86,31 +80,23 @@ class AsyncEngineRPCServer:
...
@@ -86,31 +80,23 @@ class AsyncEngineRPCServer:
"""Log stats and confirm success."""
"""Log stats and confirm success."""
await
self
.
engine
.
do_log_stats
()
await
self
.
engine
.
do_log_stats
()
await
self
.
socket
.
send_multipart
([
await
self
.
socket
.
send_multipart
(
identity
,
[
identity
,
cloudpickle
.
dumps
(
VLLM_RPC_SUCCESS_STR
)])
cloudpickle
.
dumps
(
VLLM_RPC_SUCCESS_STR
),
])
async
def
is_server_ready
(
self
,
identity
):
async
def
is_server_ready
(
self
,
identity
):
"""Notify the client that we are ready."""
"""Notify the client that we are ready."""
await
self
.
socket
.
send_multipart
([
await
self
.
socket
.
send_multipart
(
identity
,
[
identity
,
cloudpickle
.
dumps
(
VLLM_RPC_SUCCESS_STR
)])
cloudpickle
.
dumps
(
VLLM_RPC_SUCCESS_STR
),
])
async
def
abort
(
self
,
identity
,
request
:
RPCAbortRequest
):
async
def
abort
(
self
,
identity
,
request
:
RPCAbortRequest
):
"""Abort request and notify the client of success."""
"""Abort request and notify the client of success."""
try
:
try
:
# Abort the request in the llm engine.
# Abort the request in the llm engine.
await
self
.
engine
.
abort
(
request
.
request_id
)
await
self
.
engine
.
abort
(
request
.
request_id
)
except
Exception
:
result
:
Union
[
str
,
Exception
]
=
VLLM_RPC_SUCCESS_STR
logger
.
warning
(
"Failed to abort request %s"
,
request
.
request_id
)
except
Exception
as
e
:
finally
:
result
=
e
# Send confirmation to the client.
await
self
.
socket
.
send_multipart
([
identity
,
cloudpickle
.
dumps
(
result
)])
await
self
.
socket
.
send_multipart
([
identity
,
cloudpickle
.
dumps
(
VLLM_RPC_SUCCESS_STR
),
])
async
def
generate
(
self
,
identity
,
generate_request
:
RPCGenerateRequest
):
async
def
generate
(
self
,
identity
,
generate_request
:
RPCGenerateRequest
):
try
:
try
:
...
@@ -127,14 +113,14 @@ class AsyncEngineRPCServer:
...
@@ -127,14 +113,14 @@ class AsyncEngineRPCServer:
[
identity
,
cloudpickle
.
dumps
(
request_output
)])
[
identity
,
cloudpickle
.
dumps
(
request_output
)])
except
Exception
as
e
:
except
Exception
as
e
:
### Notify client of all failures
await
self
.
socket
.
send_multipart
([
identity
,
cloudpickle
.
dumps
(
e
)])
await
self
.
socket
.
send_multipart
([
identity
,
cloudpickle
.
dumps
(
e
)])
async
def
check_health
(
self
,
identity
):
async
def
check_health
(
self
,
identity
):
try
:
try
:
await
self
.
engine
.
check_health
()
await
self
.
engine
.
check_health
()
await
self
.
socket
.
send_multipart
(
await
self
.
socket
.
send_multipart
(
[
identity
,
cloudpickle
.
dumps
(
VLLM_RPC_HEALTHY_STR
)])
[
identity
,
cloudpickle
.
dumps
(
VLLM_RPC_SUCCESS_STR
)])
except
Exception
as
e
:
except
Exception
as
e
:
await
self
.
socket
.
send_multipart
([
identity
,
cloudpickle
.
dumps
(
e
)])
await
self
.
socket
.
send_multipart
([
identity
,
cloudpickle
.
dumps
(
e
)])
...
@@ -151,21 +137,19 @@ class AsyncEngineRPCServer:
...
@@ -151,21 +137,19 @@ class AsyncEngineRPCServer:
return
self
.
abort
(
identity
,
request
)
return
self
.
abort
(
identity
,
request
)
elif
isinstance
(
request
,
RPCUtilityRequest
):
elif
isinstance
(
request
,
RPCUtilityRequest
):
if
request
==
RPCUtilityRequest
.
GET_MODEL_CONFIG
:
if
request
in
[
return
self
.
get_model_config
(
identity
)
RPCUtilityRequest
.
GET_MODEL_CONFIG
,
elif
request
==
RPCUtilityRequest
.
GET_PARALLEL_CONFIG
:
RPCUtilityRequest
.
GET_PARALLEL_CONFIG
,
return
self
.
get_parallel_config
(
identity
)
RPCUtilityRequest
.
GET_DECODING_CONFIG
,
elif
request
==
RPCUtilityRequest
.
GET_DECODING_CONFIG
:
RPCUtilityRequest
.
GET_SCHEDULER_CONFIG
,
return
self
.
get_decoding_config
(
identity
)
RPCUtilityRequest
.
GET_LORA_CONFIG
elif
request
==
RPCUtilityRequest
.
GET_SCHEDULER_CONFIG
:
]:
return
self
.
get_scheduler_config
(
identity
)
return
self
.
get_config
(
identity
,
request
)
elif
request
==
RPCUtilityRequest
.
GET_LORA_CONFIG
:
return
self
.
get_lora_config
(
identity
)
elif
request
==
RPCUtilityRequest
.
DO_LOG_STATS
:
elif
request
==
RPCUtilityRequest
.
DO_LOG_STATS
:
return
self
.
do_log_stats
(
identity
)
return
self
.
do_log_stats
(
identity
)
elif
request
==
RPCUtilityRequest
.
IS_SERVER_READY
:
elif
request
==
RPCUtilityRequest
.
IS_SERVER_READY
:
return
self
.
is_server_ready
(
identity
)
return
self
.
is_server_ready
(
identity
)
elif
request
==
RPCUtilityRequest
.
CHECK
_HEALTH
:
elif
request
==
RPCUtilityRequest
.
IS_SERVER
_HEALTH
Y
:
return
self
.
check_health
(
identity
)
return
self
.
check_health
(
identity
)
elif
request
==
RPCUtilityRequest
.
IS_TRACING_ENABLED
:
elif
request
==
RPCUtilityRequest
.
IS_TRACING_ENABLED
:
return
self
.
is_tracing_enabled
(
identity
)
return
self
.
is_tracing_enabled
(
identity
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment