Unverified Commit c499591a authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Add canary for EPLB rebalancing (#6895)

parent e1ce44cd
......@@ -61,7 +61,62 @@ class ExpertLocationUpdater:
)
def _update_expert_weights(
def _update_expert_weights(**kwargs):
if get_bool_env_var("SGLANG_EXPERT_LOCATION_UPDATER_CANARY"):
return _update_expert_weights_with_canary(**kwargs)
else:
return _update_expert_weights_raw(**kwargs)
# can add watchdog as well
def _update_expert_weights_with_canary(
routed_experts_weights_of_layer: Dict[int, List[torch.Tensor]],
old_expert_location_metadata: ExpertLocationMetadata,
new_expert_location_metadata: ExpertLocationMetadata,
update_layer_ids: List[int],
nnodes: int,
rank: int,
):
num_local_physical_experts = old_expert_location_metadata.num_local_physical_experts
def _get_canary_value(meta: ExpertLocationMetadata, layer_id: int):
return meta.physical_to_logical_map_cpu[
layer_id,
num_local_physical_experts * rank : num_local_physical_experts * (rank + 1),
]
routed_experts_weights_of_layer = {
k: [x for x in v] for k, v in routed_experts_weights_of_layer.items()
}
for layer_id in update_layer_ids:
canary_tensor = (
_get_canary_value(old_expert_location_metadata, layer_id)
.clone()
.to(device=global_server_args_dict["device"], non_blocking=True)
)
routed_experts_weights_of_layer[layer_id].append(canary_tensor)
_update_expert_weights_raw(
routed_experts_weights_of_layer=routed_experts_weights_of_layer,
old_expert_location_metadata=old_expert_location_metadata,
new_expert_location_metadata=new_expert_location_metadata,
update_layer_ids=update_layer_ids,
nnodes=nnodes,
rank=rank,
)
for layer_id in update_layer_ids:
# can optimize speed if needed
expect_value = _get_canary_value(new_expert_location_metadata, layer_id)
actual_value = routed_experts_weights_of_layer[layer_id][-1].cpu()
assert torch.all(expect_value == actual_value), (
f"{expect_value=} {actual_value=} {layer_id=} "
f"{old_expert_location_metadata.physical_to_logical_map_cpu.tolist()=} "
f"{new_expert_location_metadata.physical_to_logical_map_cpu.tolist()=} "
)
def _update_expert_weights_raw(
routed_experts_weights_of_layer: Dict[int, List[torch.Tensor]],
old_expert_location_metadata: ExpertLocationMetadata,
new_expert_location_metadata: ExpertLocationMetadata,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment