Commit 1c555604 authored by Jiezhong Qiu's avatar Jiezhong Qiu
Browse files

check moe input/output batch sizes are the same

parent 27c8c2f3
...@@ -179,6 +179,14 @@ class FMoE(nn.Module): ...@@ -179,6 +179,14 @@ class FMoE(nn.Module):
according to the gate. The score of the selected gate given by the according to the gate. The score of the selected gate given by the
expert is multiplied to the experts' output tensors as a weight. expert is multiplied to the experts' output tensors as a weight.
""" """
moe_inp_batch_size = tree.flatten(
tree.map_structure(lambda tensor: tensor.shape[0], moe_inp)
)
assert all(
[batch_size == moe_inp_batch_size[0] for batch_size in moe_inp_batch_size]
), "MoE inputs must have the same batch size"
if self.world_size > 1: if self.world_size > 1:
def ensure_comm_func(tensor): def ensure_comm_func(tensor):
...@@ -199,14 +207,14 @@ class FMoE(nn.Module): ...@@ -199,14 +207,14 @@ class FMoE(nn.Module):
if self.gate_hook is not None: if self.gate_hook is not None:
self.gate_hook(gate_top_k_idx, gate_score, None) self.gate_hook(gate_top_k_idx, gate_score, None)
# delete masked tensors
if self.mask is not None and self.mask_dict is not None:
# TODO: to fix # TODO: to fix
def delete_mask_func(tensor): def delete_mask_func(tensor):
# to: (BxL') x d_model # to: (BxL') x d_model
tensor = tensor[mask == 0, :] tensor = tensor[mask == 0, :]
return tensor return tensor
# delete masked tensors
if self.mask is not None and self.mask_dict is not None:
mask = self.mask.view(-1) mask = self.mask.view(-1)
moe_inp = tree.map_structure(delete_mask_func, moe_inp) moe_inp = tree.map_structure(delete_mask_func, moe_inp)
gate_top_k_idx = gate_top_k_idx[mask == 0, :] gate_top_k_idx = gate_top_k_idx[mask == 0, :]
...@@ -263,4 +271,11 @@ class FMoE(nn.Module): ...@@ -263,4 +271,11 @@ class FMoE(nn.Module):
) )
moe_outp = tree.map_structure(all_gather_func, moe_outp) moe_outp = tree.map_structure(all_gather_func, moe_outp)
moe_outp_batch_size = tree.flatten(
tree.map_structure(lambda tensor: tensor.shape[0], moe_outp)
)
assert all(
[batch_size == moe_outp_batch_size[0] for batch_size in moe_outp_batch_size]
), "MoE outputs must have the same batch size"
return moe_outp return moe_outp
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment