Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
bdce64f2
Unverified
Commit
bdce64f2
authored
Jun 02, 2025
by
Rui Qiao
Committed by
GitHub
Jun 02, 2025
Browse files
[V1] Support DP with Ray (#18779)
parent
9e6f61e8
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
539 additions
and
108 deletions
+539
-108
requirements/test.in
requirements/test.in
+1
-1
requirements/test.txt
requirements/test.txt
+50
-0
tests/v1/test_async_llm_dp.py
tests/v1/test_async_llm_dp.py
+10
-3
vllm/config.py
vllm/config.py
+6
-0
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+25
-4
vllm/entrypoints/cli/serve.py
vllm/entrypoints/cli/serve.py
+30
-5
vllm/v1/engine/async_llm.py
vllm/v1/engine/async_llm.py
+9
-4
vllm/v1/engine/core.py
vllm/v1/engine/core.py
+123
-45
vllm/v1/engine/core_client.py
vllm/v1/engine/core_client.py
+65
-9
vllm/v1/utils.py
vllm/v1/utils.py
+220
-37
No files found.
requirements/test.in
View file @
bdce64f2
...
@@ -17,7 +17,7 @@ vector_quantize_pytorch # required for minicpmo_26 test
...
@@ -17,7 +17,7 @@ vector_quantize_pytorch # required for minicpmo_26 test
vocos # required for minicpmo_26 test
vocos # required for minicpmo_26 test
peft
peft
pqdm
pqdm
ray[cgraph]>=2.43.0, !=2.44.* # Ray Compiled Graph, required by pipeline parallelism tests
ray[cgraph
,default
]>=2.43.0, !=2.44.* # Ray Compiled Graph, required by pipeline parallelism tests
sentence-transformers # required for embedding tests
sentence-transformers # required for embedding tests
soundfile # required for audio tests
soundfile # required for audio tests
jiwer # required for audio tests
jiwer # required for audio tests
...
...
requirements/test.txt
View file @
bdce64f2
...
@@ -10,9 +10,13 @@ aiohappyeyeballs==2.4.3
...
@@ -10,9 +10,13 @@ aiohappyeyeballs==2.4.3
# via aiohttp
# via aiohttp
aiohttp==3.10.11
aiohttp==3.10.11
# via
# via
# aiohttp-cors
# datasets
# datasets
# fsspec
# fsspec
# lm-eval
# lm-eval
# ray
aiohttp-cors==0.8.1
# via ray
aiosignal==1.3.1
aiosignal==1.3.1
# via
# via
# aiohttp
# aiohttp
...
@@ -57,6 +61,8 @@ bounded-pool-executor==0.0.3
...
@@ -57,6 +61,8 @@ bounded-pool-executor==0.0.3
# via pqdm
# via pqdm
buildkite-test-collector==0.1.9
buildkite-test-collector==0.1.9
# via -r requirements/test.in
# via -r requirements/test.in
cachetools==5.5.2
# via google-auth
certifi==2024.8.30
certifi==2024.8.30
# via
# via
# httpcore
# httpcore
...
@@ -81,6 +87,8 @@ colorama==0.4.6
...
@@ -81,6 +87,8 @@ colorama==0.4.6
# sacrebleu
# sacrebleu
# schemathesis
# schemathesis
# tqdm-multiprocess
# tqdm-multiprocess
colorful==0.5.6
# via ray
contourpy==1.3.0
contourpy==1.3.0
# via matplotlib
# via matplotlib
cramjam==2.9.0
cramjam==2.9.0
...
@@ -108,6 +116,8 @@ dill==0.3.8
...
@@ -108,6 +116,8 @@ dill==0.3.8
# evaluate
# evaluate
# lm-eval
# lm-eval
# multiprocess
# multiprocess
distlib==0.3.9
# via virtualenv
dnspython==2.7.0
dnspython==2.7.0
# via email-validator
# via email-validator
docopt==0.6.2
docopt==0.6.2
...
@@ -143,6 +153,7 @@ filelock==3.16.1
...
@@ -143,6 +153,7 @@ filelock==3.16.1
# ray
# ray
# torch
# torch
# transformers
# transformers
# virtualenv
fonttools==4.54.1
fonttools==4.54.1
# via matplotlib
# via matplotlib
fqdn==1.5.1
fqdn==1.5.1
...
@@ -165,8 +176,16 @@ genai-perf==0.0.8
...
@@ -165,8 +176,16 @@ genai-perf==0.0.8
# via -r requirements/test.in
# via -r requirements/test.in
genson==1.3.0
genson==1.3.0
# via datamodel-code-generator
# via datamodel-code-generator
google-api-core==2.24.2
# via opencensus
google-auth==2.40.2
# via google-api-core
googleapis-common-protos==1.70.0
# via google-api-core
graphql-core==3.2.6
graphql-core==3.2.6
# via hypothesis-graphql
# via hypothesis-graphql
grpcio==1.71.0
# via ray
h11==0.14.0
h11==0.14.0
# via httpcore
# via httpcore
harfile==0.3.0
harfile==0.3.0
...
@@ -392,6 +411,10 @@ nvidia-nvjitlink-cu12==12.8.61
...
@@ -392,6 +411,10 @@ nvidia-nvjitlink-cu12==12.8.61
# torch
# torch
nvidia-nvtx-cu12==12.8.55
nvidia-nvtx-cu12==12.8.55
# via torch
# via torch
opencensus==0.11.4
# via ray
opencensus-context==0.1.3
# via opencensus
opencv-python-headless==4.11.0.86
opencv-python-headless==4.11.0.86
# via
# via
# -r requirements/test.in
# -r requirements/test.in
...
@@ -445,6 +468,7 @@ platformdirs==4.3.6
...
@@ -445,6 +468,7 @@ platformdirs==4.3.6
# via
# via
# black
# black
# pooch
# pooch
# virtualenv
plotly==5.24.1
plotly==5.24.1
# via genai-perf
# via genai-perf
pluggy==1.5.0
pluggy==1.5.0
...
@@ -457,10 +481,17 @@ portalocker==2.10.1
...
@@ -457,10 +481,17 @@ portalocker==2.10.1
# via sacrebleu
# via sacrebleu
pqdm==0.2.0
pqdm==0.2.0
# via -r requirements/test.in
# via -r requirements/test.in
prometheus-client==0.22.0
# via ray
propcache==0.2.0
propcache==0.2.0
# via yarl
# via yarl
proto-plus==1.26.1
# via google-api-core
protobuf==5.28.3
protobuf==5.28.3
# via
# via
# google-api-core
# googleapis-common-protos
# proto-plus
# ray
# ray
# tensorizer
# tensorizer
psutil==6.1.0
psutil==6.1.0
...
@@ -470,10 +501,18 @@ psutil==6.1.0
...
@@ -470,10 +501,18 @@ psutil==6.1.0
# tensorizer
# tensorizer
py==1.11.0
py==1.11.0
# via pytest-forked
# via pytest-forked
py-spy==0.4.0
# via ray
pyarrow==18.0.0
pyarrow==18.0.0
# via
# via
# datasets
# datasets
# genai-perf
# genai-perf
pyasn1==0.6.1
# via
# pyasn1-modules
# rsa
pyasn1-modules==0.4.2
# via google-auth
pybind11==2.13.6
pybind11==2.13.6
# via lm-eval
# via lm-eval
pycparser==2.22
pycparser==2.22
...
@@ -486,6 +525,7 @@ pydantic==2.11.5
...
@@ -486,6 +525,7 @@ pydantic==2.11.5
# datamodel-code-generator
# datamodel-code-generator
# mistral-common
# mistral-common
# mteb
# mteb
# ray
pydantic-core==2.33.2
pydantic-core==2.33.2
# via pydantic
# via pydantic
pygments==2.18.0
pygments==2.18.0
...
@@ -573,6 +613,7 @@ requests==2.32.3
...
@@ -573,6 +613,7 @@ requests==2.32.3
# buildkite-test-collector
# buildkite-test-collector
# datasets
# datasets
# evaluate
# evaluate
# google-api-core
# huggingface-hub
# huggingface-hub
# lm-eval
# lm-eval
# mistral-common
# mistral-common
...
@@ -601,6 +642,8 @@ rpds-py==0.20.1
...
@@ -601,6 +642,8 @@ rpds-py==0.20.1
# via
# via
# jsonschema
# jsonschema
# referencing
# referencing
rsa==4.9.1
# via google-auth
runai-model-streamer==0.11.0
runai-model-streamer==0.11.0
# via -r requirements/test.in
# via -r requirements/test.in
runai-model-streamer-s3==0.11.0
runai-model-streamer-s3==0.11.0
...
@@ -648,9 +691,12 @@ shellingham==1.5.4
...
@@ -648,9 +691,12 @@ shellingham==1.5.4
six==1.16.0
six==1.16.0
# via
# via
# junit-xml
# junit-xml
# opencensus
# python-dateutil
# python-dateutil
# rfc3339-validator
# rfc3339-validator
# rouge-score
# rouge-score
smart-open==7.1.0
# via ray
sniffio==1.3.1
sniffio==1.3.1
# via
# via
# anyio
# anyio
...
@@ -801,6 +847,8 @@ urllib3==2.2.3
...
@@ -801,6 +847,8 @@ urllib3==2.2.3
# tritonclient
# tritonclient
vector-quantize-pytorch==1.21.2
vector-quantize-pytorch==1.21.2
# via -r requirements/test.in
# via -r requirements/test.in
virtualenv==20.31.2
# via ray
vocos==0.1.0
vocos==0.1.0
# via -r requirements/test.in
# via -r requirements/test.in
webcolors==24.11.1
webcolors==24.11.1
...
@@ -809,6 +857,8 @@ werkzeug==3.1.3
...
@@ -809,6 +857,8 @@ werkzeug==3.1.3
# via schemathesis
# via schemathesis
word2number==1.1
word2number==1.1
# via lm-eval
# via lm-eval
wrapt==1.17.2
# via smart-open
xxhash==3.5.0
xxhash==3.5.0
# via
# via
# datasets
# datasets
...
...
tests/v1/test_async_llm_dp.py
View file @
bdce64f2
...
@@ -59,14 +59,22 @@ async def generate(engine: AsyncLLM,
...
@@ -59,14 +59,22 @@ async def generate(engine: AsyncLLM,
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"output_kind"
,
[
RequestOutputKind
.
DELTA
,
RequestOutputKind
.
FINAL_ONLY
])
"output_kind"
,
[
RequestOutputKind
.
DELTA
,
RequestOutputKind
.
FINAL_ONLY
,
],
)
@
pytest
.
mark
.
parametrize
(
"data_parallel_backend"
,
[
"mp"
,
"ray"
])
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
asyncio
async
def
test_load
(
output_kind
:
RequestOutputKind
):
async
def
test_load
(
output_kind
:
RequestOutputKind
,
data_parallel_backend
:
str
):
with
ExitStack
()
as
after
:
with
ExitStack
()
as
after
:
prompt
=
"This is a test of data parallel"
prompt
=
"This is a test of data parallel"
engine_args
.
data_parallel_backend
=
data_parallel_backend
engine
=
AsyncLLM
.
from_engine_args
(
engine_args
)
engine
=
AsyncLLM
.
from_engine_args
(
engine_args
)
after
.
callback
(
engine
.
shutdown
)
after
.
callback
(
engine
.
shutdown
)
...
@@ -82,7 +90,6 @@ async def test_load(output_kind: RequestOutputKind):
...
@@ -82,7 +90,6 @@ async def test_load(output_kind: RequestOutputKind):
asyncio
.
create_task
(
asyncio
.
create_task
(
generate
(
engine
,
request_id
,
prompt
,
output_kind
,
generate
(
engine
,
request_id
,
prompt
,
output_kind
,
NUM_EXPECTED_TOKENS
)))
NUM_EXPECTED_TOKENS
)))
# Confirm that we got all the EXPECTED tokens from the requests.
# Confirm that we got all the EXPECTED tokens from the requests.
done
,
pending
=
await
asyncio
.
wait
(
tasks
,
done
,
pending
=
await
asyncio
.
wait
(
tasks
,
return_when
=
asyncio
.
FIRST_EXCEPTION
)
return_when
=
asyncio
.
FIRST_EXCEPTION
)
...
...
vllm/config.py
View file @
bdce64f2
...
@@ -1742,6 +1742,8 @@ class ParallelConfig:
...
@@ -1742,6 +1742,8 @@ class ParallelConfig:
"""Port for data parallel messaging."""
"""Port for data parallel messaging."""
data_parallel_master_port
:
int
=
29500
data_parallel_master_port
:
int
=
29500
"""Port of the data parallel master."""
"""Port of the data parallel master."""
data_parallel_backend
:
str
=
"mp"
"""Backend to use for data parallel, either "mp" or "ray"."""
enable_expert_parallel
:
bool
=
False
enable_expert_parallel
:
bool
=
False
"""Use expert parallelism instead of tensor parallelism for MoE layers."""
"""Use expert parallelism instead of tensor parallelism for MoE layers."""
max_parallel_loading_workers
:
Optional
[
int
]
=
None
max_parallel_loading_workers
:
Optional
[
int
]
=
None
...
@@ -1911,6 +1913,10 @@ class ParallelConfig:
...
@@ -1911,6 +1913,10 @@ class ParallelConfig:
"please install Ray with `pip install "
"please install Ray with `pip install "
"ray`."
)
from
ray_utils
.
ray_import_err
"ray`."
)
from
ray_utils
.
ray_import_err
backend
=
"ray"
backend
=
"ray"
elif
self
.
data_parallel_backend
==
"ray"
:
logger
.
info
(
"Using ray distributed inference because "
"data_parallel_backend is ray"
)
backend
=
"ray"
elif
ray_found
:
elif
ray_found
:
if
self
.
placement_group
:
if
self
.
placement_group
:
backend
=
"ray"
backend
=
"ray"
...
...
vllm/engine/arg_utils.py
View file @
bdce64f2
...
@@ -39,7 +39,7 @@ from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3
...
@@ -39,7 +39,7 @@ from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3
from
vllm.transformers_utils.utils
import
check_gguf_file
from
vllm.transformers_utils.utils
import
check_gguf_file
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.utils
import
(
STR_DUAL_CHUNK_FLASH_ATTN_VAL
,
FlexibleArgumentParser
,
from
vllm.utils
import
(
STR_DUAL_CHUNK_FLASH_ATTN_VAL
,
FlexibleArgumentParser
,
GiB_bytes
,
is_in_ray_actor
)
GiB_bytes
,
get_ip
,
is_in_ray_actor
)
# yapf: enable
# yapf: enable
...
@@ -292,6 +292,7 @@ class EngineArgs:
...
@@ -292,6 +292,7 @@ class EngineArgs:
data_parallel_size_local
:
Optional
[
int
]
=
None
data_parallel_size_local
:
Optional
[
int
]
=
None
data_parallel_address
:
Optional
[
str
]
=
None
data_parallel_address
:
Optional
[
str
]
=
None
data_parallel_rpc_port
:
Optional
[
int
]
=
None
data_parallel_rpc_port
:
Optional
[
int
]
=
None
data_parallel_backend
:
str
=
ParallelConfig
.
data_parallel_backend
enable_expert_parallel
:
bool
=
ParallelConfig
.
enable_expert_parallel
enable_expert_parallel
:
bool
=
ParallelConfig
.
enable_expert_parallel
max_parallel_loading_workers
:
Optional
[
max_parallel_loading_workers
:
Optional
[
int
]
=
ParallelConfig
.
max_parallel_loading_workers
int
]
=
ParallelConfig
.
max_parallel_loading_workers
...
@@ -624,6 +625,12 @@ class EngineArgs:
...
@@ -624,6 +625,12 @@ class EngineArgs:
type
=
int
,
type
=
int
,
help
=
'Port for data parallel RPC '
help
=
'Port for data parallel RPC '
'communication.'
)
'communication.'
)
parallel_group
.
add_argument
(
'--data-parallel-backend'
,
'-dpb'
,
type
=
str
,
default
=
'mp'
,
help
=
'Backend for data parallel, either '
'"mp" or "ray".'
)
parallel_group
.
add_argument
(
parallel_group
.
add_argument
(
"--enable-expert-parallel"
,
"--enable-expert-parallel"
,
**
parallel_kwargs
[
"enable_expert_parallel"
])
**
parallel_kwargs
[
"enable_expert_parallel"
])
...
@@ -1059,9 +1066,20 @@ class EngineArgs:
...
@@ -1059,9 +1066,20 @@ class EngineArgs:
# DP address, used in multi-node case for torch distributed group
# DP address, used in multi-node case for torch distributed group
# and ZMQ sockets.
# and ZMQ sockets.
data_parallel_address
=
self
.
data_parallel_address
if
(
if
self
.
data_parallel_address
is
None
:
self
.
data_parallel_address
if
self
.
data_parallel_backend
==
"ray"
:
is
not
None
)
else
ParallelConfig
.
data_parallel_master_ip
host_ip
=
get_ip
()
logger
.
info
(
"Using host IP %s as ray-based data parallel address"
,
host_ip
)
data_parallel_address
=
host_ip
else
:
assert
self
.
data_parallel_backend
==
"mp"
,
(
"data_parallel_backend can only be ray or mp, got %s"
,
self
.
data_parallel_backend
)
data_parallel_address
=
ParallelConfig
.
data_parallel_master_ip
else
:
data_parallel_address
=
self
.
data_parallel_address
# This port is only used when there are remote data parallel engines,
# This port is only used when there are remote data parallel engines,
# otherwise the local IPC transport is used.
# otherwise the local IPC transport is used.
...
@@ -1069,6 +1087,8 @@ class EngineArgs:
...
@@ -1069,6 +1087,8 @@ class EngineArgs:
self
.
data_parallel_rpc_port
self
.
data_parallel_rpc_port
is
not
None
)
else
ParallelConfig
.
data_parallel_rpc_port
is
not
None
)
else
ParallelConfig
.
data_parallel_rpc_port
data_parallel_backend
=
self
.
data_parallel_backend
parallel_config
=
ParallelConfig
(
parallel_config
=
ParallelConfig
(
pipeline_parallel_size
=
self
.
pipeline_parallel_size
,
pipeline_parallel_size
=
self
.
pipeline_parallel_size
,
tensor_parallel_size
=
self
.
tensor_parallel_size
,
tensor_parallel_size
=
self
.
tensor_parallel_size
,
...
@@ -1076,6 +1096,7 @@ class EngineArgs:
...
@@ -1076,6 +1096,7 @@ class EngineArgs:
data_parallel_size_local
=
data_parallel_size_local
,
data_parallel_size_local
=
data_parallel_size_local
,
data_parallel_master_ip
=
data_parallel_address
,
data_parallel_master_ip
=
data_parallel_address
,
data_parallel_rpc_port
=
data_parallel_rpc_port
,
data_parallel_rpc_port
=
data_parallel_rpc_port
,
data_parallel_backend
=
data_parallel_backend
,
enable_expert_parallel
=
self
.
enable_expert_parallel
,
enable_expert_parallel
=
self
.
enable_expert_parallel
,
max_parallel_loading_workers
=
self
.
max_parallel_loading_workers
,
max_parallel_loading_workers
=
self
.
max_parallel_loading_workers
,
disable_custom_all_reduce
=
self
.
disable_custom_all_reduce
,
disable_custom_all_reduce
=
self
.
disable_custom_all_reduce
,
...
...
vllm/entrypoints/cli/serve.py
View file @
bdce64f2
...
@@ -27,7 +27,8 @@ from vllm.v1.engine.core_client import CoreEngineProcManager
...
@@ -27,7 +27,8 @@ from vllm.v1.engine.core_client import CoreEngineProcManager
from
vllm.v1.executor.abstract
import
Executor
from
vllm.v1.executor.abstract
import
Executor
from
vllm.v1.metrics.prometheus
import
setup_multiprocess_prometheus
from
vllm.v1.metrics.prometheus
import
setup_multiprocess_prometheus
from
vllm.v1.utils
import
(
APIServerProcessManager
,
CoreEngine
,
from
vllm.v1.utils
import
(
APIServerProcessManager
,
CoreEngine
,
EngineZmqAddresses
,
get_engine_client_zmq_addr
,
CoreEngineActorManager
,
EngineZmqAddresses
,
get_engine_client_zmq_addr
,
wait_for_completion_or_failure
,
wait_for_completion_or_failure
,
wait_for_engine_startup
)
wait_for_engine_startup
)
...
@@ -229,6 +230,31 @@ def run_multi_api_server(args: argparse.Namespace):
...
@@ -229,6 +230,31 @@ def run_multi_api_server(args: argparse.Namespace):
logger
.
info
(
"Started DP Coordinator process (PID: %d)"
,
logger
.
info
(
"Started DP Coordinator process (PID: %d)"
,
coordinator
.
proc
.
pid
)
coordinator
.
proc
.
pid
)
if
parallel_config
.
data_parallel_backend
==
"ray"
:
logger
.
info
(
"Starting ray-based data parallel backend"
)
engine_actor_manager
=
CoreEngineActorManager
(
vllm_config
=
vllm_config
,
addresses
=
addresses
,
executor_class
=
Executor
.
get_class
(
vllm_config
),
log_stats
=
not
engine_args
.
disable_log_stats
,
)
# Start API servers using the manager
api_server_manager
=
APIServerProcessManager
(
target_server_fn
=
run_api_server_worker_proc
,
listen_address
=
listen_address
,
sock
=
sock
,
args
=
args
,
num_servers
=
num_api_servers
,
input_addresses
=
input_addresses
,
output_addresses
=
output_addresses
,
stats_update_address
=
stats_update_address
)
wait_for_completion_or_failure
(
api_server_manager
=
api_server_manager
,
engine_manager
=
engine_actor_manager
,
coordinator
=
coordinator
)
return
handshake_address
=
get_engine_client_zmq_addr
(
handshake_address
=
get_engine_client_zmq_addr
(
local_only
,
host
,
parallel_config
.
data_parallel_rpc_port
)
local_only
,
host
,
parallel_config
.
data_parallel_rpc_port
)
...
@@ -277,9 +303,8 @@ def run_multi_api_server(args: argparse.Namespace):
...
@@ -277,9 +303,8 @@ def run_multi_api_server(args: argparse.Namespace):
)
)
# Wait for API servers
# Wait for API servers
wait_for_completion_or_failure
(
wait_for_completion_or_failure
(
api_server_manager
=
api_server_manager
,
api_server_manager
=
api_server_manager
,
engine_manager
=
local_engine_manager
,
local_engine_manager
=
local_engine_manager
,
coordinator
=
coordinator
)
coordinator
=
coordinator
)
...
...
vllm/v1/engine/async_llm.py
View file @
bdce64f2
...
@@ -27,7 +27,8 @@ from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
...
@@ -27,7 +27,8 @@ from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.utils
import
Device
,
cdiv
from
vllm.utils
import
Device
,
cdiv
from
vllm.v1.engine
import
EngineCoreRequest
from
vllm.v1.engine
import
EngineCoreRequest
from
vllm.v1.engine.core_client
import
AsyncMPClient
,
DPAsyncMPClient
from
vllm.v1.engine.core_client
import
(
AsyncMPClient
,
DPAsyncMPClient
,
RayDPClient
)
from
vllm.v1.engine.exceptions
import
EngineDeadError
,
EngineGenerateError
from
vllm.v1.engine.exceptions
import
EngineDeadError
,
EngineGenerateError
from
vllm.v1.engine.output_processor
import
(
OutputProcessor
,
from
vllm.v1.engine.output_processor
import
(
OutputProcessor
,
RequestOutputCollector
)
RequestOutputCollector
)
...
@@ -119,9 +120,13 @@ class AsyncLLM(EngineClient):
...
@@ -119,9 +120,13 @@ class AsyncLLM(EngineClient):
log_stats
=
self
.
log_stats
)
log_stats
=
self
.
log_stats
)
# EngineCore (starts the engine in background process).
# EngineCore (starts the engine in background process).
core_client_class
=
AsyncMPClient
if
(
core_client_class
:
type
[
AsyncMPClient
]
vllm_config
.
parallel_config
.
data_parallel_size
if
vllm_config
.
parallel_config
.
data_parallel_size
==
1
:
==
1
)
else
DPAsyncMPClient
core_client_class
=
AsyncMPClient
elif
vllm_config
.
parallel_config
.
data_parallel_backend
==
"ray"
:
core_client_class
=
RayDPClient
else
:
core_client_class
=
DPAsyncMPClient
self
.
engine_core
=
core_client_class
(
self
.
engine_core
=
core_client_class
(
vllm_config
=
vllm_config
,
vllm_config
=
vllm_config
,
...
...
vllm/v1/engine/core.py
View file @
bdce64f2
...
@@ -6,8 +6,9 @@ import sys
...
@@ -6,8 +6,9 @@ import sys
import
threading
import
threading
import
time
import
time
from
collections
import
deque
from
collections
import
deque
from
collections.abc
import
Generator
from
concurrent.futures
import
Future
from
concurrent.futures
import
Future
from
contextlib
import
ExitStack
from
contextlib
import
ExitStack
,
contextmanager
from
inspect
import
isclass
,
signature
from
inspect
import
isclass
,
signature
from
logging
import
DEBUG
from
logging
import
DEBUG
from
typing
import
Any
,
Callable
,
Optional
,
TypeVar
,
Union
from
typing
import
Any
,
Callable
,
Optional
,
TypeVar
,
Union
...
@@ -367,60 +368,35 @@ class EngineCoreProc(EngineCore):
...
@@ -367,60 +368,35 @@ class EngineCoreProc(EngineCore):
log_stats
:
bool
,
log_stats
:
bool
,
engine_index
:
int
=
0
,
engine_index
:
int
=
0
,
):
):
input_queue
=
queue
.
Queue
[
tuple
[
EngineCoreRequestType
,
Any
]]()
self
.
input_queue
=
queue
.
Queue
[
tuple
[
EngineCoreRequestType
,
Any
]]()
self
.
output_queue
=
queue
.
Queue
[
Union
[
tuple
[
int
,
EngineCoreOutputs
],
executor_fail_callback
=
lambda
:
input_queue
.
put_nowait
(
bytes
]]()
executor_fail_callback
=
lambda
:
self
.
input_queue
.
put_nowait
(
(
EngineCoreRequestType
.
EXECUTOR_FAILED
,
b
''
))
(
EngineCoreRequestType
.
EXECUTOR_FAILED
,
b
''
))
# Create input socket.
self
.
engine_index
=
engine_index
input_ctx
=
zmq
.
Context
()
identity
=
self
.
engine_index
.
to_bytes
(
length
=
2
,
byteorder
=
"little"
)
identity
=
engine_index
.
to_bytes
(
length
=
2
,
byteorder
=
"little"
)
self
.
engines_running
=
False
with
make_zmq_socket
(
input_ctx
,
handshake_address
,
zmq
.
DEALER
,
identity
=
identity
,
linger
=
5000
,
bind
=
False
)
as
handshake_socket
:
# Register engine with front-end.
with
self
.
_perform_handshake
(
handshake_address
,
identity
,
on_head_node
,
addresses
=
self
.
startup_handshake
(
handshake_socket
,
on_head_node
,
vllm_config
)
as
addresses
:
vllm_config
.
parallel_config
)
self
.
client_count
=
len
(
addresses
.
outputs
)
self
.
client_count
=
len
(
addresses
.
outputs
)
# Update config which may have changed from the handshake.
vllm_config
.
__post_init__
()
# Set up data parallel environment.
# Set up data parallel environment.
self
.
has_coordinator
=
addresses
.
coordinator_output
is
not
None
self
.
has_coordinator
=
addresses
.
coordinator_output
is
not
None
self
.
_init_data_parallel
(
vllm_config
)
self
.
_init_data_parallel
(
vllm_config
)
# Initialize engine core and model.
super
().
__init__
(
vllm_config
,
executor_class
,
log_stats
,
super
().
__init__
(
vllm_config
,
executor_class
,
log_stats
,
executor_fail_callback
)
executor_fail_callback
)
self
.
engine_index
=
engine_index
self
.
step_fn
=
(
self
.
step
if
self
.
batch_queue
is
None
else
self
.
step_fn
=
(
self
.
step
if
self
.
batch_queue
is
None
else
self
.
step_with_batch_queue
)
self
.
step_with_batch_queue
)
self
.
engines_running
=
False
self
.
last_counts
=
(
0
,
0
)
# Send ready message.
num_gpu_blocks
=
vllm_config
.
cache_config
.
num_gpu_blocks
handshake_socket
.
send
(
msgspec
.
msgpack
.
encode
({
"status"
:
"READY"
,
"local"
:
on_head_node
,
"num_gpu_blocks"
:
num_gpu_blocks
,
}))
# Background Threads and Queues for IO. These enable us to
# Background Threads and Queues for IO. These enable us to
# overlap ZMQ socket IO with GPU since they release the GIL,
# overlap ZMQ socket IO with GPU since they release the GIL,
# and to overlap some serialization/deserialization with the
# and to overlap some serialization/deserialization with the
# model forward pass.
# model forward pass.
# Threads handle Socket <-> Queues and core_busy_loop uses Queue.
# Threads handle Socket <-> Queues and core_busy_loop uses Queue.
self
.
input_queue
=
input_queue
self
.
output_queue
=
queue
.
Queue
[
Union
[
tuple
[
int
,
EngineCoreOutputs
],
bytes
]]()
threading
.
Thread
(
target
=
self
.
process_input_sockets
,
threading
.
Thread
(
target
=
self
.
process_input_sockets
,
args
=
(
addresses
.
inputs
,
addresses
.
coordinator_input
,
args
=
(
addresses
.
inputs
,
addresses
.
coordinator_input
,
identity
),
identity
),
...
@@ -428,10 +404,40 @@ class EngineCoreProc(EngineCore):
...
@@ -428,10 +404,40 @@ class EngineCoreProc(EngineCore):
self
.
output_thread
=
threading
.
Thread
(
self
.
output_thread
=
threading
.
Thread
(
target
=
self
.
process_output_sockets
,
target
=
self
.
process_output_sockets
,
args
=
(
addresses
.
outputs
,
addresses
.
coordinator_output
,
args
=
(
addresses
.
outputs
,
addresses
.
coordinator_output
,
engine_index
),
self
.
engine_index
),
daemon
=
True
)
daemon
=
True
)
self
.
output_thread
.
start
()
self
.
output_thread
.
start
()
@
contextmanager
def
_perform_handshake
(
self
,
handshake_address
:
str
,
identity
:
bytes
,
on_head_node
:
bool
,
vllm_config
:
VllmConfig
)
->
Generator
[
EngineZmqAddresses
,
None
,
None
]:
input_ctx
=
zmq
.
Context
()
with
make_zmq_socket
(
input_ctx
,
handshake_address
,
zmq
.
DEALER
,
identity
=
identity
,
linger
=
5000
,
bind
=
False
)
as
handshake_socket
:
# Register engine with front-end.
addresses
=
self
.
startup_handshake
(
handshake_socket
,
on_head_node
,
vllm_config
.
parallel_config
)
# Update config which may have changed from the handshake
vllm_config
.
__post_init__
()
yield
addresses
# Send ready message.
num_gpu_blocks
=
vllm_config
.
cache_config
.
num_gpu_blocks
handshake_socket
.
send
(
msgspec
.
msgpack
.
encode
({
"status"
:
"READY"
,
"local"
:
on_head_node
,
"num_gpu_blocks"
:
num_gpu_blocks
,
}))
@
staticmethod
@
staticmethod
def
startup_handshake
(
def
startup_handshake
(
handshake_socket
:
zmq
.
Socket
,
on_head_node
:
bool
,
handshake_socket
:
zmq
.
Socket
,
on_head_node
:
bool
,
...
@@ -743,24 +749,29 @@ class DPEngineCoreProc(EngineCoreProc):
...
@@ -743,24 +749,29 @@ class DPEngineCoreProc(EngineCoreProc):
executor_class
:
type
[
Executor
],
executor_class
:
type
[
Executor
],
log_stats
:
bool
,
log_stats
:
bool
,
):
):
# Add process-specific prefix to stdout and stderr before
# we initialize the engine.
self
.
_decorate_logs
()
from
multiprocessing
import
current_process
process_name
=
current_process
().
name
pid
=
os
.
getpid
()
_add_prefix
(
sys
.
stdout
,
process_name
,
pid
)
_add_prefix
(
sys
.
stderr
,
process_name
,
pid
)
# Counts forward-passes of the model so that we can synchronize
# Counts forward-passes of the model so that we can synchronize
# finished with DP peers every N steps.
# finished with DP peers every N steps.
self
.
counter
=
0
self
.
counter
=
0
self
.
current_wave
=
0
self
.
current_wave
=
0
self
.
last_counts
=
(
0
,
0
)
# Initialize the engine.
# Initialize the engine.
dp_rank
=
vllm_config
.
parallel_config
.
data_parallel_rank
dp_rank
=
vllm_config
.
parallel_config
.
data_parallel_rank
super
().
__init__
(
vllm_config
,
on_head_node
,
handshake_address
,
super
().
__init__
(
vllm_config
,
on_head_node
,
handshake_address
,
executor_class
,
log_stats
,
dp_rank
)
executor_class
,
log_stats
,
dp_rank
)
def
_decorate_logs
(
self
):
# Add process-specific prefix to stdout and stderr before
# we initialize the engine.
from
multiprocessing
import
current_process
process_name
=
current_process
().
name
pid
=
os
.
getpid
()
_add_prefix
(
sys
.
stdout
,
process_name
,
pid
)
_add_prefix
(
sys
.
stderr
,
process_name
,
pid
)
def
_init_data_parallel
(
self
,
vllm_config
:
VllmConfig
):
def
_init_data_parallel
(
self
,
vllm_config
:
VllmConfig
):
# Configure GPUs and stateless process group for data parallel.
# Configure GPUs and stateless process group for data parallel.
...
@@ -880,3 +891,70 @@ class DPEngineCoreProc(EngineCoreProc):
...
@@ -880,3 +891,70 @@ class DPEngineCoreProc(EngineCoreProc):
return
ParallelConfig
.
has_unfinished_dp
(
self
.
dp_group
,
return
ParallelConfig
.
has_unfinished_dp
(
self
.
dp_group
,
local_unfinished
)
local_unfinished
)
class
DPEngineCoreActor
(
DPEngineCoreProc
):
"""
Ray actor for running EngineCore in a data parallel context
"""
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
on_head_node
:
bool
,
addresses
:
EngineZmqAddresses
,
executor_class
:
type
[
Executor
],
log_stats
:
bool
,
dp_rank
:
int
=
0
,
local_dp_rank
:
int
=
0
,
):
self
.
addresses
=
addresses
vllm_config
.
parallel_config
.
data_parallel_rank
=
dp_rank
vllm_config
.
parallel_config
.
data_parallel_rank_local
=
\
local_dp_rank
# Ray sets CUDA_VISIBLE_DEVICES to empty string,
# we clean this up to be able to properly initialize
# data parallel groups.
del
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
super
().
__init__
(
vllm_config
,
on_head_node
,
""
,
executor_class
,
log_stats
)
def
_decorate_logs
(
self
):
pass
@
contextmanager
def
_perform_handshake
(
self
,
handshake_address
:
str
,
identity
:
bytes
,
on_head_node
:
bool
,
vllm_config
:
VllmConfig
):
"""
For Ray, we don't need to actually perform handshake.
All addresses information is known before the actor creation.
Therefore, we simply yield these addresses.
"""
yield
self
.
addresses
def
wait_for_init
(
self
):
"""
Wait until the engine core is initialized.
This is just an empty method. When ray.get() on this method
(or any other method of the actor) returns, it is guaranteed
that actor creation (i.e., __init__) is complete.
"""
pass
def
run
(
self
):
"""
Run the engine core busy loop.
"""
try
:
self
.
run_busy_loop
()
except
SystemExit
:
logger
.
debug
(
"EngineCore exiting."
)
raise
except
Exception
:
logger
.
exception
(
"EngineCore encountered a fatal error."
)
raise
finally
:
self
.
shutdown
()
vllm/v1/engine/core_client.py
View file @
bdce64f2
...
@@ -29,9 +29,9 @@ from vllm.v1.engine.core import EngineCore, EngineCoreProc
...
@@ -29,9 +29,9 @@ from vllm.v1.engine.core import EngineCore, EngineCoreProc
from
vllm.v1.engine.exceptions
import
EngineDeadError
from
vllm.v1.engine.exceptions
import
EngineDeadError
from
vllm.v1.executor.abstract
import
Executor
from
vllm.v1.executor.abstract
import
Executor
from
vllm.v1.serial_utils
import
MsgpackDecoder
,
MsgpackEncoder
,
bytestr
from
vllm.v1.serial_utils
import
MsgpackDecoder
,
MsgpackEncoder
,
bytestr
from
vllm.v1.utils
import
(
CoreEngine
,
CoreEngine
Proc
Manager
,
from
vllm.v1.utils
import
(
CoreEngine
,
CoreEngine
Actor
Manager
,
Engine
ZmqAddresses
,
get_engine_client_zmq_addr
,
Core
Engine
ProcManager
,
EngineZmqAddresses
,
wait_for_engine_startup
)
get_engine_client_zmq_addr
,
wait_for_engine_startup
)
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -68,6 +68,8 @@ class EngineCoreClient(ABC):
...
@@ -68,6 +68,8 @@ class EngineCoreClient(ABC):
if
multiprocess_mode
and
asyncio_mode
:
if
multiprocess_mode
and
asyncio_mode
:
if
vllm_config
.
parallel_config
.
data_parallel_size
>
1
:
if
vllm_config
.
parallel_config
.
data_parallel_size
>
1
:
if
vllm_config
.
parallel_config
.
data_parallel_backend
==
"ray"
:
return
RayDPClient
(
vllm_config
,
executor_class
,
log_stats
)
return
DPAsyncMPClient
(
vllm_config
,
executor_class
,
log_stats
)
return
DPAsyncMPClient
(
vllm_config
,
executor_class
,
log_stats
)
return
AsyncMPClient
(
vllm_config
,
executor_class
,
log_stats
)
return
AsyncMPClient
(
vllm_config
,
executor_class
,
log_stats
)
...
@@ -273,7 +275,10 @@ class BackgroundResources:
...
@@ -273,7 +275,10 @@ class BackgroundResources:
circular reference back to the client object."""
circular reference back to the client object."""
ctx
:
Union
[
zmq
.
Context
]
ctx
:
Union
[
zmq
.
Context
]
local_engine_manager
:
Optional
[
CoreEngineProcManager
]
=
None
# If CoreEngineProcManager, it manages local engines;
# if CoreEngineActorManager, it manages all engines.
engine_manager
:
Optional
[
Union
[
CoreEngineProcManager
,
CoreEngineActorManager
]]
=
None
coordinator
:
Optional
[
DPCoordinator
]
=
None
coordinator
:
Optional
[
DPCoordinator
]
=
None
output_socket
:
Optional
[
Union
[
zmq
.
Socket
,
zmq
.
asyncio
.
Socket
]]
=
None
output_socket
:
Optional
[
Union
[
zmq
.
Socket
,
zmq
.
asyncio
.
Socket
]]
=
None
input_socket
:
Optional
[
Union
[
zmq
.
Socket
,
zmq
.
asyncio
.
Socket
]]
=
None
input_socket
:
Optional
[
Union
[
zmq
.
Socket
,
zmq
.
asyncio
.
Socket
]]
=
None
...
@@ -290,8 +295,8 @@ class BackgroundResources:
...
@@ -290,8 +295,8 @@ class BackgroundResources:
"""Clean up background resources."""
"""Clean up background resources."""
self
.
engine_dead
=
True
self
.
engine_dead
=
True
if
self
.
local_
engine_manager
is
not
None
:
if
self
.
engine_manager
is
not
None
:
self
.
local_
engine_manager
.
close
()
self
.
engine_manager
.
close
()
if
self
.
coordinator
is
not
None
:
if
self
.
coordinator
is
not
None
:
self
.
coordinator
.
close
()
self
.
coordinator
.
close
()
...
@@ -457,7 +462,7 @@ class MPClient(EngineCoreClient):
...
@@ -457,7 +462,7 @@ class MPClient(EngineCoreClient):
if
local_engine_count
:
if
local_engine_count
:
# In server mode, start_index and local_start_index will
# In server mode, start_index and local_start_index will
# both be 0.
# both be 0.
self
.
resources
.
local_
engine_manager
=
CoreEngineProcManager
(
self
.
resources
.
engine_manager
=
CoreEngineProcManager
(
EngineCoreProc
.
run_engine_core
,
EngineCoreProc
.
run_engine_core
,
vllm_config
=
vllm_config
,
vllm_config
=
vllm_config
,
executor_class
=
executor_class
,
executor_class
=
executor_class
,
...
@@ -484,13 +489,18 @@ class MPClient(EngineCoreClient):
...
@@ -484,13 +489,18 @@ class MPClient(EngineCoreClient):
addresses
.
coordinator_input
,
addresses
.
coordinator_output
=
(
addresses
.
coordinator_input
,
addresses
.
coordinator_output
=
(
coordinator
.
get_engine_socket_addresses
())
coordinator
.
get_engine_socket_addresses
())
proc_manager
=
self
.
resources
.
engine_manager
assert
isinstance
(
proc_manager
,
(
type
(
None
),
CoreEngineProcManager
)),
(
"_wait_for_engine_startup should only be "
"called with CoreEngineProcManager"
)
wait_for_engine_startup
(
wait_for_engine_startup
(
handshake_socket
,
handshake_socket
,
addresses
,
addresses
,
self
.
core_engines
,
self
.
core_engines
,
self
.
vllm_config
.
parallel_config
,
self
.
vllm_config
.
parallel_config
,
self
.
vllm_config
.
cache_config
,
self
.
vllm_config
.
cache_config
,
self
.
resources
.
local_engine
_manager
,
proc
_manager
,
coordinator
.
proc
if
coordinator
else
None
,
coordinator
.
proc
if
coordinator
else
None
,
)
)
...
@@ -887,7 +897,6 @@ class DPAsyncMPClient(AsyncMPClient):
...
@@ -887,7 +897,6 @@ class DPAsyncMPClient(AsyncMPClient):
log_stats
:
bool
,
log_stats
:
bool
,
client_addresses
:
Optional
[
dict
[
str
,
str
]]
=
None
,
client_addresses
:
Optional
[
dict
[
str
,
str
]]
=
None
,
client_index
:
int
=
0
):
client_index
:
int
=
0
):
self
.
current_wave
=
0
self
.
current_wave
=
0
self
.
engines_running
=
False
self
.
engines_running
=
False
# To route aborts to the correct engine.
# To route aborts to the correct engine.
...
@@ -1050,3 +1059,50 @@ class DPAsyncMPClient(AsyncMPClient):
...
@@ -1050,3 +1059,50 @@ class DPAsyncMPClient(AsyncMPClient):
if
not
self
.
resources
.
engine_dead
:
if
not
self
.
resources
.
engine_dead
:
await
self
.
_send_input
(
EngineCoreRequestType
.
ABORT
,
request_ids
,
await
self
.
_send_input
(
EngineCoreRequestType
.
ABORT
,
request_ids
,
engine
)
engine
)
class
RayDPClient
(
DPAsyncMPClient
):
"""
Ray-based client for multi-proc, multi-engine (data parallel)
EngineCore.
"""
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
executor_class
:
type
[
Executor
],
log_stats
:
bool
,
client_addresses
:
Optional
[
dict
[
str
,
str
]]
=
None
,
client_index
:
int
=
0
,
):
super
().
__init__
(
vllm_config
,
executor_class
,
log_stats
,
client_addresses
,
client_index
)
def
_init_engines_direct
(
self
,
vllm_config
:
VllmConfig
,
local_only
:
bool
,
local_start_index
:
int
,
input_address
:
str
,
output_address
:
str
,
executor_class
:
type
[
Executor
],
log_stats
:
bool
):
"""Self-contained client mode, launch engine and coordinator process
as needed."""
parallel_config
=
vllm_config
.
parallel_config
assert
parallel_config
.
data_parallel_rank
==
0
assert
local_start_index
==
0
addresses
=
EngineZmqAddresses
(
inputs
=
[
input_address
],
outputs
=
[
output_address
],
)
if
len
(
self
.
core_engines
)
>
1
:
coordinator
=
DPCoordinator
(
parallel_config
)
self
.
resources
.
coordinator
=
coordinator
addresses
.
coordinator_input
,
addresses
.
coordinator_output
=
(
coordinator
.
get_engine_socket_addresses
())
# Start all engines.
self
.
resources
.
engine_manager
=
CoreEngineActorManager
(
vllm_config
=
vllm_config
,
addresses
=
addresses
,
executor_class
=
executor_class
,
log_stats
=
log_stats
)
vllm/v1/utils.py
View file @
bdce64f2
...
@@ -27,6 +27,8 @@ from vllm.utils import (get_mp_context, get_open_port, get_open_zmq_ipc_path,
...
@@ -27,6 +27,8 @@ from vllm.utils import (get_mp_context, get_open_port, get_open_zmq_ipc_path,
from
vllm.v1.executor.abstract
import
Executor
from
vllm.v1.executor.abstract
import
Executor
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
ray.util.placement_group
import
PlacementGroup
from
vllm.attention.layer
import
Attention
from
vllm.attention.layer
import
Attention
from
vllm.v1.engine.coordinator
import
DPCoordinator
from
vllm.v1.engine.coordinator
import
DPCoordinator
...
@@ -112,6 +114,45 @@ def get_engine_client_zmq_addr(local_only: bool,
...
@@ -112,6 +114,45 @@ def get_engine_client_zmq_addr(local_only: bool,
host
,
port
or
get_open_port
()))
host
,
port
or
get_open_port
()))
class
CoreEngineState
(
Enum
):
NEW
=
auto
()
CONNECTED
=
auto
()
READY
=
auto
()
class
CoreEngine
:
"""One per data parallel rank."""
def
__init__
(
self
,
index
:
int
=
0
,
local
:
bool
=
True
):
self
.
local
=
local
self
.
index
=
index
self
.
identity
=
index
.
to_bytes
(
2
,
"little"
)
self
.
state
=
CoreEngineState
.
NEW
@
dataclass
class
EngineZmqAddresses
:
# ZMQ input socket addresses for each front-end client (requests)
inputs
:
list
[
str
]
# ZMQ output socket addresses for each front-end client (responses)
outputs
:
list
[
str
]
# ZMQ input socket address of DP coordinator if applicable
coordinator_input
:
Optional
[
str
]
=
None
# ZMQ output socket address of DP coordinator if applicable
coordinator_output
:
Optional
[
str
]
=
None
@
dataclass
class
EngineHandshakeMetadata
:
"""Metadata sent to each engine process during startup handshake,
including addresses of the front-end ZMQ queues that they should
connect to.
"""
addresses
:
EngineZmqAddresses
parallel_config
:
dict
[
str
,
Union
[
int
,
str
]]
class
APIServerProcessManager
:
class
APIServerProcessManager
:
"""Manages a group of API server processes.
"""Manages a group of API server processes.
...
@@ -245,43 +286,168 @@ class CoreEngineProcManager:
...
@@ -245,43 +286,168 @@ class CoreEngineProcManager:
}
}
class
CoreEngineState
(
Enum
):
class
CoreEngineActorManager
:
NEW
=
auto
()
"""
CONNECTED
=
auto
()
Utility class to handle creation, readiness, and shutdown
READY
=
auto
()
of core engine Ray actors used by the AsyncLLM and LLMEngine.
class
CoreEngine
:
Different from CoreEngineProcManager, this class manages
"""One per data parallel rank."""
core engines for both local and remote nodes.
"""
def
__init__
(
self
,
index
:
int
=
0
,
local
:
bool
=
True
):
def
__init__
(
self
.
local
=
local
self
,
self
.
index
=
index
vllm_config
:
VllmConfig
,
self
.
identity
=
index
.
to_bytes
(
2
,
"little"
)
addresses
:
EngineZmqAddresses
,
executor_class
:
type
[
Executor
],
log_stats
:
bool
,
placement_groups
:
Optional
[
list
[
"PlacementGroup"
]]
=
None
,
local_dp_ranks
:
Optional
[
list
[
int
]]
=
None
,
):
import
copy
self
.
state
=
CoreEngineState
.
NEW
import
ray
from
ray.util.scheduling_strategies
import
(
PlacementGroupSchedulingStrategy
)
from
vllm.v1.engine.core
import
DPEngineCoreActor
@
dataclass
self
.
local_engine_actors
:
list
[
ray
.
ActorHandle
]
=
[]
class
EngineZmqAddresses
:
self
.
remote_engine_actors
:
list
[
ray
.
ActorHandle
]
=
[]
# ZMQ input socket addresses for each front-end client (requests)
dp_size
=
vllm_config
.
parallel_config
.
data_parallel_size
inputs
:
list
[
str
]
local_engine_count
=
\
# ZMQ output socket addresses for each front-end client (responses)
vllm_config
.
parallel_config
.
data_parallel_size_local
outputs
:
list
[
str
]
world_size
=
vllm_config
.
parallel_config
.
world_size
# ZMQ input socket address of DP coordinator if applicable
coordinator_input
:
Optional
[
str
]
=
None
# ZMQ output socket address of DP coordinator if applicable
coordinator_output
:
Optional
[
str
]
=
None
if
ray
.
is_initialized
():
logger
.
info
(
"Ray is already initialized. Skipping Ray initialization."
)
else
:
ray
.
init
()
if
placement_groups
is
not
None
:
assert
local_dp_ranks
is
not
None
,
(
"local_dp_ranks must be provided if "
"placement_groups is provided"
)
assert
len
(
placement_groups
)
==
len
(
local_dp_ranks
),
(
"placement_groups and local_dp_ranks must "
"have the same length"
)
logger
.
info
(
"Using provided placement groups"
)
# TODO(rui): validate passed-in placement groups
self
.
created_placement_groups
=
[]
else
:
placement_groups
,
local_dp_ranks
=
\
CoreEngineActorManager
.
create_dp_placement_groups
(
vllm_config
)
self
.
created_placement_groups
=
placement_groups
assert
len
(
placement_groups
)
==
dp_size
,
(
"Number of placement groups must match data parallel size"
)
refs
=
[]
for
index
in
range
(
dp_size
):
local_index
=
local_dp_ranks
[
index
]
dp_vllm_config
=
copy
.
deepcopy
(
vllm_config
)
pg
=
placement_groups
[
index
]
dp_vllm_config
.
parallel_config
.
placement_group
=
pg
on_head_node
=
index
<
local_engine_count
actor
=
ray
.
remote
(
DPEngineCoreActor
).
options
(
scheduling_strategy
=
PlacementGroupSchedulingStrategy
(
placement_group
=
pg
,
placement_group_bundle_index
=
world_size
,
)).
remote
(
vllm_config
=
dp_vllm_config
,
executor_class
=
executor_class
,
log_stats
=
log_stats
,
on_head_node
=
on_head_node
,
addresses
=
addresses
,
dp_rank
=
index
,
local_dp_rank
=
local_index
)
if
on_head_node
:
self
.
local_engine_actors
.
append
(
actor
)
else
:
self
.
remote_engine_actors
.
append
(
actor
)
refs
.
append
(
actor
.
wait_for_init
.
remote
())
ray
.
get
(
refs
)
self
.
run_refs
=
[]
for
actor
in
self
.
local_engine_actors
+
self
.
remote_engine_actors
:
self
.
run_refs
.
append
(
actor
.
run
.
remote
())
@
staticmethod
def
create_dp_placement_groups
(
vllm_config
:
VllmConfig
)
->
tuple
[
list
[
"PlacementGroup"
],
list
[
int
]]:
import
ray
from
ray._private.state
import
available_resources_per_node
from
ray.util.state
import
list_nodes
logger
.
info
(
"Creating placement groups for data parallel"
)
dp_master_ip
=
\
vllm_config
.
parallel_config
.
data_parallel_master_ip
dp_size
=
vllm_config
.
parallel_config
.
data_parallel_size
local_engine_count
=
\
vllm_config
.
parallel_config
.
data_parallel_size_local
nodes
=
list_nodes
()
nodes
=
sorted
(
list_nodes
(),
key
=
lambda
node
:
node
.
node_ip
!=
dp_master_ip
)
assert
nodes
[
0
].
node_ip
==
dp_master_ip
,
(
"The first node must be the head node"
)
assert
len
(
nodes
)
==
1
or
nodes
[
1
].
node_ip
!=
dp_master_ip
,
(
"There can only be one head node"
)
available_resources
=
available_resources_per_node
()
world_size
=
vllm_config
.
parallel_config
.
world_size
placement_groups
:
list
[
PlacementGroup
]
=
[]
local_dp_ranks
:
list
[
int
]
=
[]
for
node
in
nodes
:
node_ip
=
node
.
node_ip
node_resources
=
available_resources
[
node
.
node_id
]
# For now, each DP rank can only be assigned to one node
# TODO(rui): support allocating a single DP rank
# to multiple nodes
available_engine_count
=
node_resources
[
"GPU"
]
//
world_size
if
node_ip
==
dp_master_ip
:
assert
available_engine_count
>=
local_engine_count
,
(
"Not enough resources to allocate DP ranks "
f
"on DP master node
{
node_ip
}
"
)
for
i
in
range
(
local_engine_count
):
bundles
=
[{
"GPU"
:
1.0
,
"node:"
+
dp_master_ip
:
0.001
}]
*
world_size
+
[{
"CPU"
:
1.0
}]
pg
=
ray
.
util
.
placement_group
(
name
=
f
"dp_rank_
{
len
(
placement_groups
)
}
"
,
strategy
=
"STRICT_PACK"
,
bundles
=
bundles
,
)
placement_groups
.
append
(
pg
)
local_dp_ranks
.
append
(
i
)
else
:
for
i
in
range
(
available_engine_count
):
if
len
(
placement_groups
)
==
dp_size
:
break
bundles
=
[{
"GPU"
:
1.0
}]
*
world_size
+
[{
"CPU"
:
1.0
}]
pg
=
ray
.
util
.
placement_group
(
name
=
f
"dp_rank_
{
len
(
placement_groups
)
}
"
,
strategy
=
"STRICT_PACK"
,
bundles
=
bundles
,
)
placement_groups
.
append
(
pg
)
local_dp_ranks
.
append
(
i
)
return
placement_groups
,
local_dp_ranks
def
get_run_refs
(
self
):
return
self
.
run_refs
@
dataclass
def
close
(
self
):
class
EngineHandshakeMetadata
:
import
ray
"""Metadata sent to each engine process during startup handshake,
for
actor
in
self
.
local_engine_actors
+
self
.
remote_engine_actors
:
including addresses of the front-end ZMQ queues that they should
ray
.
kill
(
actor
)
connect to.
for
pg
in
self
.
created_placement_groups
:
"""
ray
.
util
.
remove_placement_group
(
pg
)
addresses
:
EngineZmqAddresses
parallel_config
:
dict
[
str
,
Union
[
int
,
str
]]
def
wait_for_engine_startup
(
def
wait_for_engine_startup
(
...
@@ -383,11 +549,19 @@ def wait_for_engine_startup(
...
@@ -383,11 +549,19 @@ def wait_for_engine_startup(
def
wait_for_completion_or_failure
(
def
wait_for_completion_or_failure
(
api_server_manager
:
APIServerProcessManager
,
api_server_manager
:
APIServerProcessManager
,
local_engine_manager
:
Optional
[
CoreEngineProcManager
]
=
None
,
engine_manager
:
Optional
[
Union
[
CoreEngineProcManager
,
CoreEngineActorManager
]]
=
None
,
coordinator
:
Optional
[
"DPCoordinator"
]
=
None
)
->
None
:
coordinator
:
Optional
[
"DPCoordinator"
]
=
None
)
->
None
:
"""Wait for all processes to complete or detect if any fail.
"""Wait for all processes to complete or detect if any fail.
Raises an exception if any process exits with a non-zero status.
Raises an exception if any process exits with a non-zero status.
Args:
api_server_manager: The manager for API servers.
engine_manager: The manager for engine processes.
If CoreEngineProcManager, it manages local engines;
if CoreEngineActorManager, it manages all engines.
coordinator: The coordinator for data parallel.
"""
"""
try
:
try
:
...
@@ -402,14 +576,18 @@ def wait_for_completion_or_failure(
...
@@ -402,14 +576,18 @@ def wait_for_completion_or_failure(
if
coordinator
:
if
coordinator
:
sentinel_to_proc
[
coordinator
.
proc
.
sentinel
]
=
coordinator
.
proc
sentinel_to_proc
[
coordinator
.
proc
.
sentinel
]
=
coordinator
.
proc
if
local_engine_manager
:
actor_run_refs
=
[]
for
proc
in
local_engine_manager
.
processes
:
if
isinstance
(
engine_manager
,
CoreEngineProcManager
):
for
proc
in
engine_manager
.
processes
:
sentinel_to_proc
[
proc
.
sentinel
]
=
proc
sentinel_to_proc
[
proc
.
sentinel
]
=
proc
elif
isinstance
(
engine_manager
,
CoreEngineActorManager
):
actor_run_refs
=
engine_manager
.
get_run_refs
()
# Check if any process terminates
# Check if any process terminates
while
sentinel_to_proc
:
while
sentinel_to_proc
or
actor_run_refs
:
# Wait for any process to terminate
# Wait for any process to terminate
ready_sentinels
:
list
[
Any
]
=
connection
.
wait
(
sentinel_to_proc
)
ready_sentinels
:
list
[
Any
]
=
connection
.
wait
(
sentinel_to_proc
,
timeout
=
5
)
# Process any terminated processes
# Process any terminated processes
for
sentinel
in
ready_sentinels
:
for
sentinel
in
ready_sentinels
:
...
@@ -420,6 +598,11 @@ def wait_for_completion_or_failure(
...
@@ -420,6 +598,11 @@ def wait_for_completion_or_failure(
raise
RuntimeError
(
raise
RuntimeError
(
f
"Process
{
proc
.
name
}
(PID:
{
proc
.
pid
}
) "
f
"Process
{
proc
.
name
}
(PID:
{
proc
.
pid
}
) "
f
"died with exit code
{
proc
.
exitcode
}
"
)
f
"died with exit code
{
proc
.
exitcode
}
"
)
if
actor_run_refs
:
import
ray
_
,
actor_run_refs
=
ray
.
wait
(
actor_run_refs
,
timeout
=
5
)
except
KeyboardInterrupt
:
except
KeyboardInterrupt
:
logger
.
info
(
"Received KeyboardInterrupt, shutting down API servers..."
)
logger
.
info
(
"Received KeyboardInterrupt, shutting down API servers..."
)
except
Exception
as
e
:
except
Exception
as
e
:
...
@@ -431,8 +614,8 @@ def wait_for_completion_or_failure(
...
@@ -431,8 +614,8 @@ def wait_for_completion_or_failure(
api_server_manager
.
close
()
api_server_manager
.
close
()
if
coordinator
:
if
coordinator
:
coordinator
.
close
()
coordinator
.
close
()
if
local_
engine_manager
:
if
engine_manager
:
local_
engine_manager
.
close
()
engine_manager
.
close
()
# Note(rob): shutdown function cannot be a bound method,
# Note(rob): shutdown function cannot be a bound method,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment