Unverified Commit 67132945 authored by Ilya Markov's avatar Ilya Markov Committed by GitHub
Browse files

[Perf] Move eplb rebalance algo to async thread (#30888)


Signed-off-by: default avatarilmarkov <markovilya197@gmail.com>
Signed-off-by: default avatarTyler Michael Smith <tlrmchlsmth@gmail.com>
Co-authored-by: default avatarTyler Michael Smith <tyler@neuralmagic.com>
Co-authored-by: default avatarTyler Michael Smith <tlrmchlsmth@gmail.com>
parent f0ca0671
...@@ -295,12 +295,11 @@ def _test_async_transfer_layer_without_mtp_worker( ...@@ -295,12 +295,11 @@ def _test_async_transfer_layer_without_mtp_worker(
for layer_idx in range(num_layers): for layer_idx in range(num_layers):
is_unchanged, is_received_locally, recv_metadata = asyncio.run( is_unchanged, is_received_locally, recv_metadata = asyncio.run(
transfer_layer( transfer_layer(
old_global_expert_indices=old_indices_cpu, old_layer_indices=old_indices_cpu[layer_idx],
new_global_expert_indices=new_indices_cpu, new_layer_indices=new_indices_cpu[layer_idx],
expert_weights=expert_weights, expert_weights=expert_weights[layer_idx],
expert_weights_buffer=expert_buffer, expert_weights_buffer=expert_buffer,
ep_group=ep_group, ep_group=ep_group,
layer=layer_idx,
cuda_stream=cuda_stream, cuda_stream=cuda_stream,
) )
) )
......
...@@ -11,13 +11,13 @@ from typing import TYPE_CHECKING ...@@ -11,13 +11,13 @@ from typing import TYPE_CHECKING
import torch import torch
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from vllm.distributed.parallel_state import get_ep_group from vllm.distributed.parallel_state import get_eplb_group
from vllm.logger import init_logger from vllm.logger import init_logger
from .rebalance_execute import transfer_layer from .rebalance_execute import transfer_layer
if TYPE_CHECKING: if TYPE_CHECKING:
from .eplb_state import EplbState from .eplb_state import EplbModelState, EplbState
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -27,8 +27,8 @@ def start_async_worker( ...@@ -27,8 +27,8 @@ def start_async_worker(
rank_mapping: dict[int, int] | None = None, rank_mapping: dict[int, int] | None = None,
is_profile: bool = False, is_profile: bool = False,
) -> threading.Thread: ) -> threading.Thread:
ep_group = get_ep_group().device_group eplb_group = get_eplb_group().device_group
rank = ep_group.rank() rank = eplb_group.rank()
device_index = state.cuda_device_index device_index = state.cuda_device_index
assert state.is_async assert state.is_async
...@@ -42,7 +42,7 @@ def start_async_worker( ...@@ -42,7 +42,7 @@ def start_async_worker(
loop.run_until_complete( loop.run_until_complete(
transfer_run_periodically( transfer_run_periodically(
state=state, state=state,
ep_group=ep_group, eplb_group=eplb_group,
cuda_stream=cuda_stream, cuda_stream=cuda_stream,
is_profile=is_profile, is_profile=is_profile,
rank_mapping=rank_mapping, rank_mapping=rank_mapping,
...@@ -58,9 +58,53 @@ def start_async_worker( ...@@ -58,9 +58,53 @@ def start_async_worker(
return thread return thread
def run_rebalance_experts(
model_state: "EplbModelState",
eplb_state: "EplbState",
physical_to_logical_map_cpu: torch.Tensor,
) -> None:
assert model_state.eplb_stats is not None
eplb_stats = model_state.eplb_stats
# Wait for the main thread's all-reduce and clone to complete before
# accessing the global_expert_load_window tensor.
assert model_state.window_ready_event is not None
model_state.window_ready_event.wait()
model_state.window_ready_event = None
# Move the global expert load window to CPU for computation.
global_expert_load_window = eplb_stats.global_expert_load_window.cpu()
# Compute new expert mappings for the model
(
new_physical_to_logical_map,
new_logical_to_physical_map,
new_logical_replica_count,
) = eplb_state.policy.rebalance_experts(
global_expert_load_window,
eplb_stats.num_replicas,
eplb_stats.num_groups,
eplb_stats.num_nodes,
eplb_stats.num_gpus,
physical_to_logical_map_cpu,
)
assert new_physical_to_logical_map.device == torch.device("cpu")
model_state.new_physical_to_logical_map = new_physical_to_logical_map
max_slots = model_state.logical_to_physical_map.shape[-1]
padded_logical = torch.nn.functional.pad(
new_logical_to_physical_map,
(0, max(0, max_slots - new_logical_to_physical_map.shape[-1])),
value=-1,
).to(model_state.logical_to_physical_map.device)
new_replica = new_logical_replica_count.to(model_state.logical_replica_count.device)
model_state.new_logical_to_physical_map = padded_logical
model_state.new_logical_replica_count = new_replica
async def transfer_run_periodically( async def transfer_run_periodically(
state: "EplbState", state: "EplbState",
ep_group: ProcessGroup, eplb_group: ProcessGroup,
cuda_stream: torch.cuda.Stream, cuda_stream: torch.cuda.Stream,
is_profile: bool = False, is_profile: bool = False,
rank_mapping: dict[int, int] | None = None, rank_mapping: dict[int, int] | None = None,
...@@ -71,23 +115,51 @@ async def transfer_run_periodically( ...@@ -71,23 +115,51 @@ async def transfer_run_periodically(
assert state.is_async assert state.is_async
for model_state in state.model_states.values(): for model_state in state.model_states.values():
rebalancing_algorithm_executed = False
physical_to_logical_map_cpu = None
current_num_layers = model_state.model.num_moe_layers current_num_layers = model_state.model.num_moe_layers
while ( while (
model_state.rebalanced model_state.rebalanced
and model_state.layer_to_transfer < current_num_layers and model_state.layer_to_transfer < current_num_layers
): ):
if ( if not model_state.ep_buffer_ready and model_state.rebalanced:
not model_state.ep_buffer_ready # Polling the lock directly in the async thread avoids
and model_state.rebalanced # the thread switch overhead of asyncio.to_thread.
and model_state.new_physical_to_logical_map is not None # This is typically faster than offloading to a worker thread.
): while not model_state.buffer_lock.acquire(blocking=False):
await asyncio.to_thread(model_state.buffer_lock.acquire) await asyncio.sleep(0)
try: try:
if model_state.layer_to_transfer >= current_num_layers: if model_state.layer_to_transfer >= current_num_layers:
break break
if (
not rebalancing_algorithm_executed
or model_state.new_physical_to_logical_map is None
):
# Move the physical_to_logical_map to CPU
# for rebalancing and transfer_layer.
physical_to_logical_map_cpu = (
model_state.physical_to_logical_map.cpu()
)
run_rebalance_experts(
model_state, state, physical_to_logical_map_cpu
)
rebalancing_algorithm_executed = True
logger.info(
"Async worker computed new indices for model %s",
model_state.model_name,
)
assert model_state.new_physical_to_logical_map is not None
assert physical_to_logical_map_cpu is not None
layer_idx = model_state.layer_to_transfer
old_layer_indices = physical_to_logical_map_cpu[layer_idx]
new_layer_indices = model_state.new_physical_to_logical_map[
layer_idx
]
# Wait for the main thread to finish consuming the buffer # Wait for the main thread to finish consuming the buffer
# before overwriting it # before initiating an EPLB transfer on another layer.
if model_state.buffer_consumed_event is not None: if model_state.buffer_consumed_event is not None:
cuda_stream.wait_event(model_state.buffer_consumed_event) cuda_stream.wait_event(model_state.buffer_consumed_event)
model_state.buffer_consumed_event = None model_state.buffer_consumed_event = None
...@@ -97,13 +169,12 @@ async def transfer_run_periodically( ...@@ -97,13 +169,12 @@ async def transfer_run_periodically(
model_state.is_received_locally, model_state.is_received_locally,
model_state.recv_metadata, model_state.recv_metadata,
) = await transfer_layer( ) = await transfer_layer(
old_global_expert_indices=model_state.physical_to_logical_map, old_layer_indices=old_layer_indices,
new_global_expert_indices=model_state.new_physical_to_logical_map, new_layer_indices=new_layer_indices,
expert_weights=model_state.model.expert_weights, expert_weights=model_state.model.expert_weights[layer_idx],
expert_weights_buffer=model_state.expert_buffer, expert_weights_buffer=model_state.expert_buffer,
ep_group=ep_group, ep_group=eplb_group,
is_profile=is_profile, is_profile=is_profile,
layer=model_state.layer_to_transfer,
cuda_stream=cuda_stream, cuda_stream=cuda_stream,
rank_mapping=rank_mapping, rank_mapping=rank_mapping,
) )
......
...@@ -55,6 +55,35 @@ from .rebalance_execute import ( ...@@ -55,6 +55,35 @@ from .rebalance_execute import (
logger = init_logger(__name__) logger = init_logger(__name__)
@dataclass
class EplbStats:
"""
Model stats used in EPLB rebalancing algorithm.
"""
global_expert_load_window: torch.Tensor
"""
Experts load window.
Shape: (window_size, num_moe_layers, num_physical_experts)
"""
num_replicas: int
"""
Number of physical experts.
"""
num_groups: int
"""
Number of expert groups.
"""
num_nodes: int
"""
Number of nodes.
"""
num_gpus: int
"""
Number of GPUs.
"""
@dataclass @dataclass
class EplbModelState: class EplbModelState:
"""EPLB metrics.""" """EPLB metrics."""
...@@ -156,6 +185,11 @@ class EplbModelState: ...@@ -156,6 +185,11 @@ class EplbModelState:
CUDA event recorded after the main thread finishes consuming the buffer. CUDA event recorded after the main thread finishes consuming the buffer.
The async worker waits on this before writing to the buffer again. The async worker waits on this before writing to the buffer again.
""" """
window_ready_event: torch.cuda.Event | None
"""
CUDA event recorded after all-reduce and clone on the main thread.
The async worker waits on this before accessing global_expert_load_window.
"""
ep_buffer_ready: int ep_buffer_ready: int
""" """
The flag indicates whether the expert buffer is ready for transfer. The flag indicates whether the expert buffer is ready for transfer.
...@@ -173,6 +207,10 @@ class EplbModelState: ...@@ -173,6 +207,10 @@ class EplbModelState:
""" """
Whether the async EPLB needs to poll peers for buffer readiness. Whether the async EPLB needs to poll peers for buffer readiness.
""" """
eplb_stats: EplbStats | None
"""
EPLB stats for the model.
"""
is_unchanged: np.ndarray is_unchanged: np.ndarray
""" """
intermediate variable between `move_to_buffer` and `move_to_workspace`. intermediate variable between `move_to_buffer` and `move_to_workspace`.
...@@ -508,10 +546,12 @@ class EplbState: ...@@ -508,10 +546,12 @@ class EplbState:
buffer_lock=threading.Lock(), buffer_lock=threading.Lock(),
buffer_ready_event=None, buffer_ready_event=None,
buffer_consumed_event=None, buffer_consumed_event=None,
window_ready_event=None,
ep_buffer_ready=0, ep_buffer_ready=0,
layer_to_transfer=0, layer_to_transfer=0,
rebalanced=False, rebalanced=False,
pending_global_ready_check=False, pending_global_ready_check=False,
eplb_stats=None,
is_unchanged=np.array([]), is_unchanged=np.array([]),
is_received_locally=np.array([]), is_received_locally=np.array([]),
recv_metadata=RecvMetadata( recv_metadata=RecvMetadata(
...@@ -642,20 +682,6 @@ class EplbState: ...@@ -642,20 +682,6 @@ class EplbState:
ep_group=ep_group, ep_group=ep_group,
is_profile=is_profile, is_profile=is_profile,
) )
if (
eplb_model_state.layer_to_transfer
>= eplb_model_state.model.num_moe_layers
):
self.post_eplb(eplb_model_state, is_profile)
eplb_model_state.rebalanced = False
eplb_model_state.layer_to_transfer = 0
eplb_model_state.pending_global_ready_check = False
logger.info(
"finish async transfer for model %s rank %d layer %d",
eplb_model_state.model_name,
ep_group.rank(),
eplb_model_state.model.num_moe_layers,
)
if self.expert_rearrangement_step >= self.expert_rearrangement_step_interval: if self.expert_rearrangement_step >= self.expert_rearrangement_step_interval:
if self.is_async and any( if self.is_async and any(
...@@ -802,6 +828,7 @@ class EplbState: ...@@ -802,6 +828,7 @@ class EplbState:
for eplb_model_state, global_expert_load_window in zip( for eplb_model_state, global_expert_load_window in zip(
self.model_states.values(), global_expert_load_windows self.model_states.values(), global_expert_load_windows
): ):
if not self.is_async or is_profile:
# Get new expert mappings for the model # Get new expert mappings for the model
( (
new_physical_to_logical_map, new_physical_to_logical_map,
...@@ -816,7 +843,6 @@ class EplbState: ...@@ -816,7 +843,6 @@ class EplbState:
eplb_model_state.physical_to_logical_map, eplb_model_state.physical_to_logical_map,
) )
if not self.is_async or is_profile:
# Update expert weights # Update expert weights
rearrange_expert_weights_inplace( rearrange_expert_weights_inplace(
eplb_model_state.physical_to_logical_map, eplb_model_state.physical_to_logical_map,
...@@ -873,27 +899,25 @@ class EplbState: ...@@ -873,27 +899,25 @@ class EplbState:
gpu_elapsed, gpu_elapsed,
) )
else: else:
max_slots = eplb_model_state.logical_to_physical_map.shape[-1] eplb_model_state.eplb_stats = EplbStats(
padded_logical = torch.nn.functional.pad( # We copy the tensor to snapshot the global_expert_load_window
new_logical_to_physical_map, # on the main thread so that async worker can access it safely
(0, max(0, max_slots - new_logical_to_physical_map.shape[-1])), # while the main thread is running.
value=-1, global_expert_load_window=global_expert_load_window.clone(),
).to(eplb_model_state.logical_to_physical_map.device) num_replicas=num_replicas,
new_replica = new_logical_replica_count.to( num_groups=num_groups,
eplb_model_state.logical_replica_count.device num_nodes=num_nodes,
) num_gpus=num_gpus,
)
# Move map to cpu in advance # Record event after clone to signal async worker
eplb_model_state.new_physical_to_logical_map = ( # that load stats data is ready
new_physical_to_logical_map.cpu() sync_event = torch.cuda.Event()
) sync_event.record()
eplb_model_state.new_logical_to_physical_map = padded_logical eplb_model_state.window_ready_event = sync_event
eplb_model_state.new_logical_replica_count = new_replica
eplb_model_state.rebalanced = True eplb_model_state.rebalanced = True
eplb_model_state.layer_to_transfer = 0 eplb_model_state.layer_to_transfer = 0
eplb_model_state.pending_global_ready_check = True eplb_model_state.pending_global_ready_check = True
# Signal async thread to start transferring layers # Signal async thread to start transferring layers
if self.is_async and (not is_profile): if self.is_async and (not is_profile):
self.rearrange_event.set() self.rearrange_event.set()
...@@ -925,11 +949,13 @@ class EplbState: ...@@ -925,11 +949,13 @@ class EplbState:
target_device = model_state.physical_to_logical_map.device target_device = model_state.physical_to_logical_map.device
new_physical = model_state.new_physical_to_logical_map new_physical = model_state.new_physical_to_logical_map
# If the number of physical experts has changed, then the new map needs to
# be copied synchronously to avoid a race condition with the async worker
if model_state.physical_to_logical_map.shape[1] != new_physical.shape[1]: if model_state.physical_to_logical_map.shape[1] != new_physical.shape[1]:
model_state.physical_to_logical_map = new_physical.to(target_device) model_state.physical_to_logical_map = new_physical.to(target_device)
else: else:
model_state.physical_to_logical_map[layer].copy_( model_state.physical_to_logical_map[layer].copy_(
new_physical[layer].to(target_device) new_physical[layer].to(target_device, non_blocking=True)
) )
logical_device = model_state.logical_to_physical_map.device logical_device = model_state.logical_to_physical_map.device
...@@ -1004,11 +1030,9 @@ class EplbState: ...@@ -1004,11 +1030,9 @@ class EplbState:
model_state.layer_to_transfer model_state.layer_to_transfer
] ]
expert_weights_buffer = model_state.expert_buffer expert_weights_buffer = model_state.expert_buffer
new_indices = ( new_indices = model_state.new_physical_to_logical_map[
model_state.new_physical_to_logical_map[model_state.layer_to_transfer] model_state.layer_to_transfer
.cpu() ].numpy()
.numpy()
)
move_from_buffer( move_from_buffer(
expert_weights=expert_weights, expert_weights=expert_weights,
expert_weights_buffers=expert_weights_buffer, expert_weights_buffers=expert_weights_buffer,
...@@ -1019,7 +1043,7 @@ class EplbState: ...@@ -1019,7 +1043,7 @@ class EplbState:
ep_rank=ep_group.rank(), ep_rank=ep_group.rank(),
) )
# Record event after consuming buffer to signal async thread # Record event after consuming buffer to signal async thread
# that it's safe to overwrite the buffer # that it's safe to overwrite the intermediate buffer
consumed_event = torch.cuda.Event() consumed_event = torch.cuda.Event()
consumed_event.record() consumed_event.record()
model_state.buffer_consumed_event = consumed_event model_state.buffer_consumed_event = consumed_event
...@@ -1034,6 +1058,18 @@ class EplbState: ...@@ -1034,6 +1058,18 @@ class EplbState:
model_state.model_name, model_state.model_name,
transferred_layer, transferred_layer,
) )
if model_state.layer_to_transfer >= model_state.model.num_moe_layers:
self.post_eplb(model_state, is_profile)
model_state.rebalanced = False
model_state.layer_to_transfer = 0
model_state.pending_global_ready_check = False
logger.info(
"finish async transfer for model %s rank %d layer %d",
model_state.model_name,
ep_group.rank(),
model_state.model.num_moe_layers,
)
finally: finally:
try: try:
model_state.buffer_lock.release() model_state.buffer_lock.release()
...@@ -1048,9 +1084,7 @@ class EplbState: ...@@ -1048,9 +1084,7 @@ class EplbState:
assert model_state.new_physical_to_logical_map is not None assert model_state.new_physical_to_logical_map is not None
assert model_state.new_logical_to_physical_map is not None assert model_state.new_logical_to_physical_map is not None
assert model_state.new_logical_replica_count is not None assert model_state.new_logical_replica_count is not None
if not is_profile:
for layer_idx in range(model_state.physical_to_logical_map.shape[0]):
self._update_layer_mapping_from_new(model_state, layer_idx)
model_state.new_physical_to_logical_map = None model_state.new_physical_to_logical_map = None
model_state.new_logical_to_physical_map = None model_state.new_logical_to_physical_map = None
model_state.new_logical_replica_count = None model_state.new_logical_replica_count = None
......
...@@ -434,13 +434,12 @@ def move_from_buffer( ...@@ -434,13 +434,12 @@ def move_from_buffer(
async def transfer_layer( async def transfer_layer(
old_global_expert_indices: torch.Tensor, old_layer_indices: torch.Tensor,
new_global_expert_indices: torch.Tensor, new_layer_indices: torch.Tensor,
expert_weights: Sequence[Sequence[torch.Tensor]], expert_weights: Sequence[torch.Tensor],
expert_weights_buffer: Sequence[torch.Tensor], expert_weights_buffer: Sequence[torch.Tensor],
ep_group: ProcessGroup, ep_group: ProcessGroup,
is_profile: bool = False, is_profile: bool = False,
layer: int = 0,
cuda_stream: torch.cuda.Stream | None = None, cuda_stream: torch.cuda.Stream | None = None,
rank_mapping: dict[int, int] | None = None, rank_mapping: dict[int, int] | None = None,
) -> MoveToBufferResult: ) -> MoveToBufferResult:
...@@ -451,56 +450,64 @@ async def transfer_layer( ...@@ -451,56 +450,64 @@ async def transfer_layer(
while keys are physical. while keys are physical.
Args: Args:
old_global_expert_indices: Shape (num_moe_layers, num_physical_experts). old_layer_indices: Shape (num_physical_experts,).
new_global_expert_indices: Shape (num_moe_layers, num_physical_experts). new_layer_indices: Shape (num_physical_experts,).
expert_weights: A sequence of shape (num_moe_layers)(weight_count) expert_weights: Iterable of weight tensors for this layer, each with shape
of tensors of shape (num_local_physical_experts, hidden_size_i). (num_local_physical_experts, hidden_size_i).
For example, a linear layer may have up and down projection, For example, a linear layer may have up and down projection.
so weight_count = 2. Each weight's hidden size can be different. expert_weights_buffer: Intermediate buffers (one per weight tensor).
ep_group: The device process group for expert parallelism. ep_group: The device process group for expert parallelism.
is_profile (bool): If `True`, do not perform any actual weight copy. is_profile (bool): If `True`, do not perform any actual weight copy.
This is used during profile run, where we only perform dummy This is used during profile run, where we only perform dummy
communications to reserve enough memory for the buffers. communications to reserve enough memory for the buffers.
cuda_stream: CUDA stream for async copies (can be None for sync mode).
rank_mapping: Optional rank mapping for elastic expert parallelism.
Returns: Returns:
is_unchanged (np.ndarray): (1, num_local_experts), True where expert is_unchanged (np.ndarray): (num_local_experts,), True where expert
is left unchanged. is left unchanged.
is_received_locally (np.ndarray): (1, num_local_experts), True where expert is_received_locally (np.ndarray): (num_local_experts,), True where expert
can be received locally. can be received locally.
RecvMetadata: Metadata needed for completing remote weight transfers. RecvMetadata: Metadata needed for completing remote weight transfers.
""" """
ep_size = ep_group.size() ep_size = ep_group.size()
if rank_mapping is not None: if rank_mapping is not None:
# Add a layer dimension for compatibility with mapping functions
old_layer_indices_2d = old_layer_indices.unsqueeze(0)
new_layer_indices_2d = new_layer_indices.unsqueeze(0)
if len(rank_mapping) == ep_group.size(): if len(rank_mapping) == ep_group.size():
# scale down # scale down
new_global_expert_indices = _map_new_expert_indices_with_rank_mapping( new_layer_indices_2d = _map_new_expert_indices_with_rank_mapping(
new_global_expert_indices, new_layer_indices_2d,
rank_mapping, rank_mapping,
) )
else: else:
# scale up # scale up
old_global_expert_indices = _map_old_expert_indices_with_rank_mapping( old_layer_indices_2d = _map_old_expert_indices_with_rank_mapping(
old_global_expert_indices, old_layer_indices_2d,
rank_mapping, rank_mapping,
ep_group.size(), ep_group.size(),
) )
assert old_global_expert_indices.shape[1] == new_global_expert_indices.shape[1] # Remove the layer dimension
num_moe_layers, num_physical_experts = old_global_expert_indices.shape old_layer_indices = old_layer_indices_2d.squeeze(0)
assert len(expert_weights) == num_moe_layers new_layer_indices = new_layer_indices_2d.squeeze(0)
assert old_layer_indices.shape == new_layer_indices.shape
num_physical_experts = old_layer_indices.shape[0]
assert len(expert_weights[0]) >= 1 assert len(expert_weights[0]) >= 1
num_local_physical_experts = expert_weights[0][0].shape[0] num_local_physical_experts = expert_weights[0].shape[0]
assert new_global_expert_indices.shape == (num_moe_layers, num_physical_experts)
assert num_physical_experts == ep_size * num_local_physical_experts assert num_physical_experts == ep_size * num_local_physical_experts
old_global_expert_indices_np = old_global_expert_indices.cpu().numpy() old_layer_indices_np = old_layer_indices.cpu().numpy()
new_global_expert_indices_np = new_global_expert_indices.cpu().numpy() new_layer_indices_np = new_layer_indices.cpu().numpy()
is_unchanged, is_received_locally, recv_metadata = move_to_buffer( is_unchanged, is_received_locally, recv_metadata = move_to_buffer(
num_local_experts=num_local_physical_experts, num_local_experts=num_local_physical_experts,
old_indices=old_global_expert_indices_np[layer], old_indices=old_layer_indices_np,
new_indices=new_global_expert_indices_np[layer], new_indices=new_layer_indices_np,
expert_weights=expert_weights[layer], expert_weights=expert_weights,
expert_weights_buffers=expert_weights_buffer, expert_weights_buffers=expert_weights_buffer,
cuda_stream=cuda_stream, cuda_stream=cuda_stream,
ep_group=ep_group, ep_group=ep_group,
......
...@@ -1143,6 +1143,18 @@ def get_ep_group() -> GroupCoordinator: ...@@ -1143,6 +1143,18 @@ def get_ep_group() -> GroupCoordinator:
return _EP return _EP
_EPLB: GroupCoordinator | None = None
def get_eplb_group() -> GroupCoordinator:
assert _EPLB is not None, (
"EPLB group is not initialized. "
"EPLB group is only created for MoE models when EPLB is enabled. "
"Ensure parallel_config.enable_eplb is True."
)
return _EPLB
_PCP: GroupCoordinator | None = None _PCP: GroupCoordinator | None = None
...@@ -1440,12 +1452,29 @@ def initialize_model_parallel( ...@@ -1440,12 +1452,29 @@ def initialize_model_parallel(
_EP = init_model_parallel_group( _EP = init_model_parallel_group(
group_ranks, get_world_group().local_rank, backend, group_name="ep" group_ranks, get_world_group().local_rank, backend, group_name="ep"
) )
# Create EPLB group with the same ranks as EP if EPLB is enabled.
# This is a separate process group to isolate EPLB communications
# from MoE forward pass collectives and prevent deadlocks when
# using torch.distributed in execution with torch.distributed in EPLB.
global _EPLB
assert _EPLB is None, "EPLB group is already initialized"
if (
config is not None
and config.parallel_config is not None
and config.parallel_config.enable_eplb
):
# Reuse the same group_ranks from EP
_EPLB = init_model_parallel_group(
group_ranks, get_world_group().local_rank, backend, group_name="eplb"
)
# If no EP group needed, _EP remains None # If no EP group needed, _EP remains None
# If no EPLB group needed, _EPLB remains None
logger.info_once( logger.info_once(
"rank %s in world size %s is assigned as " "rank %s in world size %s is assigned as "
"DP rank %s, PP rank %s, PCP rank %s, " "DP rank %s, PP rank %s, PCP rank %s, "
"TP rank %s, EP rank %s", "TP rank %s, EP rank %s, EPLB rank %s",
rank, rank,
world_size, world_size,
_DP.rank_in_group, _DP.rank_in_group,
...@@ -1453,6 +1482,7 @@ def initialize_model_parallel( ...@@ -1453,6 +1482,7 @@ def initialize_model_parallel(
_PCP.rank_in_group, _PCP.rank_in_group,
_TP.rank_in_group, _TP.rank_in_group,
_EP.rank_in_group if _EP is not None else "N/A", _EP.rank_in_group if _EP is not None else "N/A",
_EPLB.rank_in_group if _EPLB is not None else "N/A",
) )
...@@ -1514,6 +1544,8 @@ def prepare_communication_buffer_for_model(model: torch.nn.Module): ...@@ -1514,6 +1544,8 @@ def prepare_communication_buffer_for_model(model: torch.nn.Module):
_DP.prepare_communication_buffer_for_model(model) _DP.prepare_communication_buffer_for_model(model)
if _EP is not None: if _EP is not None:
_EP.prepare_communication_buffer_for_model(model) _EP.prepare_communication_buffer_for_model(model)
if _EPLB is not None:
_EPLB.prepare_communication_buffer_for_model(model)
def model_parallel_is_initialized(): def model_parallel_is_initialized():
...@@ -1608,6 +1640,11 @@ def destroy_model_parallel(): ...@@ -1608,6 +1640,11 @@ def destroy_model_parallel():
_EP.destroy() _EP.destroy()
_EP = None _EP = None
global _EPLB
if _EPLB:
_EPLB.destroy()
_EPLB = None
def destroy_distributed_environment(): def destroy_distributed_environment():
global _WORLD, _NODE_COUNT global _WORLD, _NODE_COUNT
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment