Commit 4b650671 authored by Rick Ho's avatar Rick Ho
Browse files

fix bmm out shape

parent dc3db673
......@@ -131,7 +131,7 @@ class FMoETransformerMLP(nn.Module):
core_out = x.view(-1, self.top_k, self.d_model)
# to: (BxL) x 1 x d_model
core_out = torch.bmm(gate_score, core_out)
output = core_out + residual
output = core_out.reshape(residual.shape) + residual
if not self.pre_lnorm:
output = self.layer_norm(output)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment