Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
55aa7af9
Unverified
Commit
55aa7af9
authored
May 13, 2025
by
Nick Hill
Committed by
GitHub
May 13, 2025
Browse files
[V1] DP scale-out (2/N): Decouple engine process management and comms (#15977)
Signed-off-by:
Nick Hill
<
nhill@redhat.com
>
parent
0b217da6
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
525 additions
and
252 deletions
+525
-252
tests/async_engine/test_async_llm_engine.py
tests/async_engine/test_async_llm_engine.py
+1
-1
tests/v1/engine/test_engine_core_client.py
tests/v1/engine/test_engine_core_client.py
+8
-7
vllm/config.py
vllm/config.py
+19
-22
vllm/distributed/utils.py
vllm/distributed/utils.py
+2
-1
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+38
-0
vllm/entrypoints/cli/serve.py
vllm/entrypoints/cli/serve.py
+79
-2
vllm/utils.py
vllm/utils.py
+4
-0
vllm/v1/engine/core.py
vllm/v1/engine/core.py
+122
-67
vllm/v1/engine/core_client.py
vllm/v1/engine/core_client.py
+174
-119
vllm/v1/utils.py
vllm/v1/utils.py
+78
-33
No files found.
tests/async_engine/test_async_llm_engine.py
View file @
55aa7af9
...
...
@@ -41,7 +41,7 @@ class MockEngine:
self
.
abort_request_calls
=
0
self
.
request_id
=
None
# Ugly, remove dependency when possible
self
.
parallel_config
=
ParallelConfig
(
1
,
1
,
False
)
self
.
parallel_config
=
ParallelConfig
()
self
.
model_config
=
MockModelConfig
()
async
def
step_async
(
self
,
virtual_engine
):
...
...
tests/v1/engine/test_engine_core_client.py
View file @
55aa7af9
...
...
@@ -18,9 +18,10 @@ from vllm.platforms import current_platform
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.v1.engine
import
EngineCoreRequest
from
vllm.v1.engine.core
import
EngineCore
from
vllm.v1.engine.core_client
import
(
AsyncMPClient
,
Core
Engine
,
EngineCoreClient
,
SyncMPClient
)
from
vllm.v1.engine.core_client
import
(
AsyncMPClient
,
Engine
CoreClient
,
SyncMPClient
)
from
vllm.v1.executor.abstract
import
Executor
from
vllm.v1.utils
import
CoreEngineProcManager
from
...distributed.conftest
import
MockSubscriber
from
...utils
import
create_new_process_for_each_test
...
...
@@ -348,13 +349,13 @@ def test_startup_failure(monkeypatch: pytest.MonkeyPatch):
# Monkey-patch to extract core process pid while it's starting.
core_proc_pid
=
[
None
]
ce_ctor
=
CoreEngine
.
__init__
ce
pm
_ctor
=
CoreEngine
ProcManager
.
__init__
def
patched_ce_ctor
(
self
,
*
args
,
**
kwargs
):
ce_ctor
(
self
,
*
args
,
**
kwargs
)
core_proc_pid
[
0
]
=
self
.
proc
_handle
.
proc
.
pid
def
patched_ce
pm
_ctor
(
self
:
CoreEngineProcManager
,
*
args
,
**
kwargs
):
ce
pm
_ctor
(
self
,
*
args
,
**
kwargs
)
core_proc_pid
[
0
]
=
self
.
proc
esses
[
0
]
.
pid
m
.
setattr
(
CoreEngine
,
"__init__"
,
patched_ce_ctor
)
m
.
setattr
(
CoreEngine
ProcManager
,
"__init__"
,
patched_ce
pm
_ctor
)
t
=
time
.
time
()
engine_args
=
EngineArgs
(
model
=
MODEL_NAME
)
...
...
vllm/config.py
View file @
55aa7af9
...
...
@@ -1668,25 +1668,17 @@ class ParallelConfig:
data_parallel_size
:
int
=
1
"""Number of data parallel groups. MoE layers will be sharded according to
the product of the tensor parallel size and data parallel size."""
data_parallel_size_local
:
int
=
1
"""Number of local data parallel groups."""
data_parallel_rank
:
int
=
0
"""Rank of the data parallel group."""
_data_parallel_rank_local
:
Optional
[
int
]
=
field
(
default
=
None
,
init
=
False
)
"""Private field to store the local rank of the data parallel group."""
@
property
def
data_parallel_rank_local
(
self
)
->
int
:
"""Local rank of the data parallel group, defaults to global rank."""
if
self
.
_data_parallel_rank_local
is
None
:
return
self
.
data_parallel_rank
return
self
.
_data_parallel_rank_local
@
data_parallel_rank_local
.
setter
def
data_parallel_rank_local
(
self
,
value
:
int
)
->
None
:
"""Set the local rank of the data parallel group."""
self
.
_data_parallel_rank_local
=
value
data_parallel_rank_local
:
Optional
[
int
]
=
None
"""Local rank of the data parallel group,
set only in SPMD mode."""
data_parallel_master_ip
:
str
=
"127.0.0.1"
"""IP of the data parallel master."""
data_parallel_rpc_port
:
int
=
29550
"""Port for data parallel messaging."""
data_parallel_master_port
:
int
=
29500
"""Port of the data parallel master."""
enable_expert_parallel
:
bool
=
False
...
...
@@ -1734,13 +1726,16 @@ class ParallelConfig:
world_size
:
int
=
field
(
init
=
False
)
"""world_size is TPxPP, it affects the number of workers we create."""
world_size_across_dp
:
int
=
field
(
init
=
False
)
"""world_size_across_dp is TPxPPxDP, it is the size of the world
including data parallelism."""
rank
:
int
=
0
"""Global rank in distributed setup."""
@
property
def
world_size_across_dp
(
self
)
->
int
:
"""world_size_across_dp is TPxPPxDP, it is the size of the world
including data parallelism."""
return
self
.
world_size
*
self
.
data_parallel_size
def
get_next_dp_init_port
(
self
)
->
int
:
"""
We might need to initialize process groups in multiple
...
...
@@ -1800,10 +1795,14 @@ class ParallelConfig:
self
.
world_size
=
self
.
pipeline_parallel_size
*
\
self
.
tensor_parallel_size
if
self
.
data_parallel_size
>
1
:
if
self
.
data_parallel_size_local
>
self
.
data_parallel_size
:
raise
ValueError
(
f
"data_parallel_size_local (
{
self
.
data_parallel_size_local
}
) "
f
"must be <= data_parallel_size (
{
self
.
data_parallel_size
}
)"
)
if
self
.
data_parallel_size
>
1
or
self
.
data_parallel_size_local
==
0
:
# Data parallel was specified in the engine args.
self
.
data_parallel_master_port
=
get_open_port
()
# TODO multi-node
else
:
# Otherwise fall back to env vars (e.g. for offline SPMD case).
self
.
data_parallel_size
=
envs
.
VLLM_DP_SIZE
...
...
@@ -1812,8 +1811,6 @@ class ParallelConfig:
self
.
data_parallel_master_ip
=
envs
.
VLLM_DP_MASTER_IP
self
.
data_parallel_master_port
=
envs
.
VLLM_DP_MASTER_PORT
self
.
world_size_across_dp
=
self
.
world_size
*
self
.
data_parallel_size
if
self
.
distributed_executor_backend
==
"external_launcher"
:
import
os
os
.
environ
[
"VLLM_ENABLE_V1_MULTIPROCESSING"
]
=
"0"
...
...
vllm/distributed/utils.py
View file @
55aa7af9
...
...
@@ -22,6 +22,7 @@ from torch.distributed.rendezvous import rendezvous
import
vllm.envs
as
envs
from
vllm.logger
import
init_logger
from
vllm.utils
import
get_tcp_uri
logger
=
init_logger
(
__name__
)
...
...
@@ -303,7 +304,7 @@ def stateless_init_torch_distributed_process_group(
always formed with process 1, 2, ..., 8, and the additional communication
channel is formed with process 9 and 10.
"""
init_method
=
f
"tcp://
{
host
}
:
{
port
}
"
init_method
=
get_tcp_uri
(
host
,
port
)
backend
=
Backend
(
backend
)
# it is basically string
timeout
=
_get_default_timeout
(
backend
)
...
...
vllm/engine/arg_utils.py
View file @
55aa7af9
...
...
@@ -283,6 +283,9 @@ class EngineArgs:
pipeline_parallel_size
:
int
=
ParallelConfig
.
pipeline_parallel_size
tensor_parallel_size
:
int
=
ParallelConfig
.
tensor_parallel_size
data_parallel_size
:
int
=
ParallelConfig
.
data_parallel_size
data_parallel_size_local
:
Optional
[
int
]
=
None
data_parallel_address
:
Optional
[
str
]
=
None
data_parallel_rpc_port
:
Optional
[
int
]
=
None
enable_expert_parallel
:
bool
=
ParallelConfig
.
enable_expert_parallel
max_parallel_loading_workers
:
Optional
[
int
]
=
ParallelConfig
.
max_parallel_loading_workers
...
...
@@ -596,6 +599,21 @@ class EngineArgs:
**
parallel_kwargs
[
"tensor_parallel_size"
])
parallel_group
.
add_argument
(
"--data-parallel-size"
,
"-dp"
,
**
parallel_kwargs
[
"data_parallel_size"
])
parallel_group
.
add_argument
(
'--data-parallel-size-local'
,
'-dpl'
,
type
=
int
,
help
=
'Number of data parallel replicas '
'to run on this node.'
)
parallel_group
.
add_argument
(
'--data-parallel-address'
,
'-dpa'
,
type
=
str
,
help
=
'Address of data parallel cluster '
'head-node.'
)
parallel_group
.
add_argument
(
'--data-parallel-rpc-port'
,
'-dpp'
,
type
=
int
,
help
=
'Port for data parallel RPC '
'communication.'
)
parallel_group
.
add_argument
(
"--enable-expert-parallel"
,
**
parallel_kwargs
[
"enable_expert_parallel"
])
...
...
@@ -1019,10 +1037,30 @@ class EngineArgs:
# but we should not do this here.
placement_group
=
ray
.
util
.
get_current_placement_group
()
# Local DP size defaults to global DP size if not set.
data_parallel_size_local
=
self
.
data_parallel_size
if
(
self
.
data_parallel_size_local
is
None
)
else
self
.
data_parallel_size_local
# DP address, used in multi-node case for torch distributed group
# and ZMQ sockets.
data_parallel_address
=
self
.
data_parallel_address
if
(
self
.
data_parallel_address
is
not
None
)
else
ParallelConfig
.
data_parallel_master_ip
# This port is only used when there are remote data parallel engines,
# otherwise the local IPC transport is used.
data_parallel_rpc_port
=
self
.
data_parallel_rpc_port
if
(
self
.
data_parallel_rpc_port
is
not
None
)
else
ParallelConfig
.
data_parallel_rpc_port
parallel_config
=
ParallelConfig
(
pipeline_parallel_size
=
self
.
pipeline_parallel_size
,
tensor_parallel_size
=
self
.
tensor_parallel_size
,
data_parallel_size
=
self
.
data_parallel_size
,
data_parallel_size_local
=
data_parallel_size_local
,
data_parallel_master_ip
=
data_parallel_address
,
data_parallel_rpc_port
=
data_parallel_rpc_port
,
enable_expert_parallel
=
self
.
enable_expert_parallel
,
max_parallel_loading_workers
=
self
.
max_parallel_loading_workers
,
disable_custom_all_reduce
=
self
.
disable_custom_all_reduce
,
...
...
vllm/entrypoints/cli/serve.py
View file @
55aa7af9
# SPDX-License-Identifier: Apache-2.0
import
argparse
import
signal
import
uvloop
import
vllm.envs
as
envs
from
vllm
import
AsyncEngineArgs
from
vllm.entrypoints.cli.types
import
CLISubcommand
from
vllm.entrypoints.openai.api_server
import
run_server
from
vllm.entrypoints.openai.cli_args
import
(
make_arg_parser
,
validate_parsed_serve_args
)
from
vllm.utils
import
FlexibleArgumentParser
from
vllm.logger
import
init_logger
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.utils
import
FlexibleArgumentParser
,
get_tcp_uri
from
vllm.v1.engine.core
import
EngineCoreProc
from
vllm.v1.engine.core_client
import
CoreEngineProcManager
from
vllm.v1.executor.abstract
import
Executor
logger
=
init_logger
(
__name__
)
class
ServeSubcommand
(
CLISubcommand
):
...
...
@@ -24,7 +34,10 @@ class ServeSubcommand(CLISubcommand):
if
hasattr
(
args
,
'model_tag'
)
and
args
.
model_tag
is
not
None
:
args
.
model
=
args
.
model_tag
uvloop
.
run
(
run_server
(
args
))
if
args
.
headless
:
run_headless
(
args
)
else
:
uvloop
.
run
(
run_server
(
args
))
def
validate
(
self
,
args
:
argparse
.
Namespace
)
->
None
:
validate_parsed_serve_args
(
args
)
...
...
@@ -42,6 +55,18 @@ class ServeSubcommand(CLISubcommand):
nargs
=
'?'
,
help
=
"The model tag to serve "
"(optional if specified in config)"
)
serve_parser
.
add_argument
(
"--headless"
,
action
=
'store_true'
,
default
=
False
,
help
=
"Run in headless mode. See multi-node data parallel "
"documentation for more details."
)
serve_parser
.
add_argument
(
'--data-parallel-start-rank'
,
'-dpr'
,
type
=
int
,
default
=
0
,
help
=
'Starting data parallel rank for secondary nodes.'
)
serve_parser
.
add_argument
(
"--config"
,
type
=
str
,
...
...
@@ -57,3 +82,55 @@ class ServeSubcommand(CLISubcommand):
def
cmd_init
()
->
list
[
CLISubcommand
]:
return
[
ServeSubcommand
()]
def
run_headless
(
args
:
argparse
.
Namespace
):
# Create the EngineConfig.
engine_args
=
AsyncEngineArgs
.
from_cli_args
(
args
)
usage_context
=
UsageContext
.
OPENAI_API_SERVER
vllm_config
=
engine_args
.
create_engine_config
(
usage_context
=
usage_context
)
if
not
envs
.
VLLM_USE_V1
:
raise
RuntimeError
(
"Headless mode is only supported for V1"
)
parallel_config
=
vllm_config
.
parallel_config
local_engine_count
=
parallel_config
.
data_parallel_size_local
host
=
parallel_config
.
data_parallel_master_ip
port
=
engine_args
.
data_parallel_rpc_port
# add to config too
input_address
=
get_tcp_uri
(
host
,
port
)
if
local_engine_count
<=
0
:
raise
RuntimeError
(
"data_parallel_size_local must be > 0 in "
"headless mode"
)
# Catch SIGTERM and SIGINT to allow graceful shutdown.
def
signal_handler
(
signum
,
frame
):
logger
.
debug
(
"Received %d signal."
,
signum
)
raise
SystemExit
signal
.
signal
(
signal
.
SIGTERM
,
signal_handler
)
signal
.
signal
(
signal
.
SIGINT
,
signal_handler
)
logger
.
info
(
"Launching %d data parallel engine(s) in headless mode, "
"with head node address %s."
,
local_engine_count
,
input_address
)
# Create the engines.
engine_manager
=
CoreEngineProcManager
(
target_fn
=
EngineCoreProc
.
run_engine_core
,
local_engine_count
=
local_engine_count
,
start_index
=
args
.
data_parallel_start_rank
,
local_start_index
=
0
,
vllm_config
=
vllm_config
,
on_head_node
=
False
,
input_address
=
input_address
,
executor_class
=
Executor
.
get_class
(
vllm_config
),
log_stats
=
not
engine_args
.
disable_log_stats
,
)
try
:
engine_manager
.
join_first
()
finally
:
logger
.
info
(
"Shutting down."
)
engine_manager
.
close
()
vllm/utils.py
View file @
55aa7af9
...
...
@@ -613,6 +613,10 @@ def is_valid_ipv6_address(address: str) -> bool:
def
get_distributed_init_method
(
ip
:
str
,
port
:
int
)
->
str
:
return
get_tcp_uri
(
ip
,
port
)
def
get_tcp_uri
(
ip
:
str
,
port
:
int
)
->
str
:
# Brackets are not permitted in ipv4 addresses,
# see https://github.com/python/cpython/issues/103848
return
f
"tcp://[
{
ip
}
]:
{
port
}
"
if
":"
in
ip
else
f
"tcp://
{
ip
}
:
{
port
}
"
...
...
vllm/v1/engine/core.py
View file @
55aa7af9
# SPDX-License-Identifier: Apache-2.0
import
json
import
os
import
queue
import
signal
...
...
@@ -23,7 +22,7 @@ from vllm.logging_utils.dump_input import dump_engine_exception
from
vllm.lora.request
import
LoRARequest
from
vllm.transformers_utils.config
import
(
maybe_register_config_serialize_by_value
)
from
vllm.utils
import
resolve_obj_by_qualname
,
zmq_socket_ctx
from
vllm.utils
import
make_zmq_socket
,
resolve_obj_by_qualname
,
zmq_socket_ctx
from
vllm.v1.core.kv_cache_utils
import
(
get_kv_cache_config
,
unify_kv_cache_configs
)
from
vllm.v1.core.sched.interface
import
SchedulerInterface
...
...
@@ -43,6 +42,7 @@ from vllm.version import __version__ as VLLM_VERSION
logger
=
init_logger
(
__name__
)
POLLING_TIMEOUT_S
=
2.5
HANDSHAKE_TIMEOUT_MINS
=
5
_R
=
TypeVar
(
'_R'
)
# Return type for collective_rpc
...
...
@@ -348,9 +348,9 @@ class EngineCoreProc(EngineCore):
def
__init__
(
self
,
input_path
:
str
,
output_path
:
str
,
vllm_config
:
VllmConfig
,
on_head_node
:
bool
,
input_address
:
str
,
executor_class
:
type
[
Executor
],
log_stats
:
bool
,
engine_index
:
int
=
0
,
...
...
@@ -360,28 +360,91 @@ class EngineCoreProc(EngineCore):
executor_fail_callback
=
lambda
:
input_queue
.
put_nowait
(
(
EngineCoreRequestType
.
EXECUTOR_FAILED
,
b
''
))
super
().
__init__
(
vllm_config
,
executor_class
,
log_stats
,
executor_fail_callback
)
self
.
step_fn
=
(
self
.
step
if
self
.
batch_queue
is
None
else
self
.
step_with_batch_queue
)
self
.
engines_running
=
False
# Background Threads and Queues for IO. These enable us to
# overlap ZMQ socket IO with GPU since they release the GIL,
# and to overlap some serialization/deserialization with the
# model forward pass.
# Threads handle Socket <-> Queues and core_busy_loop uses Queue.
self
.
input_queue
=
input_queue
self
.
output_queue
=
queue
.
Queue
[
Union
[
EngineCoreOutputs
,
bytes
]]()
threading
.
Thread
(
target
=
self
.
process_input_socket
,
args
=
(
input_path
,
engine_index
),
daemon
=
True
).
start
()
self
.
output_thread
=
threading
.
Thread
(
target
=
self
.
process_output_socket
,
args
=
(
output_path
,
engine_index
),
daemon
=
True
)
self
.
output_thread
.
start
()
# Create input socket.
input_ctx
=
zmq
.
Context
()
identity
=
engine_index
.
to_bytes
(
length
=
2
,
byteorder
=
"little"
)
input_socket
=
make_zmq_socket
(
input_ctx
,
input_address
,
zmq
.
DEALER
,
identity
=
identity
,
bind
=
False
)
try
:
# Register engine with front-end.
output_address
=
self
.
startup_handshake
(
input_socket
,
on_head_node
,
vllm_config
.
parallel_config
)
# Update config which may have changed from the handshake.
vllm_config
.
__post_init__
()
# Set up data parallel environment.
self
.
_init_data_parallel
(
vllm_config
)
# Initialize engine core and model.
super
().
__init__
(
vllm_config
,
executor_class
,
log_stats
,
executor_fail_callback
)
self
.
step_fn
=
(
self
.
step
if
self
.
batch_queue
is
None
else
self
.
step_with_batch_queue
)
self
.
engines_running
=
False
# Send ready message.
num_gpu_blocks
=
vllm_config
.
cache_config
.
num_gpu_blocks
input_socket
.
send
(
msgspec
.
msgpack
.
encode
({
"status"
:
"READY"
,
"local"
:
on_head_node
,
"num_gpu_blocks"
:
num_gpu_blocks
,
}))
# Background Threads and Queues for IO. These enable us to
# overlap ZMQ socket IO with GPU since they release the GIL,
# and to overlap some serialization/deserialization with the
# model forward pass.
# Threads handle Socket <-> Queues and core_busy_loop uses Queue.
self
.
input_queue
=
input_queue
self
.
output_queue
=
queue
.
Queue
[
Union
[
EngineCoreOutputs
,
bytes
]]()
threading
.
Thread
(
target
=
self
.
process_input_socket
,
args
=
(
input_socket
,
),
daemon
=
True
).
start
()
input_socket
=
None
self
.
output_thread
=
threading
.
Thread
(
target
=
self
.
process_output_socket
,
args
=
(
output_address
,
engine_index
),
daemon
=
True
)
self
.
output_thread
.
start
()
finally
:
if
input_socket
is
not
None
:
input_socket
.
close
(
linger
=
0
)
@
staticmethod
def
startup_handshake
(
input_socket
:
zmq
.
Socket
,
on_head_node
:
bool
,
parallel_config
:
ParallelConfig
)
->
str
:
# Send registration message.
input_socket
.
send
(
msgspec
.
msgpack
.
encode
({
"status"
:
"HELLO"
,
"local"
:
on_head_node
,
}))
# Receive initialization message.
logger
.
info
(
"Waiting for init message from front-end."
)
if
not
input_socket
.
poll
(
timeout
=
HANDSHAKE_TIMEOUT_MINS
*
60
*
1000
):
raise
RuntimeError
(
"Did not receive response from front-end "
f
"process within
{
HANDSHAKE_TIMEOUT_MINS
}
"
f
"minutes"
)
init_bytes
=
input_socket
.
recv
()
init_message
=
msgspec
.
msgpack
.
decode
(
init_bytes
)
logger
.
debug
(
"Received init message: %s"
,
init_message
)
output_socket_address
=
init_message
[
"output_socket_address"
]
#TBD(nick) maybe replace IP with configured head node address
received_parallel_config
=
init_message
[
"parallel_config"
]
for
key
,
value
in
received_parallel_config
.
items
():
setattr
(
parallel_config
,
key
,
value
)
return
output_socket_address
@
staticmethod
def
run_engine_core
(
*
args
,
...
...
@@ -412,7 +475,7 @@ class EngineCoreProc(EngineCore):
try
:
parallel_config
:
ParallelConfig
=
kwargs
[
"vllm_config"
].
parallel_config
if
parallel_config
.
data_parallel_size
>
1
:
if
parallel_config
.
data_parallel_size
>
1
or
dp_rank
>
0
:
# Set data parallel rank for this engine process.
parallel_config
.
data_parallel_rank
=
dp_rank
parallel_config
.
data_parallel_rank_local
=
local_dp_rank
...
...
@@ -436,6 +499,9 @@ class EngineCoreProc(EngineCore):
if
engine_core
is
not
None
:
engine_core
.
shutdown
()
def
_init_data_parallel
(
self
,
vllm_config
:
VllmConfig
):
pass
def
run_busy_loop
(
self
):
"""Core busy loop of the EngineCore."""
...
...
@@ -527,40 +593,25 @@ class EngineCoreProc(EngineCore):
logger
.
fatal
(
"vLLM shutdown signal from EngineCore failed "
"to send. Please report this issue."
)
def
process_input_socket
(
self
,
input_
path
:
str
,
engine_index
:
in
t
):
def
process_input_socket
(
self
,
input_
socket
:
zmq
.
Socke
t
):
"""Input socket IO thread."""
# Msgpack serialization decoding.
add_request_decoder
=
MsgpackDecoder
(
EngineCoreRequest
)
generic_decoder
=
MsgpackDecoder
()
identity
=
engine_index
.
to_bytes
(
length
=
2
,
byteorder
=
"little"
)
with
zmq_socket_ctx
(
input_path
,
zmq
.
DEALER
,
identity
=
identity
,
bind
=
False
)
as
socket
:
# Send ready message to front-end once input socket is connected.
message_dict
=
{
'type'
:
'READY'
,
'num_gpu_blocks'
:
self
.
vllm_config
.
cache_config
.
num_gpu_blocks
,
}
message
=
json
.
dumps
(
message_dict
).
encode
(
'utf-8'
)
socket
.
send
(
message
)
while
True
:
# (RequestType, RequestData)
type_frame
,
*
data_frames
=
socket
.
recv_multipart
(
copy
=
False
)
request_type
=
EngineCoreRequestType
(
bytes
(
type_frame
.
buffer
))
while
True
:
# (RequestType, RequestData)
type_frame
,
*
data_frames
=
input_socket
.
recv_multipart
(
copy
=
False
)
request_type
=
EngineCoreRequestType
(
bytes
(
type_frame
.
buffer
))
# Deserialize the request data.
decoder
=
add_request_decoder
if
(
request_type
==
EngineCoreRequestType
.
ADD
)
else
generic_decoder
request
=
decoder
.
decode
(
data_frames
)
# Deserialize the request data.
decoder
=
add_request_decoder
if
(
request_type
==
EngineCoreRequestType
.
ADD
)
else
generic_decoder
request
=
decoder
.
decode
(
data_frames
)
# Push to input queue for core busy loop.
self
.
input_queue
.
put_nowait
((
request_type
,
request
))
# Push to input queue for core busy loop.
self
.
input_queue
.
put_nowait
((
request_type
,
request
))
def
process_output_socket
(
self
,
output_path
:
str
,
engine_index
:
int
):
"""Output socket IO thread."""
...
...
@@ -609,9 +660,9 @@ class DPEngineCoreProc(EngineCoreProc):
def
__init__
(
self
,
input_path
:
str
,
output_path
:
str
,
vllm_config
:
VllmConfig
,
on_head_node
:
bool
,
input_address
:
str
,
executor_class
:
type
[
Executor
],
log_stats
:
bool
,
):
...
...
@@ -623,8 +674,20 @@ class DPEngineCoreProc(EngineCoreProc):
_add_prefix
(
sys
.
stdout
,
process_name
,
pid
)
_add_prefix
(
sys
.
stderr
,
process_name
,
pid
)
dp_size
=
vllm_config
.
parallel_config
.
data_parallel_size
# Counts forward-passes of the model so that we can synchronize
# finished with DP peers every N steps.
self
.
counter
=
0
# Initialize the engine.
dp_rank
=
vllm_config
.
parallel_config
.
data_parallel_rank
super
().
__init__
(
vllm_config
,
on_head_node
,
input_address
,
executor_class
,
log_stats
,
dp_rank
)
def
_init_data_parallel
(
self
,
vllm_config
:
VllmConfig
):
# Configure GPUs and stateless process group for data parallel.
dp_rank
=
vllm_config
.
parallel_config
.
data_parallel_rank
dp_size
=
vllm_config
.
parallel_config
.
data_parallel_size
local_dp_rank
=
vllm_config
.
parallel_config
.
data_parallel_rank_local
assert
dp_size
>
1
...
...
@@ -632,24 +695,16 @@ class DPEngineCoreProc(EngineCoreProc):
from
vllm.platforms
import
current_platform
device_control_env_var
=
current_platform
.
device_control_env_var
tp
_size
=
vllm_config
.
parallel_config
.
tensor_parallel
_size
world
_size
=
vllm_config
.
parallel_config
.
world
_size
os
.
environ
[
device_control_env_var
]
=
","
.
join
(
str
(
current_platform
.
device_id_to_physical_device_id
(
i
))
for
i
in
range
(
local_dp_rank
*
tp
_size
,
(
local_dp_rank
+
1
)
*
tp
_size
))
for
i
in
range
(
local_dp_rank
*
world
_size
,
(
local_dp_rank
+
1
)
*
world
_size
))
self
.
local_dp_rank
=
local_dp_rank
self
.
dp_group
=
vllm_config
.
parallel_config
.
stateless_init_dp_group
()
self
.
current_wave
=
0
# Initialize the engine after setting up environment.
super
().
__init__
(
input_path
,
output_path
,
vllm_config
,
executor_class
,
log_stats
,
dp_rank
)
# Counts forward-passes of the model so that we can synchronize
# finished with DP peers every N steps.
self
.
counter
=
0
def
shutdown
(
self
):
super
().
shutdown
()
if
dp_group
:
=
getattr
(
self
,
"dp_group"
,
None
):
...
...
vllm/v1/engine/core_client.py
View file @
55aa7af9
# SPDX-License-Identifier: Apache-2.0
import
asyncio
import
contextlib
import
json
import
queue
import
uuid
import
weakref
...
...
@@ -9,25 +8,27 @@ from abc import ABC, abstractmethod
from
collections
import
deque
from
collections.abc
import
Awaitable
,
Sequence
from
concurrent.futures
import
Future
from
dataclasses
import
dataclass
,
field
from
dataclasses
import
dataclass
from
enum
import
Enum
,
auto
from
threading
import
Thread
from
typing
import
Any
,
Callable
,
Optional
,
TypeVar
,
Union
import
msgspec
import
zmq
import
zmq.asyncio
from
vllm.config
import
VllmConfig
from
vllm.config
import
ParallelConfig
,
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.utils
import
(
get_open_
zmq_inproc_path
,
get_open_zmq_i
p
c_path
,
make_zmq_socket
)
from
vllm.utils
import
(
get_open_
port
,
get_open_zmq_i
npro
c_path
,
get_open_zmq_ipc_path
,
get_tcp_uri
,
make_zmq_socket
)
from
vllm.v1.engine
import
(
EngineCoreOutputs
,
EngineCoreRequest
,
EngineCoreRequestType
,
UtilityOutput
)
from
vllm.v1.engine.core
import
EngineCore
,
EngineCoreProc
from
vllm.v1.engine.exceptions
import
EngineDeadError
from
vllm.v1.executor.abstract
import
Executor
from
vllm.v1.serial_utils
import
MsgpackDecoder
,
MsgpackEncoder
,
bytestr
from
vllm.v1.utils
import
Background
Proc
H
an
dle
from
vllm.v1.utils
import
CoreEngine
Proc
M
an
ager
logger
=
init_logger
(
__name__
)
...
...
@@ -264,45 +265,22 @@ class InprocClient(EngineCoreClient):
return
self
.
engine_core
.
collective_rpc
(
method
,
timeout
,
args
,
kwargs
)
class
CoreEngineState
(
Enum
):
NEW
=
auto
()
CONNECTED
=
auto
()
READY
=
auto
()
class
CoreEngine
:
"""One per data parallel rank."""
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
executor_class
:
type
[
Executor
],
log_stats
:
bool
,
input_path
:
str
,
output_path
:
str
,
index
:
int
=
0
,
local_dp_rank
:
int
=
0
,
):
def
__init__
(
self
,
index
:
int
=
0
,
local
:
bool
=
True
):
self
.
local
=
local
self
.
index
=
index
self
.
identity
=
index
.
to_bytes
(
length
=
2
,
byteorder
=
"little"
)
try
:
# Start EngineCore in background process.
self
.
proc_handle
=
BackgroundProcHandle
(
input_path
=
input_path
,
output_path
=
output_path
,
process_name
=
f
"EngineCore_
{
index
}
"
,
target_fn
=
EngineCoreProc
.
run_engine_core
,
process_kwargs
=
{
"vllm_config"
:
vllm_config
,
"dp_rank"
:
index
,
"local_dp_rank"
:
local_dp_rank
,
"executor_class"
:
executor_class
,
"log_stats"
:
log_stats
,
})
self
.
num_reqs_in_flight
=
0
finally
:
if
not
hasattr
(
self
,
"num_reqs_in_flight"
):
# Ensure socket is closed if process fails to start.
self
.
close
()
def
close
(
self
):
if
proc_handle
:
=
getattr
(
self
,
"proc_handle"
,
None
):
proc_handle
.
shutdown
()
self
.
state
=
CoreEngineState
.
NEW
self
.
num_reqs_in_flight
=
0
@
dataclass
...
...
@@ -311,7 +289,7 @@ class BackgroundResources:
circular reference back to the client object."""
ctx
:
Union
[
zmq
.
Context
]
core
_engine
s
:
list
[
CoreEngine
]
=
field
(
default_factory
=
list
)
local
_engine
_manager
:
Optional
[
CoreEngineProcManager
]
=
None
output_socket
:
Optional
[
Union
[
zmq
.
Socket
,
zmq
.
asyncio
.
Socket
]]
=
None
input_socket
:
Optional
[
Union
[
zmq
.
Socket
,
zmq
.
asyncio
.
Socket
]]
=
None
output_queue_task
:
Optional
[
asyncio
.
Task
]
=
None
...
...
@@ -325,8 +303,8 @@ class BackgroundResources:
"""Clean up background resources."""
self
.
engine_dead
=
True
for
core_engine
in
self
.
core_engi
ne
s
:
core_engine
.
close
()
if
self
.
local_engine_manager
is
not
No
ne
:
self
.
local_engine_manager
.
close
()
if
self
.
output_queue_task
is
not
None
:
self
.
output_queue_task
.
cancel
()
...
...
@@ -388,25 +366,56 @@ class MPClient(EngineCoreClient):
self
.
_finalizer
=
weakref
.
finalize
(
self
,
self
.
resources
)
success
=
False
try
:
# Paths and sockets for IPC.
self
.
output_path
=
get_open_zmq_ipc_path
()
input_path
=
get_open_zmq_ipc_path
()
self
.
input_socket
=
make_zmq_socket
(
self
.
ctx
,
input_path
,
zmq
.
ROUTER
,
bind
=
True
)
self
.
resources
.
input_socket
=
self
.
input_socket
new_core_engine
=
lambda
index
,
local_dp_rank
=
None
:
CoreEngine
(
vllm_config
,
executor_class
,
log_stats
,
input_path
,
self
.
output_path
,
index
,
local_dp_rank
)
# Start engine core process(es).
self
.
_init_core_engines
(
vllm_config
,
new_core_engine
,
self
.
resources
.
core_engines
)
parallel_config
=
vllm_config
.
parallel_config
local_engine_count
=
parallel_config
.
data_parallel_size_local
start_index
=
parallel_config
.
data_parallel_rank
local_start_index
=
parallel_config
.
data_parallel_rank_local
# SPMD mode is where there is an LLM instance per DP rank and
# one core engine per LLM, see
# examples/offline_inference/data_parallel.py.
spmd_mode
=
local_start_index
is
not
None
if
spmd_mode
:
assert
local_engine_count
==
1
self
.
core_engines
=
[
CoreEngine
(
index
=
local_start_index
,
local
=
True
)
]
else
:
assert
start_index
==
0
local_start_index
=
0
self
.
core_engines
=
[
CoreEngine
(
index
=
i
,
local
=
(
i
<
local_engine_count
))
for
i
in
range
(
parallel_config
.
data_parallel_size
)
]
input_address
,
output_address
=
self
.
_get_zmq_addresses
(
parallel_config
,
spmd_mode
)
# Create input and output sockets.
self
.
input_socket
=
self
.
resources
.
input_socket
=
make_zmq_socket
(
self
.
ctx
,
input_address
,
zmq
.
ROUTER
,
bind
=
True
)
self
.
resources
.
output_socket
=
make_zmq_socket
(
self
.
ctx
,
output_address
,
zmq
.
constants
.
PULL
)
# Start local engines.
if
local_engine_count
:
# In server mode, start_index and local_start_index will
# both be 0.
self
.
resources
.
local_engine_manager
=
CoreEngineProcManager
(
EngineCoreProc
.
run_engine_core
,
vllm_config
=
vllm_config
,
executor_class
=
executor_class
,
log_stats
=
log_stats
,
input_address
=
input_address
,
on_head_node
=
True
,
local_engine_count
=
local_engine_count
,
start_index
=
start_index
,
local_start_index
=
local_start_index
)
self
.
core_engine
=
self
.
core_engines
[
0
]
# Wait for engine core process(es) to start.
self
.
_wait_for_engine_startup
()
self
.
_wait_for_engine_startup
(
output_address
,
parallel_config
)
self
.
utility_results
:
dict
[
int
,
AnyFuture
]
=
{}
...
...
@@ -420,56 +429,116 @@ class MPClient(EngineCoreClient):
if
not
success
:
self
.
_finalizer
()
def
_wait_for_engine_startup
(
self
):
@
staticmethod
def
_get_zmq_addresses
(
parallel_config
:
ParallelConfig
,
spmd_mode
:
bool
)
->
tuple
[
str
,
str
]:
"""Returns (input_address, output_address)."""
dp_size
=
parallel_config
.
data_parallel_size
local_engine_count
=
parallel_config
.
data_parallel_size_local
if
local_engine_count
==
dp_size
or
spmd_mode
:
input_address
=
get_open_zmq_ipc_path
()
output_address
=
get_open_zmq_ipc_path
()
else
:
host
=
parallel_config
.
data_parallel_master_ip
input_port
=
parallel_config
.
data_parallel_rpc_port
output_port
=
get_open_port
()
input_address
=
get_tcp_uri
(
host
,
input_port
)
output_address
=
get_tcp_uri
(
host
,
output_port
)
return
input_address
,
output_address
def
_wait_for_engine_startup
(
self
,
output_address
:
str
,
parallel_config
:
ParallelConfig
):
# Get a sync handle to the socket which can be sync or async.
sync_input_socket
=
zmq
.
Socket
.
shadow
(
self
.
input_socket
)
# Wait for engine core process(es) to send ready messages.
identities
=
set
(
eng
.
index
for
eng
in
self
.
resources
.
core_engines
)
local_count
=
parallel_config
.
data_parallel_size_local
remote_count
=
len
(
self
.
core_engines
)
-
local_count
# [local, remote] counts
conn_pending
,
start_pending
=
[
local_count
,
remote_count
],
[
0
,
0
]
poller
=
zmq
.
Poller
()
poller
.
register
(
sync_input_socket
,
zmq
.
POLLIN
)
for
eng
in
self
.
resources
.
core_engines
:
poller
.
register
(
eng
.
proc_handle
,
zmq
.
POLLIN
)
while
identities
:
proc_manager
=
self
.
resources
.
local_engine_manager
if
proc_manager
is
not
None
:
for
sentinel
in
proc_manager
.
sentinels
():
poller
.
register
(
sentinel
,
zmq
.
POLLIN
)
while
any
(
conn_pending
)
or
any
(
start_pending
):
events
=
poller
.
poll
(
STARTUP_POLL_PERIOD_MS
)
if
not
events
:
logger
.
debug
(
"Waiting for %d core engine proc(s) to start: %s"
,
len
(
identities
),
identities
)
if
any
(
conn_pending
):
logger
.
debug
(
"Waiting for %d local, %d remote core engine proc(s) "
"to connect."
,
*
conn_pending
)
if
any
(
start_pending
):
logger
.
debug
(
"Waiting for %d local, %d remote core engine proc(s) "
"to start."
,
*
start_pending
)
continue
if
len
(
events
)
>
1
or
events
[
0
][
0
]
!=
sync_input_socket
:
# One of the core processes exited.
# One of the local core processes exited.
finished
=
proc_manager
.
finished_procs
(
)
if
proc_manager
else
{}
raise
RuntimeError
(
"Engine core initialization failed. "
"See root cause above."
)
eng_id_bytes
,
data
=
sync_input_socket
.
recv_multipart
()
eng_id
=
int
.
from_bytes
(
eng_id_bytes
,
byteorder
=
"little"
)
if
eng_id
not
in
identities
:
raise
RuntimeError
(
f
"Unexpected or duplicate engine:
{
eng_id
}
"
)
message_dict
=
json
.
loads
(
data
.
decode
(
'utf-8'
))
if
message_dict
[
'type'
]
!=
'READY'
:
raise
RuntimeError
(
f
"Engine
{
eng_id
}
failed:
{
data
.
decode
()
}
"
)
logger
.
info
(
"Core engine process %d ready."
,
eng_id
)
identities
.
discard
(
eng_id
)
# Setup KV cache config with initialization state from
# engine core process. Sum values from all engines in DP case.
num_gpu_blocks
=
self
.
vllm_config
.
cache_config
.
num_gpu_blocks
or
0
num_gpu_blocks
+=
message_dict
[
'num_gpu_blocks'
]
self
.
vllm_config
.
cache_config
.
num_gpu_blocks
=
num_gpu_blocks
def
_init_core_engines
(
self
,
vllm_config
:
VllmConfig
,
new_core_engine
:
Callable
[[
int
,
Optional
[
int
]],
CoreEngine
],
core_engines
:
list
[
CoreEngine
],
)
->
None
:
# Default case - single core engine.
core_engine
=
new_core_engine
(
vllm_config
.
parallel_config
.
data_parallel_rank
,
vllm_config
.
parallel_config
.
data_parallel_rank_local
,
)
core_engines
.
append
(
core_engine
)
self
.
core_engine
=
core_engine
"See root cause above. "
f
"Failed core proc(s):
{
finished
}
"
)
# Receive HELLO and READY messages from the input socket.
eng_identity
,
ready_msg_bytes
=
sync_input_socket
.
recv_multipart
()
eng_index
=
int
.
from_bytes
(
eng_identity
,
byteorder
=
"little"
)
engine
=
next
(
(
e
for
e
in
self
.
core_engines
if
e
.
identity
==
eng_identity
),
None
)
if
engine
is
None
:
raise
RuntimeError
(
f
"Message from engine with unexpected data "
f
"parallel rank:
{
eng_index
}
"
)
msg
=
msgspec
.
msgpack
.
decode
(
ready_msg_bytes
)
status
,
local
=
msg
[
"status"
],
msg
[
"local"
]
if
local
!=
engine
.
local
:
raise
RuntimeError
(
f
"
{
status
}
message from "
f
"
{
'local'
if
local
else
'remote'
}
"
f
"engine
{
eng_index
}
, expected it to be "
f
"
{
'local'
if
engine
.
local
else
'remote'
}
"
)
if
status
==
"HELLO"
and
engine
.
state
==
CoreEngineState
.
NEW
:
# Send init message with DP config info.
init_message
=
self
.
encoder
.
encode
({
"output_socket_address"
:
output_address
,
"parallel_config"
:
{
"data_parallel_master_ip"
:
parallel_config
.
data_parallel_master_ip
,
"data_parallel_master_port"
:
parallel_config
.
data_parallel_master_port
,
"data_parallel_size"
:
parallel_config
.
data_parallel_size
,
},
})
sync_input_socket
.
send_multipart
((
eng_identity
,
*
init_message
),
copy
=
False
)
conn_pending
[
0
if
local
else
1
]
-=
1
start_pending
[
0
if
local
else
1
]
+=
1
engine
.
state
=
CoreEngineState
.
CONNECTED
elif
status
==
"READY"
and
(
engine
.
state
==
CoreEngineState
.
CONNECTED
):
# Setup KV cache config with initialization state from
# engine core process. Sum values from all engines in DP case.
cache_config
=
self
.
vllm_config
.
cache_config
num_gpu_blocks
=
cache_config
.
num_gpu_blocks
or
0
num_gpu_blocks
+=
msg
[
'num_gpu_blocks'
]
cache_config
.
num_gpu_blocks
=
num_gpu_blocks
start_pending
[
0
if
local
else
1
]
-=
1
engine
.
state
=
CoreEngineState
.
READY
else
:
raise
RuntimeError
(
f
"Unexpected
{
status
}
message for "
f
"
{
'local'
if
local
else
'remote'
}
engine "
f
"
{
eng_index
}
in
{
engine
.
state
}
state."
)
logger
.
debug
(
"%s from %s core engine process %s."
,
status
,
"local"
if
local
else
"remote"
,
eng_index
)
def
shutdown
(
self
):
# Terminate background resources.
...
...
@@ -520,7 +589,8 @@ class SyncMPClient(MPClient):
# Ensure that the outputs socket processing thread does not have
# a ref to the client which prevents gc.
ctx
=
self
.
ctx
output_path
=
self
.
output_path
out_socket
=
self
.
resources
.
output_socket
assert
out_socket
is
not
None
decoder
=
self
.
decoder
utility_results
=
self
.
utility_results
outputs_queue
=
self
.
outputs_queue
...
...
@@ -531,7 +601,6 @@ class SyncMPClient(MPClient):
def
process_outputs_socket
():
shutdown_socket
=
ctx
.
socket
(
zmq
.
PAIR
)
out_socket
=
make_zmq_socket
(
ctx
,
output_path
,
zmq
.
constants
.
PULL
)
try
:
shutdown_socket
.
bind
(
shutdown_path
)
poller
=
zmq
.
Poller
()
...
...
@@ -566,6 +635,9 @@ class SyncMPClient(MPClient):
daemon
=
True
)
self
.
output_queue_thread
.
start
()
# The thread takes on responsibility for closing the socket.
self
.
resources
.
output_socket
=
None
def
get_output
(
self
)
->
EngineCoreOutputs
:
# If an exception arises in process_outputs_socket task,
# it is forwarded to the outputs_queue so we can raise it
...
...
@@ -693,10 +765,8 @@ class AsyncMPClient(MPClient):
self
.
__class__
,
"process_engine_outputs"
,
None
)
_self_ref
=
weakref
.
ref
(
self
)
if
output_handler
else
None
output_path
=
self
.
output_path
output_socket
=
make_zmq_socket
(
self
.
ctx
,
output_path
,
zmq
.
constants
.
PULL
)
resources
.
output_socket
=
output_socket
output_socket
=
resources
.
output_socket
assert
output_socket
is
not
None
async
def
process_outputs_socket
():
try
:
...
...
@@ -861,21 +931,6 @@ class DPAsyncMPClient(AsyncMPClient):
assert
len
(
self
.
core_engines
)
>
1
def
_init_core_engines
(
self
,
vllm_config
:
VllmConfig
,
new_core_engine
:
Callable
[[
int
,
Optional
[
int
]],
CoreEngine
],
core_engines
:
list
[
CoreEngine
],
)
->
None
:
# Launch a core engine for each data parallel rank.
dp_size
=
vllm_config
.
parallel_config
.
data_parallel_size
for
i
in
range
(
dp_size
):
# Multi-node not yet supported so local_dp_rank == dp_rank.
core_engines
.
append
(
new_core_engine
(
i
,
i
))
self
.
core_engines
=
core_engines
async
def
call_utility_async
(
self
,
method
:
str
,
*
args
)
->
Any
:
# Only the result from the first engine is returned.
return
(
await
asyncio
.
gather
(
*
[
...
...
vllm/v1/utils.py
View file @
55aa7af9
# SPDX-License-Identifier: Apache-2.0
import
os
import
time
import
weakref
from
collections
import
defaultdict
from
collections.abc
import
Sequence
from
multiprocessing
import
Process
from
typing
import
(
TYPE_CHECKING
,
Any
,
Callable
,
Generic
,
Optional
,
TypeVar
,
Union
,
overload
)
from
multiprocessing
import
Process
,
connection
from
typing
import
(
TYPE_CHECKING
,
Callable
,
Generic
,
Optional
,
TypeVar
,
Union
,
overload
)
import
torch
from
vllm.config
import
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.model_executor.models.utils
import
extract_layer_index
from
vllm.usage.usage_lib
import
(
UsageContext
,
is_usage_stats_enabled
,
usage_message
)
from
vllm.utils
import
get_mp_context
,
kill_process_tree
from
vllm.v1.executor.abstract
import
Executor
if
TYPE_CHECKING
:
from
vllm.attention.layer
import
Attention
...
...
@@ -92,7 +95,7 @@ class ConstantList(Generic[T], Sequence):
return
f
"ConstantList(
{
self
.
_x
}
)"
class
Background
Proc
H
an
dle
:
class
CoreEngine
Proc
M
an
ager
:
"""
Utility class to handle creation, readiness, and shutdown
of background processes used by the AsyncLLM and LLMEngine.
...
...
@@ -100,49 +103,91 @@ class BackgroundProcHandle:
def
__init__
(
self
,
input_path
:
str
,
output_path
:
str
,
process_name
:
str
,
target_fn
:
Callable
,
process_kwargs
:
dict
[
Any
,
Any
],
local_engine_count
:
int
,
start_index
:
int
,
local_start_index
:
int
,
vllm_config
:
VllmConfig
,
on_head_node
:
bool
,
input_address
:
str
,
executor_class
:
type
[
Executor
],
log_stats
:
bool
,
):
context
=
get_mp_context
()
common_kwargs
=
{
"vllm_config"
:
vllm_config
,
"on_head_node"
:
on_head_node
,
"input_address"
:
input_address
,
"executor_class"
:
executor_class
,
"log_stats"
:
log_stats
,
}
self
.
processes
:
list
[
Process
]
=
[]
for
index
in
range
(
local_engine_count
):
local_index
=
local_start_index
+
index
global_index
=
start_index
+
index
# Start EngineCore in background process.
self
.
processes
.
append
(
context
.
Process
(
target
=
target_fn
,
name
=
f
"EngineCore_
{
global_index
}
"
,
kwargs
=
common_kwargs
|
{
"dp_rank"
:
global_index
,
"local_dp_rank"
:
local_index
,
}))
self
.
_finalizer
=
weakref
.
finalize
(
self
,
shutdown
,
self
.
processes
,
input_address
)
try
:
for
proc
in
self
.
processes
:
proc
.
start
()
finally
:
# Kill other procs if not all are running.
if
self
.
finished_procs
():
self
.
close
()
def
close
(
self
):
"""Shutdown all procs."""
self
.
_finalizer
()
assert
(
"input_path"
not
in
process_kwargs
and
"output_path"
not
in
process_kwargs
)
process_kwargs
[
"input_path"
]
=
input_path
process_kwargs
[
"output_path"
]
=
output_path
# Run busy loop in background process.
self
.
proc
:
Process
=
context
.
Process
(
target
=
target_fn
,
kwargs
=
process_kwargs
,
name
=
process_name
)
self
.
_finalizer
=
weakref
.
finalize
(
self
,
shutdown
,
self
.
proc
,
input_path
,
output_path
)
self
.
proc
.
start
()
def
join_first
(
self
):
"""Wait for any process to exit."""
connection
.
wait
(
proc
.
sentinel
for
proc
in
self
.
processes
)
def
fileno
(
self
)
:
return
self
.
proc
.
sentinel
def
sentinels
(
self
)
->
list
:
return
[
proc
.
sentinel
for
proc
in
self
.
processes
]
def
shutdown
(
self
):
self
.
_finalizer
()
def
finished_procs
(
self
)
->
dict
[
str
,
int
]:
"""Returns dict of proc name -> exit code for any finished procs."""
return
{
proc
.
name
:
proc
.
exitcode
for
proc
in
self
.
processes
if
proc
.
exitcode
is
not
None
}
# Note(rob): shutdown function cannot be a bound method,
# else the gc cannot collect the object.
def
shutdown
(
proc
:
Process
,
input_
path
:
str
,
output_path
:
str
):
# else the gc cannot collect the obje
decoup
ct.
def
shutdown
(
proc
s
:
list
[
Process
]
,
input_
address
:
str
):
# Shutdown the process.
if
proc
.
is_alive
():
proc
.
terminate
()
proc
.
join
(
5
)
for
proc
in
procs
:
if
proc
.
is_alive
():
proc
.
terminate
()
# Allow 5 seconds for remaining procs to terminate.
deadline
=
time
.
monotonic
()
+
5
for
proc
in
procs
:
remaining
=
deadline
-
time
.
monotonic
()
if
remaining
<=
0
:
break
if
proc
.
is_alive
():
proc
.
join
(
remaining
)
for
proc
in
procs
:
if
proc
.
is_alive
()
and
(
pid
:
=
proc
.
pid
)
is
not
None
:
kill_process_tree
(
pid
)
# Remove zmq ipc socket files.
ipc_sockets
=
[
output_path
,
input_path
]
for
ipc_socket
in
ipc_sockets
:
socket_file
=
ipc_socket
.
replace
(
"ipc://"
,
""
)
if
input_address
.
startswith
(
"ipc://"
):
socket_file
=
input_address
[
len
(
"ipc://"
):]
if
os
and
os
.
path
.
exists
(
socket_file
):
os
.
remove
(
socket_file
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment