Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
55aa7af9
Unverified
Commit
55aa7af9
authored
May 13, 2025
by
Nick Hill
Committed by
GitHub
May 13, 2025
Browse files
[V1] DP scale-out (2/N): Decouple engine process management and comms (#15977)
Signed-off-by:
Nick Hill
<
nhill@redhat.com
>
parent
0b217da6
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
525 additions
and
252 deletions
+525
-252
tests/async_engine/test_async_llm_engine.py
tests/async_engine/test_async_llm_engine.py
+1
-1
tests/v1/engine/test_engine_core_client.py
tests/v1/engine/test_engine_core_client.py
+8
-7
vllm/config.py
vllm/config.py
+19
-22
vllm/distributed/utils.py
vllm/distributed/utils.py
+2
-1
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+38
-0
vllm/entrypoints/cli/serve.py
vllm/entrypoints/cli/serve.py
+79
-2
vllm/utils.py
vllm/utils.py
+4
-0
vllm/v1/engine/core.py
vllm/v1/engine/core.py
+122
-67
vllm/v1/engine/core_client.py
vllm/v1/engine/core_client.py
+174
-119
vllm/v1/utils.py
vllm/v1/utils.py
+78
-33
No files found.
tests/async_engine/test_async_llm_engine.py
View file @
55aa7af9
...
@@ -41,7 +41,7 @@ class MockEngine:
...
@@ -41,7 +41,7 @@ class MockEngine:
self
.
abort_request_calls
=
0
self
.
abort_request_calls
=
0
self
.
request_id
=
None
self
.
request_id
=
None
# Ugly, remove dependency when possible
# Ugly, remove dependency when possible
self
.
parallel_config
=
ParallelConfig
(
1
,
1
,
False
)
self
.
parallel_config
=
ParallelConfig
()
self
.
model_config
=
MockModelConfig
()
self
.
model_config
=
MockModelConfig
()
async
def
step_async
(
self
,
virtual_engine
):
async
def
step_async
(
self
,
virtual_engine
):
...
...
tests/v1/engine/test_engine_core_client.py
View file @
55aa7af9
...
@@ -18,9 +18,10 @@ from vllm.platforms import current_platform
...
@@ -18,9 +18,10 @@ from vllm.platforms import current_platform
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.v1.engine
import
EngineCoreRequest
from
vllm.v1.engine
import
EngineCoreRequest
from
vllm.v1.engine.core
import
EngineCore
from
vllm.v1.engine.core
import
EngineCore
from
vllm.v1.engine.core_client
import
(
AsyncMPClient
,
Core
Engine
,
from
vllm.v1.engine.core_client
import
(
AsyncMPClient
,
Engine
CoreClient
,
EngineCoreClient
,
SyncMPClient
)
SyncMPClient
)
from
vllm.v1.executor.abstract
import
Executor
from
vllm.v1.executor.abstract
import
Executor
from
vllm.v1.utils
import
CoreEngineProcManager
from
...distributed.conftest
import
MockSubscriber
from
...distributed.conftest
import
MockSubscriber
from
...utils
import
create_new_process_for_each_test
from
...utils
import
create_new_process_for_each_test
...
@@ -348,13 +349,13 @@ def test_startup_failure(monkeypatch: pytest.MonkeyPatch):
...
@@ -348,13 +349,13 @@ def test_startup_failure(monkeypatch: pytest.MonkeyPatch):
# Monkey-patch to extract core process pid while it's starting.
# Monkey-patch to extract core process pid while it's starting.
core_proc_pid
=
[
None
]
core_proc_pid
=
[
None
]
ce_ctor
=
CoreEngine
.
__init__
ce
pm
_ctor
=
CoreEngine
ProcManager
.
__init__
def
patched_ce_ctor
(
self
,
*
args
,
**
kwargs
):
def
patched_ce
pm
_ctor
(
self
:
CoreEngineProcManager
,
*
args
,
**
kwargs
):
ce_ctor
(
self
,
*
args
,
**
kwargs
)
ce
pm
_ctor
(
self
,
*
args
,
**
kwargs
)
core_proc_pid
[
0
]
=
self
.
proc
_handle
.
proc
.
pid
core_proc_pid
[
0
]
=
self
.
proc
esses
[
0
]
.
pid
m
.
setattr
(
CoreEngine
,
"__init__"
,
patched_ce_ctor
)
m
.
setattr
(
CoreEngine
ProcManager
,
"__init__"
,
patched_ce
pm
_ctor
)
t
=
time
.
time
()
t
=
time
.
time
()
engine_args
=
EngineArgs
(
model
=
MODEL_NAME
)
engine_args
=
EngineArgs
(
model
=
MODEL_NAME
)
...
...
vllm/config.py
View file @
55aa7af9
...
@@ -1668,25 +1668,17 @@ class ParallelConfig:
...
@@ -1668,25 +1668,17 @@ class ParallelConfig:
data_parallel_size
:
int
=
1
data_parallel_size
:
int
=
1
"""Number of data parallel groups. MoE layers will be sharded according to
"""Number of data parallel groups. MoE layers will be sharded according to
the product of the tensor parallel size and data parallel size."""
the product of the tensor parallel size and data parallel size."""
data_parallel_size_local
:
int
=
1
"""Number of local data parallel groups."""
data_parallel_rank
:
int
=
0
data_parallel_rank
:
int
=
0
"""Rank of the data parallel group."""
"""Rank of the data parallel group."""
_data_parallel_rank_local
:
Optional
[
int
]
=
field
(
default
=
None
,
init
=
False
)
data_parallel_rank_local
:
Optional
[
int
]
=
None
"""Private field to store the local rank of the data parallel group."""
"""Local rank of the data parallel group,
set only in SPMD mode."""
@
property
def
data_parallel_rank_local
(
self
)
->
int
:
"""Local rank of the data parallel group, defaults to global rank."""
if
self
.
_data_parallel_rank_local
is
None
:
return
self
.
data_parallel_rank
return
self
.
_data_parallel_rank_local
@
data_parallel_rank_local
.
setter
def
data_parallel_rank_local
(
self
,
value
:
int
)
->
None
:
"""Set the local rank of the data parallel group."""
self
.
_data_parallel_rank_local
=
value
data_parallel_master_ip
:
str
=
"127.0.0.1"
data_parallel_master_ip
:
str
=
"127.0.0.1"
"""IP of the data parallel master."""
"""IP of the data parallel master."""
data_parallel_rpc_port
:
int
=
29550
"""Port for data parallel messaging."""
data_parallel_master_port
:
int
=
29500
data_parallel_master_port
:
int
=
29500
"""Port of the data parallel master."""
"""Port of the data parallel master."""
enable_expert_parallel
:
bool
=
False
enable_expert_parallel
:
bool
=
False
...
@@ -1734,13 +1726,16 @@ class ParallelConfig:
...
@@ -1734,13 +1726,16 @@ class ParallelConfig:
world_size
:
int
=
field
(
init
=
False
)
world_size
:
int
=
field
(
init
=
False
)
"""world_size is TPxPP, it affects the number of workers we create."""
"""world_size is TPxPP, it affects the number of workers we create."""
world_size_across_dp
:
int
=
field
(
init
=
False
)
"""world_size_across_dp is TPxPPxDP, it is the size of the world
including data parallelism."""
rank
:
int
=
0
rank
:
int
=
0
"""Global rank in distributed setup."""
"""Global rank in distributed setup."""
@
property
def
world_size_across_dp
(
self
)
->
int
:
"""world_size_across_dp is TPxPPxDP, it is the size of the world
including data parallelism."""
return
self
.
world_size
*
self
.
data_parallel_size
def
get_next_dp_init_port
(
self
)
->
int
:
def
get_next_dp_init_port
(
self
)
->
int
:
"""
"""
We might need to initialize process groups in multiple
We might need to initialize process groups in multiple
...
@@ -1800,10 +1795,14 @@ class ParallelConfig:
...
@@ -1800,10 +1795,14 @@ class ParallelConfig:
self
.
world_size
=
self
.
pipeline_parallel_size
*
\
self
.
world_size
=
self
.
pipeline_parallel_size
*
\
self
.
tensor_parallel_size
self
.
tensor_parallel_size
if
self
.
data_parallel_size
>
1
:
if
self
.
data_parallel_size_local
>
self
.
data_parallel_size
:
raise
ValueError
(
f
"data_parallel_size_local (
{
self
.
data_parallel_size_local
}
) "
f
"must be <= data_parallel_size (
{
self
.
data_parallel_size
}
)"
)
if
self
.
data_parallel_size
>
1
or
self
.
data_parallel_size_local
==
0
:
# Data parallel was specified in the engine args.
# Data parallel was specified in the engine args.
self
.
data_parallel_master_port
=
get_open_port
()
self
.
data_parallel_master_port
=
get_open_port
()
# TODO multi-node
else
:
else
:
# Otherwise fall back to env vars (e.g. for offline SPMD case).
# Otherwise fall back to env vars (e.g. for offline SPMD case).
self
.
data_parallel_size
=
envs
.
VLLM_DP_SIZE
self
.
data_parallel_size
=
envs
.
VLLM_DP_SIZE
...
@@ -1812,8 +1811,6 @@ class ParallelConfig:
...
@@ -1812,8 +1811,6 @@ class ParallelConfig:
self
.
data_parallel_master_ip
=
envs
.
VLLM_DP_MASTER_IP
self
.
data_parallel_master_ip
=
envs
.
VLLM_DP_MASTER_IP
self
.
data_parallel_master_port
=
envs
.
VLLM_DP_MASTER_PORT
self
.
data_parallel_master_port
=
envs
.
VLLM_DP_MASTER_PORT
self
.
world_size_across_dp
=
self
.
world_size
*
self
.
data_parallel_size
if
self
.
distributed_executor_backend
==
"external_launcher"
:
if
self
.
distributed_executor_backend
==
"external_launcher"
:
import
os
import
os
os
.
environ
[
"VLLM_ENABLE_V1_MULTIPROCESSING"
]
=
"0"
os
.
environ
[
"VLLM_ENABLE_V1_MULTIPROCESSING"
]
=
"0"
...
...
vllm/distributed/utils.py
View file @
55aa7af9
...
@@ -22,6 +22,7 @@ from torch.distributed.rendezvous import rendezvous
...
@@ -22,6 +22,7 @@ from torch.distributed.rendezvous import rendezvous
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.utils
import
get_tcp_uri
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -303,7 +304,7 @@ def stateless_init_torch_distributed_process_group(
...
@@ -303,7 +304,7 @@ def stateless_init_torch_distributed_process_group(
always formed with process 1, 2, ..., 8, and the additional communication
always formed with process 1, 2, ..., 8, and the additional communication
channel is formed with process 9 and 10.
channel is formed with process 9 and 10.
"""
"""
init_method
=
f
"tcp://
{
host
}
:
{
port
}
"
init_method
=
get_tcp_uri
(
host
,
port
)
backend
=
Backend
(
backend
)
# it is basically string
backend
=
Backend
(
backend
)
# it is basically string
timeout
=
_get_default_timeout
(
backend
)
timeout
=
_get_default_timeout
(
backend
)
...
...
vllm/engine/arg_utils.py
View file @
55aa7af9
...
@@ -283,6 +283,9 @@ class EngineArgs:
...
@@ -283,6 +283,9 @@ class EngineArgs:
pipeline_parallel_size
:
int
=
ParallelConfig
.
pipeline_parallel_size
pipeline_parallel_size
:
int
=
ParallelConfig
.
pipeline_parallel_size
tensor_parallel_size
:
int
=
ParallelConfig
.
tensor_parallel_size
tensor_parallel_size
:
int
=
ParallelConfig
.
tensor_parallel_size
data_parallel_size
:
int
=
ParallelConfig
.
data_parallel_size
data_parallel_size
:
int
=
ParallelConfig
.
data_parallel_size
data_parallel_size_local
:
Optional
[
int
]
=
None
data_parallel_address
:
Optional
[
str
]
=
None
data_parallel_rpc_port
:
Optional
[
int
]
=
None
enable_expert_parallel
:
bool
=
ParallelConfig
.
enable_expert_parallel
enable_expert_parallel
:
bool
=
ParallelConfig
.
enable_expert_parallel
max_parallel_loading_workers
:
Optional
[
max_parallel_loading_workers
:
Optional
[
int
]
=
ParallelConfig
.
max_parallel_loading_workers
int
]
=
ParallelConfig
.
max_parallel_loading_workers
...
@@ -596,6 +599,21 @@ class EngineArgs:
...
@@ -596,6 +599,21 @@ class EngineArgs:
**
parallel_kwargs
[
"tensor_parallel_size"
])
**
parallel_kwargs
[
"tensor_parallel_size"
])
parallel_group
.
add_argument
(
"--data-parallel-size"
,
"-dp"
,
parallel_group
.
add_argument
(
"--data-parallel-size"
,
"-dp"
,
**
parallel_kwargs
[
"data_parallel_size"
])
**
parallel_kwargs
[
"data_parallel_size"
])
parallel_group
.
add_argument
(
'--data-parallel-size-local'
,
'-dpl'
,
type
=
int
,
help
=
'Number of data parallel replicas '
'to run on this node.'
)
parallel_group
.
add_argument
(
'--data-parallel-address'
,
'-dpa'
,
type
=
str
,
help
=
'Address of data parallel cluster '
'head-node.'
)
parallel_group
.
add_argument
(
'--data-parallel-rpc-port'
,
'-dpp'
,
type
=
int
,
help
=
'Port for data parallel RPC '
'communication.'
)
parallel_group
.
add_argument
(
parallel_group
.
add_argument
(
"--enable-expert-parallel"
,
"--enable-expert-parallel"
,
**
parallel_kwargs
[
"enable_expert_parallel"
])
**
parallel_kwargs
[
"enable_expert_parallel"
])
...
@@ -1019,10 +1037,30 @@ class EngineArgs:
...
@@ -1019,10 +1037,30 @@ class EngineArgs:
# but we should not do this here.
# but we should not do this here.
placement_group
=
ray
.
util
.
get_current_placement_group
()
placement_group
=
ray
.
util
.
get_current_placement_group
()
# Local DP size defaults to global DP size if not set.
data_parallel_size_local
=
self
.
data_parallel_size
if
(
self
.
data_parallel_size_local
is
None
)
else
self
.
data_parallel_size_local
# DP address, used in multi-node case for torch distributed group
# and ZMQ sockets.
data_parallel_address
=
self
.
data_parallel_address
if
(
self
.
data_parallel_address
is
not
None
)
else
ParallelConfig
.
data_parallel_master_ip
# This port is only used when there are remote data parallel engines,
# otherwise the local IPC transport is used.
data_parallel_rpc_port
=
self
.
data_parallel_rpc_port
if
(
self
.
data_parallel_rpc_port
is
not
None
)
else
ParallelConfig
.
data_parallel_rpc_port
parallel_config
=
ParallelConfig
(
parallel_config
=
ParallelConfig
(
pipeline_parallel_size
=
self
.
pipeline_parallel_size
,
pipeline_parallel_size
=
self
.
pipeline_parallel_size
,
tensor_parallel_size
=
self
.
tensor_parallel_size
,
tensor_parallel_size
=
self
.
tensor_parallel_size
,
data_parallel_size
=
self
.
data_parallel_size
,
data_parallel_size
=
self
.
data_parallel_size
,
data_parallel_size_local
=
data_parallel_size_local
,
data_parallel_master_ip
=
data_parallel_address
,
data_parallel_rpc_port
=
data_parallel_rpc_port
,
enable_expert_parallel
=
self
.
enable_expert_parallel
,
enable_expert_parallel
=
self
.
enable_expert_parallel
,
max_parallel_loading_workers
=
self
.
max_parallel_loading_workers
,
max_parallel_loading_workers
=
self
.
max_parallel_loading_workers
,
disable_custom_all_reduce
=
self
.
disable_custom_all_reduce
,
disable_custom_all_reduce
=
self
.
disable_custom_all_reduce
,
...
...
vllm/entrypoints/cli/serve.py
View file @
55aa7af9
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
import
argparse
import
argparse
import
signal
import
uvloop
import
uvloop
import
vllm.envs
as
envs
from
vllm
import
AsyncEngineArgs
from
vllm.entrypoints.cli.types
import
CLISubcommand
from
vllm.entrypoints.cli.types
import
CLISubcommand
from
vllm.entrypoints.openai.api_server
import
run_server
from
vllm.entrypoints.openai.api_server
import
run_server
from
vllm.entrypoints.openai.cli_args
import
(
make_arg_parser
,
from
vllm.entrypoints.openai.cli_args
import
(
make_arg_parser
,
validate_parsed_serve_args
)
validate_parsed_serve_args
)
from
vllm.utils
import
FlexibleArgumentParser
from
vllm.logger
import
init_logger
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.utils
import
FlexibleArgumentParser
,
get_tcp_uri
from
vllm.v1.engine.core
import
EngineCoreProc
from
vllm.v1.engine.core_client
import
CoreEngineProcManager
from
vllm.v1.executor.abstract
import
Executor
logger
=
init_logger
(
__name__
)
class
ServeSubcommand
(
CLISubcommand
):
class
ServeSubcommand
(
CLISubcommand
):
...
@@ -24,6 +34,9 @@ class ServeSubcommand(CLISubcommand):
...
@@ -24,6 +34,9 @@ class ServeSubcommand(CLISubcommand):
if
hasattr
(
args
,
'model_tag'
)
and
args
.
model_tag
is
not
None
:
if
hasattr
(
args
,
'model_tag'
)
and
args
.
model_tag
is
not
None
:
args
.
model
=
args
.
model_tag
args
.
model
=
args
.
model_tag
if
args
.
headless
:
run_headless
(
args
)
else
:
uvloop
.
run
(
run_server
(
args
))
uvloop
.
run
(
run_server
(
args
))
def
validate
(
self
,
args
:
argparse
.
Namespace
)
->
None
:
def
validate
(
self
,
args
:
argparse
.
Namespace
)
->
None
:
...
@@ -42,6 +55,18 @@ class ServeSubcommand(CLISubcommand):
...
@@ -42,6 +55,18 @@ class ServeSubcommand(CLISubcommand):
nargs
=
'?'
,
nargs
=
'?'
,
help
=
"The model tag to serve "
help
=
"The model tag to serve "
"(optional if specified in config)"
)
"(optional if specified in config)"
)
serve_parser
.
add_argument
(
"--headless"
,
action
=
'store_true'
,
default
=
False
,
help
=
"Run in headless mode. See multi-node data parallel "
"documentation for more details."
)
serve_parser
.
add_argument
(
'--data-parallel-start-rank'
,
'-dpr'
,
type
=
int
,
default
=
0
,
help
=
'Starting data parallel rank for secondary nodes.'
)
serve_parser
.
add_argument
(
serve_parser
.
add_argument
(
"--config"
,
"--config"
,
type
=
str
,
type
=
str
,
...
@@ -57,3 +82,55 @@ class ServeSubcommand(CLISubcommand):
...
@@ -57,3 +82,55 @@ class ServeSubcommand(CLISubcommand):
def
cmd_init
()
->
list
[
CLISubcommand
]:
def
cmd_init
()
->
list
[
CLISubcommand
]:
return
[
ServeSubcommand
()]
return
[
ServeSubcommand
()]
def
run_headless
(
args
:
argparse
.
Namespace
):
# Create the EngineConfig.
engine_args
=
AsyncEngineArgs
.
from_cli_args
(
args
)
usage_context
=
UsageContext
.
OPENAI_API_SERVER
vllm_config
=
engine_args
.
create_engine_config
(
usage_context
=
usage_context
)
if
not
envs
.
VLLM_USE_V1
:
raise
RuntimeError
(
"Headless mode is only supported for V1"
)
parallel_config
=
vllm_config
.
parallel_config
local_engine_count
=
parallel_config
.
data_parallel_size_local
host
=
parallel_config
.
data_parallel_master_ip
port
=
engine_args
.
data_parallel_rpc_port
# add to config too
input_address
=
get_tcp_uri
(
host
,
port
)
if
local_engine_count
<=
0
:
raise
RuntimeError
(
"data_parallel_size_local must be > 0 in "
"headless mode"
)
# Catch SIGTERM and SIGINT to allow graceful shutdown.
def
signal_handler
(
signum
,
frame
):
logger
.
debug
(
"Received %d signal."
,
signum
)
raise
SystemExit
signal
.
signal
(
signal
.
SIGTERM
,
signal_handler
)
signal
.
signal
(
signal
.
SIGINT
,
signal_handler
)
logger
.
info
(
"Launching %d data parallel engine(s) in headless mode, "
"with head node address %s."
,
local_engine_count
,
input_address
)
# Create the engines.
engine_manager
=
CoreEngineProcManager
(
target_fn
=
EngineCoreProc
.
run_engine_core
,
local_engine_count
=
local_engine_count
,
start_index
=
args
.
data_parallel_start_rank
,
local_start_index
=
0
,
vllm_config
=
vllm_config
,
on_head_node
=
False
,
input_address
=
input_address
,
executor_class
=
Executor
.
get_class
(
vllm_config
),
log_stats
=
not
engine_args
.
disable_log_stats
,
)
try
:
engine_manager
.
join_first
()
finally
:
logger
.
info
(
"Shutting down."
)
engine_manager
.
close
()
vllm/utils.py
View file @
55aa7af9
...
@@ -613,6 +613,10 @@ def is_valid_ipv6_address(address: str) -> bool:
...
@@ -613,6 +613,10 @@ def is_valid_ipv6_address(address: str) -> bool:
def
get_distributed_init_method
(
ip
:
str
,
port
:
int
)
->
str
:
def
get_distributed_init_method
(
ip
:
str
,
port
:
int
)
->
str
:
return
get_tcp_uri
(
ip
,
port
)
def
get_tcp_uri
(
ip
:
str
,
port
:
int
)
->
str
:
# Brackets are not permitted in ipv4 addresses,
# Brackets are not permitted in ipv4 addresses,
# see https://github.com/python/cpython/issues/103848
# see https://github.com/python/cpython/issues/103848
return
f
"tcp://[
{
ip
}
]:
{
port
}
"
if
":"
in
ip
else
f
"tcp://
{
ip
}
:
{
port
}
"
return
f
"tcp://[
{
ip
}
]:
{
port
}
"
if
":"
in
ip
else
f
"tcp://
{
ip
}
:
{
port
}
"
...
...
vllm/v1/engine/core.py
View file @
55aa7af9
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
import
json
import
os
import
os
import
queue
import
queue
import
signal
import
signal
...
@@ -23,7 +22,7 @@ from vllm.logging_utils.dump_input import dump_engine_exception
...
@@ -23,7 +22,7 @@ from vllm.logging_utils.dump_input import dump_engine_exception
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.transformers_utils.config
import
(
from
vllm.transformers_utils.config
import
(
maybe_register_config_serialize_by_value
)
maybe_register_config_serialize_by_value
)
from
vllm.utils
import
resolve_obj_by_qualname
,
zmq_socket_ctx
from
vllm.utils
import
make_zmq_socket
,
resolve_obj_by_qualname
,
zmq_socket_ctx
from
vllm.v1.core.kv_cache_utils
import
(
get_kv_cache_config
,
from
vllm.v1.core.kv_cache_utils
import
(
get_kv_cache_config
,
unify_kv_cache_configs
)
unify_kv_cache_configs
)
from
vllm.v1.core.sched.interface
import
SchedulerInterface
from
vllm.v1.core.sched.interface
import
SchedulerInterface
...
@@ -43,6 +42,7 @@ from vllm.version import __version__ as VLLM_VERSION
...
@@ -43,6 +42,7 @@ from vllm.version import __version__ as VLLM_VERSION
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
POLLING_TIMEOUT_S
=
2.5
POLLING_TIMEOUT_S
=
2.5
HANDSHAKE_TIMEOUT_MINS
=
5
_R
=
TypeVar
(
'_R'
)
# Return type for collective_rpc
_R
=
TypeVar
(
'_R'
)
# Return type for collective_rpc
...
@@ -348,9 +348,9 @@ class EngineCoreProc(EngineCore):
...
@@ -348,9 +348,9 @@ class EngineCoreProc(EngineCore):
def
__init__
(
def
__init__
(
self
,
self
,
input_path
:
str
,
output_path
:
str
,
vllm_config
:
VllmConfig
,
vllm_config
:
VllmConfig
,
on_head_node
:
bool
,
input_address
:
str
,
executor_class
:
type
[
Executor
],
executor_class
:
type
[
Executor
],
log_stats
:
bool
,
log_stats
:
bool
,
engine_index
:
int
=
0
,
engine_index
:
int
=
0
,
...
@@ -360,6 +360,26 @@ class EngineCoreProc(EngineCore):
...
@@ -360,6 +360,26 @@ class EngineCoreProc(EngineCore):
executor_fail_callback
=
lambda
:
input_queue
.
put_nowait
(
executor_fail_callback
=
lambda
:
input_queue
.
put_nowait
(
(
EngineCoreRequestType
.
EXECUTOR_FAILED
,
b
''
))
(
EngineCoreRequestType
.
EXECUTOR_FAILED
,
b
''
))
# Create input socket.
input_ctx
=
zmq
.
Context
()
identity
=
engine_index
.
to_bytes
(
length
=
2
,
byteorder
=
"little"
)
input_socket
=
make_zmq_socket
(
input_ctx
,
input_address
,
zmq
.
DEALER
,
identity
=
identity
,
bind
=
False
)
try
:
# Register engine with front-end.
output_address
=
self
.
startup_handshake
(
input_socket
,
on_head_node
,
vllm_config
.
parallel_config
)
# Update config which may have changed from the handshake.
vllm_config
.
__post_init__
()
# Set up data parallel environment.
self
.
_init_data_parallel
(
vllm_config
)
# Initialize engine core and model.
super
().
__init__
(
vllm_config
,
executor_class
,
log_stats
,
super
().
__init__
(
vllm_config
,
executor_class
,
log_stats
,
executor_fail_callback
)
executor_fail_callback
)
...
@@ -367,6 +387,15 @@ class EngineCoreProc(EngineCore):
...
@@ -367,6 +387,15 @@ class EngineCoreProc(EngineCore):
self
.
step_with_batch_queue
)
self
.
step_with_batch_queue
)
self
.
engines_running
=
False
self
.
engines_running
=
False
# Send ready message.
num_gpu_blocks
=
vllm_config
.
cache_config
.
num_gpu_blocks
input_socket
.
send
(
msgspec
.
msgpack
.
encode
({
"status"
:
"READY"
,
"local"
:
on_head_node
,
"num_gpu_blocks"
:
num_gpu_blocks
,
}))
# Background Threads and Queues for IO. These enable us to
# Background Threads and Queues for IO. These enable us to
# overlap ZMQ socket IO with GPU since they release the GIL,
# overlap ZMQ socket IO with GPU since they release the GIL,
# and to overlap some serialization/deserialization with the
# and to overlap some serialization/deserialization with the
...
@@ -375,13 +404,47 @@ class EngineCoreProc(EngineCore):
...
@@ -375,13 +404,47 @@ class EngineCoreProc(EngineCore):
self
.
input_queue
=
input_queue
self
.
input_queue
=
input_queue
self
.
output_queue
=
queue
.
Queue
[
Union
[
EngineCoreOutputs
,
bytes
]]()
self
.
output_queue
=
queue
.
Queue
[
Union
[
EngineCoreOutputs
,
bytes
]]()
threading
.
Thread
(
target
=
self
.
process_input_socket
,
threading
.
Thread
(
target
=
self
.
process_input_socket
,
args
=
(
input_
path
,
engine_index
),
args
=
(
input_
socket
,
),
daemon
=
True
).
start
()
daemon
=
True
).
start
()
input_socket
=
None
self
.
output_thread
=
threading
.
Thread
(
self
.
output_thread
=
threading
.
Thread
(
target
=
self
.
process_output_socket
,
target
=
self
.
process_output_socket
,
args
=
(
output_
path
,
engine_index
),
args
=
(
output_
address
,
engine_index
),
daemon
=
True
)
daemon
=
True
)
self
.
output_thread
.
start
()
self
.
output_thread
.
start
()
finally
:
if
input_socket
is
not
None
:
input_socket
.
close
(
linger
=
0
)
@
staticmethod
def
startup_handshake
(
input_socket
:
zmq
.
Socket
,
on_head_node
:
bool
,
parallel_config
:
ParallelConfig
)
->
str
:
# Send registration message.
input_socket
.
send
(
msgspec
.
msgpack
.
encode
({
"status"
:
"HELLO"
,
"local"
:
on_head_node
,
}))
# Receive initialization message.
logger
.
info
(
"Waiting for init message from front-end."
)
if
not
input_socket
.
poll
(
timeout
=
HANDSHAKE_TIMEOUT_MINS
*
60
*
1000
):
raise
RuntimeError
(
"Did not receive response from front-end "
f
"process within
{
HANDSHAKE_TIMEOUT_MINS
}
"
f
"minutes"
)
init_bytes
=
input_socket
.
recv
()
init_message
=
msgspec
.
msgpack
.
decode
(
init_bytes
)
logger
.
debug
(
"Received init message: %s"
,
init_message
)
output_socket_address
=
init_message
[
"output_socket_address"
]
#TBD(nick) maybe replace IP with configured head node address
received_parallel_config
=
init_message
[
"parallel_config"
]
for
key
,
value
in
received_parallel_config
.
items
():
setattr
(
parallel_config
,
key
,
value
)
return
output_socket_address
@
staticmethod
@
staticmethod
def
run_engine_core
(
*
args
,
def
run_engine_core
(
*
args
,
...
@@ -412,7 +475,7 @@ class EngineCoreProc(EngineCore):
...
@@ -412,7 +475,7 @@ class EngineCoreProc(EngineCore):
try
:
try
:
parallel_config
:
ParallelConfig
=
kwargs
[
parallel_config
:
ParallelConfig
=
kwargs
[
"vllm_config"
].
parallel_config
"vllm_config"
].
parallel_config
if
parallel_config
.
data_parallel_size
>
1
:
if
parallel_config
.
data_parallel_size
>
1
or
dp_rank
>
0
:
# Set data parallel rank for this engine process.
# Set data parallel rank for this engine process.
parallel_config
.
data_parallel_rank
=
dp_rank
parallel_config
.
data_parallel_rank
=
dp_rank
parallel_config
.
data_parallel_rank_local
=
local_dp_rank
parallel_config
.
data_parallel_rank_local
=
local_dp_rank
...
@@ -436,6 +499,9 @@ class EngineCoreProc(EngineCore):
...
@@ -436,6 +499,9 @@ class EngineCoreProc(EngineCore):
if
engine_core
is
not
None
:
if
engine_core
is
not
None
:
engine_core
.
shutdown
()
engine_core
.
shutdown
()
def
_init_data_parallel
(
self
,
vllm_config
:
VllmConfig
):
pass
def
run_busy_loop
(
self
):
def
run_busy_loop
(
self
):
"""Core busy loop of the EngineCore."""
"""Core busy loop of the EngineCore."""
...
@@ -527,36 +593,21 @@ class EngineCoreProc(EngineCore):
...
@@ -527,36 +593,21 @@ class EngineCoreProc(EngineCore):
logger
.
fatal
(
"vLLM shutdown signal from EngineCore failed "
logger
.
fatal
(
"vLLM shutdown signal from EngineCore failed "
"to send. Please report this issue."
)
"to send. Please report this issue."
)
def
process_input_socket
(
self
,
input_
path
:
str
,
engine_index
:
in
t
):
def
process_input_socket
(
self
,
input_
socket
:
zmq
.
Socke
t
):
"""Input socket IO thread."""
"""Input socket IO thread."""
# Msgpack serialization decoding.
# Msgpack serialization decoding.
add_request_decoder
=
MsgpackDecoder
(
EngineCoreRequest
)
add_request_decoder
=
MsgpackDecoder
(
EngineCoreRequest
)
generic_decoder
=
MsgpackDecoder
()
generic_decoder
=
MsgpackDecoder
()
identity
=
engine_index
.
to_bytes
(
length
=
2
,
byteorder
=
"little"
)
with
zmq_socket_ctx
(
input_path
,
zmq
.
DEALER
,
identity
=
identity
,
bind
=
False
)
as
socket
:
# Send ready message to front-end once input socket is connected.
message_dict
=
{
'type'
:
'READY'
,
'num_gpu_blocks'
:
self
.
vllm_config
.
cache_config
.
num_gpu_blocks
,
}
message
=
json
.
dumps
(
message_dict
).
encode
(
'utf-8'
)
socket
.
send
(
message
)
while
True
:
while
True
:
# (RequestType, RequestData)
# (RequestType, RequestData)
type_frame
,
*
data_frames
=
socket
.
recv_multipart
(
copy
=
False
)
type_frame
,
*
data_frames
=
input_
socket
.
recv_multipart
(
copy
=
False
)
request_type
=
EngineCoreRequestType
(
bytes
(
type_frame
.
buffer
))
request_type
=
EngineCoreRequestType
(
bytes
(
type_frame
.
buffer
))
# Deserialize the request data.
# Deserialize the request data.
decoder
=
add_request_decoder
if
(
decoder
=
add_request_decoder
if
(
request_type
request_type
==
EngineCoreRequestType
.
ADD
)
else
generic_decoder
==
EngineCoreRequestType
.
ADD
)
else
generic_decoder
request
=
decoder
.
decode
(
data_frames
)
request
=
decoder
.
decode
(
data_frames
)
# Push to input queue for core busy loop.
# Push to input queue for core busy loop.
...
@@ -609,9 +660,9 @@ class DPEngineCoreProc(EngineCoreProc):
...
@@ -609,9 +660,9 @@ class DPEngineCoreProc(EngineCoreProc):
def
__init__
(
def
__init__
(
self
,
self
,
input_path
:
str
,
output_path
:
str
,
vllm_config
:
VllmConfig
,
vllm_config
:
VllmConfig
,
on_head_node
:
bool
,
input_address
:
str
,
executor_class
:
type
[
Executor
],
executor_class
:
type
[
Executor
],
log_stats
:
bool
,
log_stats
:
bool
,
):
):
...
@@ -623,8 +674,20 @@ class DPEngineCoreProc(EngineCoreProc):
...
@@ -623,8 +674,20 @@ class DPEngineCoreProc(EngineCoreProc):
_add_prefix
(
sys
.
stdout
,
process_name
,
pid
)
_add_prefix
(
sys
.
stdout
,
process_name
,
pid
)
_add_prefix
(
sys
.
stderr
,
process_name
,
pid
)
_add_prefix
(
sys
.
stderr
,
process_name
,
pid
)
dp_size
=
vllm_config
.
parallel_config
.
data_parallel_size
# Counts forward-passes of the model so that we can synchronize
# finished with DP peers every N steps.
self
.
counter
=
0
# Initialize the engine.
dp_rank
=
vllm_config
.
parallel_config
.
data_parallel_rank
dp_rank
=
vllm_config
.
parallel_config
.
data_parallel_rank
super
().
__init__
(
vllm_config
,
on_head_node
,
input_address
,
executor_class
,
log_stats
,
dp_rank
)
def
_init_data_parallel
(
self
,
vllm_config
:
VllmConfig
):
# Configure GPUs and stateless process group for data parallel.
dp_rank
=
vllm_config
.
parallel_config
.
data_parallel_rank
dp_size
=
vllm_config
.
parallel_config
.
data_parallel_size
local_dp_rank
=
vllm_config
.
parallel_config
.
data_parallel_rank_local
local_dp_rank
=
vllm_config
.
parallel_config
.
data_parallel_rank_local
assert
dp_size
>
1
assert
dp_size
>
1
...
@@ -632,24 +695,16 @@ class DPEngineCoreProc(EngineCoreProc):
...
@@ -632,24 +695,16 @@ class DPEngineCoreProc(EngineCoreProc):
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
device_control_env_var
=
current_platform
.
device_control_env_var
device_control_env_var
=
current_platform
.
device_control_env_var
tp
_size
=
vllm_config
.
parallel_config
.
tensor_parallel
_size
world
_size
=
vllm_config
.
parallel_config
.
world
_size
os
.
environ
[
device_control_env_var
]
=
","
.
join
(
os
.
environ
[
device_control_env_var
]
=
","
.
join
(
str
(
current_platform
.
device_id_to_physical_device_id
(
i
))
str
(
current_platform
.
device_id_to_physical_device_id
(
i
))
for
i
in
range
(
local_dp_rank
*
tp
_size
,
(
local_dp_rank
+
1
)
*
for
i
in
range
(
local_dp_rank
*
world
_size
,
(
local_dp_rank
+
1
)
*
tp
_size
))
world
_size
))
self
.
local_dp_rank
=
local_dp_rank
self
.
local_dp_rank
=
local_dp_rank
self
.
dp_group
=
vllm_config
.
parallel_config
.
stateless_init_dp_group
()
self
.
dp_group
=
vllm_config
.
parallel_config
.
stateless_init_dp_group
()
self
.
current_wave
=
0
self
.
current_wave
=
0
# Initialize the engine after setting up environment.
super
().
__init__
(
input_path
,
output_path
,
vllm_config
,
executor_class
,
log_stats
,
dp_rank
)
# Counts forward-passes of the model so that we can synchronize
# finished with DP peers every N steps.
self
.
counter
=
0
def
shutdown
(
self
):
def
shutdown
(
self
):
super
().
shutdown
()
super
().
shutdown
()
if
dp_group
:
=
getattr
(
self
,
"dp_group"
,
None
):
if
dp_group
:
=
getattr
(
self
,
"dp_group"
,
None
):
...
...
vllm/v1/engine/core_client.py
View file @
55aa7af9
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
import
asyncio
import
asyncio
import
contextlib
import
contextlib
import
json
import
queue
import
queue
import
uuid
import
uuid
import
weakref
import
weakref
...
@@ -9,25 +8,27 @@ from abc import ABC, abstractmethod
...
@@ -9,25 +8,27 @@ from abc import ABC, abstractmethod
from
collections
import
deque
from
collections
import
deque
from
collections.abc
import
Awaitable
,
Sequence
from
collections.abc
import
Awaitable
,
Sequence
from
concurrent.futures
import
Future
from
concurrent.futures
import
Future
from
dataclasses
import
dataclass
,
field
from
dataclasses
import
dataclass
from
enum
import
Enum
,
auto
from
threading
import
Thread
from
threading
import
Thread
from
typing
import
Any
,
Callable
,
Optional
,
TypeVar
,
Union
from
typing
import
Any
,
Callable
,
Optional
,
TypeVar
,
Union
import
msgspec
import
zmq
import
zmq
import
zmq.asyncio
import
zmq.asyncio
from
vllm.config
import
VllmConfig
from
vllm.config
import
ParallelConfig
,
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.utils
import
(
get_open_
zmq_inproc_path
,
get_open_zmq_i
p
c_path
,
from
vllm.utils
import
(
get_open_
port
,
get_open_zmq_i
npro
c_path
,
make_zmq_socket
)
get_open_zmq_ipc_path
,
get_tcp_uri
,
make_zmq_socket
)
from
vllm.v1.engine
import
(
EngineCoreOutputs
,
EngineCoreRequest
,
from
vllm.v1.engine
import
(
EngineCoreOutputs
,
EngineCoreRequest
,
EngineCoreRequestType
,
UtilityOutput
)
EngineCoreRequestType
,
UtilityOutput
)
from
vllm.v1.engine.core
import
EngineCore
,
EngineCoreProc
from
vllm.v1.engine.core
import
EngineCore
,
EngineCoreProc
from
vllm.v1.engine.exceptions
import
EngineDeadError
from
vllm.v1.engine.exceptions
import
EngineDeadError
from
vllm.v1.executor.abstract
import
Executor
from
vllm.v1.executor.abstract
import
Executor
from
vllm.v1.serial_utils
import
MsgpackDecoder
,
MsgpackEncoder
,
bytestr
from
vllm.v1.serial_utils
import
MsgpackDecoder
,
MsgpackEncoder
,
bytestr
from
vllm.v1.utils
import
Background
Proc
H
an
dle
from
vllm.v1.utils
import
CoreEngine
Proc
M
an
ager
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -264,45 +265,22 @@ class InprocClient(EngineCoreClient):
...
@@ -264,45 +265,22 @@ class InprocClient(EngineCoreClient):
return
self
.
engine_core
.
collective_rpc
(
method
,
timeout
,
args
,
kwargs
)
return
self
.
engine_core
.
collective_rpc
(
method
,
timeout
,
args
,
kwargs
)
class
CoreEngineState
(
Enum
):
NEW
=
auto
()
CONNECTED
=
auto
()
READY
=
auto
()
class
CoreEngine
:
class
CoreEngine
:
"""One per data parallel rank."""
"""One per data parallel rank."""
def
__init__
(
def
__init__
(
self
,
index
:
int
=
0
,
local
:
bool
=
True
):
self
,
self
.
local
=
local
vllm_config
:
VllmConfig
,
executor_class
:
type
[
Executor
],
log_stats
:
bool
,
input_path
:
str
,
output_path
:
str
,
index
:
int
=
0
,
local_dp_rank
:
int
=
0
,
):
self
.
index
=
index
self
.
index
=
index
self
.
identity
=
index
.
to_bytes
(
length
=
2
,
byteorder
=
"little"
)
self
.
identity
=
index
.
to_bytes
(
length
=
2
,
byteorder
=
"little"
)
try
:
# Start EngineCore in background process.
self
.
proc_handle
=
BackgroundProcHandle
(
input_path
=
input_path
,
output_path
=
output_path
,
process_name
=
f
"EngineCore_
{
index
}
"
,
target_fn
=
EngineCoreProc
.
run_engine_core
,
process_kwargs
=
{
"vllm_config"
:
vllm_config
,
"dp_rank"
:
index
,
"local_dp_rank"
:
local_dp_rank
,
"executor_class"
:
executor_class
,
"log_stats"
:
log_stats
,
})
self
.
state
=
CoreEngineState
.
NEW
self
.
num_reqs_in_flight
=
0
self
.
num_reqs_in_flight
=
0
finally
:
if
not
hasattr
(
self
,
"num_reqs_in_flight"
):
# Ensure socket is closed if process fails to start.
self
.
close
()
def
close
(
self
):
if
proc_handle
:
=
getattr
(
self
,
"proc_handle"
,
None
):
proc_handle
.
shutdown
()
@
dataclass
@
dataclass
...
@@ -311,7 +289,7 @@ class BackgroundResources:
...
@@ -311,7 +289,7 @@ class BackgroundResources:
circular reference back to the client object."""
circular reference back to the client object."""
ctx
:
Union
[
zmq
.
Context
]
ctx
:
Union
[
zmq
.
Context
]
core
_engine
s
:
list
[
CoreEngine
]
=
field
(
default_factory
=
list
)
local
_engine
_manager
:
Optional
[
CoreEngineProcManager
]
=
None
output_socket
:
Optional
[
Union
[
zmq
.
Socket
,
zmq
.
asyncio
.
Socket
]]
=
None
output_socket
:
Optional
[
Union
[
zmq
.
Socket
,
zmq
.
asyncio
.
Socket
]]
=
None
input_socket
:
Optional
[
Union
[
zmq
.
Socket
,
zmq
.
asyncio
.
Socket
]]
=
None
input_socket
:
Optional
[
Union
[
zmq
.
Socket
,
zmq
.
asyncio
.
Socket
]]
=
None
output_queue_task
:
Optional
[
asyncio
.
Task
]
=
None
output_queue_task
:
Optional
[
asyncio
.
Task
]
=
None
...
@@ -325,8 +303,8 @@ class BackgroundResources:
...
@@ -325,8 +303,8 @@ class BackgroundResources:
"""Clean up background resources."""
"""Clean up background resources."""
self
.
engine_dead
=
True
self
.
engine_dead
=
True
for
core_engine
in
self
.
core_engi
ne
s
:
if
self
.
local_engine_manager
is
not
No
ne
:
core_engine
.
close
()
self
.
local_engine_manager
.
close
()
if
self
.
output_queue_task
is
not
None
:
if
self
.
output_queue_task
is
not
None
:
self
.
output_queue_task
.
cancel
()
self
.
output_queue_task
.
cancel
()
...
@@ -388,25 +366,56 @@ class MPClient(EngineCoreClient):
...
@@ -388,25 +366,56 @@ class MPClient(EngineCoreClient):
self
.
_finalizer
=
weakref
.
finalize
(
self
,
self
.
resources
)
self
.
_finalizer
=
weakref
.
finalize
(
self
,
self
.
resources
)
success
=
False
success
=
False
try
:
try
:
# Paths and sockets for IPC.
parallel_config
=
vllm_config
.
parallel_config
self
.
output_path
=
get_open_zmq_ipc_path
()
local_engine_count
=
parallel_config
.
data_parallel_size_local
input_path
=
get_open_zmq_ipc_path
()
start_index
=
parallel_config
.
data_parallel_rank
self
.
input_socket
=
make_zmq_socket
(
self
.
ctx
,
local_start_index
=
parallel_config
.
data_parallel_rank_local
input_path
,
zmq
.
ROUTER
,
# SPMD mode is where there is an LLM instance per DP rank and
bind
=
True
)
# one core engine per LLM, see
self
.
resources
.
input_socket
=
self
.
input_socket
# examples/offline_inference/data_parallel.py.
spmd_mode
=
local_start_index
is
not
None
new_core_engine
=
lambda
index
,
local_dp_rank
=
None
:
CoreEngine
(
if
spmd_mode
:
vllm_config
,
executor_class
,
log_stats
,
input_path
,
self
.
assert
local_engine_count
==
1
output_path
,
index
,
local_dp_rank
)
self
.
core_engines
=
[
CoreEngine
(
index
=
local_start_index
,
local
=
True
)
# Start engine core process(es).
]
self
.
_init_core_engines
(
vllm_config
,
new_core_engine
,
else
:
self
.
resources
.
core_engines
)
assert
start_index
==
0
local_start_index
=
0
self
.
core_engines
=
[
CoreEngine
(
index
=
i
,
local
=
(
i
<
local_engine_count
))
for
i
in
range
(
parallel_config
.
data_parallel_size
)
]
input_address
,
output_address
=
self
.
_get_zmq_addresses
(
parallel_config
,
spmd_mode
)
# Create input and output sockets.
self
.
input_socket
=
self
.
resources
.
input_socket
=
make_zmq_socket
(
self
.
ctx
,
input_address
,
zmq
.
ROUTER
,
bind
=
True
)
self
.
resources
.
output_socket
=
make_zmq_socket
(
self
.
ctx
,
output_address
,
zmq
.
constants
.
PULL
)
# Start local engines.
if
local_engine_count
:
# In server mode, start_index and local_start_index will
# both be 0.
self
.
resources
.
local_engine_manager
=
CoreEngineProcManager
(
EngineCoreProc
.
run_engine_core
,
vllm_config
=
vllm_config
,
executor_class
=
executor_class
,
log_stats
=
log_stats
,
input_address
=
input_address
,
on_head_node
=
True
,
local_engine_count
=
local_engine_count
,
start_index
=
start_index
,
local_start_index
=
local_start_index
)
self
.
core_engine
=
self
.
core_engines
[
0
]
# Wait for engine core process(es) to start.
# Wait for engine core process(es) to start.
self
.
_wait_for_engine_startup
()
self
.
_wait_for_engine_startup
(
output_address
,
parallel_config
)
self
.
utility_results
:
dict
[
int
,
AnyFuture
]
=
{}
self
.
utility_results
:
dict
[
int
,
AnyFuture
]
=
{}
...
@@ -420,56 +429,116 @@ class MPClient(EngineCoreClient):
...
@@ -420,56 +429,116 @@ class MPClient(EngineCoreClient):
if
not
success
:
if
not
success
:
self
.
_finalizer
()
self
.
_finalizer
()
def
_wait_for_engine_startup
(
self
):
@
staticmethod
def
_get_zmq_addresses
(
parallel_config
:
ParallelConfig
,
spmd_mode
:
bool
)
->
tuple
[
str
,
str
]:
"""Returns (input_address, output_address)."""
dp_size
=
parallel_config
.
data_parallel_size
local_engine_count
=
parallel_config
.
data_parallel_size_local
if
local_engine_count
==
dp_size
or
spmd_mode
:
input_address
=
get_open_zmq_ipc_path
()
output_address
=
get_open_zmq_ipc_path
()
else
:
host
=
parallel_config
.
data_parallel_master_ip
input_port
=
parallel_config
.
data_parallel_rpc_port
output_port
=
get_open_port
()
input_address
=
get_tcp_uri
(
host
,
input_port
)
output_address
=
get_tcp_uri
(
host
,
output_port
)
return
input_address
,
output_address
def
_wait_for_engine_startup
(
self
,
output_address
:
str
,
parallel_config
:
ParallelConfig
):
# Get a sync handle to the socket which can be sync or async.
# Get a sync handle to the socket which can be sync or async.
sync_input_socket
=
zmq
.
Socket
.
shadow
(
self
.
input_socket
)
sync_input_socket
=
zmq
.
Socket
.
shadow
(
self
.
input_socket
)
# Wait for engine core process(es) to send ready messages.
# Wait for engine core process(es) to send ready messages.
identities
=
set
(
eng
.
index
for
eng
in
self
.
resources
.
core_engines
)
local_count
=
parallel_config
.
data_parallel_size_local
remote_count
=
len
(
self
.
core_engines
)
-
local_count
# [local, remote] counts
conn_pending
,
start_pending
=
[
local_count
,
remote_count
],
[
0
,
0
]
poller
=
zmq
.
Poller
()
poller
=
zmq
.
Poller
()
poller
.
register
(
sync_input_socket
,
zmq
.
POLLIN
)
poller
.
register
(
sync_input_socket
,
zmq
.
POLLIN
)
for
eng
in
self
.
resources
.
core_engines
:
proc_manager
=
self
.
resources
.
local_engine_manager
poller
.
register
(
eng
.
proc_handle
,
zmq
.
POLLIN
)
if
proc_manager
is
not
None
:
while
identities
:
for
sentinel
in
proc_manager
.
sentinels
():
poller
.
register
(
sentinel
,
zmq
.
POLLIN
)
while
any
(
conn_pending
)
or
any
(
start_pending
):
events
=
poller
.
poll
(
STARTUP_POLL_PERIOD_MS
)
events
=
poller
.
poll
(
STARTUP_POLL_PERIOD_MS
)
if
not
events
:
if
not
events
:
logger
.
debug
(
"Waiting for %d core engine proc(s) to start: %s"
,
if
any
(
conn_pending
):
len
(
identities
),
identities
)
logger
.
debug
(
"Waiting for %d local, %d remote core engine proc(s) "
"to connect."
,
*
conn_pending
)
if
any
(
start_pending
):
logger
.
debug
(
"Waiting for %d local, %d remote core engine proc(s) "
"to start."
,
*
start_pending
)
continue
continue
if
len
(
events
)
>
1
or
events
[
0
][
0
]
!=
sync_input_socket
:
if
len
(
events
)
>
1
or
events
[
0
][
0
]
!=
sync_input_socket
:
# One of the core processes exited.
# One of the local core processes exited.
finished
=
proc_manager
.
finished_procs
(
)
if
proc_manager
else
{}
raise
RuntimeError
(
"Engine core initialization failed. "
raise
RuntimeError
(
"Engine core initialization failed. "
"See root cause above."
)
"See root cause above. "
f
"Failed core proc(s):
{
finished
}
"
)
eng_id_bytes
,
data
=
sync_input_socket
.
recv_multipart
()
eng_id
=
int
.
from_bytes
(
eng_id_bytes
,
byteorder
=
"little"
)
# Receive HELLO and READY messages from the input socket.
if
eng_id
not
in
identities
:
eng_identity
,
ready_msg_bytes
=
sync_input_socket
.
recv_multipart
()
raise
RuntimeError
(
f
"Unexpected or duplicate engine:
{
eng_id
}
"
)
eng_index
=
int
.
from_bytes
(
eng_identity
,
byteorder
=
"little"
)
message_dict
=
json
.
loads
(
data
.
decode
(
'utf-8'
))
engine
=
next
(
if
message_dict
[
'type'
]
!=
'READY'
:
(
e
for
e
in
self
.
core_engines
if
e
.
identity
==
eng_identity
),
raise
RuntimeError
(
f
"Engine
{
eng_id
}
failed:
{
data
.
decode
()
}
"
)
None
)
logger
.
info
(
"Core engine process %d ready."
,
eng_id
)
if
engine
is
None
:
identities
.
discard
(
eng_id
)
raise
RuntimeError
(
f
"Message from engine with unexpected data "
f
"parallel rank:
{
eng_index
}
"
)
msg
=
msgspec
.
msgpack
.
decode
(
ready_msg_bytes
)
status
,
local
=
msg
[
"status"
],
msg
[
"local"
]
if
local
!=
engine
.
local
:
raise
RuntimeError
(
f
"
{
status
}
message from "
f
"
{
'local'
if
local
else
'remote'
}
"
f
"engine
{
eng_index
}
, expected it to be "
f
"
{
'local'
if
engine
.
local
else
'remote'
}
"
)
if
status
==
"HELLO"
and
engine
.
state
==
CoreEngineState
.
NEW
:
# Send init message with DP config info.
init_message
=
self
.
encoder
.
encode
({
"output_socket_address"
:
output_address
,
"parallel_config"
:
{
"data_parallel_master_ip"
:
parallel_config
.
data_parallel_master_ip
,
"data_parallel_master_port"
:
parallel_config
.
data_parallel_master_port
,
"data_parallel_size"
:
parallel_config
.
data_parallel_size
,
},
})
sync_input_socket
.
send_multipart
((
eng_identity
,
*
init_message
),
copy
=
False
)
conn_pending
[
0
if
local
else
1
]
-=
1
start_pending
[
0
if
local
else
1
]
+=
1
engine
.
state
=
CoreEngineState
.
CONNECTED
elif
status
==
"READY"
and
(
engine
.
state
==
CoreEngineState
.
CONNECTED
):
# Setup KV cache config with initialization state from
# Setup KV cache config with initialization state from
# engine core process. Sum values from all engines in DP case.
# engine core process. Sum values from all engines in DP case.
num_gpu_blocks
=
self
.
vllm_config
.
cache_config
.
num_gpu_blocks
or
0
cache_config
=
self
.
vllm_config
.
cache_config
num_gpu_blocks
+=
message_dict
[
'num_gpu_blocks'
]
num_gpu_blocks
=
cache_config
.
num_gpu_blocks
or
0
self
.
vllm_config
.
cache_config
.
num_gpu_blocks
=
num_gpu_blocks
num_gpu_blocks
+=
msg
[
'num_gpu_blocks'
]
cache_config
.
num_gpu_blocks
=
num_gpu_blocks
def
_init_core_engines
(
start_pending
[
0
if
local
else
1
]
-=
1
self
,
engine
.
state
=
CoreEngineState
.
READY
vllm_config
:
VllmConfig
,
else
:
new_core_engine
:
Callable
[[
int
,
Optional
[
int
]],
CoreEngine
],
raise
RuntimeError
(
f
"Unexpected
{
status
}
message for "
core_engines
:
list
[
CoreEngine
],
f
"
{
'local'
if
local
else
'remote'
}
engine "
)
->
None
:
f
"
{
eng_index
}
in
{
engine
.
state
}
state."
)
# Default case - single core engine.
logger
.
debug
(
"%s from %s core engine process %s."
,
status
,
core_engine
=
new_core_engine
(
"local"
if
local
else
"remote"
,
eng_index
)
vllm_config
.
parallel_config
.
data_parallel_rank
,
vllm_config
.
parallel_config
.
data_parallel_rank_local
,
)
core_engines
.
append
(
core_engine
)
self
.
core_engine
=
core_engine
def
shutdown
(
self
):
def
shutdown
(
self
):
# Terminate background resources.
# Terminate background resources.
...
@@ -520,7 +589,8 @@ class SyncMPClient(MPClient):
...
@@ -520,7 +589,8 @@ class SyncMPClient(MPClient):
# Ensure that the outputs socket processing thread does not have
# Ensure that the outputs socket processing thread does not have
# a ref to the client which prevents gc.
# a ref to the client which prevents gc.
ctx
=
self
.
ctx
ctx
=
self
.
ctx
output_path
=
self
.
output_path
out_socket
=
self
.
resources
.
output_socket
assert
out_socket
is
not
None
decoder
=
self
.
decoder
decoder
=
self
.
decoder
utility_results
=
self
.
utility_results
utility_results
=
self
.
utility_results
outputs_queue
=
self
.
outputs_queue
outputs_queue
=
self
.
outputs_queue
...
@@ -531,7 +601,6 @@ class SyncMPClient(MPClient):
...
@@ -531,7 +601,6 @@ class SyncMPClient(MPClient):
def
process_outputs_socket
():
def
process_outputs_socket
():
shutdown_socket
=
ctx
.
socket
(
zmq
.
PAIR
)
shutdown_socket
=
ctx
.
socket
(
zmq
.
PAIR
)
out_socket
=
make_zmq_socket
(
ctx
,
output_path
,
zmq
.
constants
.
PULL
)
try
:
try
:
shutdown_socket
.
bind
(
shutdown_path
)
shutdown_socket
.
bind
(
shutdown_path
)
poller
=
zmq
.
Poller
()
poller
=
zmq
.
Poller
()
...
@@ -566,6 +635,9 @@ class SyncMPClient(MPClient):
...
@@ -566,6 +635,9 @@ class SyncMPClient(MPClient):
daemon
=
True
)
daemon
=
True
)
self
.
output_queue_thread
.
start
()
self
.
output_queue_thread
.
start
()
# The thread takes on responsibility for closing the socket.
self
.
resources
.
output_socket
=
None
def
get_output
(
self
)
->
EngineCoreOutputs
:
def
get_output
(
self
)
->
EngineCoreOutputs
:
# If an exception arises in process_outputs_socket task,
# If an exception arises in process_outputs_socket task,
# it is forwarded to the outputs_queue so we can raise it
# it is forwarded to the outputs_queue so we can raise it
...
@@ -693,10 +765,8 @@ class AsyncMPClient(MPClient):
...
@@ -693,10 +765,8 @@ class AsyncMPClient(MPClient):
self
.
__class__
,
self
.
__class__
,
"process_engine_outputs"
,
None
)
"process_engine_outputs"
,
None
)
_self_ref
=
weakref
.
ref
(
self
)
if
output_handler
else
None
_self_ref
=
weakref
.
ref
(
self
)
if
output_handler
else
None
output_path
=
self
.
output_path
output_socket
=
resources
.
output_socket
output_socket
=
make_zmq_socket
(
self
.
ctx
,
output_path
,
assert
output_socket
is
not
None
zmq
.
constants
.
PULL
)
resources
.
output_socket
=
output_socket
async
def
process_outputs_socket
():
async
def
process_outputs_socket
():
try
:
try
:
...
@@ -861,21 +931,6 @@ class DPAsyncMPClient(AsyncMPClient):
...
@@ -861,21 +931,6 @@ class DPAsyncMPClient(AsyncMPClient):
assert
len
(
self
.
core_engines
)
>
1
assert
len
(
self
.
core_engines
)
>
1
def
_init_core_engines
(
self
,
vllm_config
:
VllmConfig
,
new_core_engine
:
Callable
[[
int
,
Optional
[
int
]],
CoreEngine
],
core_engines
:
list
[
CoreEngine
],
)
->
None
:
# Launch a core engine for each data parallel rank.
dp_size
=
vllm_config
.
parallel_config
.
data_parallel_size
for
i
in
range
(
dp_size
):
# Multi-node not yet supported so local_dp_rank == dp_rank.
core_engines
.
append
(
new_core_engine
(
i
,
i
))
self
.
core_engines
=
core_engines
async
def
call_utility_async
(
self
,
method
:
str
,
*
args
)
->
Any
:
async
def
call_utility_async
(
self
,
method
:
str
,
*
args
)
->
Any
:
# Only the result from the first engine is returned.
# Only the result from the first engine is returned.
return
(
await
asyncio
.
gather
(
*
[
return
(
await
asyncio
.
gather
(
*
[
...
...
vllm/v1/utils.py
View file @
55aa7af9
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
import
os
import
os
import
time
import
weakref
import
weakref
from
collections
import
defaultdict
from
collections
import
defaultdict
from
collections.abc
import
Sequence
from
collections.abc
import
Sequence
from
multiprocessing
import
Process
from
multiprocessing
import
Process
,
connection
from
typing
import
(
TYPE_CHECKING
,
Any
,
Callable
,
Generic
,
Optional
,
TypeVar
,
from
typing
import
(
TYPE_CHECKING
,
Callable
,
Generic
,
Optional
,
TypeVar
,
Union
,
Union
,
overload
)
overload
)
import
torch
import
torch
from
vllm.config
import
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.models.utils
import
extract_layer_index
from
vllm.model_executor.models.utils
import
extract_layer_index
from
vllm.usage.usage_lib
import
(
UsageContext
,
is_usage_stats_enabled
,
from
vllm.usage.usage_lib
import
(
UsageContext
,
is_usage_stats_enabled
,
usage_message
)
usage_message
)
from
vllm.utils
import
get_mp_context
,
kill_process_tree
from
vllm.utils
import
get_mp_context
,
kill_process_tree
from
vllm.v1.executor.abstract
import
Executor
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
vllm.attention.layer
import
Attention
from
vllm.attention.layer
import
Attention
...
@@ -92,7 +95,7 @@ class ConstantList(Generic[T], Sequence):
...
@@ -92,7 +95,7 @@ class ConstantList(Generic[T], Sequence):
return
f
"ConstantList(
{
self
.
_x
}
)"
return
f
"ConstantList(
{
self
.
_x
}
)"
class
Background
Proc
H
an
dle
:
class
CoreEngine
Proc
M
an
ager
:
"""
"""
Utility class to handle creation, readiness, and shutdown
Utility class to handle creation, readiness, and shutdown
of background processes used by the AsyncLLM and LLMEngine.
of background processes used by the AsyncLLM and LLMEngine.
...
@@ -100,49 +103,91 @@ class BackgroundProcHandle:
...
@@ -100,49 +103,91 @@ class BackgroundProcHandle:
def
__init__
(
def
__init__
(
self
,
self
,
input_path
:
str
,
output_path
:
str
,
process_name
:
str
,
target_fn
:
Callable
,
target_fn
:
Callable
,
process_kwargs
:
dict
[
Any
,
Any
],
local_engine_count
:
int
,
start_index
:
int
,
local_start_index
:
int
,
vllm_config
:
VllmConfig
,
on_head_node
:
bool
,
input_address
:
str
,
executor_class
:
type
[
Executor
],
log_stats
:
bool
,
):
):
context
=
get_mp_context
()
context
=
get_mp_context
()
common_kwargs
=
{
"vllm_config"
:
vllm_config
,
"on_head_node"
:
on_head_node
,
"input_address"
:
input_address
,
"executor_class"
:
executor_class
,
"log_stats"
:
log_stats
,
}
self
.
processes
:
list
[
Process
]
=
[]
for
index
in
range
(
local_engine_count
):
local_index
=
local_start_index
+
index
global_index
=
start_index
+
index
# Start EngineCore in background process.
self
.
processes
.
append
(
context
.
Process
(
target
=
target_fn
,
name
=
f
"EngineCore_
{
global_index
}
"
,
kwargs
=
common_kwargs
|
{
"dp_rank"
:
global_index
,
"local_dp_rank"
:
local_index
,
}))
self
.
_finalizer
=
weakref
.
finalize
(
self
,
shutdown
,
self
.
processes
,
input_address
)
try
:
for
proc
in
self
.
processes
:
proc
.
start
()
finally
:
# Kill other procs if not all are running.
if
self
.
finished_procs
():
self
.
close
()
def
close
(
self
):
"""Shutdown all procs."""
self
.
_finalizer
()
assert
(
"input_path"
not
in
process_kwargs
def
join_first
(
self
):
and
"output_path"
not
in
process_kwargs
)
"""Wait for any process to exit."""
process_kwargs
[
"input_path"
]
=
input_path
connection
.
wait
(
proc
.
sentinel
for
proc
in
self
.
processes
)
process_kwargs
[
"output_path"
]
=
output_path
# Run busy loop in background process.
self
.
proc
:
Process
=
context
.
Process
(
target
=
target_fn
,
kwargs
=
process_kwargs
,
name
=
process_name
)
self
.
_finalizer
=
weakref
.
finalize
(
self
,
shutdown
,
self
.
proc
,
input_path
,
output_path
)
self
.
proc
.
start
()
def
fileno
(
self
)
:
def
sentinels
(
self
)
->
list
:
return
self
.
proc
.
sentinel
return
[
proc
.
sentinel
for
proc
in
self
.
processes
]
def
shutdown
(
self
):
def
finished_procs
(
self
)
->
dict
[
str
,
int
]:
self
.
_finalizer
()
"""Returns dict of proc name -> exit code for any finished procs."""
return
{
proc
.
name
:
proc
.
exitcode
for
proc
in
self
.
processes
if
proc
.
exitcode
is
not
None
}
# Note(rob): shutdown function cannot be a bound method,
# Note(rob): shutdown function cannot be a bound method,
# else the gc cannot collect the object.
# else the gc cannot collect the obje
decoup
ct.
def
shutdown
(
proc
:
Process
,
input_
path
:
str
,
output_path
:
str
):
def
shutdown
(
proc
s
:
list
[
Process
]
,
input_
address
:
str
):
# Shutdown the process.
# Shutdown the process.
for
proc
in
procs
:
if
proc
.
is_alive
():
if
proc
.
is_alive
():
proc
.
terminate
()
proc
.
terminate
()
proc
.
join
(
5
)
# Allow 5 seconds for remaining procs to terminate.
deadline
=
time
.
monotonic
()
+
5
for
proc
in
procs
:
remaining
=
deadline
-
time
.
monotonic
()
if
remaining
<=
0
:
break
if
proc
.
is_alive
():
proc
.
join
(
remaining
)
for
proc
in
procs
:
if
proc
.
is_alive
()
and
(
pid
:
=
proc
.
pid
)
is
not
None
:
if
proc
.
is_alive
()
and
(
pid
:
=
proc
.
pid
)
is
not
None
:
kill_process_tree
(
pid
)
kill_process_tree
(
pid
)
# Remove zmq ipc socket files.
# Remove zmq ipc socket files.
ipc_sockets
=
[
output_path
,
input_path
]
if
input_address
.
startswith
(
"ipc://"
):
for
ipc_socket
in
ipc_sockets
:
socket_file
=
input_address
[
len
(
"ipc://"
):]
socket_file
=
ipc_socket
.
replace
(
"ipc://"
,
""
)
if
os
and
os
.
path
.
exists
(
socket_file
):
if
os
and
os
.
path
.
exists
(
socket_file
):
os
.
remove
(
socket_file
)
os
.
remove
(
socket_file
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment