Unverified Commit de1a86b7 authored by Itay Alroy's avatar Itay Alroy Committed by GitHub
Browse files

elastic_ep: Fix stateless group port races (#36330)


Signed-off-by: default avatarItay Alroy <ialroy@nvidia.com>
parent 99267c23
...@@ -24,8 +24,7 @@ steps: ...@@ -24,8 +24,7 @@ steps:
- label: Elastic EP Scaling Test - label: Elastic EP Scaling Test
timeout_in_minutes: 20 timeout_in_minutes: 20
device: b200 device: h100
optional: true
working_dir: "/vllm-workspace/tests" working_dir: "/vllm-workspace/tests"
num_devices: 4 num_devices: 4
source_file_dependencies: source_file_dependencies:
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os import os
import socket
from collections.abc import Callable from collections.abc import Callable
from typing import TYPE_CHECKING, Any, Literal, overload from typing import TYPE_CHECKING, Any, Literal, overload
...@@ -266,33 +267,9 @@ class ParallelConfig: ...@@ -266,33 +267,9 @@ class ParallelConfig:
Set to be private as it's not intended to be configured by users. Set to be private as it's not intended to be configured by users.
""" """
_stateless_dp_group_port_list: list[list[int]] = Field(default_factory=list) _coord_store_port: int = 0
"""List of open ports for stateless DP groups when enable_elastic_ep is True. """Port of the coordination TCPStore. Can be set by the API server; workers
Set to be private as it's not intended to be configured by users. connect as clients to exchange self-picked group ports at runtime."""
It is a list of list[int], with each inner list contains a set of 3 ports
to be used for setting up the stateless CPU/device/TCPStore groups
in StatelessGroupCoordinator. The number of inner lists is equal to
the number of DP groups,
i.e., len(self._stateless_dp_group_port_list) == world_size_across_dp // dp_size,
and len(self._stateless_dp_group_port_list[i]) == 3 for all i.
"""
_stateless_ep_group_port_list: list[list[int]] = Field(default_factory=list)
"""List of open ports for stateless EP groups when enable_elastic_ep is True.
Set to be private as it's not intended to be configured by users.
len(self._stateless_ep_group_port_list) == world_size_across_dp // ep_size,
"""
_stateless_eplb_group_port_list: list[list[int]] = Field(default_factory=list)
"""List of open ports for stateless EPLB groups when enable_elastic_ep is True.
Same topology as EP but separate NCCL communicator to avoid deadlocks.
"""
_stateless_world_group_port_list: list[list[int]] = Field(default_factory=list)
"""List of open ports for stateless world group when enable_elastic_ep is True.
Set to be private as it's not intended to be configured by users.
len(self._stateless_world_group_port_list) == 1,
"""
decode_context_parallel_size: int = 1 decode_context_parallel_size: int = 1
"""Number of decode context parallel groups, because the world size does """Number of decode context parallel groups, because the world size does
...@@ -465,65 +442,32 @@ class ParallelConfig: ...@@ -465,65 +442,32 @@ class ParallelConfig:
return answer return answer
def allocate_elastic_ep_ports(self) -> None: def _pick_stateless_dp_port(self) -> tuple[int, socket.socket | None]:
"""Allocate all ports for elastic EP (stateless groups + DP master). """Return ``(port, listen_socket)`` for DP group init.
Must be called AFTER ray.init() so that ports claimed by Ray's With a coord store, rank 0 binds a socket and publishes the port;
idle worker pool are already in use and won't be returned by others read it. Without one, pops a pre-allocated port and
get_open_ports_list(). returns ``listen_socket=None``.
""" """
if not self.enable_elastic_ep: if not self._coord_store_port:
return return self.get_next_dp_init_port(), None
if self._stateless_world_group_port_list:
return from vllm.distributed.utils import get_cached_tcp_store_client
num_world_groups = 1 store = get_cached_tcp_store_client(
dp_size = self.data_parallel_size self.data_parallel_master_ip, self._coord_store_port
ep_size = self.data_parallel_size * self.world_size_across_dp )
num_dp_groups = max(1, self.world_size_across_dp // dp_size)
num_ep_groups = max(1, self.world_size_across_dp // ep_size) key = "dp_master_port"
num_eplb_groups = num_ep_groups if self.data_parallel_rank == 0:
total_stateless_ports = ( s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
num_world_groups + num_dp_groups + num_ep_groups + num_eplb_groups s.bind((self.data_parallel_master_ip, 0))
) * 3 s.listen()
num_dp_master_ports = 5 port = s.getsockname()[1]
store.set(key, str(port).encode())
all_ports = get_open_ports_list(total_stateless_ports + num_dp_master_ports) return port, s
else:
self._data_parallel_master_port_list = all_ports[-num_dp_master_ports:] return int(store.get(key).decode()), None
self.data_parallel_master_port = self._data_parallel_master_port_list.pop()
all_ports = all_ports[:-num_dp_master_ports]
self._stateless_world_group_port_list = [
all_ports[i : i + 3] for i in range(0, num_world_groups * 3, 3)
]
start_idx = num_world_groups * 3
self._stateless_dp_group_port_list = [
all_ports[i : i + 3]
for i in range(start_idx, start_idx + num_dp_groups * 3, 3)
]
start_idx += num_dp_groups * 3
self._stateless_ep_group_port_list = [
all_ports[i : i + 3]
for i in range(start_idx, start_idx + num_ep_groups * 3, 3)
]
start_idx += num_ep_groups * 3
self._stateless_eplb_group_port_list = [
all_ports[i : i + 3]
for i in range(start_idx, start_idx + num_eplb_groups * 3, 3)
]
def get_next_stateless_world_group_port(self) -> list[int]:
return self._stateless_world_group_port_list.pop()
def get_next_stateless_dp_group_port(self) -> list[int]:
return self._stateless_dp_group_port_list.pop()
def get_next_stateless_ep_group_port(self) -> list[int]:
return self._stateless_ep_group_port_list.pop()
def get_next_stateless_eplb_group_port(self) -> list[int]:
return self._stateless_eplb_group_port_list.pop()
@overload @overload
def stateless_init_dp_group( def stateless_init_dp_group(
...@@ -553,14 +497,16 @@ class ParallelConfig: ...@@ -553,14 +497,16 @@ class ParallelConfig:
last_exc: Exception | None = None last_exc: Exception | None = None
for _ in range(max_retries): for _ in range(max_retries):
try: try:
port, listen_socket = self._pick_stateless_dp_port()
# use gloo since the engine process might not have cuda device # use gloo since the engine process might not have cuda device
return stateless_init_torch_distributed_process_group( return stateless_init_torch_distributed_process_group(
self.data_parallel_master_ip, self.data_parallel_master_ip,
self.get_next_dp_init_port(), port,
self.data_parallel_rank, self.data_parallel_rank,
self.data_parallel_size, self.data_parallel_size,
backend="gloo", backend="gloo",
return_store=return_store, return_store=return_store,
listen_socket=listen_socket,
) )
except DistNetworkError as e: except DistNetworkError as e:
# We only want to retry when the root cause is EADDRINUSE. # We only want to retry when the root cause is EADDRINUSE.
......
...@@ -162,10 +162,8 @@ class ElasticEPScalingExecutor: ...@@ -162,10 +162,8 @@ class ElasticEPScalingExecutor:
new_dp_size=new_dp_size, new_dp_size=new_dp_size,
new_world_size_across_dp=new_world_size_across_dp, new_world_size_across_dp=new_world_size_across_dp,
master_ip=reconfig_request.new_data_parallel_master_ip, master_ip=reconfig_request.new_data_parallel_master_ip,
world_group_ports=reconfig_request.new_stateless_world_group_port_list, coord_store_port=reconfig_request.coord_store_port,
dp_group_ports=reconfig_request.new_stateless_dp_group_port_list, enable_eplb=updated_config.parallel_config.enable_eplb,
ep_group_ports=reconfig_request.new_stateless_ep_group_port_list,
eplb_group_ports=reconfig_request.new_stateless_eplb_group_port_list,
) )
self.worker.model_runner.eep_eplb_suppressed = True self.worker.model_runner.eep_eplb_suppressed = True
standby_ep_group = get_standby_ep_group() standby_ep_group = get_standby_ep_group()
......
...@@ -563,15 +563,4 @@ class ElasticEPScalingState: ...@@ -563,15 +563,4 @@ class ElasticEPScalingState:
parallel_config._data_parallel_master_port_list = ( parallel_config._data_parallel_master_port_list = (
reconfig_request.new_data_parallel_master_port_list reconfig_request.new_data_parallel_master_port_list
) )
parallel_config._stateless_world_group_port_list = ( parallel_config._coord_store_port = reconfig_request.coord_store_port
reconfig_request.new_stateless_world_group_port_list
)
parallel_config._stateless_dp_group_port_list = (
reconfig_request.new_stateless_dp_group_port_list
)
parallel_config._stateless_ep_group_port_list = (
reconfig_request.new_stateless_ep_group_port_list
)
parallel_config._stateless_eplb_group_port_list = (
reconfig_request.new_stateless_eplb_group_port_list
)
...@@ -38,10 +38,8 @@ def create_standby_groups( ...@@ -38,10 +38,8 @@ def create_standby_groups(
new_dp_size: int, new_dp_size: int,
new_world_size_across_dp: int, new_world_size_across_dp: int,
master_ip: str, master_ip: str,
world_group_ports: list[list[int]], coord_store_port: int,
dp_group_ports: list[list[int]], enable_eplb: bool = True,
ep_group_ports: list[list[int]],
eplb_group_ports: list[list[int]] | None = None,
backend: str | None = None, backend: str | None = None,
) -> None: ) -> None:
global \ global \
...@@ -51,19 +49,23 @@ def create_standby_groups( ...@@ -51,19 +49,23 @@ def create_standby_groups(
_STANDBY_EP, \ _STANDBY_EP, \
_STANDBY_EPLB _STANDBY_EPLB
from vllm.distributed.utils import get_cached_tcp_store_client
assert new_world_size_across_dp == torch.distributed.get_world_size() * new_dp_size assert new_world_size_across_dp == torch.distributed.get_world_size() * new_dp_size
world_group = get_world_group() world_group = get_world_group()
assert isinstance(world_group, StatelessGroupCoordinator) assert isinstance(world_group, StatelessGroupCoordinator)
backend = backend or world_group.backend backend = backend or world_group.backend
coord_store = get_cached_tcp_store_client(master_ip, coord_store_port)
standby_world_ranks = [list(range(new_world_size_across_dp))] standby_world_ranks = [list(range(new_world_size_across_dp))]
_STANDBY_WORLD = _init_stateless_group( _STANDBY_WORLD = _init_stateless_group(
standby_world_ranks, standby_world_ranks,
"world", "world",
world_group_ports,
master_ip, master_ip,
backend, backend,
use_device_communicator=False, use_device_communicator=False,
coord_store=coord_store,
) )
_STANDBY_WORLD_NODE_COUNT = _node_count(_STANDBY_WORLD.tcp_store_group) _STANDBY_WORLD_NODE_COUNT = _node_count(_STANDBY_WORLD.tcp_store_group)
...@@ -76,7 +78,7 @@ def create_standby_groups( ...@@ -76,7 +78,7 @@ def create_standby_groups(
standby_dp_ranks = all_ranks.transpose(1, 3).reshape(-1, new_dp_size).unbind(0) standby_dp_ranks = all_ranks.transpose(1, 3).reshape(-1, new_dp_size).unbind(0)
standby_dp_ranks = [x.tolist() for x in standby_dp_ranks] standby_dp_ranks = [x.tolist() for x in standby_dp_ranks]
_STANDBY_DP = _init_stateless_group( _STANDBY_DP = _init_stateless_group(
standby_dp_ranks, "dp", dp_group_ports, master_ip, backend standby_dp_ranks, "dp", master_ip, backend, coord_store=coord_store
) )
standby_ep_ranks = ( standby_ep_ranks = (
...@@ -84,12 +86,16 @@ def create_standby_groups( ...@@ -84,12 +86,16 @@ def create_standby_groups(
) )
standby_ep_ranks = [x.tolist() for x in standby_ep_ranks] standby_ep_ranks = [x.tolist() for x in standby_ep_ranks]
_STANDBY_EP = _init_stateless_group( _STANDBY_EP = _init_stateless_group(
standby_ep_ranks, "ep", ep_group_ports, master_ip, backend standby_ep_ranks, "ep", master_ip, backend, coord_store=coord_store
) )
if eplb_group_ports is not None: if enable_eplb:
_STANDBY_EPLB = _init_stateless_group( _STANDBY_EPLB = _init_stateless_group(
standby_ep_ranks, "eplb", eplb_group_ports, master_ip, backend standby_ep_ranks,
"eplb",
master_ip,
backend,
coord_store=coord_store,
) )
......
...@@ -40,13 +40,16 @@ import torch ...@@ -40,13 +40,16 @@ import torch
import torch.distributed import torch.distributed
import torch.distributed._functional_collectives as funcol import torch.distributed._functional_collectives as funcol
import torch.distributed._symmetric_memory import torch.distributed._symmetric_memory
from torch.distributed import Backend, ProcessGroup from torch.distributed import Backend, ProcessGroup, Store
import vllm.envs as envs import vllm.envs as envs
from vllm.distributed.device_communicators.base_device_communicator import ( from vllm.distributed.device_communicators.base_device_communicator import (
DeviceCommunicatorBase, DeviceCommunicatorBase,
) )
from vllm.distributed.utils import StatelessProcessGroup from vllm.distributed.utils import (
StatelessProcessGroup,
get_cached_tcp_store_client,
)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils.import_utils import resolve_obj_by_qualname from vllm.utils.import_utils import resolve_obj_by_qualname
from vllm.utils.network_utils import get_distributed_init_method from vllm.utils.network_utils import get_distributed_init_method
...@@ -1164,9 +1167,9 @@ def init_model_parallel_group( ...@@ -1164,9 +1167,9 @@ def init_model_parallel_group(
def _init_stateless_group( def _init_stateless_group(
group_ranks: list[list[int]], group_ranks: list[list[int]],
group_name: str, group_name: str,
group_ports: list[list[int]],
host: str, host: str,
backend: str, backend: str,
coord_store: Store,
use_device_communicator: bool = True, use_device_communicator: bool = True,
) -> "StatelessGroupCoordinator": ) -> "StatelessGroupCoordinator":
"""Create a StatelessGroupCoordinator with the given parameters.""" """Create a StatelessGroupCoordinator with the given parameters."""
...@@ -1180,7 +1183,7 @@ def _init_stateless_group( ...@@ -1180,7 +1183,7 @@ def _init_stateless_group(
use_device_communicator=use_device_communicator, use_device_communicator=use_device_communicator,
group_name=group_name, group_name=group_name,
host=host, host=host,
group_ports=group_ports, coord_store=coord_store,
global_rank=world.rank, global_rank=world.rank,
global_world_size=world.world_size, global_world_size=world.world_size,
) )
...@@ -1321,7 +1324,9 @@ def _init_elastic_ep_world( ...@@ -1321,7 +1324,9 @@ def _init_elastic_ep_world(
group_ranks = [all_ranks[i : i + 1] for i in range(global_world_size)] group_ranks = [all_ranks[i : i + 1] for i in range(global_world_size)]
if global_rank in all_ranks: if global_rank in all_ranks:
group_ranks = [all_ranks] group_ranks = [all_ranks]
group_ports = [parallel_config.get_next_stateless_world_group_port()] coord_store = get_cached_tcp_store_client(
parallel_config.data_parallel_master_ip, parallel_config._coord_store_port
)
world = StatelessGroupCoordinator( world = StatelessGroupCoordinator(
group_ranks=group_ranks, group_ranks=group_ranks,
local_rank=local_rank, local_rank=local_rank,
...@@ -1329,7 +1334,7 @@ def _init_elastic_ep_world( ...@@ -1329,7 +1334,7 @@ def _init_elastic_ep_world(
use_device_communicator=False, use_device_communicator=False,
group_name="world", group_name="world",
host=parallel_config.data_parallel_master_ip, host=parallel_config.data_parallel_master_ip,
group_ports=group_ports, coord_store=coord_store,
global_rank=global_rank, global_rank=global_rank,
global_world_size=global_world_size, global_world_size=global_world_size,
) )
...@@ -1513,7 +1518,13 @@ def initialize_model_parallel( ...@@ -1513,7 +1518,13 @@ def initialize_model_parallel(
config = get_current_vllm_config() config = get_current_vllm_config()
data_parallel_size = config.parallel_config.data_parallel_size data_parallel_size = config.parallel_config.data_parallel_size
enable_elastic_ep = config.parallel_config.enable_elastic_ep enable_elastic_ep = config.parallel_config.enable_elastic_ep
parallel_config = config.parallel_config
coord_store: Store | None = None
if enable_elastic_ep: if enable_elastic_ep:
coord_store = get_cached_tcp_store_client(
parallel_config.data_parallel_master_ip,
parallel_config._coord_store_port,
)
# Use stateless world group for global information # Use stateless world group for global information
world_size = get_world_group().world_size world_size = get_world_group().world_size
rank = get_world_group().rank rank = get_world_group().rank
...@@ -1633,16 +1644,12 @@ def initialize_model_parallel( ...@@ -1633,16 +1644,12 @@ def initialize_model_parallel(
group_ranks = all_ranks.transpose(1, 4).reshape(-1, data_parallel_size).unbind(0) group_ranks = all_ranks.transpose(1, 4).reshape(-1, data_parallel_size).unbind(0)
group_ranks = [x.tolist() for x in group_ranks] group_ranks = [x.tolist() for x in group_ranks]
if enable_elastic_ep: if enable_elastic_ep:
parallel_config = config.parallel_config
dp_ports = [
parallel_config.get_next_stateless_dp_group_port() for _ in group_ranks
]
_DP = _init_stateless_group( _DP = _init_stateless_group(
group_ranks, group_ranks,
"dp", "dp",
dp_ports,
parallel_config.data_parallel_master_ip, parallel_config.data_parallel_master_ip,
backend, backend,
coord_store=coord_store,
) )
else: else:
_DP = init_model_parallel_group( _DP = init_model_parallel_group(
...@@ -1665,16 +1672,12 @@ def initialize_model_parallel( ...@@ -1665,16 +1672,12 @@ def initialize_model_parallel(
) )
group_ranks = [x.tolist() for x in group_ranks] group_ranks = [x.tolist() for x in group_ranks]
if enable_elastic_ep: if enable_elastic_ep:
parallel_config = config.parallel_config
ep_ports = [
parallel_config.get_next_stateless_ep_group_port() for _ in group_ranks
]
_EP = _init_stateless_group( _EP = _init_stateless_group(
group_ranks, group_ranks,
"ep", "ep",
ep_ports,
parallel_config.data_parallel_master_ip, parallel_config.data_parallel_master_ip,
backend, backend,
coord_store=coord_store,
) )
else: else:
_EP = init_model_parallel_group( _EP = init_model_parallel_group(
...@@ -1693,16 +1696,12 @@ def initialize_model_parallel( ...@@ -1693,16 +1696,12 @@ def initialize_model_parallel(
and config.parallel_config.enable_eplb and config.parallel_config.enable_eplb
): ):
if enable_elastic_ep: if enable_elastic_ep:
eplb_ports = [
parallel_config.get_next_stateless_eplb_group_port()
for _ in group_ranks
]
_EPLB = _init_stateless_group( _EPLB = _init_stateless_group(
group_ranks, group_ranks,
"eplb", "eplb",
eplb_ports,
parallel_config.data_parallel_master_ip, parallel_config.data_parallel_master_ip,
backend, backend,
coord_store=coord_store,
) )
else: else:
_EPLB = init_model_parallel_group( _EPLB = init_model_parallel_group(
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import socket
import struct
from typing import Any, Optional from typing import Any, Optional
import torch import torch
from torch.distributed import Backend, ProcessGroup from torch.distributed import Backend, ProcessGroup, Store
from vllm.distributed.device_communicators.cuda_communicator import CudaCommunicator from vllm.distributed.device_communicators.cuda_communicator import CudaCommunicator
from vllm.distributed.parallel_state import ( from vllm.distributed.parallel_state import (
...@@ -23,6 +25,38 @@ from vllm.utils.import_utils import resolve_obj_by_qualname ...@@ -23,6 +25,38 @@ from vllm.utils.import_utils import resolve_obj_by_qualname
logger = init_logger(__name__) logger = init_logger(__name__)
_PORTS_FMT = "!3I"
def _allocate_group_ports(
key: str,
host: str,
coord_store: Store,
) -> tuple[list[int], list[socket.socket]]:
"""Bind 3 sockets and publish the ports to *coord_store*.
Called by rank 0 only. Returns ``(ports, sockets)`` with the
sockets still open.
"""
socks: list[socket.socket] = []
ports: list[int] = []
for _ in range(3):
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.bind((host, 0))
s.listen()
socks.append(s)
ports.append(s.getsockname()[1])
coord_store.set(key, struct.pack(_PORTS_FMT, *ports))
return ports, socks
def _fetch_group_ports(key: str, coord_store: Store) -> list[int]:
"""Read 3 ports published by rank 0 from *coord_store*.
Blocks until the key is available.
"""
return list(struct.unpack(_PORTS_FMT, coord_store.get(key)))
class StatelessGroupCoordinator(GroupCoordinator): class StatelessGroupCoordinator(GroupCoordinator):
""" """
...@@ -39,10 +73,10 @@ class StatelessGroupCoordinator(GroupCoordinator): ...@@ -39,10 +73,10 @@ class StatelessGroupCoordinator(GroupCoordinator):
local_rank: int, local_rank: int,
torch_distributed_backend: str | Backend, torch_distributed_backend: str | Backend,
use_device_communicator: bool, use_device_communicator: bool,
coord_store: Store,
use_message_queue_broadcaster: bool = False, use_message_queue_broadcaster: bool = False,
group_name: str | None = None, group_name: str | None = None,
host: str = "127.0.0.1", host: str = "127.0.0.1",
group_ports: list[list[int]] | None = None,
global_rank: int = 0, global_rank: int = 0,
global_world_size: int = 1, global_world_size: int = 1,
): ):
...@@ -61,17 +95,23 @@ class StatelessGroupCoordinator(GroupCoordinator): ...@@ -61,17 +95,23 @@ class StatelessGroupCoordinator(GroupCoordinator):
backend = str(torch_distributed_backend) backend = str(torch_distributed_backend)
self.backend = backend self.backend = backend
assert group_ports is not None, "group_ports is not provided"
for idx, ranks in enumerate(group_ranks): for idx, ranks in enumerate(group_ranks):
if self.rank in ranks: if self.rank in ranks:
self.ranks = ranks self.ranks = ranks
self.world_size = len(ranks) self.world_size = len(ranks)
self.rank_in_group = ranks.index(self.rank) self.rank_in_group = ranks.index(self.rank)
ports = group_ports[idx] key = f"{group_name}_{idx}"
device_port = ports[0] if self.rank_in_group == 0:
cpu_port = ports[1] ports, socks = _allocate_group_ports(
tcp_store_port = ports[2] key,
host,
coord_store,
)
else:
ports = _fetch_group_ports(key, coord_store)
socks = []
device_port, cpu_port, tcp_store_port = ports
device_group = stateless_init_torch_distributed_process_group( device_group = stateless_init_torch_distributed_process_group(
host=host, host=host,
...@@ -80,6 +120,7 @@ class StatelessGroupCoordinator(GroupCoordinator): ...@@ -80,6 +120,7 @@ class StatelessGroupCoordinator(GroupCoordinator):
world_size=self.world_size, world_size=self.world_size,
backend=backend, backend=backend,
group_name=f"{self.unique_name}_device", group_name=f"{self.unique_name}_device",
listen_socket=socks[0] if socks else None,
) )
cpu_group = stateless_init_torch_distributed_process_group( cpu_group = stateless_init_torch_distributed_process_group(
host=host, host=host,
...@@ -88,12 +129,14 @@ class StatelessGroupCoordinator(GroupCoordinator): ...@@ -88,12 +129,14 @@ class StatelessGroupCoordinator(GroupCoordinator):
world_size=self.world_size, world_size=self.world_size,
backend="gloo", backend="gloo",
group_name=f"{self.unique_name}_cpu", group_name=f"{self.unique_name}_cpu",
listen_socket=socks[1] if socks else None,
) )
tcp_store_group = StatelessProcessGroup.create( tcp_store_group = StatelessProcessGroup.create(
host=host, host=host,
port=tcp_store_port, port=tcp_store_port,
rank=self.rank_in_group, rank=self.rank_in_group,
world_size=self.world_size, world_size=self.world_size,
listen_socket=socks[2] if socks else None,
) )
self_device_group = device_group self_device_group = device_group
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import dataclasses import dataclasses
import functools
import os import os
import pickle import pickle
import socket import socket
...@@ -139,6 +140,29 @@ def get_pp_indices( ...@@ -139,6 +140,29 @@ def get_pp_indices(
return (start_layer, end_layer) return (start_layer, end_layer)
def create_tcp_store(
host: str,
port: int,
listen_socket: socket.socket | None = None,
**kwargs: Any,
) -> TCPStore:
"""Create a TCPStore, optionally taking ownership of ``listen_socket``."""
if listen_socket is None:
return TCPStore(host_name=host, port=port, **kwargs)
listen_fd = listen_socket.detach()
try:
return TCPStore(
host_name=host,
port=port,
master_listen_fd=listen_fd,
**kwargs,
)
except Exception:
socket.close(listen_fd)
raise
@dataclasses.dataclass @dataclasses.dataclass
class StatelessProcessGroup: class StatelessProcessGroup:
"""A dataclass to hold a metadata store, and the rank, world_size of the """A dataclass to hold a metadata store, and the rank, world_size of the
...@@ -150,9 +174,6 @@ class StatelessProcessGroup: ...@@ -150,9 +174,6 @@ class StatelessProcessGroup:
world_size: int world_size: int
store: torch._C._distributed_c10d.Store store: torch._C._distributed_c10d.Store
# stores a reference to the socket so that the file descriptor stays alive
socket: socket.socket | None
data_expiration_seconds: int = 3600 # 1 hour data_expiration_seconds: int = 3600 # 1 hour
# dst rank -> counter # dst rank -> counter
...@@ -419,6 +440,7 @@ class StatelessProcessGroup: ...@@ -419,6 +440,7 @@ class StatelessProcessGroup:
world_size: int, world_size: int,
data_expiration_seconds: int = 3600, data_expiration_seconds: int = 3600,
store_timeout: int = 300, store_timeout: int = 300,
listen_socket: socket.socket | None = None,
) -> "StatelessProcessGroup": ) -> "StatelessProcessGroup":
"""A replacement for `torch.distributed.init_process_group` that does not """A replacement for `torch.distributed.init_process_group` that does not
pollute the global state. pollute the global state.
...@@ -436,36 +458,39 @@ class StatelessProcessGroup: ...@@ -436,36 +458,39 @@ class StatelessProcessGroup:
C, and D can call `StatelessProcessGroup.create` to form another group. C, and D can call `StatelessProcessGroup.create` to form another group.
""" # noqa """ # noqa
launch_server = rank == 0 launch_server = rank == 0
if launch_server: if launch_server and listen_socket is None:
# listen on the specified interface (instead of 0.0.0.0)
listen_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) listen_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
listen_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) listen_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
listen_socket.bind((host, port)) listen_socket.bind((host, port))
listen_socket.listen() listen_socket.listen()
listen_fd = listen_socket.fileno() store = create_tcp_store(
else: host,
listen_socket = None port,
listen_fd = None listen_socket=listen_socket,
store = TCPStore(
host_name=host,
port=port,
world_size=world_size, world_size=world_size,
is_master=launch_server, is_master=launch_server,
timeout=timedelta(seconds=store_timeout), timeout=timedelta(seconds=store_timeout),
use_libuv=False, # for now: github.com/pytorch/pytorch/pull/150215 use_libuv=False, # for now: github.com/pytorch/pytorch/pull/150215
master_listen_fd=listen_fd,
) )
return StatelessProcessGroup( return StatelessProcessGroup(
rank=rank, rank=rank,
world_size=world_size, world_size=world_size,
store=store, store=store,
socket=listen_socket,
data_expiration_seconds=data_expiration_seconds, data_expiration_seconds=data_expiration_seconds,
) )
@functools.lru_cache(maxsize=1)
def get_cached_tcp_store_client(host: str, port: int) -> TCPStore:
"""Return a cached TCPStore client.
Cached so that every call with the same ``(host, port)`` reuses the
same connection. A new ``(host, port)`` evicts the old entry.
"""
return TCPStore(host, port, is_master=False, wait_for_workers=False)
def init_gloo_process_group( def init_gloo_process_group(
prefix_store: PrefixStore, prefix_store: PrefixStore,
group_rank: int, group_rank: int,
...@@ -504,6 +529,7 @@ def stateless_init_torch_distributed_process_group( ...@@ -504,6 +529,7 @@ def stateless_init_torch_distributed_process_group(
backend: str, backend: str,
group_name: str | None = None, group_name: str | None = None,
return_store: bool = False, return_store: bool = False,
listen_socket: socket.socket | None = None,
) -> ProcessGroup | tuple[ProcessGroup, Store]: ) -> ProcessGroup | tuple[ProcessGroup, Store]:
""" """
A replacement for `torch.distributed.init_process_group` that does not A replacement for `torch.distributed.init_process_group` that does not
...@@ -535,14 +561,30 @@ def stateless_init_torch_distributed_process_group( ...@@ -535,14 +561,30 @@ def stateless_init_torch_distributed_process_group(
are the same as process 1 and 5, the main communication channel is are the same as process 1 and 5, the main communication channel is
always formed with process 1, 2, ..., 8, and the additional communication always formed with process 1, 2, ..., 8, and the additional communication
channel is formed with process 9 and 10. channel is formed with process 9 and 10.
When *listen_socket* is provided, the rendezvous step
is skipped and a ``TCPStore`` server is created directly using the
pre-bound socket. This is useful for eliminating TOCTOU races
between port allocation and binding.
""" """
init_method = get_tcp_uri(host, port) init_method = get_tcp_uri(host, port)
backend = Backend(backend) # it is basically string backend = Backend(backend) # it is basically string
timeout = _get_default_timeout(backend) timeout = _get_default_timeout(backend)
store, rank, world_size = next( if listen_socket is not None:
rendezvous(init_method, rank, world_size, timeout=timeout) store = create_tcp_store(
) host,
port,
listen_socket=listen_socket,
world_size=world_size,
is_master=True,
timeout=timeout,
multi_tenant=True,
)
else:
store, rank, world_size = next(
rendezvous(init_method, rank, world_size, timeout=timeout)
)
store.set_timeout(timeout) store.set_timeout(timeout)
group_rank = rank group_rank = rank
......
...@@ -237,10 +237,7 @@ class ReconfigureDistributedRequest(msgspec.Struct): ...@@ -237,10 +237,7 @@ class ReconfigureDistributedRequest(msgspec.Struct):
new_data_parallel_master_ip: str new_data_parallel_master_ip: str
new_data_parallel_master_port: int new_data_parallel_master_port: int
new_data_parallel_master_port_list: list[int] new_data_parallel_master_port_list: list[int]
new_stateless_world_group_port_list: list[list[int]] coord_store_port: int
new_stateless_dp_group_port_list: list[list[int]]
new_stateless_ep_group_port_list: list[list[int]]
new_stateless_eplb_group_port_list: list[list[int]]
class ReconfigureRankType(enum.IntEnum): class ReconfigureRankType(enum.IntEnum):
......
...@@ -1767,6 +1767,7 @@ class DPEngineCoreProc(EngineCoreProc): ...@@ -1767,6 +1767,7 @@ class DPEngineCoreProc(EngineCoreProc):
new_parallel_config._data_parallel_master_port_list = ( new_parallel_config._data_parallel_master_port_list = (
reconfig_request.new_data_parallel_master_port_list reconfig_request.new_data_parallel_master_port_list
) )
new_parallel_config._coord_store_port = reconfig_request.coord_store_port
is_scale_down = reconfig_request.new_data_parallel_size < old_dp_size is_scale_down = reconfig_request.new_data_parallel_size < old_dp_size
is_shutdown = ( is_shutdown = (
......
...@@ -455,56 +455,6 @@ class ElasticScalingCache: ...@@ -455,56 +455,6 @@ class ElasticScalingCache:
pending_notifications: dict[EEPNotificationType, set[int]] pending_notifications: dict[EEPNotificationType, set[int]]
def allocate_stateless_group_ports(parallel_config, new_data_parallel_size: int):
"""
Allocate stateless group ports for elastic EP.
"""
from vllm.utils.network_utils import get_open_ports_list
assert parallel_config.enable_elastic_ep, "Elastic EP must be enabled"
world_size = parallel_config.world_size
new_world_size_across_dp = world_size * new_data_parallel_size
num_world_groups = 1
num_dp_groups = max(1, new_world_size_across_dp // new_data_parallel_size)
num_ep_groups = max(
1,
new_world_size_across_dp
// (new_data_parallel_size * parallel_config.tensor_parallel_size),
)
num_eplb_groups = num_ep_groups
total_ports_needed = (
num_world_groups + num_dp_groups + num_ep_groups + num_eplb_groups
) * 3 + 5
all_ports = get_open_ports_list(total_ports_needed)
new_data_parallel_master_port_list = all_ports[-5:]
all_ports = all_ports[:-5]
new_stateless_world_group_port_list = [
all_ports[i : i + 3] for i in range(0, num_world_groups * 3, 3)
]
start_idx = num_world_groups * 3
new_stateless_dp_group_port_list = [
all_ports[i : i + 3] for i in range(start_idx, start_idx + num_dp_groups * 3, 3)
]
start_idx += num_dp_groups * 3
new_stateless_ep_group_port_list = [
all_ports[i : i + 3] for i in range(start_idx, start_idx + num_ep_groups * 3, 3)
]
start_idx += num_ep_groups * 3
new_stateless_eplb_group_port_list = [
all_ports[i : i + 3]
for i in range(start_idx, start_idx + num_eplb_groups * 3, 3)
]
parallel_config._stateless_world_group_port_list = (
new_stateless_world_group_port_list
)
parallel_config._stateless_dp_group_port_list = new_stateless_dp_group_port_list
parallel_config._stateless_ep_group_port_list = new_stateless_ep_group_port_list
parallel_config._stateless_eplb_group_port_list = new_stateless_eplb_group_port_list
parallel_config.data_parallel_master_port = new_data_parallel_master_port_list.pop()
parallel_config._data_parallel_master_port_list = new_data_parallel_master_port_list
class MPClient(EngineCoreClient): class MPClient(EngineCoreClient):
""" """
MPClient: base client for multi-proc EngineCore. MPClient: base client for multi-proc EngineCore.
...@@ -1541,6 +1491,28 @@ class DPLBAsyncMPClient(DPAsyncMPClient): ...@@ -1541,6 +1491,28 @@ class DPLBAsyncMPClient(DPAsyncMPClient):
self._ensure_output_queue_task() self._ensure_output_queue_task()
await future await future
def _setup_elastic_ep_reconfig_bootstrap(self) -> tuple[str, int]:
from vllm.distributed.utils import create_tcp_store
from vllm.utils.network_utils import get_open_ports_list
parallel_config = self.vllm_config.parallel_config
parallel_config._data_parallel_master_port_list = get_open_ports_list(5)
parallel_config.data_parallel_master_port = (
parallel_config._data_parallel_master_port_list.pop()
)
ip = parallel_config.data_parallel_master_ip
store = create_tcp_store(
ip,
0,
is_master=True,
world_size=-1,
wait_for_workers=False,
)
parallel_config._coord_store_port = store.port
self._coord_store = store
return ip, store.port
async def _scale_up_elastic_ep( async def _scale_up_elastic_ep(
self, cur_data_parallel_size: int, new_data_parallel_size: int self, cur_data_parallel_size: int, new_data_parallel_size: int
) -> None: ) -> None:
...@@ -1555,7 +1527,7 @@ class DPLBAsyncMPClient(DPAsyncMPClient): ...@@ -1555,7 +1527,7 @@ class DPLBAsyncMPClient(DPAsyncMPClient):
) )
parallel_config = self.vllm_config.parallel_config parallel_config = self.vllm_config.parallel_config
allocate_stateless_group_ports(parallel_config, new_data_parallel_size) ip, coord_store_port = self._setup_elastic_ep_reconfig_bootstrap()
# Phase 1: Send reconfig messages to existing engines # Phase 1: Send reconfig messages to existing engines
reconfig_futures = [] reconfig_futures = []
...@@ -1564,13 +1536,10 @@ class DPLBAsyncMPClient(DPAsyncMPClient): ...@@ -1564,13 +1536,10 @@ class DPLBAsyncMPClient(DPAsyncMPClient):
new_data_parallel_size=new_data_parallel_size, new_data_parallel_size=new_data_parallel_size,
new_data_parallel_rank=ReconfigureRankType.KEEP_CURRENT_RANK, new_data_parallel_rank=ReconfigureRankType.KEEP_CURRENT_RANK,
new_data_parallel_rank_local=ReconfigureRankType.KEEP_CURRENT_RANK, new_data_parallel_rank_local=ReconfigureRankType.KEEP_CURRENT_RANK,
new_data_parallel_master_ip=parallel_config.data_parallel_master_ip, new_data_parallel_master_ip=ip,
new_data_parallel_master_port=parallel_config.data_parallel_master_port, new_data_parallel_master_port=parallel_config.data_parallel_master_port,
new_data_parallel_master_port_list=parallel_config._data_parallel_master_port_list, new_data_parallel_master_port_list=parallel_config._data_parallel_master_port_list,
new_stateless_world_group_port_list=parallel_config._stateless_world_group_port_list, coord_store_port=coord_store_port,
new_stateless_dp_group_port_list=parallel_config._stateless_dp_group_port_list,
new_stateless_ep_group_port_list=parallel_config._stateless_ep_group_port_list,
new_stateless_eplb_group_port_list=parallel_config._stateless_eplb_group_port_list,
) )
coro = self._call_utility_async( coro = self._call_utility_async(
"reinitialize_distributed", reconfig_request, engine=engine "reinitialize_distributed", reconfig_request, engine=engine
...@@ -1650,7 +1619,7 @@ class DPLBAsyncMPClient(DPAsyncMPClient): ...@@ -1650,7 +1619,7 @@ class DPLBAsyncMPClient(DPAsyncMPClient):
) )
parallel_config = self.vllm_config.parallel_config parallel_config = self.vllm_config.parallel_config
allocate_stateless_group_ports(parallel_config, new_data_parallel_size) ip, coord_store_port = self._setup_elastic_ep_reconfig_bootstrap()
reconfig_futures = [] reconfig_futures = []
for cur_dp_rank, engine in enumerate(self.core_engines): for cur_dp_rank, engine in enumerate(self.core_engines):
...@@ -1658,13 +1627,10 @@ class DPLBAsyncMPClient(DPAsyncMPClient): ...@@ -1658,13 +1627,10 @@ class DPLBAsyncMPClient(DPAsyncMPClient):
new_data_parallel_size=new_data_parallel_size, new_data_parallel_size=new_data_parallel_size,
new_data_parallel_rank=ReconfigureRankType.KEEP_CURRENT_RANK, new_data_parallel_rank=ReconfigureRankType.KEEP_CURRENT_RANK,
new_data_parallel_rank_local=ReconfigureRankType.KEEP_CURRENT_RANK, new_data_parallel_rank_local=ReconfigureRankType.KEEP_CURRENT_RANK,
new_data_parallel_master_ip=parallel_config.data_parallel_master_ip, new_data_parallel_master_ip=ip,
new_data_parallel_master_port=parallel_config.data_parallel_master_port, new_data_parallel_master_port=parallel_config.data_parallel_master_port,
new_data_parallel_master_port_list=parallel_config._data_parallel_master_port_list, new_data_parallel_master_port_list=parallel_config._data_parallel_master_port_list,
new_stateless_world_group_port_list=parallel_config._stateless_world_group_port_list, coord_store_port=coord_store_port,
new_stateless_dp_group_port_list=parallel_config._stateless_dp_group_port_list,
new_stateless_ep_group_port_list=parallel_config._stateless_ep_group_port_list,
new_stateless_eplb_group_port_list=parallel_config._stateless_eplb_group_port_list,
) )
if cur_dp_rank >= new_data_parallel_size: if cur_dp_rank >= new_data_parallel_size:
reconfig_request.new_data_parallel_rank = ( reconfig_request.new_data_parallel_rank = (
......
...@@ -301,7 +301,20 @@ class CoreEngineActorManager: ...@@ -301,7 +301,20 @@ class CoreEngineActorManager:
else: else:
ray.init() ray.init()
vllm_config.parallel_config.allocate_elastic_ep_ports() parallel_config = vllm_config.parallel_config
if parallel_config.enable_elastic_ep:
from vllm.distributed.utils import create_tcp_store
ip = parallel_config.data_parallel_master_ip
store = create_tcp_store(
ip,
0,
is_master=True,
world_size=-1,
wait_for_workers=False,
)
parallel_config._coord_store_port = store.port
self._coord_store = store
if placement_groups is not None: if placement_groups is not None:
assert local_dp_ranks is not None, ( assert local_dp_ranks is not None, (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment