Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
de1a86b7
Unverified
Commit
de1a86b7
authored
Mar 18, 2026
by
Itay Alroy
Committed by
GitHub
Mar 18, 2026
Browse files
elastic_ep: Fix stateless group port races (#36330)
Signed-off-by:
Itay Alroy
<
ialroy@nvidia.com
>
parent
99267c23
Changes
12
Show whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
224 additions
and
225 deletions
+224
-225
.buildkite/test_areas/expert_parallelism.yaml
.buildkite/test_areas/expert_parallelism.yaml
+1
-2
vllm/config/parallel.py
vllm/config/parallel.py
+31
-85
vllm/distributed/elastic_ep/elastic_execute.py
vllm/distributed/elastic_ep/elastic_execute.py
+2
-4
vllm/distributed/elastic_ep/elastic_state.py
vllm/distributed/elastic_ep/elastic_state.py
+1
-12
vllm/distributed/elastic_ep/standby_state.py
vllm/distributed/elastic_ep/standby_state.py
+15
-9
vllm/distributed/parallel_state.py
vllm/distributed/parallel_state.py
+20
-21
vllm/distributed/stateless_coordinator.py
vllm/distributed/stateless_coordinator.py
+50
-7
vllm/distributed/utils.py
vllm/distributed/utils.py
+60
-18
vllm/v1/engine/__init__.py
vllm/v1/engine/__init__.py
+1
-4
vllm/v1/engine/core.py
vllm/v1/engine/core.py
+1
-0
vllm/v1/engine/core_client.py
vllm/v1/engine/core_client.py
+28
-62
vllm/v1/engine/utils.py
vllm/v1/engine/utils.py
+14
-1
No files found.
.buildkite/test_areas/expert_parallelism.yaml
View file @
de1a86b7
...
...
@@ -24,8 +24,7 @@ steps:
-
label
:
Elastic EP Scaling Test
timeout_in_minutes
:
20
device
:
b200
optional
:
true
device
:
h100
working_dir
:
"
/vllm-workspace/tests"
num_devices
:
4
source_file_dependencies
:
...
...
vllm/config/parallel.py
View file @
de1a86b7
...
...
@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
os
import
socket
from
collections.abc
import
Callable
from
typing
import
TYPE_CHECKING
,
Any
,
Literal
,
overload
...
...
@@ -266,33 +267,9 @@ class ParallelConfig:
Set to be private as it's not intended to be configured by users.
"""
_stateless_dp_group_port_list
:
list
[
list
[
int
]]
=
Field
(
default_factory
=
list
)
"""List of open ports for stateless DP groups when enable_elastic_ep is True.
Set to be private as it's not intended to be configured by users.
It is a list of list[int], with each inner list contains a set of 3 ports
to be used for setting up the stateless CPU/device/TCPStore groups
in StatelessGroupCoordinator. The number of inner lists is equal to
the number of DP groups,
i.e., len(self._stateless_dp_group_port_list) == world_size_across_dp // dp_size,
and len(self._stateless_dp_group_port_list[i]) == 3 for all i.
"""
_stateless_ep_group_port_list
:
list
[
list
[
int
]]
=
Field
(
default_factory
=
list
)
"""List of open ports for stateless EP groups when enable_elastic_ep is True.
Set to be private as it's not intended to be configured by users.
len(self._stateless_ep_group_port_list) == world_size_across_dp // ep_size,
"""
_stateless_eplb_group_port_list
:
list
[
list
[
int
]]
=
Field
(
default_factory
=
list
)
"""List of open ports for stateless EPLB groups when enable_elastic_ep is True.
Same topology as EP but separate NCCL communicator to avoid deadlocks.
"""
_stateless_world_group_port_list
:
list
[
list
[
int
]]
=
Field
(
default_factory
=
list
)
"""List of open ports for stateless world group when enable_elastic_ep is True.
Set to be private as it's not intended to be configured by users.
len(self._stateless_world_group_port_list) == 1,
"""
_coord_store_port
:
int
=
0
"""Port of the coordination TCPStore. Can be set by the API server; workers
connect as clients to exchange self-picked group ports at runtime."""
decode_context_parallel_size
:
int
=
1
"""Number of decode context parallel groups, because the world size does
...
...
@@ -465,65 +442,32 @@ class ParallelConfig:
return
answer
def
allocate_elastic_e
p_port
s
(
self
)
->
None
:
"""
Allocate all ports for elastic EP (stateless groups + DP master)
.
def
_pick_stateless_d
p_port
(
self
)
->
tuple
[
int
,
socket
.
socket
|
None
]
:
"""
Return ``(port, listen_socket)`` for DP group init
.
Must be called AFTER ray.init() so that ports claimed by Ray's
idle worker pool are al
read
y
i
n use and won't be returned by
g
et
_open_ports_list()
.
With a coord store, rank 0 binds a socket and publishes the port;
others
read i
t. Without one, pops a pre-allocated port and
r
et
urns ``listen_socket=None``
.
"""
if
not
self
.
enable_elastic_ep
:
return
if
self
.
_stateless_world_group_port_list
:
return
num_world_groups
=
1
dp_size
=
self
.
data_parallel_size
ep_size
=
self
.
data_parallel_size
*
self
.
world_size_across_dp
num_dp_groups
=
max
(
1
,
self
.
world_size_across_dp
//
dp_size
)
num_ep_groups
=
max
(
1
,
self
.
world_size_across_dp
//
ep_size
)
num_eplb_groups
=
num_ep_groups
total_stateless_ports
=
(
num_world_groups
+
num_dp_groups
+
num_ep_groups
+
num_eplb_groups
)
*
3
num_dp_master_ports
=
5
all_ports
=
get_open_ports_list
(
total_stateless_ports
+
num_dp_master_ports
)
self
.
_data_parallel_master_port_list
=
all_ports
[
-
num_dp_master_ports
:]
self
.
data_parallel_master_port
=
self
.
_data_parallel_master_port_list
.
pop
()
all_ports
=
all_ports
[:
-
num_dp_master_ports
]
self
.
_stateless_world_group_port_list
=
[
all_ports
[
i
:
i
+
3
]
for
i
in
range
(
0
,
num_world_groups
*
3
,
3
)
]
start_idx
=
num_world_groups
*
3
self
.
_stateless_dp_group_port_list
=
[
all_ports
[
i
:
i
+
3
]
for
i
in
range
(
start_idx
,
start_idx
+
num_dp_groups
*
3
,
3
)
]
start_idx
+=
num_dp_groups
*
3
self
.
_stateless_ep_group_port_list
=
[
all_ports
[
i
:
i
+
3
]
for
i
in
range
(
start_idx
,
start_idx
+
num_ep_groups
*
3
,
3
)
]
start_idx
+=
num_ep_groups
*
3
self
.
_stateless_eplb_group_port_list
=
[
all_ports
[
i
:
i
+
3
]
for
i
in
range
(
start_idx
,
start_idx
+
num_eplb_groups
*
3
,
3
)
]
def
get_next_stateless_world_group_port
(
self
)
->
list
[
int
]:
return
self
.
_stateless_world_group_port_list
.
pop
()
def
get_next_stateless_dp_group_port
(
self
)
->
list
[
int
]:
return
self
.
_stateless_dp_group_port_list
.
pop
()
def
get_next_stateless_ep_group_port
(
self
)
->
list
[
int
]:
return
self
.
_stateless_ep_group_port_list
.
pop
()
def
get_next_stateless_eplb_group_port
(
self
)
->
list
[
int
]:
return
self
.
_stateless_eplb_group_port_list
.
pop
()
if
not
self
.
_coord_store_port
:
return
self
.
get_next_dp_init_port
(),
None
from
vllm.distributed.utils
import
get_cached_tcp_store_client
store
=
get_cached_tcp_store_client
(
self
.
data_parallel_master_ip
,
self
.
_coord_store_port
)
key
=
"dp_master_port"
if
self
.
data_parallel_rank
==
0
:
s
=
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
s
.
bind
((
self
.
data_parallel_master_ip
,
0
))
s
.
listen
()
port
=
s
.
getsockname
()[
1
]
store
.
set
(
key
,
str
(
port
).
encode
())
return
port
,
s
else
:
return
int
(
store
.
get
(
key
).
decode
()),
None
@
overload
def
stateless_init_dp_group
(
...
...
@@ -553,14 +497,16 @@ class ParallelConfig:
last_exc
:
Exception
|
None
=
None
for
_
in
range
(
max_retries
):
try
:
port
,
listen_socket
=
self
.
_pick_stateless_dp_port
()
# use gloo since the engine process might not have cuda device
return
stateless_init_torch_distributed_process_group
(
self
.
data_parallel_master_ip
,
self
.
get_next_dp_init_
port
()
,
port
,
self
.
data_parallel_rank
,
self
.
data_parallel_size
,
backend
=
"gloo"
,
return_store
=
return_store
,
listen_socket
=
listen_socket
,
)
except
DistNetworkError
as
e
:
# We only want to retry when the root cause is EADDRINUSE.
...
...
vllm/distributed/elastic_ep/elastic_execute.py
View file @
de1a86b7
...
...
@@ -162,10 +162,8 @@ class ElasticEPScalingExecutor:
new_dp_size
=
new_dp_size
,
new_world_size_across_dp
=
new_world_size_across_dp
,
master_ip
=
reconfig_request
.
new_data_parallel_master_ip
,
world_group_ports
=
reconfig_request
.
new_stateless_world_group_port_list
,
dp_group_ports
=
reconfig_request
.
new_stateless_dp_group_port_list
,
ep_group_ports
=
reconfig_request
.
new_stateless_ep_group_port_list
,
eplb_group_ports
=
reconfig_request
.
new_stateless_eplb_group_port_list
,
coord_store_port
=
reconfig_request
.
coord_store_port
,
enable_eplb
=
updated_config
.
parallel_config
.
enable_eplb
,
)
self
.
worker
.
model_runner
.
eep_eplb_suppressed
=
True
standby_ep_group
=
get_standby_ep_group
()
...
...
vllm/distributed/elastic_ep/elastic_state.py
View file @
de1a86b7
...
...
@@ -563,15 +563,4 @@ class ElasticEPScalingState:
parallel_config
.
_data_parallel_master_port_list
=
(
reconfig_request
.
new_data_parallel_master_port_list
)
parallel_config
.
_stateless_world_group_port_list
=
(
reconfig_request
.
new_stateless_world_group_port_list
)
parallel_config
.
_stateless_dp_group_port_list
=
(
reconfig_request
.
new_stateless_dp_group_port_list
)
parallel_config
.
_stateless_ep_group_port_list
=
(
reconfig_request
.
new_stateless_ep_group_port_list
)
parallel_config
.
_stateless_eplb_group_port_list
=
(
reconfig_request
.
new_stateless_eplb_group_port_list
)
parallel_config
.
_coord_store_port
=
reconfig_request
.
coord_store_port
vllm/distributed/elastic_ep/standby_state.py
View file @
de1a86b7
...
...
@@ -38,10 +38,8 @@ def create_standby_groups(
new_dp_size
:
int
,
new_world_size_across_dp
:
int
,
master_ip
:
str
,
world_group_ports
:
list
[
list
[
int
]],
dp_group_ports
:
list
[
list
[
int
]],
ep_group_ports
:
list
[
list
[
int
]],
eplb_group_ports
:
list
[
list
[
int
]]
|
None
=
None
,
coord_store_port
:
int
,
enable_eplb
:
bool
=
True
,
backend
:
str
|
None
=
None
,
)
->
None
:
global
\
...
...
@@ -51,19 +49,23 @@ def create_standby_groups(
_STANDBY_EP
,
\
_STANDBY_EPLB
from
vllm.distributed.utils
import
get_cached_tcp_store_client
assert
new_world_size_across_dp
==
torch
.
distributed
.
get_world_size
()
*
new_dp_size
world_group
=
get_world_group
()
assert
isinstance
(
world_group
,
StatelessGroupCoordinator
)
backend
=
backend
or
world_group
.
backend
coord_store
=
get_cached_tcp_store_client
(
master_ip
,
coord_store_port
)
standby_world_ranks
=
[
list
(
range
(
new_world_size_across_dp
))]
_STANDBY_WORLD
=
_init_stateless_group
(
standby_world_ranks
,
"world"
,
world_group_ports
,
master_ip
,
backend
,
use_device_communicator
=
False
,
coord_store
=
coord_store
,
)
_STANDBY_WORLD_NODE_COUNT
=
_node_count
(
_STANDBY_WORLD
.
tcp_store_group
)
...
...
@@ -76,7 +78,7 @@ def create_standby_groups(
standby_dp_ranks
=
all_ranks
.
transpose
(
1
,
3
).
reshape
(
-
1
,
new_dp_size
).
unbind
(
0
)
standby_dp_ranks
=
[
x
.
tolist
()
for
x
in
standby_dp_ranks
]
_STANDBY_DP
=
_init_stateless_group
(
standby_dp_ranks
,
"dp"
,
dp_group_ports
,
master_ip
,
backend
standby_dp_ranks
,
"dp"
,
master_ip
,
backend
,
coord_store
=
coord_store
)
standby_ep_ranks
=
(
...
...
@@ -84,12 +86,16 @@ def create_standby_groups(
)
standby_ep_ranks
=
[
x
.
tolist
()
for
x
in
standby_ep_ranks
]
_STANDBY_EP
=
_init_stateless_group
(
standby_ep_ranks
,
"ep"
,
ep_group_ports
,
master_ip
,
backend
standby_ep_ranks
,
"ep"
,
master_ip
,
backend
,
coord_store
=
coord_store
)
if
e
plb_group_ports
is
not
None
:
if
e
nable_eplb
:
_STANDBY_EPLB
=
_init_stateless_group
(
standby_ep_ranks
,
"eplb"
,
eplb_group_ports
,
master_ip
,
backend
standby_ep_ranks
,
"eplb"
,
master_ip
,
backend
,
coord_store
=
coord_store
,
)
...
...
vllm/distributed/parallel_state.py
View file @
de1a86b7
...
...
@@ -40,13 +40,16 @@ import torch
import
torch.distributed
import
torch.distributed._functional_collectives
as
funcol
import
torch.distributed._symmetric_memory
from
torch.distributed
import
Backend
,
ProcessGroup
from
torch.distributed
import
Backend
,
ProcessGroup
,
Store
import
vllm.envs
as
envs
from
vllm.distributed.device_communicators.base_device_communicator
import
(
DeviceCommunicatorBase
,
)
from
vllm.distributed.utils
import
StatelessProcessGroup
from
vllm.distributed.utils
import
(
StatelessProcessGroup
,
get_cached_tcp_store_client
,
)
from
vllm.logger
import
init_logger
from
vllm.utils.import_utils
import
resolve_obj_by_qualname
from
vllm.utils.network_utils
import
get_distributed_init_method
...
...
@@ -1164,9 +1167,9 @@ def init_model_parallel_group(
def
_init_stateless_group
(
group_ranks
:
list
[
list
[
int
]],
group_name
:
str
,
group_ports
:
list
[
list
[
int
]],
host
:
str
,
backend
:
str
,
coord_store
:
Store
,
use_device_communicator
:
bool
=
True
,
)
->
"StatelessGroupCoordinator"
:
"""Create a StatelessGroupCoordinator with the given parameters."""
...
...
@@ -1180,7 +1183,7 @@ def _init_stateless_group(
use_device_communicator
=
use_device_communicator
,
group_name
=
group_name
,
host
=
host
,
group_ports
=
group_ports
,
coord_store
=
coord_store
,
global_rank
=
world
.
rank
,
global_world_size
=
world
.
world_size
,
)
...
...
@@ -1321,7 +1324,9 @@ def _init_elastic_ep_world(
group_ranks
=
[
all_ranks
[
i
:
i
+
1
]
for
i
in
range
(
global_world_size
)]
if
global_rank
in
all_ranks
:
group_ranks
=
[
all_ranks
]
group_ports
=
[
parallel_config
.
get_next_stateless_world_group_port
()]
coord_store
=
get_cached_tcp_store_client
(
parallel_config
.
data_parallel_master_ip
,
parallel_config
.
_coord_store_port
)
world
=
StatelessGroupCoordinator
(
group_ranks
=
group_ranks
,
local_rank
=
local_rank
,
...
...
@@ -1329,7 +1334,7 @@ def _init_elastic_ep_world(
use_device_communicator
=
False
,
group_name
=
"world"
,
host
=
parallel_config
.
data_parallel_master_ip
,
group_ports
=
group_ports
,
coord_store
=
coord_store
,
global_rank
=
global_rank
,
global_world_size
=
global_world_size
,
)
...
...
@@ -1513,7 +1518,13 @@ def initialize_model_parallel(
config
=
get_current_vllm_config
()
data_parallel_size
=
config
.
parallel_config
.
data_parallel_size
enable_elastic_ep
=
config
.
parallel_config
.
enable_elastic_ep
parallel_config
=
config
.
parallel_config
coord_store
:
Store
|
None
=
None
if
enable_elastic_ep
:
coord_store
=
get_cached_tcp_store_client
(
parallel_config
.
data_parallel_master_ip
,
parallel_config
.
_coord_store_port
,
)
# Use stateless world group for global information
world_size
=
get_world_group
().
world_size
rank
=
get_world_group
().
rank
...
...
@@ -1633,16 +1644,12 @@ def initialize_model_parallel(
group_ranks
=
all_ranks
.
transpose
(
1
,
4
).
reshape
(
-
1
,
data_parallel_size
).
unbind
(
0
)
group_ranks
=
[
x
.
tolist
()
for
x
in
group_ranks
]
if
enable_elastic_ep
:
parallel_config
=
config
.
parallel_config
dp_ports
=
[
parallel_config
.
get_next_stateless_dp_group_port
()
for
_
in
group_ranks
]
_DP
=
_init_stateless_group
(
group_ranks
,
"dp"
,
dp_ports
,
parallel_config
.
data_parallel_master_ip
,
backend
,
coord_store
=
coord_store
,
)
else
:
_DP
=
init_model_parallel_group
(
...
...
@@ -1665,16 +1672,12 @@ def initialize_model_parallel(
)
group_ranks
=
[
x
.
tolist
()
for
x
in
group_ranks
]
if
enable_elastic_ep
:
parallel_config
=
config
.
parallel_config
ep_ports
=
[
parallel_config
.
get_next_stateless_ep_group_port
()
for
_
in
group_ranks
]
_EP
=
_init_stateless_group
(
group_ranks
,
"ep"
,
ep_ports
,
parallel_config
.
data_parallel_master_ip
,
backend
,
coord_store
=
coord_store
,
)
else
:
_EP
=
init_model_parallel_group
(
...
...
@@ -1693,16 +1696,12 @@ def initialize_model_parallel(
and
config
.
parallel_config
.
enable_eplb
):
if
enable_elastic_ep
:
eplb_ports
=
[
parallel_config
.
get_next_stateless_eplb_group_port
()
for
_
in
group_ranks
]
_EPLB
=
_init_stateless_group
(
group_ranks
,
"eplb"
,
eplb_ports
,
parallel_config
.
data_parallel_master_ip
,
backend
,
coord_store
=
coord_store
,
)
else
:
_EPLB
=
init_model_parallel_group
(
...
...
vllm/distributed/stateless_coordinator.py
View file @
de1a86b7
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
socket
import
struct
from
typing
import
Any
,
Optional
import
torch
from
torch.distributed
import
Backend
,
ProcessGroup
from
torch.distributed
import
Backend
,
ProcessGroup
,
Store
from
vllm.distributed.device_communicators.cuda_communicator
import
CudaCommunicator
from
vllm.distributed.parallel_state
import
(
...
...
@@ -23,6 +25,38 @@ from vllm.utils.import_utils import resolve_obj_by_qualname
logger
=
init_logger
(
__name__
)
_PORTS_FMT
=
"!3I"
def
_allocate_group_ports
(
key
:
str
,
host
:
str
,
coord_store
:
Store
,
)
->
tuple
[
list
[
int
],
list
[
socket
.
socket
]]:
"""Bind 3 sockets and publish the ports to *coord_store*.
Called by rank 0 only. Returns ``(ports, sockets)`` with the
sockets still open.
"""
socks
:
list
[
socket
.
socket
]
=
[]
ports
:
list
[
int
]
=
[]
for
_
in
range
(
3
):
s
=
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
s
.
bind
((
host
,
0
))
s
.
listen
()
socks
.
append
(
s
)
ports
.
append
(
s
.
getsockname
()[
1
])
coord_store
.
set
(
key
,
struct
.
pack
(
_PORTS_FMT
,
*
ports
))
return
ports
,
socks
def
_fetch_group_ports
(
key
:
str
,
coord_store
:
Store
)
->
list
[
int
]:
"""Read 3 ports published by rank 0 from *coord_store*.
Blocks until the key is available.
"""
return
list
(
struct
.
unpack
(
_PORTS_FMT
,
coord_store
.
get
(
key
)))
class
StatelessGroupCoordinator
(
GroupCoordinator
):
"""
...
...
@@ -39,10 +73,10 @@ class StatelessGroupCoordinator(GroupCoordinator):
local_rank
:
int
,
torch_distributed_backend
:
str
|
Backend
,
use_device_communicator
:
bool
,
coord_store
:
Store
,
use_message_queue_broadcaster
:
bool
=
False
,
group_name
:
str
|
None
=
None
,
host
:
str
=
"127.0.0.1"
,
group_ports
:
list
[
list
[
int
]]
|
None
=
None
,
global_rank
:
int
=
0
,
global_world_size
:
int
=
1
,
):
...
...
@@ -61,17 +95,23 @@ class StatelessGroupCoordinator(GroupCoordinator):
backend
=
str
(
torch_distributed_backend
)
self
.
backend
=
backend
assert
group_ports
is
not
None
,
"group_ports is not provided"
for
idx
,
ranks
in
enumerate
(
group_ranks
):
if
self
.
rank
in
ranks
:
self
.
ranks
=
ranks
self
.
world_size
=
len
(
ranks
)
self
.
rank_in_group
=
ranks
.
index
(
self
.
rank
)
ports
=
group_ports
[
idx
]
device_port
=
ports
[
0
]
cpu_port
=
ports
[
1
]
tcp_store_port
=
ports
[
2
]
key
=
f
"
{
group_name
}
_
{
idx
}
"
if
self
.
rank_in_group
==
0
:
ports
,
socks
=
_allocate_group_ports
(
key
,
host
,
coord_store
,
)
else
:
ports
=
_fetch_group_ports
(
key
,
coord_store
)
socks
=
[]
device_port
,
cpu_port
,
tcp_store_port
=
ports
device_group
=
stateless_init_torch_distributed_process_group
(
host
=
host
,
...
...
@@ -80,6 +120,7 @@ class StatelessGroupCoordinator(GroupCoordinator):
world_size
=
self
.
world_size
,
backend
=
backend
,
group_name
=
f
"
{
self
.
unique_name
}
_device"
,
listen_socket
=
socks
[
0
]
if
socks
else
None
,
)
cpu_group
=
stateless_init_torch_distributed_process_group
(
host
=
host
,
...
...
@@ -88,12 +129,14 @@ class StatelessGroupCoordinator(GroupCoordinator):
world_size
=
self
.
world_size
,
backend
=
"gloo"
,
group_name
=
f
"
{
self
.
unique_name
}
_cpu"
,
listen_socket
=
socks
[
1
]
if
socks
else
None
,
)
tcp_store_group
=
StatelessProcessGroup
.
create
(
host
=
host
,
port
=
tcp_store_port
,
rank
=
self
.
rank_in_group
,
world_size
=
self
.
world_size
,
listen_socket
=
socks
[
2
]
if
socks
else
None
,
)
self_device_group
=
device_group
...
...
vllm/distributed/utils.py
View file @
de1a86b7
...
...
@@ -6,6 +6,7 @@
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import
dataclasses
import
functools
import
os
import
pickle
import
socket
...
...
@@ -139,6 +140,29 @@ def get_pp_indices(
return
(
start_layer
,
end_layer
)
def
create_tcp_store
(
host
:
str
,
port
:
int
,
listen_socket
:
socket
.
socket
|
None
=
None
,
**
kwargs
:
Any
,
)
->
TCPStore
:
"""Create a TCPStore, optionally taking ownership of ``listen_socket``."""
if
listen_socket
is
None
:
return
TCPStore
(
host_name
=
host
,
port
=
port
,
**
kwargs
)
listen_fd
=
listen_socket
.
detach
()
try
:
return
TCPStore
(
host_name
=
host
,
port
=
port
,
master_listen_fd
=
listen_fd
,
**
kwargs
,
)
except
Exception
:
socket
.
close
(
listen_fd
)
raise
@
dataclasses
.
dataclass
class
StatelessProcessGroup
:
"""A dataclass to hold a metadata store, and the rank, world_size of the
...
...
@@ -150,9 +174,6 @@ class StatelessProcessGroup:
world_size
:
int
store
:
torch
.
_C
.
_distributed_c10d
.
Store
# stores a reference to the socket so that the file descriptor stays alive
socket
:
socket
.
socket
|
None
data_expiration_seconds
:
int
=
3600
# 1 hour
# dst rank -> counter
...
...
@@ -419,6 +440,7 @@ class StatelessProcessGroup:
world_size
:
int
,
data_expiration_seconds
:
int
=
3600
,
store_timeout
:
int
=
300
,
listen_socket
:
socket
.
socket
|
None
=
None
,
)
->
"StatelessProcessGroup"
:
"""A replacement for `torch.distributed.init_process_group` that does not
pollute the global state.
...
...
@@ -436,36 +458,39 @@ class StatelessProcessGroup:
C, and D can call `StatelessProcessGroup.create` to form another group.
"""
# noqa
launch_server
=
rank
==
0
if
launch_server
:
# listen on the specified interface (instead of 0.0.0.0)
if
launch_server
and
listen_socket
is
None
:
listen_socket
=
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
listen_socket
.
setsockopt
(
socket
.
SOL_SOCKET
,
socket
.
SO_REUSEADDR
,
1
)
listen_socket
.
bind
((
host
,
port
))
listen_socket
.
listen
()
listen_fd
=
listen_socket
.
fileno
()
else
:
listen_socket
=
None
listen_fd
=
None
store
=
TCPStore
(
host_name
=
host
,
port
=
port
,
store
=
create_tcp_store
(
host
,
port
,
listen_socket
=
listen_socket
,
world_size
=
world_size
,
is_master
=
launch_server
,
timeout
=
timedelta
(
seconds
=
store_timeout
),
use_libuv
=
False
,
# for now: github.com/pytorch/pytorch/pull/150215
master_listen_fd
=
listen_fd
,
)
return
StatelessProcessGroup
(
rank
=
rank
,
world_size
=
world_size
,
store
=
store
,
socket
=
listen_socket
,
data_expiration_seconds
=
data_expiration_seconds
,
)
@
functools
.
lru_cache
(
maxsize
=
1
)
def
get_cached_tcp_store_client
(
host
:
str
,
port
:
int
)
->
TCPStore
:
"""Return a cached TCPStore client.
Cached so that every call with the same ``(host, port)`` reuses the
same connection. A new ``(host, port)`` evicts the old entry.
"""
return
TCPStore
(
host
,
port
,
is_master
=
False
,
wait_for_workers
=
False
)
def
init_gloo_process_group
(
prefix_store
:
PrefixStore
,
group_rank
:
int
,
...
...
@@ -504,6 +529,7 @@ def stateless_init_torch_distributed_process_group(
backend
:
str
,
group_name
:
str
|
None
=
None
,
return_store
:
bool
=
False
,
listen_socket
:
socket
.
socket
|
None
=
None
,
)
->
ProcessGroup
|
tuple
[
ProcessGroup
,
Store
]:
"""
A replacement for `torch.distributed.init_process_group` that does not
...
...
@@ -535,11 +561,27 @@ def stateless_init_torch_distributed_process_group(
are the same as process 1 and 5, the main communication channel is
always formed with process 1, 2, ..., 8, and the additional communication
channel is formed with process 9 and 10.
When *listen_socket* is provided, the rendezvous step
is skipped and a ``TCPStore`` server is created directly using the
pre-bound socket. This is useful for eliminating TOCTOU races
between port allocation and binding.
"""
init_method
=
get_tcp_uri
(
host
,
port
)
backend
=
Backend
(
backend
)
# it is basically string
timeout
=
_get_default_timeout
(
backend
)
if
listen_socket
is
not
None
:
store
=
create_tcp_store
(
host
,
port
,
listen_socket
=
listen_socket
,
world_size
=
world_size
,
is_master
=
True
,
timeout
=
timeout
,
multi_tenant
=
True
,
)
else
:
store
,
rank
,
world_size
=
next
(
rendezvous
(
init_method
,
rank
,
world_size
,
timeout
=
timeout
)
)
...
...
vllm/v1/engine/__init__.py
View file @
de1a86b7
...
...
@@ -237,10 +237,7 @@ class ReconfigureDistributedRequest(msgspec.Struct):
new_data_parallel_master_ip
:
str
new_data_parallel_master_port
:
int
new_data_parallel_master_port_list
:
list
[
int
]
new_stateless_world_group_port_list
:
list
[
list
[
int
]]
new_stateless_dp_group_port_list
:
list
[
list
[
int
]]
new_stateless_ep_group_port_list
:
list
[
list
[
int
]]
new_stateless_eplb_group_port_list
:
list
[
list
[
int
]]
coord_store_port
:
int
class
ReconfigureRankType
(
enum
.
IntEnum
):
...
...
vllm/v1/engine/core.py
View file @
de1a86b7
...
...
@@ -1767,6 +1767,7 @@ class DPEngineCoreProc(EngineCoreProc):
new_parallel_config
.
_data_parallel_master_port_list
=
(
reconfig_request
.
new_data_parallel_master_port_list
)
new_parallel_config
.
_coord_store_port
=
reconfig_request
.
coord_store_port
is_scale_down
=
reconfig_request
.
new_data_parallel_size
<
old_dp_size
is_shutdown
=
(
...
...
vllm/v1/engine/core_client.py
View file @
de1a86b7
...
...
@@ -455,56 +455,6 @@ class ElasticScalingCache:
pending_notifications
:
dict
[
EEPNotificationType
,
set
[
int
]]
def
allocate_stateless_group_ports
(
parallel_config
,
new_data_parallel_size
:
int
):
"""
Allocate stateless group ports for elastic EP.
"""
from
vllm.utils.network_utils
import
get_open_ports_list
assert
parallel_config
.
enable_elastic_ep
,
"Elastic EP must be enabled"
world_size
=
parallel_config
.
world_size
new_world_size_across_dp
=
world_size
*
new_data_parallel_size
num_world_groups
=
1
num_dp_groups
=
max
(
1
,
new_world_size_across_dp
//
new_data_parallel_size
)
num_ep_groups
=
max
(
1
,
new_world_size_across_dp
//
(
new_data_parallel_size
*
parallel_config
.
tensor_parallel_size
),
)
num_eplb_groups
=
num_ep_groups
total_ports_needed
=
(
num_world_groups
+
num_dp_groups
+
num_ep_groups
+
num_eplb_groups
)
*
3
+
5
all_ports
=
get_open_ports_list
(
total_ports_needed
)
new_data_parallel_master_port_list
=
all_ports
[
-
5
:]
all_ports
=
all_ports
[:
-
5
]
new_stateless_world_group_port_list
=
[
all_ports
[
i
:
i
+
3
]
for
i
in
range
(
0
,
num_world_groups
*
3
,
3
)
]
start_idx
=
num_world_groups
*
3
new_stateless_dp_group_port_list
=
[
all_ports
[
i
:
i
+
3
]
for
i
in
range
(
start_idx
,
start_idx
+
num_dp_groups
*
3
,
3
)
]
start_idx
+=
num_dp_groups
*
3
new_stateless_ep_group_port_list
=
[
all_ports
[
i
:
i
+
3
]
for
i
in
range
(
start_idx
,
start_idx
+
num_ep_groups
*
3
,
3
)
]
start_idx
+=
num_ep_groups
*
3
new_stateless_eplb_group_port_list
=
[
all_ports
[
i
:
i
+
3
]
for
i
in
range
(
start_idx
,
start_idx
+
num_eplb_groups
*
3
,
3
)
]
parallel_config
.
_stateless_world_group_port_list
=
(
new_stateless_world_group_port_list
)
parallel_config
.
_stateless_dp_group_port_list
=
new_stateless_dp_group_port_list
parallel_config
.
_stateless_ep_group_port_list
=
new_stateless_ep_group_port_list
parallel_config
.
_stateless_eplb_group_port_list
=
new_stateless_eplb_group_port_list
parallel_config
.
data_parallel_master_port
=
new_data_parallel_master_port_list
.
pop
()
parallel_config
.
_data_parallel_master_port_list
=
new_data_parallel_master_port_list
class
MPClient
(
EngineCoreClient
):
"""
MPClient: base client for multi-proc EngineCore.
...
...
@@ -1541,6 +1491,28 @@ class DPLBAsyncMPClient(DPAsyncMPClient):
self
.
_ensure_output_queue_task
()
await
future
def
_setup_elastic_ep_reconfig_bootstrap
(
self
)
->
tuple
[
str
,
int
]:
from
vllm.distributed.utils
import
create_tcp_store
from
vllm.utils.network_utils
import
get_open_ports_list
parallel_config
=
self
.
vllm_config
.
parallel_config
parallel_config
.
_data_parallel_master_port_list
=
get_open_ports_list
(
5
)
parallel_config
.
data_parallel_master_port
=
(
parallel_config
.
_data_parallel_master_port_list
.
pop
()
)
ip
=
parallel_config
.
data_parallel_master_ip
store
=
create_tcp_store
(
ip
,
0
,
is_master
=
True
,
world_size
=-
1
,
wait_for_workers
=
False
,
)
parallel_config
.
_coord_store_port
=
store
.
port
self
.
_coord_store
=
store
return
ip
,
store
.
port
async
def
_scale_up_elastic_ep
(
self
,
cur_data_parallel_size
:
int
,
new_data_parallel_size
:
int
)
->
None
:
...
...
@@ -1555,7 +1527,7 @@ class DPLBAsyncMPClient(DPAsyncMPClient):
)
parallel_config
=
self
.
vllm_config
.
parallel_config
allocate_stateless_group_ports
(
parallel_config
,
new_data_parallel_size
)
ip
,
coord_store_port
=
self
.
_setup_elastic_ep_reconfig_bootstrap
(
)
# Phase 1: Send reconfig messages to existing engines
reconfig_futures
=
[]
...
...
@@ -1564,13 +1536,10 @@ class DPLBAsyncMPClient(DPAsyncMPClient):
new_data_parallel_size
=
new_data_parallel_size
,
new_data_parallel_rank
=
ReconfigureRankType
.
KEEP_CURRENT_RANK
,
new_data_parallel_rank_local
=
ReconfigureRankType
.
KEEP_CURRENT_RANK
,
new_data_parallel_master_ip
=
parallel_config
.
data_parallel_master_
ip
,
new_data_parallel_master_ip
=
ip
,
new_data_parallel_master_port
=
parallel_config
.
data_parallel_master_port
,
new_data_parallel_master_port_list
=
parallel_config
.
_data_parallel_master_port_list
,
new_stateless_world_group_port_list
=
parallel_config
.
_stateless_world_group_port_list
,
new_stateless_dp_group_port_list
=
parallel_config
.
_stateless_dp_group_port_list
,
new_stateless_ep_group_port_list
=
parallel_config
.
_stateless_ep_group_port_list
,
new_stateless_eplb_group_port_list
=
parallel_config
.
_stateless_eplb_group_port_list
,
coord_store_port
=
coord_store_port
,
)
coro
=
self
.
_call_utility_async
(
"reinitialize_distributed"
,
reconfig_request
,
engine
=
engine
...
...
@@ -1650,7 +1619,7 @@ class DPLBAsyncMPClient(DPAsyncMPClient):
)
parallel_config
=
self
.
vllm_config
.
parallel_config
allocate_stateless_group_ports
(
parallel_config
,
new_data_parallel_size
)
ip
,
coord_store_port
=
self
.
_setup_elastic_ep_reconfig_bootstrap
(
)
reconfig_futures
=
[]
for
cur_dp_rank
,
engine
in
enumerate
(
self
.
core_engines
):
...
...
@@ -1658,13 +1627,10 @@ class DPLBAsyncMPClient(DPAsyncMPClient):
new_data_parallel_size
=
new_data_parallel_size
,
new_data_parallel_rank
=
ReconfigureRankType
.
KEEP_CURRENT_RANK
,
new_data_parallel_rank_local
=
ReconfigureRankType
.
KEEP_CURRENT_RANK
,
new_data_parallel_master_ip
=
parallel_config
.
data_parallel_master_
ip
,
new_data_parallel_master_ip
=
ip
,
new_data_parallel_master_port
=
parallel_config
.
data_parallel_master_port
,
new_data_parallel_master_port_list
=
parallel_config
.
_data_parallel_master_port_list
,
new_stateless_world_group_port_list
=
parallel_config
.
_stateless_world_group_port_list
,
new_stateless_dp_group_port_list
=
parallel_config
.
_stateless_dp_group_port_list
,
new_stateless_ep_group_port_list
=
parallel_config
.
_stateless_ep_group_port_list
,
new_stateless_eplb_group_port_list
=
parallel_config
.
_stateless_eplb_group_port_list
,
coord_store_port
=
coord_store_port
,
)
if
cur_dp_rank
>=
new_data_parallel_size
:
reconfig_request
.
new_data_parallel_rank
=
(
...
...
vllm/v1/engine/utils.py
View file @
de1a86b7
...
...
@@ -301,7 +301,20 @@ class CoreEngineActorManager:
else
:
ray
.
init
()
vllm_config
.
parallel_config
.
allocate_elastic_ep_ports
()
parallel_config
=
vllm_config
.
parallel_config
if
parallel_config
.
enable_elastic_ep
:
from
vllm.distributed.utils
import
create_tcp_store
ip
=
parallel_config
.
data_parallel_master_ip
store
=
create_tcp_store
(
ip
,
0
,
is_master
=
True
,
world_size
=-
1
,
wait_for_workers
=
False
,
)
parallel_config
.
_coord_store_port
=
store
.
port
self
.
_coord_store
=
store
if
placement_groups
is
not
None
:
assert
local_dp_ranks
is
not
None
,
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment