Commit fb3e3c29 authored by Rick Ho's avatar Rick Ho
Browse files

update the mem-transformer example

parent b3380ec2
...@@ -825,10 +825,13 @@ class CustomizedMoEPositionwiseFF(FMoETransformerMLP): ...@@ -825,10 +825,13 @@ class CustomizedMoEPositionwiseFF(FMoETransformerMLP):
super().__init__(num_expert=8, d_model=d_model, d_hidden=d_inner, super().__init__(num_expert=8, d_model=d_model, d_hidden=d_inner,
pre_lnorm=pre_lnorm, activation=activation) pre_lnorm=pre_lnorm, activation=activation)
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)
self.bias = nn.Parameter(
torch.zeros(d_model, dtype=torch.float32)
)
def forward(self, x): def forward(self, x):
x, bias = super().forward(x) x = super().forward(x)
return x + bias return x + self.bias
class DecoderLayer(nn.Module): class DecoderLayer(nn.Module):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment