Unverified Commit 3173441b authored by Sage Moore's avatar Sage Moore Committed by GitHub
Browse files

[EPLB] Consolidate is_unchanged/is_received_locally into TransferMetadata (#37341)


Signed-off-by: default avatarSage Moore <sage@neuralmagic.com>
parent 8b1f3beb
...@@ -361,7 +361,7 @@ def _test_async_transfer_layer_without_mtp_worker( ...@@ -361,7 +361,7 @@ def _test_async_transfer_layer_without_mtp_worker(
communicator.set_stream(cuda_stream) communicator.set_stream(cuda_stream)
for layer_idx in range(num_layers): for layer_idx in range(num_layers):
is_unchanged, is_received_locally, recv_metadata = asyncio.run( transfer_metadata = asyncio.run(
transfer_layer( transfer_layer(
old_layer_indices=old_indices_cpu[layer_idx], old_layer_indices=old_indices_cpu[layer_idx],
new_layer_indices=new_indices_cpu[layer_idx], new_layer_indices=new_indices_cpu[layer_idx],
...@@ -376,9 +376,7 @@ def _test_async_transfer_layer_without_mtp_worker( ...@@ -376,9 +376,7 @@ def _test_async_transfer_layer_without_mtp_worker(
move_from_buffer( move_from_buffer(
expert_weights=expert_weights[layer_idx], expert_weights=expert_weights[layer_idx],
expert_weights_buffers=expert_buffer, expert_weights_buffers=expert_buffer,
is_unchanged=is_unchanged, transfer_metadata=transfer_metadata,
is_received_locally=is_received_locally,
recv_metadata=recv_metadata,
new_indices=new_indices_cpu[layer_idx].numpy(), new_indices=new_indices_cpu[layer_idx].numpy(),
ep_rank=ep_rank, ep_rank=ep_rank,
) )
......
...@@ -118,11 +118,7 @@ async def transfer_run_periodically( ...@@ -118,11 +118,7 @@ async def transfer_run_periodically(
# model_state.expert_buffer, which will be consumed by the main thread in # model_state.expert_buffer, which will be consumed by the main thread in
# move_to_workspace # move_to_workspace
while model_state.rebalanced and layer_idx < num_layers: while model_state.rebalanced and layer_idx < num_layers:
( transfer_metadata = await transfer_layer(
is_unchanged,
is_received_locally,
recv_metadata,
) = await transfer_layer(
old_layer_indices=physical_to_logical_map_cpu[layer_idx], old_layer_indices=physical_to_logical_map_cpu[layer_idx],
new_layer_indices=new_physical_to_logical_map[layer_idx], new_layer_indices=new_physical_to_logical_map[layer_idx],
expert_weights=model_state.model.expert_weights[layer_idx], expert_weights=model_state.model.expert_weights[layer_idx],
...@@ -145,9 +141,7 @@ async def transfer_run_periodically( ...@@ -145,9 +141,7 @@ async def transfer_run_periodically(
model_state.pending_result = AsyncEplbLayerResult( model_state.pending_result = AsyncEplbLayerResult(
layer_idx=layer_idx, layer_idx=layer_idx,
new_physical_to_logical_map=new_physical_to_logical_map[layer_idx], new_physical_to_logical_map=new_physical_to_logical_map[layer_idx],
is_unchanged=is_unchanged, transfer_metadata=transfer_metadata,
is_received_locally=is_received_locally,
recv_metadata=recv_metadata,
consumed_event=consumed_event, consumed_event=consumed_event,
) )
......
...@@ -1147,9 +1147,7 @@ def _move_to_workspace( ...@@ -1147,9 +1147,7 @@ def _move_to_workspace(
move_from_buffer( move_from_buffer(
expert_weights=model_state.model.expert_weights[result.layer_idx], expert_weights=model_state.model.expert_weights[result.layer_idx],
expert_weights_buffers=model_state.expert_buffer, expert_weights_buffers=model_state.expert_buffer,
is_unchanged=result.is_unchanged, transfer_metadata=result.transfer_metadata,
is_received_locally=result.is_received_locally,
recv_metadata=result.recv_metadata,
new_indices=result.new_physical_to_logical_map.numpy(), new_indices=result.new_physical_to_logical_map.numpy(),
ep_rank=ep_rank, ep_rank=ep_rank,
) )
......
...@@ -21,9 +21,13 @@ logger = init_logger(__name__) ...@@ -21,9 +21,13 @@ logger = init_logger(__name__)
@dataclass @dataclass
class RecvMetadata: class TransferMetadata:
"""Metadata describing remote receives during EPLB rebalancing.""" """Metadata describing a completed EPLB buffer transfer."""
is_unchanged: np.ndarray
"""Mask of (num_local_experts,) indicating experts unchanged after rebalance."""
is_received_locally: np.ndarray
"""Mask of (num_local_experts,) indicating experts received from local data."""
recv_primary_mask: np.ndarray recv_primary_mask: np.ndarray
"""Mask of (num_local_experts,) indicating primary experts received.""" """Mask of (num_local_experts,) indicating primary experts received."""
recv_count: int recv_count: int
...@@ -34,10 +38,6 @@ class RecvMetadata: ...@@ -34,10 +38,6 @@ class RecvMetadata:
"""Target expert indices (num_local_experts,) in local tensors to send.""" """Target expert indices (num_local_experts,) in local tensors to send."""
# Type alias for the result of move_to_buffer or transfer_layer
MoveToBufferResult = tuple[np.ndarray, np.ndarray, RecvMetadata]
@dataclass @dataclass
class AsyncEplbLayerResult: class AsyncEplbLayerResult:
""" """
...@@ -51,11 +51,7 @@ class AsyncEplbLayerResult: ...@@ -51,11 +51,7 @@ class AsyncEplbLayerResult:
New physical→logical mapping for layers_idx, on CPU. New physical→logical mapping for layers_idx, on CPU.
Shape: (num_physical_experts) Shape: (num_physical_experts)
""" """
is_unchanged: np.ndarray transfer_metadata: TransferMetadata
"""Per-physical-expert flag: weight was not moved during transfer."""
is_received_locally: np.ndarray
"""Per-physical-expert flag: weight was received on this rank."""
recv_metadata: RecvMetadata
"""Metadata describing what was received during transfer_layer.""" """Metadata describing what was received during transfer_layer."""
consumed_event: CpuGpuEvent consumed_event: CpuGpuEvent
""" """
...@@ -182,7 +178,7 @@ def move_to_buffer( ...@@ -182,7 +178,7 @@ def move_to_buffer(
cuda_stream: torch.cuda.Stream | None, cuda_stream: torch.cuda.Stream | None,
ep_rank: int, ep_rank: int,
communicator: EplbCommunicator, communicator: EplbCommunicator,
) -> MoveToBufferResult: ) -> TransferMetadata:
""" """
Rearranges expert weights during EPLB rebalancing. Rearranges expert weights during EPLB rebalancing.
...@@ -199,11 +195,7 @@ def move_to_buffer( ...@@ -199,11 +195,7 @@ def move_to_buffer(
communicator: EplbCommunicator instance for P2P communication. communicator: EplbCommunicator instance for P2P communication.
Returns: Returns:
is_unchanged (np.ndarray): (num_local_experts,), True where an expert row TransferMetadata: Metadata needed for completing remote weight transfers.
is unchanged after rebalance.
is_received_locally (np.ndarray): (num_local_experts,), True where a row
can be updated from local data.
RecvMetadata: Metadata needed for completing remote weight transfers.
""" """
assert old_indices.shape == new_indices.shape assert old_indices.shape == new_indices.shape
recv_primary_mask = np.zeros((num_local_experts,), dtype=np.bool_) recv_primary_mask = np.zeros((num_local_experts,), dtype=np.bool_)
...@@ -339,24 +331,20 @@ def move_to_buffer( ...@@ -339,24 +331,20 @@ def move_to_buffer(
# 4. Execute the P2P operations. The real communication happens here. # 4. Execute the P2P operations. The real communication happens here.
communicator.execute() communicator.execute()
# wait for the communication to finish # wait for the communication to finish
return ( return TransferMetadata(
is_unchanged, is_unchanged=is_unchanged,
is_received_locally, is_received_locally=is_received_locally,
RecvMetadata( recv_primary_mask=recv_primary_mask,
recv_primary_mask=recv_primary_mask, recv_count=recv_count,
recv_count=recv_count, recv_expert_ids=recv_expert_ids,
recv_expert_ids=recv_expert_ids, recv_dst_rows=recv_dst_rows,
recv_dst_rows=recv_dst_rows,
),
) )
def move_from_buffer( def move_from_buffer(
expert_weights: Sequence[torch.Tensor], expert_weights: Sequence[torch.Tensor],
expert_weights_buffers: list[torch.Tensor], expert_weights_buffers: list[torch.Tensor],
is_unchanged: np.ndarray, transfer_metadata: TransferMetadata,
is_received_locally: np.ndarray,
recv_metadata: RecvMetadata,
new_indices: np.ndarray, new_indices: np.ndarray,
ep_rank: int, ep_rank: int,
) -> None: ) -> None:
...@@ -368,17 +356,17 @@ def move_from_buffer( ...@@ -368,17 +356,17 @@ def move_from_buffer(
expert_weights: List of the actual MoE layer weights used in the execution. expert_weights: List of the actual MoE layer weights used in the execution.
expert_weights_buffers: Intermediate buffers containing the experts weights expert_weights_buffers: Intermediate buffers containing the experts weights
after the transfer is completed. after the transfer is completed.
is_unchanged: (num_local_experts,), True where an expert row is unchanged. transfer_metadata: TransferMetadata containing transfer metadata.
is_received_locally: (num_local_experts,), True where a row is updated locally.
recv_metadata: RecvMetadata containing remote receive metadata.
new_indices: (num_experts_total,) mapping from local rows to desired new_indices: (num_experts_total,) mapping from local rows to desired
(possibly global) expert id, after rebalance. (possibly global) expert id, after rebalance.
ep_rank: Rank of the process in the expert parallel group. ep_rank: Rank of the process in the expert parallel group.
""" """
recv_primary_mask = recv_metadata.recv_primary_mask is_unchanged = transfer_metadata.is_unchanged
recv_count = recv_metadata.recv_count is_received_locally = transfer_metadata.is_received_locally
recv_expert_ids = recv_metadata.recv_expert_ids recv_primary_mask = transfer_metadata.recv_primary_mask
recv_dst_rows = recv_metadata.recv_dst_rows recv_count = transfer_metadata.recv_count
recv_expert_ids = transfer_metadata.recv_expert_ids
recv_dst_rows = transfer_metadata.recv_dst_rows
num_local_experts = is_unchanged.shape[0] num_local_experts = is_unchanged.shape[0]
# Mask for rows to copy back from buffers: # Mask for rows to copy back from buffers:
...@@ -440,7 +428,7 @@ async def transfer_layer( ...@@ -440,7 +428,7 @@ async def transfer_layer(
is_profile: bool = False, is_profile: bool = False,
cuda_stream: torch.cuda.Stream | None = None, cuda_stream: torch.cuda.Stream | None = None,
rank_mapping: dict[int, int] | None = None, rank_mapping: dict[int, int] | None = None,
) -> MoveToBufferResult: ) -> TransferMetadata:
""" """
Rearranges the expert weights in place according to the new expert indices. Rearranges the expert weights in place according to the new expert indices.
...@@ -463,11 +451,8 @@ async def transfer_layer( ...@@ -463,11 +451,8 @@ async def transfer_layer(
rank_mapping: Optional rank mapping for elastic expert parallelism. rank_mapping: Optional rank mapping for elastic expert parallelism.
Returns: Returns:
is_unchanged (np.ndarray): (num_local_experts,), True where expert TransferMetadata: Metadata needed for completing remote weight transfers,
is left unchanged. including is_unchanged and is_received_locally masks.
is_received_locally (np.ndarray): (num_local_experts,), True where expert
can be received locally.
RecvMetadata: Metadata needed for completing remote weight transfers.
""" """
ep_size = ep_group.size() ep_size = ep_group.size()
if rank_mapping is not None: if rank_mapping is not None:
...@@ -502,7 +487,7 @@ async def transfer_layer( ...@@ -502,7 +487,7 @@ async def transfer_layer(
old_layer_indices_np = old_layer_indices.cpu().numpy() old_layer_indices_np = old_layer_indices.cpu().numpy()
new_layer_indices_np = new_layer_indices.cpu().numpy() new_layer_indices_np = new_layer_indices.cpu().numpy()
is_unchanged, is_received_locally, recv_metadata = move_to_buffer( return move_to_buffer(
num_local_experts=num_local_physical_experts, num_local_experts=num_local_physical_experts,
old_indices=old_layer_indices_np, old_indices=old_layer_indices_np,
new_indices=new_layer_indices_np, new_indices=new_layer_indices_np,
...@@ -512,7 +497,6 @@ async def transfer_layer( ...@@ -512,7 +497,6 @@ async def transfer_layer(
ep_rank=ep_group.rank(), ep_rank=ep_group.rank(),
communicator=communicator, communicator=communicator,
) )
return is_unchanged, is_received_locally, recv_metadata
def rearrange_expert_weights_inplace( def rearrange_expert_weights_inplace(
...@@ -605,7 +589,7 @@ def rearrange_expert_weights_inplace( ...@@ -605,7 +589,7 @@ def rearrange_expert_weights_inplace(
new_global_expert_indices_cpu = new_global_expert_indices.cpu().numpy() new_global_expert_indices_cpu = new_global_expert_indices.cpu().numpy()
for layer_idx in range(num_moe_layers): for layer_idx in range(num_moe_layers):
is_unchanged, is_received_locally, recv_metadata = move_to_buffer( transfer_metadata = move_to_buffer(
num_local_experts=num_local_physical_experts, num_local_experts=num_local_physical_experts,
old_indices=old_global_expert_indices_cpu[layer_idx], old_indices=old_global_expert_indices_cpu[layer_idx],
new_indices=new_global_expert_indices_cpu[layer_idx], new_indices=new_global_expert_indices_cpu[layer_idx],
...@@ -619,9 +603,7 @@ def rearrange_expert_weights_inplace( ...@@ -619,9 +603,7 @@ def rearrange_expert_weights_inplace(
move_from_buffer( move_from_buffer(
expert_weights=expert_weights[layer_idx], expert_weights=expert_weights[layer_idx],
expert_weights_buffers=weights_buffer, expert_weights_buffers=weights_buffer,
is_unchanged=is_unchanged, transfer_metadata=transfer_metadata,
is_received_locally=is_received_locally,
recv_metadata=recv_metadata,
new_indices=new_global_expert_indices_cpu[layer_idx], new_indices=new_global_expert_indices_cpu[layer_idx],
ep_rank=ep_rank, ep_rank=ep_rank,
) )
...@@ -715,4 +697,4 @@ def _map_new_expert_indices_with_rank_mapping( ...@@ -715,4 +697,4 @@ def _map_new_expert_indices_with_rank_mapping(
return mapped_expert_indices return mapped_expert_indices
__all__ = ["transfer_layer", "move_from_buffer", "RecvMetadata"] __all__ = ["transfer_layer", "move_from_buffer", "TransferMetadata"]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment