Commit 103c6395 authored by Jiezhong Qiu's avatar Jiezhong Qiu
Browse files

update

parent 19ee0ff2
...@@ -7,7 +7,7 @@ import numpy as np ...@@ -7,7 +7,7 @@ import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch_sparse # import torch_sparse
sys.path.append('utils') sys.path.append('utils')
from proj_adaptive_softmax import ProjectedAdaptiveLogSoftmax from proj_adaptive_softmax import ProjectedAdaptiveLogSoftmax
...@@ -152,51 +152,51 @@ class HierarchicalMoEPositionwiseFF(nn.Module): ...@@ -152,51 +152,51 @@ class HierarchicalMoEPositionwiseFF(nn.Module):
return output return output
class SparsePositionwiseFF(nn.Module): # class SparsePositionwiseFF(nn.Module):
def __init__(self, d_model, d_inner, dropout, pre_lnorm=False): # def __init__(self, d_model, d_inner, dropout, pre_lnorm=False):
super(SparsePositionwiseFF, self).__init__() # super(SparsePositionwiseFF, self).__init__()
print("SparsePositionwiseFF") # print("SparsePositionwiseFF")
self.d_model = d_model # self.d_model = d_model
self.d_inner = d_inner # self.d_inner = d_inner
self.dropout = dropout # self.dropout = dropout
self.CoreNet_1 = nn.Sequential( # self.CoreNet_1 = nn.Sequential(
nn.Linear(d_model, d_inner), # nn.Linear(d_model, d_inner),
nn.ReLU(inplace=True), # nn.ReLU(inplace=True),
nn.Dropout(dropout) # nn.Dropout(dropout)
) # )
self.W2 = nn.Parameter(torch.Tensor(d_inner, d_model)) # self.W2 = nn.Parameter(torch.Tensor(d_inner, d_model))
self.b2 = nn.Parameter(torch.Tensor(d_model)) # self.b2 = nn.Parameter(torch.Tensor(d_model))
self.layer_norm = nn.LayerNorm(d_model) # self.layer_norm = nn.LayerNorm(d_model)
self.pre_lnorm = pre_lnorm # self.pre_lnorm = pre_lnorm
self.dropout_final = nn.Dropout(dropout) # self.dropout_final = nn.Dropout(dropout)
self.reset_parameter() # self.reset_parameter()
def reset_parameter(self): # def reset_parameter(self):
temp_Linear = nn.Linear(self.d_inner, self.d_model) # temp_Linear = nn.Linear(self.d_inner, self.d_model)
self.W2.data = temp_Linear.weight.data.transpose(0, 1) # self.W2.data = temp_Linear.weight.data.transpose(0, 1)
self.b2.data = temp_Linear.bias.data # self.b2.data = temp_Linear.bias.data
def forward(self, inp): # def forward(self, inp):
residual = inp # residual = inp
if self.pre_lnorm: # if self.pre_lnorm:
inp = self.layer_norm(inp) # inp = self.layer_norm(inp)
relu_out = self.CoreNet_1(inp).view(-1, self.d_inner) # relu_out = self.CoreNet_1(inp).view(-1, self.d_inner)
sparse_relu_out = torch_sparse.SparseTensor.from_dense(relu_out) # sparse_relu_out = torch_sparse.SparseTensor.from_dense(relu_out)
core_out = torch_sparse.matmul(sparse_relu_out, self.W2) + self.b2 # core_out = torch_sparse.matmul(sparse_relu_out, self.W2) + self.b2
core_out = core_out.view(inp.size(0), inp.size(1), self.d_model) # core_out = core_out.view(inp.size(0), inp.size(1), self.d_model)
core_out = self.dropout_final(core_out) # core_out = self.dropout_final(core_out)
output = core_out + residual # output = core_out + residual
if not self.pre_lnorm: # if not self.pre_lnorm:
output = self.layer_norm(output) # output = self.layer_norm(output)
return output # return output
class MultiHeadPositionwiseFF(nn.Module): class MultiHeadPositionwiseFF(nn.Module):
def __init__(self, d_model, d_inner, dropout, pre_lnorm=False, n_head=2): def __init__(self, d_model, d_inner, dropout, pre_lnorm=False, n_head=2):
...@@ -711,7 +711,7 @@ class DecoderLayer(nn.Module): ...@@ -711,7 +711,7 @@ class DecoderLayer(nn.Module):
self.dec_attn = MultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs) self.dec_attn = MultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs)
# self.dec_attn = ExtendedMultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs) # self.dec_attn = ExtendedMultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs)
self.pos_ff = SparsePositionwiseFF(d_model, d_inner, dropout, self.pos_ff = HierarchicalMoEPositionwiseFF(d_model, d_inner, dropout,
pre_lnorm=kwargs.get('pre_lnorm')) pre_lnorm=kwargs.get('pre_lnorm'))
def forward(self, dec_inp, dec_attn_mask=None, mems=None): def forward(self, dec_inp, dec_attn_mask=None, mems=None):
...@@ -731,7 +731,7 @@ class RelLearnableDecoderLayer(nn.Module): ...@@ -731,7 +731,7 @@ class RelLearnableDecoderLayer(nn.Module):
self.dec_attn = RelLearnableMultiHeadAttn(n_head, d_model, d_head, dropout, self.dec_attn = RelLearnableMultiHeadAttn(n_head, d_model, d_head, dropout,
**kwargs) **kwargs)
self.pos_ff = SparsePositionwiseFF(d_model, d_inner, dropout, self.pos_ff = HierarchicalMoEPositionwiseFF(d_model, d_inner, dropout,
pre_lnorm=kwargs.get('pre_lnorm')) pre_lnorm=kwargs.get('pre_lnorm'))
def forward(self, dec_inp, r_emb, r_w_bias, r_bias, dec_attn_mask=None, mems=None): def forward(self, dec_inp, r_emb, r_w_bias, r_bias, dec_attn_mask=None, mems=None):
...@@ -752,7 +752,7 @@ class RelPartialLearnableDecoderLayer(nn.Module): ...@@ -752,7 +752,7 @@ class RelPartialLearnableDecoderLayer(nn.Module):
self.dec_attn = RelPartialLearnableMultiHeadAttn(n_head, d_model, self.dec_attn = RelPartialLearnableMultiHeadAttn(n_head, d_model,
d_head, dropout, **kwargs) d_head, dropout, **kwargs)
self.pos_ff = SparsePositionwiseFF(d_model, d_inner, dropout, self.pos_ff = HierarchicalMoEPositionwiseFF(d_model, d_inner, dropout,
pre_lnorm=kwargs.get('pre_lnorm')) pre_lnorm=kwargs.get('pre_lnorm'))
def forward(self, dec_inp, r, r_w_bias, r_r_bias, dec_attn_mask=None, mems=None): def forward(self, dec_inp, r, r_w_bias, r_r_bias, dec_attn_mask=None, mems=None):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment