Commit 5dc62b41 authored by Jiezhong Qiu's avatar Jiezhong Qiu
Browse files

use nn.Sequential as activation

activation passed by func can not be saved
parent 61e4533f
......@@ -821,11 +821,13 @@ class RelLearnableMultiHeadAttn(RelMultiHeadAttn):
from fmoe import FMoETransformerMLP
class CustomizedMoEPositionwiseFF(FMoETransformerMLP):
def __init__(self, d_model, d_inner, dropout, pre_lnorm=False, moe_num_expert=64, moe_top_k=2):
def activation(x):
return self.dropout(F.relu(x))
activation = nn.Sequential(
nn.Dropout(dropout),
nn.ReLU()
)
super().__init__(num_expert=moe_num_expert, d_model=d_model, d_hidden=d_inner, top_k=moe_top_k,
do_lnorm=True, pre_lnorm=pre_lnorm, activation=activation)
self.dropout = nn.Dropout(dropout)
do_lnorm=True, pre_lnorm=pre_lnorm, activation=activation, dropout=dropout)
#self.dropout = nn.Dropout(dropout)
def forward(self, x):
x = super().forward(x)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment