Commit 9e67148c authored by Rick Ho's avatar Rick Ho
Browse files

optional layer-norm in transformer mlp module

parent b97483a4
......@@ -47,6 +47,7 @@ class FMoETransformerMLP(FMoE):
activation=torch.nn.functional.gelu,
gate=NaiveGate,
top_k=2,
do_lnorm=False,
pre_lnorm=False,
expert_dp_comm='none'
):
......@@ -55,7 +56,11 @@ class FMoETransformerMLP(FMoE):
self.experts = _Expert(num_expert, d_model, d_hidden, activation,
rank=self.mp_rank)
self.pre_lnorm = pre_lnorm
self.layer_norm = nn.LayerNorm(d_model)
if do_lnorm:
self.layer_norm = nn.LayerNorm(d_model)
self.pre_lnorm = pre_lnorm
else:
self.pre_lnorm = None
self.mark_parallel_comm(expert_dp_comm)
def forward(self, inp: torch.Tensor):
......@@ -65,9 +70,9 @@ class FMoETransformerMLP(FMoE):
'''
original_shape = inp.shape
inp = inp.reshape(-1, self.d_model)
if self.pre_lnorm:
if self.pre_lnorm is not None and self.pre_lnorm:
inp = self.layer_norm(inp)
output = super().forward(inp) + inp
if not self.pre_lnorm:
if self.pre_lnorm is not None and not self.pre_lnorm:
output = self.layer_norm(output)
return output.reshape(original_shape)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment