Commit 9f8d8cd8 authored by Jiezhong Qiu's avatar Jiezhong Qiu
Browse files

softmax rather than relu

parent 51361a3f
...@@ -272,7 +272,7 @@ class MultiHeadPositionwiseFF(nn.Module): ...@@ -272,7 +272,7 @@ class MultiHeadPositionwiseFF(nn.Module):
return output return output
class PositionwiseFF(nn.Module): class PositionwiseFF(nn.Module):
def __init__(self, d_model, d_inner, dropout, pre_lnorm=False): def __init__(self, d_model, d_inner, dropout, pre_lnorm=False, use_softmax=True):
super(PositionwiseFF, self).__init__() super(PositionwiseFF, self).__init__()
self.d_model = d_model self.d_model = d_model
...@@ -281,7 +281,7 @@ class PositionwiseFF(nn.Module): ...@@ -281,7 +281,7 @@ class PositionwiseFF(nn.Module):
self.CoreNet_1 = nn.Sequential( self.CoreNet_1 = nn.Sequential(
nn.Linear(d_model, d_inner), nn.Linear(d_model, d_inner),
nn.ReLU(inplace=True) nn.Softmax(dim=-1) if use_softmax else nn.ReLU(inplace=True)
) )
self.CoreNet_2 = nn.Sequential( self.CoreNet_2 = nn.Sequential(
nn.Dropout(dropout), nn.Dropout(dropout),
...@@ -719,7 +719,7 @@ class DecoderLayer(nn.Module): ...@@ -719,7 +719,7 @@ class DecoderLayer(nn.Module):
self.dec_attn = MultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs) self.dec_attn = MultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs)
# self.dec_attn = ExtendedMultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs) # self.dec_attn = ExtendedMultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs)
self.pos_ff = HierarchicalMoEPositionwiseFF(d_model, d_inner, dropout, self.pos_ff = PositionwiseFF(d_model, d_inner, dropout,
pre_lnorm=kwargs.get('pre_lnorm')) pre_lnorm=kwargs.get('pre_lnorm'))
def forward(self, dec_inp, dec_attn_mask=None, mems=None): def forward(self, dec_inp, dec_attn_mask=None, mems=None):
...@@ -739,7 +739,7 @@ class RelLearnableDecoderLayer(nn.Module): ...@@ -739,7 +739,7 @@ class RelLearnableDecoderLayer(nn.Module):
self.dec_attn = RelLearnableMultiHeadAttn(n_head, d_model, d_head, dropout, self.dec_attn = RelLearnableMultiHeadAttn(n_head, d_model, d_head, dropout,
**kwargs) **kwargs)
self.pos_ff = HierarchicalMoEPositionwiseFF(d_model, d_inner, dropout, self.pos_ff = PositionwiseFF(d_model, d_inner, dropout,
pre_lnorm=kwargs.get('pre_lnorm')) pre_lnorm=kwargs.get('pre_lnorm'))
def forward(self, dec_inp, r_emb, r_w_bias, r_bias, dec_attn_mask=None, mems=None): def forward(self, dec_inp, r_emb, r_w_bias, r_bias, dec_attn_mask=None, mems=None):
...@@ -760,7 +760,7 @@ class RelPartialLearnableDecoderLayer(nn.Module): ...@@ -760,7 +760,7 @@ class RelPartialLearnableDecoderLayer(nn.Module):
self.dec_attn = RelPartialLearnableMultiHeadAttn(n_head, d_model, self.dec_attn = RelPartialLearnableMultiHeadAttn(n_head, d_model,
d_head, dropout, **kwargs) d_head, dropout, **kwargs)
self.pos_ff = HierarchicalMoEPositionwiseFF(d_model, d_inner, dropout, self.pos_ff = PositionwiseFF(d_model, d_inner, dropout,
pre_lnorm=kwargs.get('pre_lnorm')) pre_lnorm=kwargs.get('pre_lnorm'))
def forward(self, dec_inp, r, r_w_bias, r_r_bias, dec_attn_mask=None, mems=None): def forward(self, dec_inp, r, r_w_bias, r_r_bias, dec_attn_mask=None, mems=None):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment