Unverified Commit 1b8cf77b authored by Cheng Wan's avatar Cheng Wan Committed by GitHub
Browse files

[Fix] incorrect assert in EPLB (#7575)

parent bb9b608c
......@@ -479,10 +479,6 @@ class _SelectExpertsSinglePassGatherer(_LayerBasedGpuSinglePassGatherer):
def on_select_experts(self, layer_idx: int, topk_ids: torch.Tensor):
topk_ids = topk_ids.flatten()
mask = topk_ids != -1
assert self._data[layer_idx, :].shape == topk_ids.shape, (
"Shape mismatch between data and topk_ids."
"Selecting expert is not supported for multiple token prediction at the moment."
)
self._data[layer_idx, :].scatter_add_(
dim=0, index=topk_ids.masked_fill(~mask, 0).long(), src=mask.int()
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment