Commit 61e4533f authored by Jiezhong Qiu's avatar Jiezhong Qiu
Browse files

add final dropout before residual

parent 5f4441ee
...@@ -49,10 +49,12 @@ class FMoETransformerMLP(FMoE): ...@@ -49,10 +49,12 @@ class FMoETransformerMLP(FMoE):
top_k=2, top_k=2,
do_lnorm=False, do_lnorm=False,
pre_lnorm=False, pre_lnorm=False,
expert_dp_comm='none' expert_dp_comm='none',
dropout=0.1
): ):
super().__init__(num_expert=num_expert, d_model=d_model, gate=gate, super().__init__(num_expert=num_expert, d_model=d_model, gate=gate,
top_k=top_k, world_size=world_size, mp_group=mp_group) top_k=top_k, world_size=world_size, mp_group=mp_group)
self.dropout = nn.Dropout(dropout)
self.experts = _Expert(num_expert, d_model, d_hidden, activation, self.experts = _Expert(num_expert, d_model, d_hidden, activation,
rank=self.mp_rank) rank=self.mp_rank)
self.pre_lnorm = pre_lnorm self.pre_lnorm = pre_lnorm
...@@ -72,7 +74,9 @@ class FMoETransformerMLP(FMoE): ...@@ -72,7 +74,9 @@ class FMoETransformerMLP(FMoE):
inp = inp.reshape(-1, self.d_model) inp = inp.reshape(-1, self.d_model)
if self.pre_lnorm is not None and self.pre_lnorm: if self.pre_lnorm is not None and self.pre_lnorm:
inp = self.layer_norm(inp) inp = self.layer_norm(inp)
output = super().forward(inp) + inp output = super().forward(inp)
output = self.dropout(output)
output += inp
if self.pre_lnorm is not None and not self.pre_lnorm: if self.pre_lnorm is not None and not self.pre_lnorm:
output = self.layer_norm(output) output = self.layer_norm(output)
return output.reshape(original_shape) return output.reshape(original_shape)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment