Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
bdce64f2
Unverified
Commit
bdce64f2
authored
Jun 02, 2025
by
Rui Qiao
Committed by
GitHub
Jun 02, 2025
Browse files
[V1] Support DP with Ray (#18779)
parent
9e6f61e8
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
539 additions
and
108 deletions
+539
-108
requirements/test.in
requirements/test.in
+1
-1
requirements/test.txt
requirements/test.txt
+50
-0
tests/v1/test_async_llm_dp.py
tests/v1/test_async_llm_dp.py
+10
-3
vllm/config.py
vllm/config.py
+6
-0
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+25
-4
vllm/entrypoints/cli/serve.py
vllm/entrypoints/cli/serve.py
+30
-5
vllm/v1/engine/async_llm.py
vllm/v1/engine/async_llm.py
+9
-4
vllm/v1/engine/core.py
vllm/v1/engine/core.py
+123
-45
vllm/v1/engine/core_client.py
vllm/v1/engine/core_client.py
+65
-9
vllm/v1/utils.py
vllm/v1/utils.py
+220
-37
No files found.
requirements/test.in
View file @
bdce64f2
...
...
@@ -17,7 +17,7 @@ vector_quantize_pytorch # required for minicpmo_26 test
vocos # required for minicpmo_26 test
peft
pqdm
ray[cgraph]>=2.43.0, !=2.44.* # Ray Compiled Graph, required by pipeline parallelism tests
ray[cgraph
,default
]>=2.43.0, !=2.44.* # Ray Compiled Graph, required by pipeline parallelism tests
sentence-transformers # required for embedding tests
soundfile # required for audio tests
jiwer # required for audio tests
...
...
requirements/test.txt
View file @
bdce64f2
...
...
@@ -10,9 +10,13 @@ aiohappyeyeballs==2.4.3
# via aiohttp
aiohttp==3.10.11
# via
# aiohttp-cors
# datasets
# fsspec
# lm-eval
# ray
aiohttp-cors==0.8.1
# via ray
aiosignal==1.3.1
# via
# aiohttp
...
...
@@ -57,6 +61,8 @@ bounded-pool-executor==0.0.3
# via pqdm
buildkite-test-collector==0.1.9
# via -r requirements/test.in
cachetools==5.5.2
# via google-auth
certifi==2024.8.30
# via
# httpcore
...
...
@@ -81,6 +87,8 @@ colorama==0.4.6
# sacrebleu
# schemathesis
# tqdm-multiprocess
colorful==0.5.6
# via ray
contourpy==1.3.0
# via matplotlib
cramjam==2.9.0
...
...
@@ -108,6 +116,8 @@ dill==0.3.8
# evaluate
# lm-eval
# multiprocess
distlib==0.3.9
# via virtualenv
dnspython==2.7.0
# via email-validator
docopt==0.6.2
...
...
@@ -143,6 +153,7 @@ filelock==3.16.1
# ray
# torch
# transformers
# virtualenv
fonttools==4.54.1
# via matplotlib
fqdn==1.5.1
...
...
@@ -165,8 +176,16 @@ genai-perf==0.0.8
# via -r requirements/test.in
genson==1.3.0
# via datamodel-code-generator
google-api-core==2.24.2
# via opencensus
google-auth==2.40.2
# via google-api-core
googleapis-common-protos==1.70.0
# via google-api-core
graphql-core==3.2.6
# via hypothesis-graphql
grpcio==1.71.0
# via ray
h11==0.14.0
# via httpcore
harfile==0.3.0
...
...
@@ -392,6 +411,10 @@ nvidia-nvjitlink-cu12==12.8.61
# torch
nvidia-nvtx-cu12==12.8.55
# via torch
opencensus==0.11.4
# via ray
opencensus-context==0.1.3
# via opencensus
opencv-python-headless==4.11.0.86
# via
# -r requirements/test.in
...
...
@@ -445,6 +468,7 @@ platformdirs==4.3.6
# via
# black
# pooch
# virtualenv
plotly==5.24.1
# via genai-perf
pluggy==1.5.0
...
...
@@ -457,10 +481,17 @@ portalocker==2.10.1
# via sacrebleu
pqdm==0.2.0
# via -r requirements/test.in
prometheus-client==0.22.0
# via ray
propcache==0.2.0
# via yarl
proto-plus==1.26.1
# via google-api-core
protobuf==5.28.3
# via
# google-api-core
# googleapis-common-protos
# proto-plus
# ray
# tensorizer
psutil==6.1.0
...
...
@@ -470,10 +501,18 @@ psutil==6.1.0
# tensorizer
py==1.11.0
# via pytest-forked
py-spy==0.4.0
# via ray
pyarrow==18.0.0
# via
# datasets
# genai-perf
pyasn1==0.6.1
# via
# pyasn1-modules
# rsa
pyasn1-modules==0.4.2
# via google-auth
pybind11==2.13.6
# via lm-eval
pycparser==2.22
...
...
@@ -486,6 +525,7 @@ pydantic==2.11.5
# datamodel-code-generator
# mistral-common
# mteb
# ray
pydantic-core==2.33.2
# via pydantic
pygments==2.18.0
...
...
@@ -573,6 +613,7 @@ requests==2.32.3
# buildkite-test-collector
# datasets
# evaluate
# google-api-core
# huggingface-hub
# lm-eval
# mistral-common
...
...
@@ -601,6 +642,8 @@ rpds-py==0.20.1
# via
# jsonschema
# referencing
rsa==4.9.1
# via google-auth
runai-model-streamer==0.11.0
# via -r requirements/test.in
runai-model-streamer-s3==0.11.0
...
...
@@ -648,9 +691,12 @@ shellingham==1.5.4
six==1.16.0
# via
# junit-xml
# opencensus
# python-dateutil
# rfc3339-validator
# rouge-score
smart-open==7.1.0
# via ray
sniffio==1.3.1
# via
# anyio
...
...
@@ -801,6 +847,8 @@ urllib3==2.2.3
# tritonclient
vector-quantize-pytorch==1.21.2
# via -r requirements/test.in
virtualenv==20.31.2
# via ray
vocos==0.1.0
# via -r requirements/test.in
webcolors==24.11.1
...
...
@@ -809,6 +857,8 @@ werkzeug==3.1.3
# via schemathesis
word2number==1.1
# via lm-eval
wrapt==1.17.2
# via smart-open
xxhash==3.5.0
# via
# datasets
...
...
tests/v1/test_async_llm_dp.py
View file @
bdce64f2
...
...
@@ -59,14 +59,22 @@ async def generate(engine: AsyncLLM,
@
pytest
.
mark
.
parametrize
(
"output_kind"
,
[
RequestOutputKind
.
DELTA
,
RequestOutputKind
.
FINAL_ONLY
])
"output_kind"
,
[
RequestOutputKind
.
DELTA
,
RequestOutputKind
.
FINAL_ONLY
,
],
)
@
pytest
.
mark
.
parametrize
(
"data_parallel_backend"
,
[
"mp"
,
"ray"
])
@
pytest
.
mark
.
asyncio
async
def
test_load
(
output_kind
:
RequestOutputKind
):
async
def
test_load
(
output_kind
:
RequestOutputKind
,
data_parallel_backend
:
str
):
with
ExitStack
()
as
after
:
prompt
=
"This is a test of data parallel"
engine_args
.
data_parallel_backend
=
data_parallel_backend
engine
=
AsyncLLM
.
from_engine_args
(
engine_args
)
after
.
callback
(
engine
.
shutdown
)
...
...
@@ -82,7 +90,6 @@ async def test_load(output_kind: RequestOutputKind):
asyncio
.
create_task
(
generate
(
engine
,
request_id
,
prompt
,
output_kind
,
NUM_EXPECTED_TOKENS
)))
# Confirm that we got all the EXPECTED tokens from the requests.
done
,
pending
=
await
asyncio
.
wait
(
tasks
,
return_when
=
asyncio
.
FIRST_EXCEPTION
)
...
...
vllm/config.py
View file @
bdce64f2
...
...
@@ -1742,6 +1742,8 @@ class ParallelConfig:
"""Port for data parallel messaging."""
data_parallel_master_port
:
int
=
29500
"""Port of the data parallel master."""
data_parallel_backend
:
str
=
"mp"
"""Backend to use for data parallel, either "mp" or "ray"."""
enable_expert_parallel
:
bool
=
False
"""Use expert parallelism instead of tensor parallelism for MoE layers."""
max_parallel_loading_workers
:
Optional
[
int
]
=
None
...
...
@@ -1911,6 +1913,10 @@ class ParallelConfig:
"please install Ray with `pip install "
"ray`."
)
from
ray_utils
.
ray_import_err
backend
=
"ray"
elif
self
.
data_parallel_backend
==
"ray"
:
logger
.
info
(
"Using ray distributed inference because "
"data_parallel_backend is ray"
)
backend
=
"ray"
elif
ray_found
:
if
self
.
placement_group
:
backend
=
"ray"
...
...
vllm/engine/arg_utils.py
View file @
bdce64f2
...
...
@@ -39,7 +39,7 @@ from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3
from
vllm.transformers_utils.utils
import
check_gguf_file
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.utils
import
(
STR_DUAL_CHUNK_FLASH_ATTN_VAL
,
FlexibleArgumentParser
,
GiB_bytes
,
is_in_ray_actor
)
GiB_bytes
,
get_ip
,
is_in_ray_actor
)
# yapf: enable
...
...
@@ -292,6 +292,7 @@ class EngineArgs:
data_parallel_size_local
:
Optional
[
int
]
=
None
data_parallel_address
:
Optional
[
str
]
=
None
data_parallel_rpc_port
:
Optional
[
int
]
=
None
data_parallel_backend
:
str
=
ParallelConfig
.
data_parallel_backend
enable_expert_parallel
:
bool
=
ParallelConfig
.
enable_expert_parallel
max_parallel_loading_workers
:
Optional
[
int
]
=
ParallelConfig
.
max_parallel_loading_workers
...
...
@@ -624,6 +625,12 @@ class EngineArgs:
type
=
int
,
help
=
'Port for data parallel RPC '
'communication.'
)
parallel_group
.
add_argument
(
'--data-parallel-backend'
,
'-dpb'
,
type
=
str
,
default
=
'mp'
,
help
=
'Backend for data parallel, either '
'"mp" or "ray".'
)
parallel_group
.
add_argument
(
"--enable-expert-parallel"
,
**
parallel_kwargs
[
"enable_expert_parallel"
])
...
...
@@ -1059,9 +1066,20 @@ class EngineArgs:
# DP address, used in multi-node case for torch distributed group
# and ZMQ sockets.
data_parallel_address
=
self
.
data_parallel_address
if
(
self
.
data_parallel_address
is
not
None
)
else
ParallelConfig
.
data_parallel_master_ip
if
self
.
data_parallel_address
is
None
:
if
self
.
data_parallel_backend
==
"ray"
:
host_ip
=
get_ip
()
logger
.
info
(
"Using host IP %s as ray-based data parallel address"
,
host_ip
)
data_parallel_address
=
host_ip
else
:
assert
self
.
data_parallel_backend
==
"mp"
,
(
"data_parallel_backend can only be ray or mp, got %s"
,
self
.
data_parallel_backend
)
data_parallel_address
=
ParallelConfig
.
data_parallel_master_ip
else
:
data_parallel_address
=
self
.
data_parallel_address
# This port is only used when there are remote data parallel engines,
# otherwise the local IPC transport is used.
...
...
@@ -1069,6 +1087,8 @@ class EngineArgs:
self
.
data_parallel_rpc_port
is
not
None
)
else
ParallelConfig
.
data_parallel_rpc_port
data_parallel_backend
=
self
.
data_parallel_backend
parallel_config
=
ParallelConfig
(
pipeline_parallel_size
=
self
.
pipeline_parallel_size
,
tensor_parallel_size
=
self
.
tensor_parallel_size
,
...
...
@@ -1076,6 +1096,7 @@ class EngineArgs:
data_parallel_size_local
=
data_parallel_size_local
,
data_parallel_master_ip
=
data_parallel_address
,
data_parallel_rpc_port
=
data_parallel_rpc_port
,
data_parallel_backend
=
data_parallel_backend
,
enable_expert_parallel
=
self
.
enable_expert_parallel
,
max_parallel_loading_workers
=
self
.
max_parallel_loading_workers
,
disable_custom_all_reduce
=
self
.
disable_custom_all_reduce
,
...
...
vllm/entrypoints/cli/serve.py
View file @
bdce64f2
...
...
@@ -27,7 +27,8 @@ from vllm.v1.engine.core_client import CoreEngineProcManager
from
vllm.v1.executor.abstract
import
Executor
from
vllm.v1.metrics.prometheus
import
setup_multiprocess_prometheus
from
vllm.v1.utils
import
(
APIServerProcessManager
,
CoreEngine
,
EngineZmqAddresses
,
get_engine_client_zmq_addr
,
CoreEngineActorManager
,
EngineZmqAddresses
,
get_engine_client_zmq_addr
,
wait_for_completion_or_failure
,
wait_for_engine_startup
)
...
...
@@ -229,6 +230,31 @@ def run_multi_api_server(args: argparse.Namespace):
logger
.
info
(
"Started DP Coordinator process (PID: %d)"
,
coordinator
.
proc
.
pid
)
if
parallel_config
.
data_parallel_backend
==
"ray"
:
logger
.
info
(
"Starting ray-based data parallel backend"
)
engine_actor_manager
=
CoreEngineActorManager
(
vllm_config
=
vllm_config
,
addresses
=
addresses
,
executor_class
=
Executor
.
get_class
(
vllm_config
),
log_stats
=
not
engine_args
.
disable_log_stats
,
)
# Start API servers using the manager
api_server_manager
=
APIServerProcessManager
(
target_server_fn
=
run_api_server_worker_proc
,
listen_address
=
listen_address
,
sock
=
sock
,
args
=
args
,
num_servers
=
num_api_servers
,
input_addresses
=
input_addresses
,
output_addresses
=
output_addresses
,
stats_update_address
=
stats_update_address
)
wait_for_completion_or_failure
(
api_server_manager
=
api_server_manager
,
engine_manager
=
engine_actor_manager
,
coordinator
=
coordinator
)
return
handshake_address
=
get_engine_client_zmq_addr
(
local_only
,
host
,
parallel_config
.
data_parallel_rpc_port
)
...
...
@@ -277,10 +303,9 @@ def run_multi_api_server(args: argparse.Namespace):
)
# Wait for API servers
wait_for_completion_or_failure
(
api_server_manager
=
api_server_manager
,
local_engine_manager
=
local_engine_manager
,
coordinator
=
coordinator
)
wait_for_completion_or_failure
(
api_server_manager
=
api_server_manager
,
engine_manager
=
local_engine_manager
,
coordinator
=
coordinator
)
def
run_api_server_worker_proc
(
listen_address
,
...
...
vllm/v1/engine/async_llm.py
View file @
bdce64f2
...
...
@@ -27,7 +27,8 @@ from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.utils
import
Device
,
cdiv
from
vllm.v1.engine
import
EngineCoreRequest
from
vllm.v1.engine.core_client
import
AsyncMPClient
,
DPAsyncMPClient
from
vllm.v1.engine.core_client
import
(
AsyncMPClient
,
DPAsyncMPClient
,
RayDPClient
)
from
vllm.v1.engine.exceptions
import
EngineDeadError
,
EngineGenerateError
from
vllm.v1.engine.output_processor
import
(
OutputProcessor
,
RequestOutputCollector
)
...
...
@@ -119,9 +120,13 @@ class AsyncLLM(EngineClient):
log_stats
=
self
.
log_stats
)
# EngineCore (starts the engine in background process).
core_client_class
=
AsyncMPClient
if
(
vllm_config
.
parallel_config
.
data_parallel_size
==
1
)
else
DPAsyncMPClient
core_client_class
:
type
[
AsyncMPClient
]
if
vllm_config
.
parallel_config
.
data_parallel_size
==
1
:
core_client_class
=
AsyncMPClient
elif
vllm_config
.
parallel_config
.
data_parallel_backend
==
"ray"
:
core_client_class
=
RayDPClient
else
:
core_client_class
=
DPAsyncMPClient
self
.
engine_core
=
core_client_class
(
vllm_config
=
vllm_config
,
...
...
vllm/v1/engine/core.py
View file @
bdce64f2
...
...
@@ -6,8 +6,9 @@ import sys
import
threading
import
time
from
collections
import
deque
from
collections.abc
import
Generator
from
concurrent.futures
import
Future
from
contextlib
import
ExitStack
from
contextlib
import
ExitStack
,
contextmanager
from
inspect
import
isclass
,
signature
from
logging
import
DEBUG
from
typing
import
Any
,
Callable
,
Optional
,
TypeVar
,
Union
...
...
@@ -367,60 +368,35 @@ class EngineCoreProc(EngineCore):
log_stats
:
bool
,
engine_index
:
int
=
0
,
):
input_queue
=
queue
.
Queue
[
tuple
[
EngineCoreRequestType
,
Any
]]()
executor_fail_callback
=
lambda
:
input_queue
.
put_nowait
(
self
.
input_queue
=
queue
.
Queue
[
tuple
[
EngineCoreRequestType
,
Any
]]()
self
.
output_queue
=
queue
.
Queue
[
Union
[
tuple
[
int
,
EngineCoreOutputs
],
bytes
]]()
executor_fail_callback
=
lambda
:
self
.
input_queue
.
put_nowait
(
(
EngineCoreRequestType
.
EXECUTOR_FAILED
,
b
''
))
# Create input socket.
input_ctx
=
zmq
.
Context
()
identity
=
engine_index
.
to_bytes
(
length
=
2
,
byteorder
=
"little"
)
with
make_zmq_socket
(
input_ctx
,
handshake_address
,
zmq
.
DEALER
,
identity
=
identity
,
linger
=
5000
,
bind
=
False
)
as
handshake_socket
:
self
.
engine_index
=
engine_index
identity
=
self
.
engine_index
.
to_bytes
(
length
=
2
,
byteorder
=
"little"
)
self
.
engines_running
=
False
# Register engine with front-end.
addresses
=
self
.
startup_handshake
(
handshake_socket
,
on_head_node
,
vllm_config
.
parallel_config
)
with
self
.
_perform_handshake
(
handshake_address
,
identity
,
on_head_node
,
vllm_config
)
as
addresses
:
self
.
client_count
=
len
(
addresses
.
outputs
)
# Update config which may have changed from the handshake.
vllm_config
.
__post_init__
()
# Set up data parallel environment.
self
.
has_coordinator
=
addresses
.
coordinator_output
is
not
None
self
.
_init_data_parallel
(
vllm_config
)
# Initialize engine core and model.
super
().
__init__
(
vllm_config
,
executor_class
,
log_stats
,
executor_fail_callback
)
self
.
engine_index
=
engine_index
self
.
step_fn
=
(
self
.
step
if
self
.
batch_queue
is
None
else
self
.
step_with_batch_queue
)
self
.
engines_running
=
False
self
.
last_counts
=
(
0
,
0
)
# Send ready message.
num_gpu_blocks
=
vllm_config
.
cache_config
.
num_gpu_blocks
handshake_socket
.
send
(
msgspec
.
msgpack
.
encode
({
"status"
:
"READY"
,
"local"
:
on_head_node
,
"num_gpu_blocks"
:
num_gpu_blocks
,
}))
self
.
step_fn
=
(
self
.
step
if
self
.
batch_queue
is
None
else
self
.
step_with_batch_queue
)
# Background Threads and Queues for IO. These enable us to
# overlap ZMQ socket IO with GPU since they release the GIL,
# and to overlap some serialization/deserialization with the
# model forward pass.
# Threads handle Socket <-> Queues and core_busy_loop uses Queue.
self
.
input_queue
=
input_queue
self
.
output_queue
=
queue
.
Queue
[
Union
[
tuple
[
int
,
EngineCoreOutputs
],
bytes
]]()
threading
.
Thread
(
target
=
self
.
process_input_sockets
,
args
=
(
addresses
.
inputs
,
addresses
.
coordinator_input
,
identity
),
...
...
@@ -428,10 +404,40 @@ class EngineCoreProc(EngineCore):
self
.
output_thread
=
threading
.
Thread
(
target
=
self
.
process_output_sockets
,
args
=
(
addresses
.
outputs
,
addresses
.
coordinator_output
,
engine_index
),
self
.
engine_index
),
daemon
=
True
)
self
.
output_thread
.
start
()
@
contextmanager
def
_perform_handshake
(
self
,
handshake_address
:
str
,
identity
:
bytes
,
on_head_node
:
bool
,
vllm_config
:
VllmConfig
)
->
Generator
[
EngineZmqAddresses
,
None
,
None
]:
input_ctx
=
zmq
.
Context
()
with
make_zmq_socket
(
input_ctx
,
handshake_address
,
zmq
.
DEALER
,
identity
=
identity
,
linger
=
5000
,
bind
=
False
)
as
handshake_socket
:
# Register engine with front-end.
addresses
=
self
.
startup_handshake
(
handshake_socket
,
on_head_node
,
vllm_config
.
parallel_config
)
# Update config which may have changed from the handshake
vllm_config
.
__post_init__
()
yield
addresses
# Send ready message.
num_gpu_blocks
=
vllm_config
.
cache_config
.
num_gpu_blocks
handshake_socket
.
send
(
msgspec
.
msgpack
.
encode
({
"status"
:
"READY"
,
"local"
:
on_head_node
,
"num_gpu_blocks"
:
num_gpu_blocks
,
}))
@
staticmethod
def
startup_handshake
(
handshake_socket
:
zmq
.
Socket
,
on_head_node
:
bool
,
...
...
@@ -743,24 +749,29 @@ class DPEngineCoreProc(EngineCoreProc):
executor_class
:
type
[
Executor
],
log_stats
:
bool
,
):
# Add process-specific prefix to stdout and stderr before
# we initialize the engine.
from
multiprocessing
import
current_process
process_name
=
current_process
().
name
pid
=
os
.
getpid
()
_add_prefix
(
sys
.
stdout
,
process_name
,
pid
)
_add_prefix
(
sys
.
stderr
,
process_name
,
pid
)
self
.
_decorate_logs
()
# Counts forward-passes of the model so that we can synchronize
# finished with DP peers every N steps.
self
.
counter
=
0
self
.
current_wave
=
0
self
.
last_counts
=
(
0
,
0
)
# Initialize the engine.
dp_rank
=
vllm_config
.
parallel_config
.
data_parallel_rank
super
().
__init__
(
vllm_config
,
on_head_node
,
handshake_address
,
executor_class
,
log_stats
,
dp_rank
)
def
_decorate_logs
(
self
):
# Add process-specific prefix to stdout and stderr before
# we initialize the engine.
from
multiprocessing
import
current_process
process_name
=
current_process
().
name
pid
=
os
.
getpid
()
_add_prefix
(
sys
.
stdout
,
process_name
,
pid
)
_add_prefix
(
sys
.
stderr
,
process_name
,
pid
)
def
_init_data_parallel
(
self
,
vllm_config
:
VllmConfig
):
# Configure GPUs and stateless process group for data parallel.
...
...
@@ -880,3 +891,70 @@ class DPEngineCoreProc(EngineCoreProc):
return
ParallelConfig
.
has_unfinished_dp
(
self
.
dp_group
,
local_unfinished
)
class
DPEngineCoreActor
(
DPEngineCoreProc
):
"""
Ray actor for running EngineCore in a data parallel context
"""
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
on_head_node
:
bool
,
addresses
:
EngineZmqAddresses
,
executor_class
:
type
[
Executor
],
log_stats
:
bool
,
dp_rank
:
int
=
0
,
local_dp_rank
:
int
=
0
,
):
self
.
addresses
=
addresses
vllm_config
.
parallel_config
.
data_parallel_rank
=
dp_rank
vllm_config
.
parallel_config
.
data_parallel_rank_local
=
\
local_dp_rank
# Ray sets CUDA_VISIBLE_DEVICES to empty string,
# we clean this up to be able to properly initialize
# data parallel groups.
del
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
super
().
__init__
(
vllm_config
,
on_head_node
,
""
,
executor_class
,
log_stats
)
def
_decorate_logs
(
self
):
pass
@
contextmanager
def
_perform_handshake
(
self
,
handshake_address
:
str
,
identity
:
bytes
,
on_head_node
:
bool
,
vllm_config
:
VllmConfig
):
"""
For Ray, we don't need to actually perform handshake.
All addresses information is known before the actor creation.
Therefore, we simply yield these addresses.
"""
yield
self
.
addresses
def
wait_for_init
(
self
):
"""
Wait until the engine core is initialized.
This is just an empty method. When ray.get() on this method
(or any other method of the actor) returns, it is guaranteed
that actor creation (i.e., __init__) is complete.
"""
pass
def
run
(
self
):
"""
Run the engine core busy loop.
"""
try
:
self
.
run_busy_loop
()
except
SystemExit
:
logger
.
debug
(
"EngineCore exiting."
)
raise
except
Exception
:
logger
.
exception
(
"EngineCore encountered a fatal error."
)
raise
finally
:
self
.
shutdown
()
vllm/v1/engine/core_client.py
View file @
bdce64f2
...
...
@@ -29,9 +29,9 @@ from vllm.v1.engine.core import EngineCore, EngineCoreProc
from
vllm.v1.engine.exceptions
import
EngineDeadError
from
vllm.v1.executor.abstract
import
Executor
from
vllm.v1.serial_utils
import
MsgpackDecoder
,
MsgpackEncoder
,
bytestr
from
vllm.v1.utils
import
(
CoreEngine
,
CoreEngine
Proc
Manager
,
Engine
ZmqAddresses
,
get_engine_client_zmq_addr
,
wait_for_engine_startup
)
from
vllm.v1.utils
import
(
CoreEngine
,
CoreEngine
Actor
Manager
,
Core
Engine
ProcManager
,
EngineZmqAddresses
,
get_engine_client_zmq_addr
,
wait_for_engine_startup
)
logger
=
init_logger
(
__name__
)
...
...
@@ -68,6 +68,8 @@ class EngineCoreClient(ABC):
if
multiprocess_mode
and
asyncio_mode
:
if
vllm_config
.
parallel_config
.
data_parallel_size
>
1
:
if
vllm_config
.
parallel_config
.
data_parallel_backend
==
"ray"
:
return
RayDPClient
(
vllm_config
,
executor_class
,
log_stats
)
return
DPAsyncMPClient
(
vllm_config
,
executor_class
,
log_stats
)
return
AsyncMPClient
(
vllm_config
,
executor_class
,
log_stats
)
...
...
@@ -273,7 +275,10 @@ class BackgroundResources:
circular reference back to the client object."""
ctx
:
Union
[
zmq
.
Context
]
local_engine_manager
:
Optional
[
CoreEngineProcManager
]
=
None
# If CoreEngineProcManager, it manages local engines;
# if CoreEngineActorManager, it manages all engines.
engine_manager
:
Optional
[
Union
[
CoreEngineProcManager
,
CoreEngineActorManager
]]
=
None
coordinator
:
Optional
[
DPCoordinator
]
=
None
output_socket
:
Optional
[
Union
[
zmq
.
Socket
,
zmq
.
asyncio
.
Socket
]]
=
None
input_socket
:
Optional
[
Union
[
zmq
.
Socket
,
zmq
.
asyncio
.
Socket
]]
=
None
...
...
@@ -290,8 +295,8 @@ class BackgroundResources:
"""Clean up background resources."""
self
.
engine_dead
=
True
if
self
.
local_
engine_manager
is
not
None
:
self
.
local_
engine_manager
.
close
()
if
self
.
engine_manager
is
not
None
:
self
.
engine_manager
.
close
()
if
self
.
coordinator
is
not
None
:
self
.
coordinator
.
close
()
...
...
@@ -457,7 +462,7 @@ class MPClient(EngineCoreClient):
if
local_engine_count
:
# In server mode, start_index and local_start_index will
# both be 0.
self
.
resources
.
local_
engine_manager
=
CoreEngineProcManager
(
self
.
resources
.
engine_manager
=
CoreEngineProcManager
(
EngineCoreProc
.
run_engine_core
,
vllm_config
=
vllm_config
,
executor_class
=
executor_class
,
...
...
@@ -484,13 +489,18 @@ class MPClient(EngineCoreClient):
addresses
.
coordinator_input
,
addresses
.
coordinator_output
=
(
coordinator
.
get_engine_socket_addresses
())
proc_manager
=
self
.
resources
.
engine_manager
assert
isinstance
(
proc_manager
,
(
type
(
None
),
CoreEngineProcManager
)),
(
"_wait_for_engine_startup should only be "
"called with CoreEngineProcManager"
)
wait_for_engine_startup
(
handshake_socket
,
addresses
,
self
.
core_engines
,
self
.
vllm_config
.
parallel_config
,
self
.
vllm_config
.
cache_config
,
self
.
resources
.
local_engine
_manager
,
proc
_manager
,
coordinator
.
proc
if
coordinator
else
None
,
)
...
...
@@ -887,7 +897,6 @@ class DPAsyncMPClient(AsyncMPClient):
log_stats
:
bool
,
client_addresses
:
Optional
[
dict
[
str
,
str
]]
=
None
,
client_index
:
int
=
0
):
self
.
current_wave
=
0
self
.
engines_running
=
False
# To route aborts to the correct engine.
...
...
@@ -1050,3 +1059,50 @@ class DPAsyncMPClient(AsyncMPClient):
if
not
self
.
resources
.
engine_dead
:
await
self
.
_send_input
(
EngineCoreRequestType
.
ABORT
,
request_ids
,
engine
)
class
RayDPClient
(
DPAsyncMPClient
):
"""
Ray-based client for multi-proc, multi-engine (data parallel)
EngineCore.
"""
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
executor_class
:
type
[
Executor
],
log_stats
:
bool
,
client_addresses
:
Optional
[
dict
[
str
,
str
]]
=
None
,
client_index
:
int
=
0
,
):
super
().
__init__
(
vllm_config
,
executor_class
,
log_stats
,
client_addresses
,
client_index
)
def
_init_engines_direct
(
self
,
vllm_config
:
VllmConfig
,
local_only
:
bool
,
local_start_index
:
int
,
input_address
:
str
,
output_address
:
str
,
executor_class
:
type
[
Executor
],
log_stats
:
bool
):
"""Self-contained client mode, launch engine and coordinator process
as needed."""
parallel_config
=
vllm_config
.
parallel_config
assert
parallel_config
.
data_parallel_rank
==
0
assert
local_start_index
==
0
addresses
=
EngineZmqAddresses
(
inputs
=
[
input_address
],
outputs
=
[
output_address
],
)
if
len
(
self
.
core_engines
)
>
1
:
coordinator
=
DPCoordinator
(
parallel_config
)
self
.
resources
.
coordinator
=
coordinator
addresses
.
coordinator_input
,
addresses
.
coordinator_output
=
(
coordinator
.
get_engine_socket_addresses
())
# Start all engines.
self
.
resources
.
engine_manager
=
CoreEngineActorManager
(
vllm_config
=
vllm_config
,
addresses
=
addresses
,
executor_class
=
executor_class
,
log_stats
=
log_stats
)
vllm/v1/utils.py
View file @
bdce64f2
...
...
@@ -27,6 +27,8 @@ from vllm.utils import (get_mp_context, get_open_port, get_open_zmq_ipc_path,
from
vllm.v1.executor.abstract
import
Executor
if
TYPE_CHECKING
:
from
ray.util.placement_group
import
PlacementGroup
from
vllm.attention.layer
import
Attention
from
vllm.v1.engine.coordinator
import
DPCoordinator
...
...
@@ -112,6 +114,45 @@ def get_engine_client_zmq_addr(local_only: bool,
host
,
port
or
get_open_port
()))
class
CoreEngineState
(
Enum
):
NEW
=
auto
()
CONNECTED
=
auto
()
READY
=
auto
()
class
CoreEngine
:
"""One per data parallel rank."""
def
__init__
(
self
,
index
:
int
=
0
,
local
:
bool
=
True
):
self
.
local
=
local
self
.
index
=
index
self
.
identity
=
index
.
to_bytes
(
2
,
"little"
)
self
.
state
=
CoreEngineState
.
NEW
@
dataclass
class
EngineZmqAddresses
:
# ZMQ input socket addresses for each front-end client (requests)
inputs
:
list
[
str
]
# ZMQ output socket addresses for each front-end client (responses)
outputs
:
list
[
str
]
# ZMQ input socket address of DP coordinator if applicable
coordinator_input
:
Optional
[
str
]
=
None
# ZMQ output socket address of DP coordinator if applicable
coordinator_output
:
Optional
[
str
]
=
None
@
dataclass
class
EngineHandshakeMetadata
:
"""Metadata sent to each engine process during startup handshake,
including addresses of the front-end ZMQ queues that they should
connect to.
"""
addresses
:
EngineZmqAddresses
parallel_config
:
dict
[
str
,
Union
[
int
,
str
]]
class
APIServerProcessManager
:
"""Manages a group of API server processes.
...
...
@@ -245,43 +286,168 @@ class CoreEngineProcManager:
}
class
CoreEngineState
(
Enum
):
NEW
=
auto
()
CONNECTED
=
auto
()
READY
=
auto
()
class
CoreEngineActorManager
:
"""
Utility class to handle creation, readiness, and shutdown
of core engine Ray actors used by the AsyncLLM and LLMEngine.
class
CoreEngine
:
"""One per data parallel rank."""
Different from CoreEngineProcManager, this class manages
core engines for both local and remote nodes.
"""
def
__init__
(
self
,
index
:
int
=
0
,
local
:
bool
=
True
):
self
.
local
=
local
self
.
index
=
index
self
.
identity
=
index
.
to_bytes
(
2
,
"little"
)
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
addresses
:
EngineZmqAddresses
,
executor_class
:
type
[
Executor
],
log_stats
:
bool
,
placement_groups
:
Optional
[
list
[
"PlacementGroup"
]]
=
None
,
local_dp_ranks
:
Optional
[
list
[
int
]]
=
None
,
):
import
copy
self
.
state
=
CoreEngineState
.
NEW
import
ray
from
ray.util.scheduling_strategies
import
(
PlacementGroupSchedulingStrategy
)
from
vllm.v1.engine.core
import
DPEngineCoreActor
@
dataclass
class
EngineZmqAddresses
:
# ZMQ input socket addresses for each front-end client (requests)
inputs
:
list
[
str
]
# ZMQ output socket addresses for each front-end client (responses)
outputs
:
list
[
str
]
# ZMQ input socket address of DP coordinator if applicable
coordinator_input
:
Optional
[
str
]
=
None
# ZMQ output socket address of DP coordinator if applicable
coordinator_output
:
Optional
[
str
]
=
None
self
.
local_engine_actors
:
list
[
ray
.
ActorHandle
]
=
[]
self
.
remote_engine_actors
:
list
[
ray
.
ActorHandle
]
=
[]
dp_size
=
vllm_config
.
parallel_config
.
data_parallel_size
local_engine_count
=
\
vllm_config
.
parallel_config
.
data_parallel_size_local
world_size
=
vllm_config
.
parallel_config
.
world_size
if
ray
.
is_initialized
():
logger
.
info
(
"Ray is already initialized. Skipping Ray initialization."
)
else
:
ray
.
init
()
if
placement_groups
is
not
None
:
assert
local_dp_ranks
is
not
None
,
(
"local_dp_ranks must be provided if "
"placement_groups is provided"
)
assert
len
(
placement_groups
)
==
len
(
local_dp_ranks
),
(
"placement_groups and local_dp_ranks must "
"have the same length"
)
logger
.
info
(
"Using provided placement groups"
)
# TODO(rui): validate passed-in placement groups
self
.
created_placement_groups
=
[]
else
:
placement_groups
,
local_dp_ranks
=
\
CoreEngineActorManager
.
create_dp_placement_groups
(
vllm_config
)
self
.
created_placement_groups
=
placement_groups
assert
len
(
placement_groups
)
==
dp_size
,
(
"Number of placement groups must match data parallel size"
)
refs
=
[]
for
index
in
range
(
dp_size
):
local_index
=
local_dp_ranks
[
index
]
dp_vllm_config
=
copy
.
deepcopy
(
vllm_config
)
pg
=
placement_groups
[
index
]
dp_vllm_config
.
parallel_config
.
placement_group
=
pg
on_head_node
=
index
<
local_engine_count
actor
=
ray
.
remote
(
DPEngineCoreActor
).
options
(
scheduling_strategy
=
PlacementGroupSchedulingStrategy
(
placement_group
=
pg
,
placement_group_bundle_index
=
world_size
,
)).
remote
(
vllm_config
=
dp_vllm_config
,
executor_class
=
executor_class
,
log_stats
=
log_stats
,
on_head_node
=
on_head_node
,
addresses
=
addresses
,
dp_rank
=
index
,
local_dp_rank
=
local_index
)
if
on_head_node
:
self
.
local_engine_actors
.
append
(
actor
)
else
:
self
.
remote_engine_actors
.
append
(
actor
)
refs
.
append
(
actor
.
wait_for_init
.
remote
())
ray
.
get
(
refs
)
self
.
run_refs
=
[]
for
actor
in
self
.
local_engine_actors
+
self
.
remote_engine_actors
:
self
.
run_refs
.
append
(
actor
.
run
.
remote
())
@
staticmethod
def
create_dp_placement_groups
(
vllm_config
:
VllmConfig
)
->
tuple
[
list
[
"PlacementGroup"
],
list
[
int
]]:
import
ray
from
ray._private.state
import
available_resources_per_node
from
ray.util.state
import
list_nodes
logger
.
info
(
"Creating placement groups for data parallel"
)
dp_master_ip
=
\
vllm_config
.
parallel_config
.
data_parallel_master_ip
dp_size
=
vllm_config
.
parallel_config
.
data_parallel_size
local_engine_count
=
\
vllm_config
.
parallel_config
.
data_parallel_size_local
nodes
=
list_nodes
()
nodes
=
sorted
(
list_nodes
(),
key
=
lambda
node
:
node
.
node_ip
!=
dp_master_ip
)
assert
nodes
[
0
].
node_ip
==
dp_master_ip
,
(
"The first node must be the head node"
)
assert
len
(
nodes
)
==
1
or
nodes
[
1
].
node_ip
!=
dp_master_ip
,
(
"There can only be one head node"
)
available_resources
=
available_resources_per_node
()
world_size
=
vllm_config
.
parallel_config
.
world_size
placement_groups
:
list
[
PlacementGroup
]
=
[]
local_dp_ranks
:
list
[
int
]
=
[]
for
node
in
nodes
:
node_ip
=
node
.
node_ip
node_resources
=
available_resources
[
node
.
node_id
]
# For now, each DP rank can only be assigned to one node
# TODO(rui): support allocating a single DP rank
# to multiple nodes
available_engine_count
=
node_resources
[
"GPU"
]
//
world_size
if
node_ip
==
dp_master_ip
:
assert
available_engine_count
>=
local_engine_count
,
(
"Not enough resources to allocate DP ranks "
f
"on DP master node
{
node_ip
}
"
)
for
i
in
range
(
local_engine_count
):
bundles
=
[{
"GPU"
:
1.0
,
"node:"
+
dp_master_ip
:
0.001
}]
*
world_size
+
[{
"CPU"
:
1.0
}]
pg
=
ray
.
util
.
placement_group
(
name
=
f
"dp_rank_
{
len
(
placement_groups
)
}
"
,
strategy
=
"STRICT_PACK"
,
bundles
=
bundles
,
)
placement_groups
.
append
(
pg
)
local_dp_ranks
.
append
(
i
)
else
:
for
i
in
range
(
available_engine_count
):
if
len
(
placement_groups
)
==
dp_size
:
break
bundles
=
[{
"GPU"
:
1.0
}]
*
world_size
+
[{
"CPU"
:
1.0
}]
pg
=
ray
.
util
.
placement_group
(
name
=
f
"dp_rank_
{
len
(
placement_groups
)
}
"
,
strategy
=
"STRICT_PACK"
,
bundles
=
bundles
,
)
placement_groups
.
append
(
pg
)
local_dp_ranks
.
append
(
i
)
return
placement_groups
,
local_dp_ranks
def
get_run_refs
(
self
):
return
self
.
run_refs
@
dataclass
class
EngineHandshakeMetadata
:
"""Metadata sent to each engine process during startup handshake,
including addresses of the front-end ZMQ queues that they should
connect to.
"""
addresses
:
EngineZmqAddresses
parallel_config
:
dict
[
str
,
Union
[
int
,
str
]]
def
close
(
self
):
import
ray
for
actor
in
self
.
local_engine_actors
+
self
.
remote_engine_actors
:
ray
.
kill
(
actor
)
for
pg
in
self
.
created_placement_groups
:
ray
.
util
.
remove_placement_group
(
pg
)
def
wait_for_engine_startup
(
...
...
@@ -383,11 +549,19 @@ def wait_for_engine_startup(
def
wait_for_completion_or_failure
(
api_server_manager
:
APIServerProcessManager
,
local_engine_manager
:
Optional
[
CoreEngineProcManager
]
=
None
,
engine_manager
:
Optional
[
Union
[
CoreEngineProcManager
,
CoreEngineActorManager
]]
=
None
,
coordinator
:
Optional
[
"DPCoordinator"
]
=
None
)
->
None
:
"""Wait for all processes to complete or detect if any fail.
Raises an exception if any process exits with a non-zero status.
Args:
api_server_manager: The manager for API servers.
engine_manager: The manager for engine processes.
If CoreEngineProcManager, it manages local engines;
if CoreEngineActorManager, it manages all engines.
coordinator: The coordinator for data parallel.
"""
try
:
...
...
@@ -402,14 +576,18 @@ def wait_for_completion_or_failure(
if
coordinator
:
sentinel_to_proc
[
coordinator
.
proc
.
sentinel
]
=
coordinator
.
proc
if
local_engine_manager
:
for
proc
in
local_engine_manager
.
processes
:
actor_run_refs
=
[]
if
isinstance
(
engine_manager
,
CoreEngineProcManager
):
for
proc
in
engine_manager
.
processes
:
sentinel_to_proc
[
proc
.
sentinel
]
=
proc
elif
isinstance
(
engine_manager
,
CoreEngineActorManager
):
actor_run_refs
=
engine_manager
.
get_run_refs
()
# Check if any process terminates
while
sentinel_to_proc
:
while
sentinel_to_proc
or
actor_run_refs
:
# Wait for any process to terminate
ready_sentinels
:
list
[
Any
]
=
connection
.
wait
(
sentinel_to_proc
)
ready_sentinels
:
list
[
Any
]
=
connection
.
wait
(
sentinel_to_proc
,
timeout
=
5
)
# Process any terminated processes
for
sentinel
in
ready_sentinels
:
...
...
@@ -420,6 +598,11 @@ def wait_for_completion_or_failure(
raise
RuntimeError
(
f
"Process
{
proc
.
name
}
(PID:
{
proc
.
pid
}
) "
f
"died with exit code
{
proc
.
exitcode
}
"
)
if
actor_run_refs
:
import
ray
_
,
actor_run_refs
=
ray
.
wait
(
actor_run_refs
,
timeout
=
5
)
except
KeyboardInterrupt
:
logger
.
info
(
"Received KeyboardInterrupt, shutting down API servers..."
)
except
Exception
as
e
:
...
...
@@ -431,8 +614,8 @@ def wait_for_completion_or_failure(
api_server_manager
.
close
()
if
coordinator
:
coordinator
.
close
()
if
local_
engine_manager
:
local_
engine_manager
.
close
()
if
engine_manager
:
engine_manager
.
close
()
# Note(rob): shutdown function cannot be a bound method,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment