Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
de1a86b7
Unverified
Commit
de1a86b7
authored
Mar 18, 2026
by
Itay Alroy
Committed by
GitHub
Mar 18, 2026
Browse files
elastic_ep: Fix stateless group port races (#36330)
Signed-off-by:
Itay Alroy
<
ialroy@nvidia.com
>
parent
99267c23
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
224 additions
and
225 deletions
+224
-225
.buildkite/test_areas/expert_parallelism.yaml
.buildkite/test_areas/expert_parallelism.yaml
+1
-2
vllm/config/parallel.py
vllm/config/parallel.py
+31
-85
vllm/distributed/elastic_ep/elastic_execute.py
vllm/distributed/elastic_ep/elastic_execute.py
+2
-4
vllm/distributed/elastic_ep/elastic_state.py
vllm/distributed/elastic_ep/elastic_state.py
+1
-12
vllm/distributed/elastic_ep/standby_state.py
vllm/distributed/elastic_ep/standby_state.py
+15
-9
vllm/distributed/parallel_state.py
vllm/distributed/parallel_state.py
+20
-21
vllm/distributed/stateless_coordinator.py
vllm/distributed/stateless_coordinator.py
+50
-7
vllm/distributed/utils.py
vllm/distributed/utils.py
+60
-18
vllm/v1/engine/__init__.py
vllm/v1/engine/__init__.py
+1
-4
vllm/v1/engine/core.py
vllm/v1/engine/core.py
+1
-0
vllm/v1/engine/core_client.py
vllm/v1/engine/core_client.py
+28
-62
vllm/v1/engine/utils.py
vllm/v1/engine/utils.py
+14
-1
No files found.
.buildkite/test_areas/expert_parallelism.yaml
View file @
de1a86b7
...
@@ -24,8 +24,7 @@ steps:
...
@@ -24,8 +24,7 @@ steps:
-
label
:
Elastic EP Scaling Test
-
label
:
Elastic EP Scaling Test
timeout_in_minutes
:
20
timeout_in_minutes
:
20
device
:
b200
device
:
h100
optional
:
true
working_dir
:
"
/vllm-workspace/tests"
working_dir
:
"
/vllm-workspace/tests"
num_devices
:
4
num_devices
:
4
source_file_dependencies
:
source_file_dependencies
:
...
...
vllm/config/parallel.py
View file @
de1a86b7
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
os
import
os
import
socket
from
collections.abc
import
Callable
from
collections.abc
import
Callable
from
typing
import
TYPE_CHECKING
,
Any
,
Literal
,
overload
from
typing
import
TYPE_CHECKING
,
Any
,
Literal
,
overload
...
@@ -266,33 +267,9 @@ class ParallelConfig:
...
@@ -266,33 +267,9 @@ class ParallelConfig:
Set to be private as it's not intended to be configured by users.
Set to be private as it's not intended to be configured by users.
"""
"""
_stateless_dp_group_port_list
:
list
[
list
[
int
]]
=
Field
(
default_factory
=
list
)
_coord_store_port
:
int
=
0
"""List of open ports for stateless DP groups when enable_elastic_ep is True.
"""Port of the coordination TCPStore. Can be set by the API server; workers
Set to be private as it's not intended to be configured by users.
connect as clients to exchange self-picked group ports at runtime."""
It is a list of list[int], with each inner list contains a set of 3 ports
to be used for setting up the stateless CPU/device/TCPStore groups
in StatelessGroupCoordinator. The number of inner lists is equal to
the number of DP groups,
i.e., len(self._stateless_dp_group_port_list) == world_size_across_dp // dp_size,
and len(self._stateless_dp_group_port_list[i]) == 3 for all i.
"""
_stateless_ep_group_port_list
:
list
[
list
[
int
]]
=
Field
(
default_factory
=
list
)
"""List of open ports for stateless EP groups when enable_elastic_ep is True.
Set to be private as it's not intended to be configured by users.
len(self._stateless_ep_group_port_list) == world_size_across_dp // ep_size,
"""
_stateless_eplb_group_port_list
:
list
[
list
[
int
]]
=
Field
(
default_factory
=
list
)
"""List of open ports for stateless EPLB groups when enable_elastic_ep is True.
Same topology as EP but separate NCCL communicator to avoid deadlocks.
"""
_stateless_world_group_port_list
:
list
[
list
[
int
]]
=
Field
(
default_factory
=
list
)
"""List of open ports for stateless world group when enable_elastic_ep is True.
Set to be private as it's not intended to be configured by users.
len(self._stateless_world_group_port_list) == 1,
"""
decode_context_parallel_size
:
int
=
1
decode_context_parallel_size
:
int
=
1
"""Number of decode context parallel groups, because the world size does
"""Number of decode context parallel groups, because the world size does
...
@@ -465,65 +442,32 @@ class ParallelConfig:
...
@@ -465,65 +442,32 @@ class ParallelConfig:
return
answer
return
answer
def
allocate_elastic_e
p_port
s
(
self
)
->
None
:
def
_pick_stateless_d
p_port
(
self
)
->
tuple
[
int
,
socket
.
socket
|
None
]
:
"""
Allocate all ports for elastic EP (stateless groups + DP master)
.
"""
Return ``(port, listen_socket)`` for DP group init
.
Must be called AFTER ray.init() so that ports claimed by Ray's
With a coord store, rank 0 binds a socket and publishes the port;
idle worker pool are al
read
y
i
n use and won't be returned by
others
read i
t. Without one, pops a pre-allocated port and
g
et
_open_ports_list()
.
r
et
urns ``listen_socket=None``
.
"""
"""
if
not
self
.
enable_elastic_ep
:
if
not
self
.
_coord_store_port
:
return
return
self
.
get_next_dp_init_port
(),
None
if
self
.
_stateless_world_group_port_list
:
return
from
vllm.distributed.utils
import
get_cached_tcp_store_client
num_world_groups
=
1
store
=
get_cached_tcp_store_client
(
dp_size
=
self
.
data_parallel_size
self
.
data_parallel_master_ip
,
self
.
_coord_store_port
ep_size
=
self
.
data_parallel_size
*
self
.
world_size_across_dp
)
num_dp_groups
=
max
(
1
,
self
.
world_size_across_dp
//
dp_size
)
num_ep_groups
=
max
(
1
,
self
.
world_size_across_dp
//
ep_size
)
key
=
"dp_master_port"
num_eplb_groups
=
num_ep_groups
if
self
.
data_parallel_rank
==
0
:
total_stateless_ports
=
(
s
=
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
num_world_groups
+
num_dp_groups
+
num_ep_groups
+
num_eplb_groups
s
.
bind
((
self
.
data_parallel_master_ip
,
0
))
)
*
3
s
.
listen
()
num_dp_master_ports
=
5
port
=
s
.
getsockname
()[
1
]
store
.
set
(
key
,
str
(
port
).
encode
())
all_ports
=
get_open_ports_list
(
total_stateless_ports
+
num_dp_master_ports
)
return
port
,
s
else
:
self
.
_data_parallel_master_port_list
=
all_ports
[
-
num_dp_master_ports
:]
return
int
(
store
.
get
(
key
).
decode
()),
None
self
.
data_parallel_master_port
=
self
.
_data_parallel_master_port_list
.
pop
()
all_ports
=
all_ports
[:
-
num_dp_master_ports
]
self
.
_stateless_world_group_port_list
=
[
all_ports
[
i
:
i
+
3
]
for
i
in
range
(
0
,
num_world_groups
*
3
,
3
)
]
start_idx
=
num_world_groups
*
3
self
.
_stateless_dp_group_port_list
=
[
all_ports
[
i
:
i
+
3
]
for
i
in
range
(
start_idx
,
start_idx
+
num_dp_groups
*
3
,
3
)
]
start_idx
+=
num_dp_groups
*
3
self
.
_stateless_ep_group_port_list
=
[
all_ports
[
i
:
i
+
3
]
for
i
in
range
(
start_idx
,
start_idx
+
num_ep_groups
*
3
,
3
)
]
start_idx
+=
num_ep_groups
*
3
self
.
_stateless_eplb_group_port_list
=
[
all_ports
[
i
:
i
+
3
]
for
i
in
range
(
start_idx
,
start_idx
+
num_eplb_groups
*
3
,
3
)
]
def
get_next_stateless_world_group_port
(
self
)
->
list
[
int
]:
return
self
.
_stateless_world_group_port_list
.
pop
()
def
get_next_stateless_dp_group_port
(
self
)
->
list
[
int
]:
return
self
.
_stateless_dp_group_port_list
.
pop
()
def
get_next_stateless_ep_group_port
(
self
)
->
list
[
int
]:
return
self
.
_stateless_ep_group_port_list
.
pop
()
def
get_next_stateless_eplb_group_port
(
self
)
->
list
[
int
]:
return
self
.
_stateless_eplb_group_port_list
.
pop
()
@
overload
@
overload
def
stateless_init_dp_group
(
def
stateless_init_dp_group
(
...
@@ -553,14 +497,16 @@ class ParallelConfig:
...
@@ -553,14 +497,16 @@ class ParallelConfig:
last_exc
:
Exception
|
None
=
None
last_exc
:
Exception
|
None
=
None
for
_
in
range
(
max_retries
):
for
_
in
range
(
max_retries
):
try
:
try
:
port
,
listen_socket
=
self
.
_pick_stateless_dp_port
()
# use gloo since the engine process might not have cuda device
# use gloo since the engine process might not have cuda device
return
stateless_init_torch_distributed_process_group
(
return
stateless_init_torch_distributed_process_group
(
self
.
data_parallel_master_ip
,
self
.
data_parallel_master_ip
,
self
.
get_next_dp_init_
port
()
,
port
,
self
.
data_parallel_rank
,
self
.
data_parallel_rank
,
self
.
data_parallel_size
,
self
.
data_parallel_size
,
backend
=
"gloo"
,
backend
=
"gloo"
,
return_store
=
return_store
,
return_store
=
return_store
,
listen_socket
=
listen_socket
,
)
)
except
DistNetworkError
as
e
:
except
DistNetworkError
as
e
:
# We only want to retry when the root cause is EADDRINUSE.
# We only want to retry when the root cause is EADDRINUSE.
...
...
vllm/distributed/elastic_ep/elastic_execute.py
View file @
de1a86b7
...
@@ -162,10 +162,8 @@ class ElasticEPScalingExecutor:
...
@@ -162,10 +162,8 @@ class ElasticEPScalingExecutor:
new_dp_size
=
new_dp_size
,
new_dp_size
=
new_dp_size
,
new_world_size_across_dp
=
new_world_size_across_dp
,
new_world_size_across_dp
=
new_world_size_across_dp
,
master_ip
=
reconfig_request
.
new_data_parallel_master_ip
,
master_ip
=
reconfig_request
.
new_data_parallel_master_ip
,
world_group_ports
=
reconfig_request
.
new_stateless_world_group_port_list
,
coord_store_port
=
reconfig_request
.
coord_store_port
,
dp_group_ports
=
reconfig_request
.
new_stateless_dp_group_port_list
,
enable_eplb
=
updated_config
.
parallel_config
.
enable_eplb
,
ep_group_ports
=
reconfig_request
.
new_stateless_ep_group_port_list
,
eplb_group_ports
=
reconfig_request
.
new_stateless_eplb_group_port_list
,
)
)
self
.
worker
.
model_runner
.
eep_eplb_suppressed
=
True
self
.
worker
.
model_runner
.
eep_eplb_suppressed
=
True
standby_ep_group
=
get_standby_ep_group
()
standby_ep_group
=
get_standby_ep_group
()
...
...
vllm/distributed/elastic_ep/elastic_state.py
View file @
de1a86b7
...
@@ -563,15 +563,4 @@ class ElasticEPScalingState:
...
@@ -563,15 +563,4 @@ class ElasticEPScalingState:
parallel_config
.
_data_parallel_master_port_list
=
(
parallel_config
.
_data_parallel_master_port_list
=
(
reconfig_request
.
new_data_parallel_master_port_list
reconfig_request
.
new_data_parallel_master_port_list
)
)
parallel_config
.
_stateless_world_group_port_list
=
(
parallel_config
.
_coord_store_port
=
reconfig_request
.
coord_store_port
reconfig_request
.
new_stateless_world_group_port_list
)
parallel_config
.
_stateless_dp_group_port_list
=
(
reconfig_request
.
new_stateless_dp_group_port_list
)
parallel_config
.
_stateless_ep_group_port_list
=
(
reconfig_request
.
new_stateless_ep_group_port_list
)
parallel_config
.
_stateless_eplb_group_port_list
=
(
reconfig_request
.
new_stateless_eplb_group_port_list
)
vllm/distributed/elastic_ep/standby_state.py
View file @
de1a86b7
...
@@ -38,10 +38,8 @@ def create_standby_groups(
...
@@ -38,10 +38,8 @@ def create_standby_groups(
new_dp_size
:
int
,
new_dp_size
:
int
,
new_world_size_across_dp
:
int
,
new_world_size_across_dp
:
int
,
master_ip
:
str
,
master_ip
:
str
,
world_group_ports
:
list
[
list
[
int
]],
coord_store_port
:
int
,
dp_group_ports
:
list
[
list
[
int
]],
enable_eplb
:
bool
=
True
,
ep_group_ports
:
list
[
list
[
int
]],
eplb_group_ports
:
list
[
list
[
int
]]
|
None
=
None
,
backend
:
str
|
None
=
None
,
backend
:
str
|
None
=
None
,
)
->
None
:
)
->
None
:
global
\
global
\
...
@@ -51,19 +49,23 @@ def create_standby_groups(
...
@@ -51,19 +49,23 @@ def create_standby_groups(
_STANDBY_EP
,
\
_STANDBY_EP
,
\
_STANDBY_EPLB
_STANDBY_EPLB
from
vllm.distributed.utils
import
get_cached_tcp_store_client
assert
new_world_size_across_dp
==
torch
.
distributed
.
get_world_size
()
*
new_dp_size
assert
new_world_size_across_dp
==
torch
.
distributed
.
get_world_size
()
*
new_dp_size
world_group
=
get_world_group
()
world_group
=
get_world_group
()
assert
isinstance
(
world_group
,
StatelessGroupCoordinator
)
assert
isinstance
(
world_group
,
StatelessGroupCoordinator
)
backend
=
backend
or
world_group
.
backend
backend
=
backend
or
world_group
.
backend
coord_store
=
get_cached_tcp_store_client
(
master_ip
,
coord_store_port
)
standby_world_ranks
=
[
list
(
range
(
new_world_size_across_dp
))]
standby_world_ranks
=
[
list
(
range
(
new_world_size_across_dp
))]
_STANDBY_WORLD
=
_init_stateless_group
(
_STANDBY_WORLD
=
_init_stateless_group
(
standby_world_ranks
,
standby_world_ranks
,
"world"
,
"world"
,
world_group_ports
,
master_ip
,
master_ip
,
backend
,
backend
,
use_device_communicator
=
False
,
use_device_communicator
=
False
,
coord_store
=
coord_store
,
)
)
_STANDBY_WORLD_NODE_COUNT
=
_node_count
(
_STANDBY_WORLD
.
tcp_store_group
)
_STANDBY_WORLD_NODE_COUNT
=
_node_count
(
_STANDBY_WORLD
.
tcp_store_group
)
...
@@ -76,7 +78,7 @@ def create_standby_groups(
...
@@ -76,7 +78,7 @@ def create_standby_groups(
standby_dp_ranks
=
all_ranks
.
transpose
(
1
,
3
).
reshape
(
-
1
,
new_dp_size
).
unbind
(
0
)
standby_dp_ranks
=
all_ranks
.
transpose
(
1
,
3
).
reshape
(
-
1
,
new_dp_size
).
unbind
(
0
)
standby_dp_ranks
=
[
x
.
tolist
()
for
x
in
standby_dp_ranks
]
standby_dp_ranks
=
[
x
.
tolist
()
for
x
in
standby_dp_ranks
]
_STANDBY_DP
=
_init_stateless_group
(
_STANDBY_DP
=
_init_stateless_group
(
standby_dp_ranks
,
"dp"
,
dp_group_ports
,
master_ip
,
backend
standby_dp_ranks
,
"dp"
,
master_ip
,
backend
,
coord_store
=
coord_store
)
)
standby_ep_ranks
=
(
standby_ep_ranks
=
(
...
@@ -84,12 +86,16 @@ def create_standby_groups(
...
@@ -84,12 +86,16 @@ def create_standby_groups(
)
)
standby_ep_ranks
=
[
x
.
tolist
()
for
x
in
standby_ep_ranks
]
standby_ep_ranks
=
[
x
.
tolist
()
for
x
in
standby_ep_ranks
]
_STANDBY_EP
=
_init_stateless_group
(
_STANDBY_EP
=
_init_stateless_group
(
standby_ep_ranks
,
"ep"
,
ep_group_ports
,
master_ip
,
backend
standby_ep_ranks
,
"ep"
,
master_ip
,
backend
,
coord_store
=
coord_store
)
)
if
e
plb_group_ports
is
not
None
:
if
e
nable_eplb
:
_STANDBY_EPLB
=
_init_stateless_group
(
_STANDBY_EPLB
=
_init_stateless_group
(
standby_ep_ranks
,
"eplb"
,
eplb_group_ports
,
master_ip
,
backend
standby_ep_ranks
,
"eplb"
,
master_ip
,
backend
,
coord_store
=
coord_store
,
)
)
...
...
vllm/distributed/parallel_state.py
View file @
de1a86b7
...
@@ -40,13 +40,16 @@ import torch
...
@@ -40,13 +40,16 @@ import torch
import
torch.distributed
import
torch.distributed
import
torch.distributed._functional_collectives
as
funcol
import
torch.distributed._functional_collectives
as
funcol
import
torch.distributed._symmetric_memory
import
torch.distributed._symmetric_memory
from
torch.distributed
import
Backend
,
ProcessGroup
from
torch.distributed
import
Backend
,
ProcessGroup
,
Store
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.distributed.device_communicators.base_device_communicator
import
(
from
vllm.distributed.device_communicators.base_device_communicator
import
(
DeviceCommunicatorBase
,
DeviceCommunicatorBase
,
)
)
from
vllm.distributed.utils
import
StatelessProcessGroup
from
vllm.distributed.utils
import
(
StatelessProcessGroup
,
get_cached_tcp_store_client
,
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.utils.import_utils
import
resolve_obj_by_qualname
from
vllm.utils.import_utils
import
resolve_obj_by_qualname
from
vllm.utils.network_utils
import
get_distributed_init_method
from
vllm.utils.network_utils
import
get_distributed_init_method
...
@@ -1164,9 +1167,9 @@ def init_model_parallel_group(
...
@@ -1164,9 +1167,9 @@ def init_model_parallel_group(
def
_init_stateless_group
(
def
_init_stateless_group
(
group_ranks
:
list
[
list
[
int
]],
group_ranks
:
list
[
list
[
int
]],
group_name
:
str
,
group_name
:
str
,
group_ports
:
list
[
list
[
int
]],
host
:
str
,
host
:
str
,
backend
:
str
,
backend
:
str
,
coord_store
:
Store
,
use_device_communicator
:
bool
=
True
,
use_device_communicator
:
bool
=
True
,
)
->
"StatelessGroupCoordinator"
:
)
->
"StatelessGroupCoordinator"
:
"""Create a StatelessGroupCoordinator with the given parameters."""
"""Create a StatelessGroupCoordinator with the given parameters."""
...
@@ -1180,7 +1183,7 @@ def _init_stateless_group(
...
@@ -1180,7 +1183,7 @@ def _init_stateless_group(
use_device_communicator
=
use_device_communicator
,
use_device_communicator
=
use_device_communicator
,
group_name
=
group_name
,
group_name
=
group_name
,
host
=
host
,
host
=
host
,
group_ports
=
group_ports
,
coord_store
=
coord_store
,
global_rank
=
world
.
rank
,
global_rank
=
world
.
rank
,
global_world_size
=
world
.
world_size
,
global_world_size
=
world
.
world_size
,
)
)
...
@@ -1321,7 +1324,9 @@ def _init_elastic_ep_world(
...
@@ -1321,7 +1324,9 @@ def _init_elastic_ep_world(
group_ranks
=
[
all_ranks
[
i
:
i
+
1
]
for
i
in
range
(
global_world_size
)]
group_ranks
=
[
all_ranks
[
i
:
i
+
1
]
for
i
in
range
(
global_world_size
)]
if
global_rank
in
all_ranks
:
if
global_rank
in
all_ranks
:
group_ranks
=
[
all_ranks
]
group_ranks
=
[
all_ranks
]
group_ports
=
[
parallel_config
.
get_next_stateless_world_group_port
()]
coord_store
=
get_cached_tcp_store_client
(
parallel_config
.
data_parallel_master_ip
,
parallel_config
.
_coord_store_port
)
world
=
StatelessGroupCoordinator
(
world
=
StatelessGroupCoordinator
(
group_ranks
=
group_ranks
,
group_ranks
=
group_ranks
,
local_rank
=
local_rank
,
local_rank
=
local_rank
,
...
@@ -1329,7 +1334,7 @@ def _init_elastic_ep_world(
...
@@ -1329,7 +1334,7 @@ def _init_elastic_ep_world(
use_device_communicator
=
False
,
use_device_communicator
=
False
,
group_name
=
"world"
,
group_name
=
"world"
,
host
=
parallel_config
.
data_parallel_master_ip
,
host
=
parallel_config
.
data_parallel_master_ip
,
group_ports
=
group_ports
,
coord_store
=
coord_store
,
global_rank
=
global_rank
,
global_rank
=
global_rank
,
global_world_size
=
global_world_size
,
global_world_size
=
global_world_size
,
)
)
...
@@ -1513,7 +1518,13 @@ def initialize_model_parallel(
...
@@ -1513,7 +1518,13 @@ def initialize_model_parallel(
config
=
get_current_vllm_config
()
config
=
get_current_vllm_config
()
data_parallel_size
=
config
.
parallel_config
.
data_parallel_size
data_parallel_size
=
config
.
parallel_config
.
data_parallel_size
enable_elastic_ep
=
config
.
parallel_config
.
enable_elastic_ep
enable_elastic_ep
=
config
.
parallel_config
.
enable_elastic_ep
parallel_config
=
config
.
parallel_config
coord_store
:
Store
|
None
=
None
if
enable_elastic_ep
:
if
enable_elastic_ep
:
coord_store
=
get_cached_tcp_store_client
(
parallel_config
.
data_parallel_master_ip
,
parallel_config
.
_coord_store_port
,
)
# Use stateless world group for global information
# Use stateless world group for global information
world_size
=
get_world_group
().
world_size
world_size
=
get_world_group
().
world_size
rank
=
get_world_group
().
rank
rank
=
get_world_group
().
rank
...
@@ -1633,16 +1644,12 @@ def initialize_model_parallel(
...
@@ -1633,16 +1644,12 @@ def initialize_model_parallel(
group_ranks
=
all_ranks
.
transpose
(
1
,
4
).
reshape
(
-
1
,
data_parallel_size
).
unbind
(
0
)
group_ranks
=
all_ranks
.
transpose
(
1
,
4
).
reshape
(
-
1
,
data_parallel_size
).
unbind
(
0
)
group_ranks
=
[
x
.
tolist
()
for
x
in
group_ranks
]
group_ranks
=
[
x
.
tolist
()
for
x
in
group_ranks
]
if
enable_elastic_ep
:
if
enable_elastic_ep
:
parallel_config
=
config
.
parallel_config
dp_ports
=
[
parallel_config
.
get_next_stateless_dp_group_port
()
for
_
in
group_ranks
]
_DP
=
_init_stateless_group
(
_DP
=
_init_stateless_group
(
group_ranks
,
group_ranks
,
"dp"
,
"dp"
,
dp_ports
,
parallel_config
.
data_parallel_master_ip
,
parallel_config
.
data_parallel_master_ip
,
backend
,
backend
,
coord_store
=
coord_store
,
)
)
else
:
else
:
_DP
=
init_model_parallel_group
(
_DP
=
init_model_parallel_group
(
...
@@ -1665,16 +1672,12 @@ def initialize_model_parallel(
...
@@ -1665,16 +1672,12 @@ def initialize_model_parallel(
)
)
group_ranks
=
[
x
.
tolist
()
for
x
in
group_ranks
]
group_ranks
=
[
x
.
tolist
()
for
x
in
group_ranks
]
if
enable_elastic_ep
:
if
enable_elastic_ep
:
parallel_config
=
config
.
parallel_config
ep_ports
=
[
parallel_config
.
get_next_stateless_ep_group_port
()
for
_
in
group_ranks
]
_EP
=
_init_stateless_group
(
_EP
=
_init_stateless_group
(
group_ranks
,
group_ranks
,
"ep"
,
"ep"
,
ep_ports
,
parallel_config
.
data_parallel_master_ip
,
parallel_config
.
data_parallel_master_ip
,
backend
,
backend
,
coord_store
=
coord_store
,
)
)
else
:
else
:
_EP
=
init_model_parallel_group
(
_EP
=
init_model_parallel_group
(
...
@@ -1693,16 +1696,12 @@ def initialize_model_parallel(
...
@@ -1693,16 +1696,12 @@ def initialize_model_parallel(
and
config
.
parallel_config
.
enable_eplb
and
config
.
parallel_config
.
enable_eplb
):
):
if
enable_elastic_ep
:
if
enable_elastic_ep
:
eplb_ports
=
[
parallel_config
.
get_next_stateless_eplb_group_port
()
for
_
in
group_ranks
]
_EPLB
=
_init_stateless_group
(
_EPLB
=
_init_stateless_group
(
group_ranks
,
group_ranks
,
"eplb"
,
"eplb"
,
eplb_ports
,
parallel_config
.
data_parallel_master_ip
,
parallel_config
.
data_parallel_master_ip
,
backend
,
backend
,
coord_store
=
coord_store
,
)
)
else
:
else
:
_EPLB
=
init_model_parallel_group
(
_EPLB
=
init_model_parallel_group
(
...
...
vllm/distributed/stateless_coordinator.py
View file @
de1a86b7
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
socket
import
struct
from
typing
import
Any
,
Optional
from
typing
import
Any
,
Optional
import
torch
import
torch
from
torch.distributed
import
Backend
,
ProcessGroup
from
torch.distributed
import
Backend
,
ProcessGroup
,
Store
from
vllm.distributed.device_communicators.cuda_communicator
import
CudaCommunicator
from
vllm.distributed.device_communicators.cuda_communicator
import
CudaCommunicator
from
vllm.distributed.parallel_state
import
(
from
vllm.distributed.parallel_state
import
(
...
@@ -23,6 +25,38 @@ from vllm.utils.import_utils import resolve_obj_by_qualname
...
@@ -23,6 +25,38 @@ from vllm.utils.import_utils import resolve_obj_by_qualname
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
_PORTS_FMT
=
"!3I"
def
_allocate_group_ports
(
key
:
str
,
host
:
str
,
coord_store
:
Store
,
)
->
tuple
[
list
[
int
],
list
[
socket
.
socket
]]:
"""Bind 3 sockets and publish the ports to *coord_store*.
Called by rank 0 only. Returns ``(ports, sockets)`` with the
sockets still open.
"""
socks
:
list
[
socket
.
socket
]
=
[]
ports
:
list
[
int
]
=
[]
for
_
in
range
(
3
):
s
=
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
s
.
bind
((
host
,
0
))
s
.
listen
()
socks
.
append
(
s
)
ports
.
append
(
s
.
getsockname
()[
1
])
coord_store
.
set
(
key
,
struct
.
pack
(
_PORTS_FMT
,
*
ports
))
return
ports
,
socks
def
_fetch_group_ports
(
key
:
str
,
coord_store
:
Store
)
->
list
[
int
]:
"""Read 3 ports published by rank 0 from *coord_store*.
Blocks until the key is available.
"""
return
list
(
struct
.
unpack
(
_PORTS_FMT
,
coord_store
.
get
(
key
)))
class
StatelessGroupCoordinator
(
GroupCoordinator
):
class
StatelessGroupCoordinator
(
GroupCoordinator
):
"""
"""
...
@@ -39,10 +73,10 @@ class StatelessGroupCoordinator(GroupCoordinator):
...
@@ -39,10 +73,10 @@ class StatelessGroupCoordinator(GroupCoordinator):
local_rank
:
int
,
local_rank
:
int
,
torch_distributed_backend
:
str
|
Backend
,
torch_distributed_backend
:
str
|
Backend
,
use_device_communicator
:
bool
,
use_device_communicator
:
bool
,
coord_store
:
Store
,
use_message_queue_broadcaster
:
bool
=
False
,
use_message_queue_broadcaster
:
bool
=
False
,
group_name
:
str
|
None
=
None
,
group_name
:
str
|
None
=
None
,
host
:
str
=
"127.0.0.1"
,
host
:
str
=
"127.0.0.1"
,
group_ports
:
list
[
list
[
int
]]
|
None
=
None
,
global_rank
:
int
=
0
,
global_rank
:
int
=
0
,
global_world_size
:
int
=
1
,
global_world_size
:
int
=
1
,
):
):
...
@@ -61,17 +95,23 @@ class StatelessGroupCoordinator(GroupCoordinator):
...
@@ -61,17 +95,23 @@ class StatelessGroupCoordinator(GroupCoordinator):
backend
=
str
(
torch_distributed_backend
)
backend
=
str
(
torch_distributed_backend
)
self
.
backend
=
backend
self
.
backend
=
backend
assert
group_ports
is
not
None
,
"group_ports is not provided"
for
idx
,
ranks
in
enumerate
(
group_ranks
):
for
idx
,
ranks
in
enumerate
(
group_ranks
):
if
self
.
rank
in
ranks
:
if
self
.
rank
in
ranks
:
self
.
ranks
=
ranks
self
.
ranks
=
ranks
self
.
world_size
=
len
(
ranks
)
self
.
world_size
=
len
(
ranks
)
self
.
rank_in_group
=
ranks
.
index
(
self
.
rank
)
self
.
rank_in_group
=
ranks
.
index
(
self
.
rank
)
ports
=
group_ports
[
idx
]
key
=
f
"
{
group_name
}
_
{
idx
}
"
device_port
=
ports
[
0
]
if
self
.
rank_in_group
==
0
:
cpu_port
=
ports
[
1
]
ports
,
socks
=
_allocate_group_ports
(
tcp_store_port
=
ports
[
2
]
key
,
host
,
coord_store
,
)
else
:
ports
=
_fetch_group_ports
(
key
,
coord_store
)
socks
=
[]
device_port
,
cpu_port
,
tcp_store_port
=
ports
device_group
=
stateless_init_torch_distributed_process_group
(
device_group
=
stateless_init_torch_distributed_process_group
(
host
=
host
,
host
=
host
,
...
@@ -80,6 +120,7 @@ class StatelessGroupCoordinator(GroupCoordinator):
...
@@ -80,6 +120,7 @@ class StatelessGroupCoordinator(GroupCoordinator):
world_size
=
self
.
world_size
,
world_size
=
self
.
world_size
,
backend
=
backend
,
backend
=
backend
,
group_name
=
f
"
{
self
.
unique_name
}
_device"
,
group_name
=
f
"
{
self
.
unique_name
}
_device"
,
listen_socket
=
socks
[
0
]
if
socks
else
None
,
)
)
cpu_group
=
stateless_init_torch_distributed_process_group
(
cpu_group
=
stateless_init_torch_distributed_process_group
(
host
=
host
,
host
=
host
,
...
@@ -88,12 +129,14 @@ class StatelessGroupCoordinator(GroupCoordinator):
...
@@ -88,12 +129,14 @@ class StatelessGroupCoordinator(GroupCoordinator):
world_size
=
self
.
world_size
,
world_size
=
self
.
world_size
,
backend
=
"gloo"
,
backend
=
"gloo"
,
group_name
=
f
"
{
self
.
unique_name
}
_cpu"
,
group_name
=
f
"
{
self
.
unique_name
}
_cpu"
,
listen_socket
=
socks
[
1
]
if
socks
else
None
,
)
)
tcp_store_group
=
StatelessProcessGroup
.
create
(
tcp_store_group
=
StatelessProcessGroup
.
create
(
host
=
host
,
host
=
host
,
port
=
tcp_store_port
,
port
=
tcp_store_port
,
rank
=
self
.
rank_in_group
,
rank
=
self
.
rank_in_group
,
world_size
=
self
.
world_size
,
world_size
=
self
.
world_size
,
listen_socket
=
socks
[
2
]
if
socks
else
None
,
)
)
self_device_group
=
device_group
self_device_group
=
device_group
...
...
vllm/distributed/utils.py
View file @
de1a86b7
...
@@ -6,6 +6,7 @@
...
@@ -6,6 +6,7 @@
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import
dataclasses
import
dataclasses
import
functools
import
os
import
os
import
pickle
import
pickle
import
socket
import
socket
...
@@ -139,6 +140,29 @@ def get_pp_indices(
...
@@ -139,6 +140,29 @@ def get_pp_indices(
return
(
start_layer
,
end_layer
)
return
(
start_layer
,
end_layer
)
def
create_tcp_store
(
host
:
str
,
port
:
int
,
listen_socket
:
socket
.
socket
|
None
=
None
,
**
kwargs
:
Any
,
)
->
TCPStore
:
"""Create a TCPStore, optionally taking ownership of ``listen_socket``."""
if
listen_socket
is
None
:
return
TCPStore
(
host_name
=
host
,
port
=
port
,
**
kwargs
)
listen_fd
=
listen_socket
.
detach
()
try
:
return
TCPStore
(
host_name
=
host
,
port
=
port
,
master_listen_fd
=
listen_fd
,
**
kwargs
,
)
except
Exception
:
socket
.
close
(
listen_fd
)
raise
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
class
StatelessProcessGroup
:
class
StatelessProcessGroup
:
"""A dataclass to hold a metadata store, and the rank, world_size of the
"""A dataclass to hold a metadata store, and the rank, world_size of the
...
@@ -150,9 +174,6 @@ class StatelessProcessGroup:
...
@@ -150,9 +174,6 @@ class StatelessProcessGroup:
world_size
:
int
world_size
:
int
store
:
torch
.
_C
.
_distributed_c10d
.
Store
store
:
torch
.
_C
.
_distributed_c10d
.
Store
# stores a reference to the socket so that the file descriptor stays alive
socket
:
socket
.
socket
|
None
data_expiration_seconds
:
int
=
3600
# 1 hour
data_expiration_seconds
:
int
=
3600
# 1 hour
# dst rank -> counter
# dst rank -> counter
...
@@ -419,6 +440,7 @@ class StatelessProcessGroup:
...
@@ -419,6 +440,7 @@ class StatelessProcessGroup:
world_size
:
int
,
world_size
:
int
,
data_expiration_seconds
:
int
=
3600
,
data_expiration_seconds
:
int
=
3600
,
store_timeout
:
int
=
300
,
store_timeout
:
int
=
300
,
listen_socket
:
socket
.
socket
|
None
=
None
,
)
->
"StatelessProcessGroup"
:
)
->
"StatelessProcessGroup"
:
"""A replacement for `torch.distributed.init_process_group` that does not
"""A replacement for `torch.distributed.init_process_group` that does not
pollute the global state.
pollute the global state.
...
@@ -436,36 +458,39 @@ class StatelessProcessGroup:
...
@@ -436,36 +458,39 @@ class StatelessProcessGroup:
C, and D can call `StatelessProcessGroup.create` to form another group.
C, and D can call `StatelessProcessGroup.create` to form another group.
"""
# noqa
"""
# noqa
launch_server
=
rank
==
0
launch_server
=
rank
==
0
if
launch_server
:
if
launch_server
and
listen_socket
is
None
:
# listen on the specified interface (instead of 0.0.0.0)
listen_socket
=
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
listen_socket
=
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
listen_socket
.
setsockopt
(
socket
.
SOL_SOCKET
,
socket
.
SO_REUSEADDR
,
1
)
listen_socket
.
setsockopt
(
socket
.
SOL_SOCKET
,
socket
.
SO_REUSEADDR
,
1
)
listen_socket
.
bind
((
host
,
port
))
listen_socket
.
bind
((
host
,
port
))
listen_socket
.
listen
()
listen_socket
.
listen
()
listen_fd
=
listen_socket
.
fileno
()
store
=
create_tcp_store
(
else
:
host
,
listen_socket
=
None
port
,
listen_fd
=
None
listen_socket
=
listen_socket
,
store
=
TCPStore
(
host_name
=
host
,
port
=
port
,
world_size
=
world_size
,
world_size
=
world_size
,
is_master
=
launch_server
,
is_master
=
launch_server
,
timeout
=
timedelta
(
seconds
=
store_timeout
),
timeout
=
timedelta
(
seconds
=
store_timeout
),
use_libuv
=
False
,
# for now: github.com/pytorch/pytorch/pull/150215
use_libuv
=
False
,
# for now: github.com/pytorch/pytorch/pull/150215
master_listen_fd
=
listen_fd
,
)
)
return
StatelessProcessGroup
(
return
StatelessProcessGroup
(
rank
=
rank
,
rank
=
rank
,
world_size
=
world_size
,
world_size
=
world_size
,
store
=
store
,
store
=
store
,
socket
=
listen_socket
,
data_expiration_seconds
=
data_expiration_seconds
,
data_expiration_seconds
=
data_expiration_seconds
,
)
)
@
functools
.
lru_cache
(
maxsize
=
1
)
def
get_cached_tcp_store_client
(
host
:
str
,
port
:
int
)
->
TCPStore
:
"""Return a cached TCPStore client.
Cached so that every call with the same ``(host, port)`` reuses the
same connection. A new ``(host, port)`` evicts the old entry.
"""
return
TCPStore
(
host
,
port
,
is_master
=
False
,
wait_for_workers
=
False
)
def
init_gloo_process_group
(
def
init_gloo_process_group
(
prefix_store
:
PrefixStore
,
prefix_store
:
PrefixStore
,
group_rank
:
int
,
group_rank
:
int
,
...
@@ -504,6 +529,7 @@ def stateless_init_torch_distributed_process_group(
...
@@ -504,6 +529,7 @@ def stateless_init_torch_distributed_process_group(
backend
:
str
,
backend
:
str
,
group_name
:
str
|
None
=
None
,
group_name
:
str
|
None
=
None
,
return_store
:
bool
=
False
,
return_store
:
bool
=
False
,
listen_socket
:
socket
.
socket
|
None
=
None
,
)
->
ProcessGroup
|
tuple
[
ProcessGroup
,
Store
]:
)
->
ProcessGroup
|
tuple
[
ProcessGroup
,
Store
]:
"""
"""
A replacement for `torch.distributed.init_process_group` that does not
A replacement for `torch.distributed.init_process_group` that does not
...
@@ -535,14 +561,30 @@ def stateless_init_torch_distributed_process_group(
...
@@ -535,14 +561,30 @@ def stateless_init_torch_distributed_process_group(
are the same as process 1 and 5, the main communication channel is
are the same as process 1 and 5, the main communication channel is
always formed with process 1, 2, ..., 8, and the additional communication
always formed with process 1, 2, ..., 8, and the additional communication
channel is formed with process 9 and 10.
channel is formed with process 9 and 10.
When *listen_socket* is provided, the rendezvous step
is skipped and a ``TCPStore`` server is created directly using the
pre-bound socket. This is useful for eliminating TOCTOU races
between port allocation and binding.
"""
"""
init_method
=
get_tcp_uri
(
host
,
port
)
init_method
=
get_tcp_uri
(
host
,
port
)
backend
=
Backend
(
backend
)
# it is basically string
backend
=
Backend
(
backend
)
# it is basically string
timeout
=
_get_default_timeout
(
backend
)
timeout
=
_get_default_timeout
(
backend
)
store
,
rank
,
world_size
=
next
(
if
listen_socket
is
not
None
:
rendezvous
(
init_method
,
rank
,
world_size
,
timeout
=
timeout
)
store
=
create_tcp_store
(
)
host
,
port
,
listen_socket
=
listen_socket
,
world_size
=
world_size
,
is_master
=
True
,
timeout
=
timeout
,
multi_tenant
=
True
,
)
else
:
store
,
rank
,
world_size
=
next
(
rendezvous
(
init_method
,
rank
,
world_size
,
timeout
=
timeout
)
)
store
.
set_timeout
(
timeout
)
store
.
set_timeout
(
timeout
)
group_rank
=
rank
group_rank
=
rank
...
...
vllm/v1/engine/__init__.py
View file @
de1a86b7
...
@@ -237,10 +237,7 @@ class ReconfigureDistributedRequest(msgspec.Struct):
...
@@ -237,10 +237,7 @@ class ReconfigureDistributedRequest(msgspec.Struct):
new_data_parallel_master_ip
:
str
new_data_parallel_master_ip
:
str
new_data_parallel_master_port
:
int
new_data_parallel_master_port
:
int
new_data_parallel_master_port_list
:
list
[
int
]
new_data_parallel_master_port_list
:
list
[
int
]
new_stateless_world_group_port_list
:
list
[
list
[
int
]]
coord_store_port
:
int
new_stateless_dp_group_port_list
:
list
[
list
[
int
]]
new_stateless_ep_group_port_list
:
list
[
list
[
int
]]
new_stateless_eplb_group_port_list
:
list
[
list
[
int
]]
class
ReconfigureRankType
(
enum
.
IntEnum
):
class
ReconfigureRankType
(
enum
.
IntEnum
):
...
...
vllm/v1/engine/core.py
View file @
de1a86b7
...
@@ -1767,6 +1767,7 @@ class DPEngineCoreProc(EngineCoreProc):
...
@@ -1767,6 +1767,7 @@ class DPEngineCoreProc(EngineCoreProc):
new_parallel_config
.
_data_parallel_master_port_list
=
(
new_parallel_config
.
_data_parallel_master_port_list
=
(
reconfig_request
.
new_data_parallel_master_port_list
reconfig_request
.
new_data_parallel_master_port_list
)
)
new_parallel_config
.
_coord_store_port
=
reconfig_request
.
coord_store_port
is_scale_down
=
reconfig_request
.
new_data_parallel_size
<
old_dp_size
is_scale_down
=
reconfig_request
.
new_data_parallel_size
<
old_dp_size
is_shutdown
=
(
is_shutdown
=
(
...
...
vllm/v1/engine/core_client.py
View file @
de1a86b7
...
@@ -455,56 +455,6 @@ class ElasticScalingCache:
...
@@ -455,56 +455,6 @@ class ElasticScalingCache:
pending_notifications
:
dict
[
EEPNotificationType
,
set
[
int
]]
pending_notifications
:
dict
[
EEPNotificationType
,
set
[
int
]]
def
allocate_stateless_group_ports
(
parallel_config
,
new_data_parallel_size
:
int
):
"""
Allocate stateless group ports for elastic EP.
"""
from
vllm.utils.network_utils
import
get_open_ports_list
assert
parallel_config
.
enable_elastic_ep
,
"Elastic EP must be enabled"
world_size
=
parallel_config
.
world_size
new_world_size_across_dp
=
world_size
*
new_data_parallel_size
num_world_groups
=
1
num_dp_groups
=
max
(
1
,
new_world_size_across_dp
//
new_data_parallel_size
)
num_ep_groups
=
max
(
1
,
new_world_size_across_dp
//
(
new_data_parallel_size
*
parallel_config
.
tensor_parallel_size
),
)
num_eplb_groups
=
num_ep_groups
total_ports_needed
=
(
num_world_groups
+
num_dp_groups
+
num_ep_groups
+
num_eplb_groups
)
*
3
+
5
all_ports
=
get_open_ports_list
(
total_ports_needed
)
new_data_parallel_master_port_list
=
all_ports
[
-
5
:]
all_ports
=
all_ports
[:
-
5
]
new_stateless_world_group_port_list
=
[
all_ports
[
i
:
i
+
3
]
for
i
in
range
(
0
,
num_world_groups
*
3
,
3
)
]
start_idx
=
num_world_groups
*
3
new_stateless_dp_group_port_list
=
[
all_ports
[
i
:
i
+
3
]
for
i
in
range
(
start_idx
,
start_idx
+
num_dp_groups
*
3
,
3
)
]
start_idx
+=
num_dp_groups
*
3
new_stateless_ep_group_port_list
=
[
all_ports
[
i
:
i
+
3
]
for
i
in
range
(
start_idx
,
start_idx
+
num_ep_groups
*
3
,
3
)
]
start_idx
+=
num_ep_groups
*
3
new_stateless_eplb_group_port_list
=
[
all_ports
[
i
:
i
+
3
]
for
i
in
range
(
start_idx
,
start_idx
+
num_eplb_groups
*
3
,
3
)
]
parallel_config
.
_stateless_world_group_port_list
=
(
new_stateless_world_group_port_list
)
parallel_config
.
_stateless_dp_group_port_list
=
new_stateless_dp_group_port_list
parallel_config
.
_stateless_ep_group_port_list
=
new_stateless_ep_group_port_list
parallel_config
.
_stateless_eplb_group_port_list
=
new_stateless_eplb_group_port_list
parallel_config
.
data_parallel_master_port
=
new_data_parallel_master_port_list
.
pop
()
parallel_config
.
_data_parallel_master_port_list
=
new_data_parallel_master_port_list
class
MPClient
(
EngineCoreClient
):
class
MPClient
(
EngineCoreClient
):
"""
"""
MPClient: base client for multi-proc EngineCore.
MPClient: base client for multi-proc EngineCore.
...
@@ -1541,6 +1491,28 @@ class DPLBAsyncMPClient(DPAsyncMPClient):
...
@@ -1541,6 +1491,28 @@ class DPLBAsyncMPClient(DPAsyncMPClient):
self
.
_ensure_output_queue_task
()
self
.
_ensure_output_queue_task
()
await
future
await
future
def
_setup_elastic_ep_reconfig_bootstrap
(
self
)
->
tuple
[
str
,
int
]:
from
vllm.distributed.utils
import
create_tcp_store
from
vllm.utils.network_utils
import
get_open_ports_list
parallel_config
=
self
.
vllm_config
.
parallel_config
parallel_config
.
_data_parallel_master_port_list
=
get_open_ports_list
(
5
)
parallel_config
.
data_parallel_master_port
=
(
parallel_config
.
_data_parallel_master_port_list
.
pop
()
)
ip
=
parallel_config
.
data_parallel_master_ip
store
=
create_tcp_store
(
ip
,
0
,
is_master
=
True
,
world_size
=-
1
,
wait_for_workers
=
False
,
)
parallel_config
.
_coord_store_port
=
store
.
port
self
.
_coord_store
=
store
return
ip
,
store
.
port
async
def
_scale_up_elastic_ep
(
async
def
_scale_up_elastic_ep
(
self
,
cur_data_parallel_size
:
int
,
new_data_parallel_size
:
int
self
,
cur_data_parallel_size
:
int
,
new_data_parallel_size
:
int
)
->
None
:
)
->
None
:
...
@@ -1555,7 +1527,7 @@ class DPLBAsyncMPClient(DPAsyncMPClient):
...
@@ -1555,7 +1527,7 @@ class DPLBAsyncMPClient(DPAsyncMPClient):
)
)
parallel_config
=
self
.
vllm_config
.
parallel_config
parallel_config
=
self
.
vllm_config
.
parallel_config
allocate_stateless_group_ports
(
parallel_config
,
new_data_parallel_size
)
ip
,
coord_store_port
=
self
.
_setup_elastic_ep_reconfig_bootstrap
(
)
# Phase 1: Send reconfig messages to existing engines
# Phase 1: Send reconfig messages to existing engines
reconfig_futures
=
[]
reconfig_futures
=
[]
...
@@ -1564,13 +1536,10 @@ class DPLBAsyncMPClient(DPAsyncMPClient):
...
@@ -1564,13 +1536,10 @@ class DPLBAsyncMPClient(DPAsyncMPClient):
new_data_parallel_size
=
new_data_parallel_size
,
new_data_parallel_size
=
new_data_parallel_size
,
new_data_parallel_rank
=
ReconfigureRankType
.
KEEP_CURRENT_RANK
,
new_data_parallel_rank
=
ReconfigureRankType
.
KEEP_CURRENT_RANK
,
new_data_parallel_rank_local
=
ReconfigureRankType
.
KEEP_CURRENT_RANK
,
new_data_parallel_rank_local
=
ReconfigureRankType
.
KEEP_CURRENT_RANK
,
new_data_parallel_master_ip
=
parallel_config
.
data_parallel_master_
ip
,
new_data_parallel_master_ip
=
ip
,
new_data_parallel_master_port
=
parallel_config
.
data_parallel_master_port
,
new_data_parallel_master_port
=
parallel_config
.
data_parallel_master_port
,
new_data_parallel_master_port_list
=
parallel_config
.
_data_parallel_master_port_list
,
new_data_parallel_master_port_list
=
parallel_config
.
_data_parallel_master_port_list
,
new_stateless_world_group_port_list
=
parallel_config
.
_stateless_world_group_port_list
,
coord_store_port
=
coord_store_port
,
new_stateless_dp_group_port_list
=
parallel_config
.
_stateless_dp_group_port_list
,
new_stateless_ep_group_port_list
=
parallel_config
.
_stateless_ep_group_port_list
,
new_stateless_eplb_group_port_list
=
parallel_config
.
_stateless_eplb_group_port_list
,
)
)
coro
=
self
.
_call_utility_async
(
coro
=
self
.
_call_utility_async
(
"reinitialize_distributed"
,
reconfig_request
,
engine
=
engine
"reinitialize_distributed"
,
reconfig_request
,
engine
=
engine
...
@@ -1650,7 +1619,7 @@ class DPLBAsyncMPClient(DPAsyncMPClient):
...
@@ -1650,7 +1619,7 @@ class DPLBAsyncMPClient(DPAsyncMPClient):
)
)
parallel_config
=
self
.
vllm_config
.
parallel_config
parallel_config
=
self
.
vllm_config
.
parallel_config
allocate_stateless_group_ports
(
parallel_config
,
new_data_parallel_size
)
ip
,
coord_store_port
=
self
.
_setup_elastic_ep_reconfig_bootstrap
(
)
reconfig_futures
=
[]
reconfig_futures
=
[]
for
cur_dp_rank
,
engine
in
enumerate
(
self
.
core_engines
):
for
cur_dp_rank
,
engine
in
enumerate
(
self
.
core_engines
):
...
@@ -1658,13 +1627,10 @@ class DPLBAsyncMPClient(DPAsyncMPClient):
...
@@ -1658,13 +1627,10 @@ class DPLBAsyncMPClient(DPAsyncMPClient):
new_data_parallel_size
=
new_data_parallel_size
,
new_data_parallel_size
=
new_data_parallel_size
,
new_data_parallel_rank
=
ReconfigureRankType
.
KEEP_CURRENT_RANK
,
new_data_parallel_rank
=
ReconfigureRankType
.
KEEP_CURRENT_RANK
,
new_data_parallel_rank_local
=
ReconfigureRankType
.
KEEP_CURRENT_RANK
,
new_data_parallel_rank_local
=
ReconfigureRankType
.
KEEP_CURRENT_RANK
,
new_data_parallel_master_ip
=
parallel_config
.
data_parallel_master_
ip
,
new_data_parallel_master_ip
=
ip
,
new_data_parallel_master_port
=
parallel_config
.
data_parallel_master_port
,
new_data_parallel_master_port
=
parallel_config
.
data_parallel_master_port
,
new_data_parallel_master_port_list
=
parallel_config
.
_data_parallel_master_port_list
,
new_data_parallel_master_port_list
=
parallel_config
.
_data_parallel_master_port_list
,
new_stateless_world_group_port_list
=
parallel_config
.
_stateless_world_group_port_list
,
coord_store_port
=
coord_store_port
,
new_stateless_dp_group_port_list
=
parallel_config
.
_stateless_dp_group_port_list
,
new_stateless_ep_group_port_list
=
parallel_config
.
_stateless_ep_group_port_list
,
new_stateless_eplb_group_port_list
=
parallel_config
.
_stateless_eplb_group_port_list
,
)
)
if
cur_dp_rank
>=
new_data_parallel_size
:
if
cur_dp_rank
>=
new_data_parallel_size
:
reconfig_request
.
new_data_parallel_rank
=
(
reconfig_request
.
new_data_parallel_rank
=
(
...
...
vllm/v1/engine/utils.py
View file @
de1a86b7
...
@@ -301,7 +301,20 @@ class CoreEngineActorManager:
...
@@ -301,7 +301,20 @@ class CoreEngineActorManager:
else
:
else
:
ray
.
init
()
ray
.
init
()
vllm_config
.
parallel_config
.
allocate_elastic_ep_ports
()
parallel_config
=
vllm_config
.
parallel_config
if
parallel_config
.
enable_elastic_ep
:
from
vllm.distributed.utils
import
create_tcp_store
ip
=
parallel_config
.
data_parallel_master_ip
store
=
create_tcp_store
(
ip
,
0
,
is_master
=
True
,
world_size
=-
1
,
wait_for_workers
=
False
,
)
parallel_config
.
_coord_store_port
=
store
.
port
self
.
_coord_store
=
store
if
placement_groups
is
not
None
:
if
placement_groups
is
not
None
:
assert
local_dp_ranks
is
not
None
,
(
assert
local_dp_ranks
is
not
None
,
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment