Commit 9f8d8cd8 authored by Jiezhong Qiu's avatar Jiezhong Qiu
Browse files

softmax rather than relu

parent 51361a3f
......@@ -272,7 +272,7 @@ class MultiHeadPositionwiseFF(nn.Module):
return output
class PositionwiseFF(nn.Module):
def __init__(self, d_model, d_inner, dropout, pre_lnorm=False):
def __init__(self, d_model, d_inner, dropout, pre_lnorm=False, use_softmax=True):
super(PositionwiseFF, self).__init__()
self.d_model = d_model
......@@ -281,7 +281,7 @@ class PositionwiseFF(nn.Module):
self.CoreNet_1 = nn.Sequential(
nn.Linear(d_model, d_inner),
nn.ReLU(inplace=True)
nn.Softmax(dim=-1) if use_softmax else nn.ReLU(inplace=True)
)
self.CoreNet_2 = nn.Sequential(
nn.Dropout(dropout),
......@@ -719,7 +719,7 @@ class DecoderLayer(nn.Module):
self.dec_attn = MultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs)
# self.dec_attn = ExtendedMultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs)
self.pos_ff = HierarchicalMoEPositionwiseFF(d_model, d_inner, dropout,
self.pos_ff = PositionwiseFF(d_model, d_inner, dropout,
pre_lnorm=kwargs.get('pre_lnorm'))
def forward(self, dec_inp, dec_attn_mask=None, mems=None):
......@@ -739,7 +739,7 @@ class RelLearnableDecoderLayer(nn.Module):
self.dec_attn = RelLearnableMultiHeadAttn(n_head, d_model, d_head, dropout,
**kwargs)
self.pos_ff = HierarchicalMoEPositionwiseFF(d_model, d_inner, dropout,
self.pos_ff = PositionwiseFF(d_model, d_inner, dropout,
pre_lnorm=kwargs.get('pre_lnorm'))
def forward(self, dec_inp, r_emb, r_w_bias, r_bias, dec_attn_mask=None, mems=None):
......@@ -760,7 +760,7 @@ class RelPartialLearnableDecoderLayer(nn.Module):
self.dec_attn = RelPartialLearnableMultiHeadAttn(n_head, d_model,
d_head, dropout, **kwargs)
self.pos_ff = HierarchicalMoEPositionwiseFF(d_model, d_inner, dropout,
self.pos_ff = PositionwiseFF(d_model, d_inner, dropout,
pre_lnorm=kwargs.get('pre_lnorm'))
def forward(self, dec_inp, r, r_w_bias, r_r_bias, dec_attn_mask=None, mems=None):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment