Commit fe2009b1 authored by Jiezhong Qiu's avatar Jiezhong Qiu
Browse files

simplify TransformerMLP

move away layernorm/residual/dropout
parent b56c8043
......@@ -384,11 +384,29 @@ class CustomizedMoEPositionwiseFF(FMoETransformerMLP):
nn.Dropout(dropout)
)
super().__init__(num_expert=moe_num_expert, d_model=d_model, d_hidden=d_inner, top_k=moe_top_k,
do_lnorm=True, pre_lnorm=pre_lnorm, activation=activation, dropout=dropout)
activation=activation)
def forward(self, x):
x = super().forward(x)
return x
self.pre_lnorm = pre_lnorm
self.layer_norm = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, inp):
if self.pre_lnorm:
##### layer normalization + positionwise feed-forward
core_out = super().forward(self.layer_norm(inp))
core_out = self.dropout(core_out)
##### residual connection
output = core_out + inp
else:
##### positionwise feed-forward
core_out = super().forward(inp)
core_out = self.dropout(core_out)
##### residual connection + layer normalization
output = self.layer_norm(inp + core_out)
return output
class DecoderLayer(nn.Module):
def __init__(self, n_head, d_model, d_head, d_inner, dropout, **kwargs):
......
......@@ -44,25 +44,15 @@ class FMoETransformerMLP(FMoE):
d_hidden=4096,
world_size=1,
mp_group=None,
activation=torch.nn.functional.gelu,
activation=torch.nn.GELU(),
gate=NaiveGate,
top_k=2,
do_lnorm=False,
pre_lnorm=False,
expert_dp_comm='none',
dropout=0.1
):
super().__init__(num_expert=num_expert, d_model=d_model, gate=gate,
top_k=top_k, world_size=world_size, mp_group=mp_group)
self.dropout = nn.Dropout(dropout)
self.experts = _Expert(num_expert, d_model, d_hidden, activation,
rank=self.mp_rank)
self.pre_lnorm = pre_lnorm
if do_lnorm:
self.layer_norm = nn.LayerNorm(d_model)
self.pre_lnorm = pre_lnorm
else:
self.pre_lnorm = None
self.mark_parallel_comm(expert_dp_comm)
def forward(self, inp: torch.Tensor):
......@@ -72,11 +62,5 @@ class FMoETransformerMLP(FMoE):
'''
original_shape = inp.shape
inp = inp.reshape(-1, self.d_model)
if self.pre_lnorm is not None and self.pre_lnorm:
inp = self.layer_norm(inp)
output = super().forward(inp)
output = self.dropout(output)
output += inp
if self.pre_lnorm is not None and not self.pre_lnorm:
output = self.layer_norm(output)
return output.reshape(original_shape)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment