Unverified Commit 3173441b authored by Sage Moore's avatar Sage Moore Committed by GitHub
Browse files

[EPLB] Consolidate is_unchanged/is_received_locally into TransferMetadata (#37341)


Signed-off-by: default avatarSage Moore <sage@neuralmagic.com>
parent 8b1f3beb
......@@ -361,7 +361,7 @@ def _test_async_transfer_layer_without_mtp_worker(
communicator.set_stream(cuda_stream)
for layer_idx in range(num_layers):
is_unchanged, is_received_locally, recv_metadata = asyncio.run(
transfer_metadata = asyncio.run(
transfer_layer(
old_layer_indices=old_indices_cpu[layer_idx],
new_layer_indices=new_indices_cpu[layer_idx],
......@@ -376,9 +376,7 @@ def _test_async_transfer_layer_without_mtp_worker(
move_from_buffer(
expert_weights=expert_weights[layer_idx],
expert_weights_buffers=expert_buffer,
is_unchanged=is_unchanged,
is_received_locally=is_received_locally,
recv_metadata=recv_metadata,
transfer_metadata=transfer_metadata,
new_indices=new_indices_cpu[layer_idx].numpy(),
ep_rank=ep_rank,
)
......
......@@ -118,11 +118,7 @@ async def transfer_run_periodically(
# model_state.expert_buffer, which will be consumed by the main thread in
# move_to_workspace
while model_state.rebalanced and layer_idx < num_layers:
(
is_unchanged,
is_received_locally,
recv_metadata,
) = await transfer_layer(
transfer_metadata = await transfer_layer(
old_layer_indices=physical_to_logical_map_cpu[layer_idx],
new_layer_indices=new_physical_to_logical_map[layer_idx],
expert_weights=model_state.model.expert_weights[layer_idx],
......@@ -145,9 +141,7 @@ async def transfer_run_periodically(
model_state.pending_result = AsyncEplbLayerResult(
layer_idx=layer_idx,
new_physical_to_logical_map=new_physical_to_logical_map[layer_idx],
is_unchanged=is_unchanged,
is_received_locally=is_received_locally,
recv_metadata=recv_metadata,
transfer_metadata=transfer_metadata,
consumed_event=consumed_event,
)
......
......@@ -1147,9 +1147,7 @@ def _move_to_workspace(
move_from_buffer(
expert_weights=model_state.model.expert_weights[result.layer_idx],
expert_weights_buffers=model_state.expert_buffer,
is_unchanged=result.is_unchanged,
is_received_locally=result.is_received_locally,
recv_metadata=result.recv_metadata,
transfer_metadata=result.transfer_metadata,
new_indices=result.new_physical_to_logical_map.numpy(),
ep_rank=ep_rank,
)
......
......@@ -21,9 +21,13 @@ logger = init_logger(__name__)
@dataclass
class RecvMetadata:
"""Metadata describing remote receives during EPLB rebalancing."""
class TransferMetadata:
"""Metadata describing a completed EPLB buffer transfer."""
is_unchanged: np.ndarray
"""Mask of (num_local_experts,) indicating experts unchanged after rebalance."""
is_received_locally: np.ndarray
"""Mask of (num_local_experts,) indicating experts received from local data."""
recv_primary_mask: np.ndarray
"""Mask of (num_local_experts,) indicating primary experts received."""
recv_count: int
......@@ -34,10 +38,6 @@ class RecvMetadata:
"""Target expert indices (num_local_experts,) in local tensors to send."""
# Type alias for the result of move_to_buffer or transfer_layer
MoveToBufferResult = tuple[np.ndarray, np.ndarray, RecvMetadata]
@dataclass
class AsyncEplbLayerResult:
"""
......@@ -51,11 +51,7 @@ class AsyncEplbLayerResult:
New physical→logical mapping for layers_idx, on CPU.
Shape: (num_physical_experts)
"""
is_unchanged: np.ndarray
"""Per-physical-expert flag: weight was not moved during transfer."""
is_received_locally: np.ndarray
"""Per-physical-expert flag: weight was received on this rank."""
recv_metadata: RecvMetadata
transfer_metadata: TransferMetadata
"""Metadata describing what was received during transfer_layer."""
consumed_event: CpuGpuEvent
"""
......@@ -182,7 +178,7 @@ def move_to_buffer(
cuda_stream: torch.cuda.Stream | None,
ep_rank: int,
communicator: EplbCommunicator,
) -> MoveToBufferResult:
) -> TransferMetadata:
"""
Rearranges expert weights during EPLB rebalancing.
......@@ -199,11 +195,7 @@ def move_to_buffer(
communicator: EplbCommunicator instance for P2P communication.
Returns:
is_unchanged (np.ndarray): (num_local_experts,), True where an expert row
is unchanged after rebalance.
is_received_locally (np.ndarray): (num_local_experts,), True where a row
can be updated from local data.
RecvMetadata: Metadata needed for completing remote weight transfers.
TransferMetadata: Metadata needed for completing remote weight transfers.
"""
assert old_indices.shape == new_indices.shape
recv_primary_mask = np.zeros((num_local_experts,), dtype=np.bool_)
......@@ -339,24 +331,20 @@ def move_to_buffer(
# 4. Execute the P2P operations. The real communication happens here.
communicator.execute()
# wait for the communication to finish
return (
is_unchanged,
is_received_locally,
RecvMetadata(
recv_primary_mask=recv_primary_mask,
recv_count=recv_count,
recv_expert_ids=recv_expert_ids,
recv_dst_rows=recv_dst_rows,
),
return TransferMetadata(
is_unchanged=is_unchanged,
is_received_locally=is_received_locally,
recv_primary_mask=recv_primary_mask,
recv_count=recv_count,
recv_expert_ids=recv_expert_ids,
recv_dst_rows=recv_dst_rows,
)
def move_from_buffer(
expert_weights: Sequence[torch.Tensor],
expert_weights_buffers: list[torch.Tensor],
is_unchanged: np.ndarray,
is_received_locally: np.ndarray,
recv_metadata: RecvMetadata,
transfer_metadata: TransferMetadata,
new_indices: np.ndarray,
ep_rank: int,
) -> None:
......@@ -368,17 +356,17 @@ def move_from_buffer(
expert_weights: List of the actual MoE layer weights used in the execution.
expert_weights_buffers: Intermediate buffers containing the experts weights
after the transfer is completed.
is_unchanged: (num_local_experts,), True where an expert row is unchanged.
is_received_locally: (num_local_experts,), True where a row is updated locally.
recv_metadata: RecvMetadata containing remote receive metadata.
transfer_metadata: TransferMetadata containing transfer metadata.
new_indices: (num_experts_total,) mapping from local rows to desired
(possibly global) expert id, after rebalance.
ep_rank: Rank of the process in the expert parallel group.
"""
recv_primary_mask = recv_metadata.recv_primary_mask
recv_count = recv_metadata.recv_count
recv_expert_ids = recv_metadata.recv_expert_ids
recv_dst_rows = recv_metadata.recv_dst_rows
is_unchanged = transfer_metadata.is_unchanged
is_received_locally = transfer_metadata.is_received_locally
recv_primary_mask = transfer_metadata.recv_primary_mask
recv_count = transfer_metadata.recv_count
recv_expert_ids = transfer_metadata.recv_expert_ids
recv_dst_rows = transfer_metadata.recv_dst_rows
num_local_experts = is_unchanged.shape[0]
# Mask for rows to copy back from buffers:
......@@ -440,7 +428,7 @@ async def transfer_layer(
is_profile: bool = False,
cuda_stream: torch.cuda.Stream | None = None,
rank_mapping: dict[int, int] | None = None,
) -> MoveToBufferResult:
) -> TransferMetadata:
"""
Rearranges the expert weights in place according to the new expert indices.
......@@ -463,11 +451,8 @@ async def transfer_layer(
rank_mapping: Optional rank mapping for elastic expert parallelism.
Returns:
is_unchanged (np.ndarray): (num_local_experts,), True where expert
is left unchanged.
is_received_locally (np.ndarray): (num_local_experts,), True where expert
can be received locally.
RecvMetadata: Metadata needed for completing remote weight transfers.
TransferMetadata: Metadata needed for completing remote weight transfers,
including is_unchanged and is_received_locally masks.
"""
ep_size = ep_group.size()
if rank_mapping is not None:
......@@ -502,7 +487,7 @@ async def transfer_layer(
old_layer_indices_np = old_layer_indices.cpu().numpy()
new_layer_indices_np = new_layer_indices.cpu().numpy()
is_unchanged, is_received_locally, recv_metadata = move_to_buffer(
return move_to_buffer(
num_local_experts=num_local_physical_experts,
old_indices=old_layer_indices_np,
new_indices=new_layer_indices_np,
......@@ -512,7 +497,6 @@ async def transfer_layer(
ep_rank=ep_group.rank(),
communicator=communicator,
)
return is_unchanged, is_received_locally, recv_metadata
def rearrange_expert_weights_inplace(
......@@ -605,7 +589,7 @@ def rearrange_expert_weights_inplace(
new_global_expert_indices_cpu = new_global_expert_indices.cpu().numpy()
for layer_idx in range(num_moe_layers):
is_unchanged, is_received_locally, recv_metadata = move_to_buffer(
transfer_metadata = move_to_buffer(
num_local_experts=num_local_physical_experts,
old_indices=old_global_expert_indices_cpu[layer_idx],
new_indices=new_global_expert_indices_cpu[layer_idx],
......@@ -619,9 +603,7 @@ def rearrange_expert_weights_inplace(
move_from_buffer(
expert_weights=expert_weights[layer_idx],
expert_weights_buffers=weights_buffer,
is_unchanged=is_unchanged,
is_received_locally=is_received_locally,
recv_metadata=recv_metadata,
transfer_metadata=transfer_metadata,
new_indices=new_global_expert_indices_cpu[layer_idx],
ep_rank=ep_rank,
)
......@@ -715,4 +697,4 @@ def _map_new_expert_indices_with_rank_mapping(
return mapped_expert_indices
__all__ = ["transfer_layer", "move_from_buffer", "RecvMetadata"]
__all__ = ["transfer_layer", "move_from_buffer", "TransferMetadata"]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment