Commit 103c6395 authored by Jiezhong Qiu's avatar Jiezhong Qiu
Browse files

update

parent 19ee0ff2
......@@ -7,7 +7,7 @@ import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_sparse
# import torch_sparse
sys.path.append('utils')
from proj_adaptive_softmax import ProjectedAdaptiveLogSoftmax
......@@ -152,51 +152,51 @@ class HierarchicalMoEPositionwiseFF(nn.Module):
return output
class SparsePositionwiseFF(nn.Module):
def __init__(self, d_model, d_inner, dropout, pre_lnorm=False):
super(SparsePositionwiseFF, self).__init__()
print("SparsePositionwiseFF")
# class SparsePositionwiseFF(nn.Module):
# def __init__(self, d_model, d_inner, dropout, pre_lnorm=False):
# super(SparsePositionwiseFF, self).__init__()
# print("SparsePositionwiseFF")
self.d_model = d_model
self.d_inner = d_inner
self.dropout = dropout
# self.d_model = d_model
# self.d_inner = d_inner
# self.dropout = dropout
self.CoreNet_1 = nn.Sequential(
nn.Linear(d_model, d_inner),
nn.ReLU(inplace=True),
nn.Dropout(dropout)
)
# self.CoreNet_1 = nn.Sequential(
# nn.Linear(d_model, d_inner),
# nn.ReLU(inplace=True),
# nn.Dropout(dropout)
# )
self.W2 = nn.Parameter(torch.Tensor(d_inner, d_model))
self.b2 = nn.Parameter(torch.Tensor(d_model))
# self.W2 = nn.Parameter(torch.Tensor(d_inner, d_model))
# self.b2 = nn.Parameter(torch.Tensor(d_model))
self.layer_norm = nn.LayerNorm(d_model)
self.pre_lnorm = pre_lnorm
# self.layer_norm = nn.LayerNorm(d_model)
# self.pre_lnorm = pre_lnorm
self.dropout_final = nn.Dropout(dropout)
self.reset_parameter()
# self.dropout_final = nn.Dropout(dropout)
# self.reset_parameter()
def reset_parameter(self):
temp_Linear = nn.Linear(self.d_inner, self.d_model)
self.W2.data = temp_Linear.weight.data.transpose(0, 1)
self.b2.data = temp_Linear.bias.data
# def reset_parameter(self):
# temp_Linear = nn.Linear(self.d_inner, self.d_model)
# self.W2.data = temp_Linear.weight.data.transpose(0, 1)
# self.b2.data = temp_Linear.bias.data
def forward(self, inp):
residual = inp
if self.pre_lnorm:
inp = self.layer_norm(inp)
# def forward(self, inp):
# residual = inp
# if self.pre_lnorm:
# inp = self.layer_norm(inp)
relu_out = self.CoreNet_1(inp).view(-1, self.d_inner)
sparse_relu_out = torch_sparse.SparseTensor.from_dense(relu_out)
core_out = torch_sparse.matmul(sparse_relu_out, self.W2) + self.b2
core_out = core_out.view(inp.size(0), inp.size(1), self.d_model)
core_out = self.dropout_final(core_out)
# relu_out = self.CoreNet_1(inp).view(-1, self.d_inner)
# sparse_relu_out = torch_sparse.SparseTensor.from_dense(relu_out)
# core_out = torch_sparse.matmul(sparse_relu_out, self.W2) + self.b2
# core_out = core_out.view(inp.size(0), inp.size(1), self.d_model)
# core_out = self.dropout_final(core_out)
output = core_out + residual
if not self.pre_lnorm:
output = self.layer_norm(output)
# output = core_out + residual
# if not self.pre_lnorm:
# output = self.layer_norm(output)
return output
# return output
class MultiHeadPositionwiseFF(nn.Module):
def __init__(self, d_model, d_inner, dropout, pre_lnorm=False, n_head=2):
......@@ -711,7 +711,7 @@ class DecoderLayer(nn.Module):
self.dec_attn = MultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs)
# self.dec_attn = ExtendedMultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs)
self.pos_ff = SparsePositionwiseFF(d_model, d_inner, dropout,
self.pos_ff = HierarchicalMoEPositionwiseFF(d_model, d_inner, dropout,
pre_lnorm=kwargs.get('pre_lnorm'))
def forward(self, dec_inp, dec_attn_mask=None, mems=None):
......@@ -731,7 +731,7 @@ class RelLearnableDecoderLayer(nn.Module):
self.dec_attn = RelLearnableMultiHeadAttn(n_head, d_model, d_head, dropout,
**kwargs)
self.pos_ff = SparsePositionwiseFF(d_model, d_inner, dropout,
self.pos_ff = HierarchicalMoEPositionwiseFF(d_model, d_inner, dropout,
pre_lnorm=kwargs.get('pre_lnorm'))
def forward(self, dec_inp, r_emb, r_w_bias, r_bias, dec_attn_mask=None, mems=None):
......@@ -752,7 +752,7 @@ class RelPartialLearnableDecoderLayer(nn.Module):
self.dec_attn = RelPartialLearnableMultiHeadAttn(n_head, d_model,
d_head, dropout, **kwargs)
self.pos_ff = SparsePositionwiseFF(d_model, d_inner, dropout,
self.pos_ff = HierarchicalMoEPositionwiseFF(d_model, d_inner, dropout,
pre_lnorm=kwargs.get('pre_lnorm'))
def forward(self, dec_inp, r, r_w_bias, r_r_bias, dec_attn_mask=None, mems=None):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment