Commit fe2009b1 authored by Jiezhong Qiu's avatar Jiezhong Qiu
Browse files

simplify TransformerMLP

move away layernorm/residual/dropout
parent b56c8043
...@@ -384,11 +384,29 @@ class CustomizedMoEPositionwiseFF(FMoETransformerMLP): ...@@ -384,11 +384,29 @@ class CustomizedMoEPositionwiseFF(FMoETransformerMLP):
nn.Dropout(dropout) nn.Dropout(dropout)
) )
super().__init__(num_expert=moe_num_expert, d_model=d_model, d_hidden=d_inner, top_k=moe_top_k, super().__init__(num_expert=moe_num_expert, d_model=d_model, d_hidden=d_inner, top_k=moe_top_k,
do_lnorm=True, pre_lnorm=pre_lnorm, activation=activation, dropout=dropout) activation=activation)
def forward(self, x): self.pre_lnorm = pre_lnorm
x = super().forward(x) self.layer_norm = nn.LayerNorm(d_model)
return x self.dropout = nn.Dropout(dropout)
def forward(self, inp):
if self.pre_lnorm:
##### layer normalization + positionwise feed-forward
core_out = super().forward(self.layer_norm(inp))
core_out = self.dropout(core_out)
##### residual connection
output = core_out + inp
else:
##### positionwise feed-forward
core_out = super().forward(inp)
core_out = self.dropout(core_out)
##### residual connection + layer normalization
output = self.layer_norm(inp + core_out)
return output
class DecoderLayer(nn.Module): class DecoderLayer(nn.Module):
def __init__(self, n_head, d_model, d_head, d_inner, dropout, **kwargs): def __init__(self, n_head, d_model, d_head, d_inner, dropout, **kwargs):
......
...@@ -44,25 +44,15 @@ class FMoETransformerMLP(FMoE): ...@@ -44,25 +44,15 @@ class FMoETransformerMLP(FMoE):
d_hidden=4096, d_hidden=4096,
world_size=1, world_size=1,
mp_group=None, mp_group=None,
activation=torch.nn.functional.gelu, activation=torch.nn.GELU(),
gate=NaiveGate, gate=NaiveGate,
top_k=2, top_k=2,
do_lnorm=False,
pre_lnorm=False,
expert_dp_comm='none', expert_dp_comm='none',
dropout=0.1
): ):
super().__init__(num_expert=num_expert, d_model=d_model, gate=gate, super().__init__(num_expert=num_expert, d_model=d_model, gate=gate,
top_k=top_k, world_size=world_size, mp_group=mp_group) top_k=top_k, world_size=world_size, mp_group=mp_group)
self.dropout = nn.Dropout(dropout)
self.experts = _Expert(num_expert, d_model, d_hidden, activation, self.experts = _Expert(num_expert, d_model, d_hidden, activation,
rank=self.mp_rank) rank=self.mp_rank)
self.pre_lnorm = pre_lnorm
if do_lnorm:
self.layer_norm = nn.LayerNorm(d_model)
self.pre_lnorm = pre_lnorm
else:
self.pre_lnorm = None
self.mark_parallel_comm(expert_dp_comm) self.mark_parallel_comm(expert_dp_comm)
def forward(self, inp: torch.Tensor): def forward(self, inp: torch.Tensor):
...@@ -72,11 +62,5 @@ class FMoETransformerMLP(FMoE): ...@@ -72,11 +62,5 @@ class FMoETransformerMLP(FMoE):
''' '''
original_shape = inp.shape original_shape = inp.shape
inp = inp.reshape(-1, self.d_model) inp = inp.reshape(-1, self.d_model)
if self.pre_lnorm is not None and self.pre_lnorm:
inp = self.layer_norm(inp)
output = super().forward(inp) output = super().forward(inp)
output = self.dropout(output)
output += inp
if self.pre_lnorm is not None and not self.pre_lnorm:
output = self.layer_norm(output)
return output.reshape(original_shape) return output.reshape(original_shape)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment