Commit 1c555604 authored by Jiezhong Qiu's avatar Jiezhong Qiu
Browse files

check moe input/output batch sizes are the same

parent 27c8c2f3
......@@ -179,6 +179,14 @@ class FMoE(nn.Module):
according to the gate. The score of the selected gate given by the
expert is multiplied to the experts' output tensors as a weight.
"""
moe_inp_batch_size = tree.flatten(
tree.map_structure(lambda tensor: tensor.shape[0], moe_inp)
)
assert all(
[batch_size == moe_inp_batch_size[0] for batch_size in moe_inp_batch_size]
), "MoE inputs must have the same batch size"
if self.world_size > 1:
def ensure_comm_func(tensor):
......@@ -199,14 +207,14 @@ class FMoE(nn.Module):
if self.gate_hook is not None:
self.gate_hook(gate_top_k_idx, gate_score, None)
# TODO: to fix
def delete_mask_func(tensor):
# to: (BxL') x d_model
tensor = tensor[mask == 0, :]
return tensor
# delete masked tensors
if self.mask is not None and self.mask_dict is not None:
# TODO: to fix
def delete_mask_func(tensor):
# to: (BxL') x d_model
tensor = tensor[mask == 0, :]
return tensor
mask = self.mask.view(-1)
moe_inp = tree.map_structure(delete_mask_func, moe_inp)
gate_top_k_idx = gate_top_k_idx[mask == 0, :]
......@@ -263,4 +271,11 @@ class FMoE(nn.Module):
)
moe_outp = tree.map_structure(all_gather_func, moe_outp)
moe_outp_batch_size = tree.flatten(
tree.map_structure(lambda tensor: tensor.shape[0], moe_outp)
)
assert all(
[batch_size == moe_outp_batch_size[0] for batch_size in moe_outp_batch_size]
), "MoE outputs must have the same batch size"
return moe_outp
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment