Commit 1b7cbeb5 authored by Jiezhong Qiu's avatar Jiezhong Qiu
Browse files

torch sparse for spmm

parent 47167bcc
......@@ -7,6 +7,7 @@ import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_sparse
sys.path.append('utils')
from proj_adaptive_softmax import ProjectedAdaptiveLogSoftmax
......@@ -82,6 +83,50 @@ class MoEPositionwiseFF(nn.Module):
return output
# return output, relu_out.detach()
class SparsePositionwiseFF(nn.Module):
def __init__(self, d_model, d_inner, dropout, pre_lnorm=False):
super(SparsePositionwiseFF, self).__init__()
print("SparsePositionwiseFF")
self.d_model = d_model
self.d_inner = d_inner
self.dropout = dropout
self.CoreNet_1 = nn.Sequential(
nn.Linear(d_model, d_inner),
nn.ReLU(inplace=True),
nn.Dropout(dropout)
)
self.W2 = nn.Parameter(torch.Tensor(d_inner, d_model))
self.b2 = nn.Parameter(torch.Tensor(d_model))
self.layer_norm = nn.LayerNorm(d_model)
self.pre_lnorm = pre_lnorm
self.dropout_final = nn.Dropout(dropout)
self.reset_parameter()
def reset_parameter(self):
temp_Linear = nn.Linear(self.d_inner, self.d_model)
self.W2.data = temp_Linear.weight.data.transpose(0, 1)
self.b2.data = temp_Linear.bias.data
def forward(self, inp):
residual = inp
if self.pre_lnorm:
inp = self.layer_norm(inp)
relu_out = self.CoreNet_1(inp).view(-1, self.d_inner)
sparse_relu_out = torch_sparse.SparseTensor.from_dense(relu_out)
core_out = torch_sparse.matmul(sparse_relu_out, self.W2) + self.b2
core_out = self.dropout_final(core_out)
output = core_out + residual
if not self.pre_lnorm:
output = self.layer_norm(output)
return output
class MultiHeadPositionwiseFF(nn.Module):
def __init__(self, d_model, d_inner, dropout, pre_lnorm=False, n_head=2):
......@@ -96,7 +141,7 @@ class MultiHeadPositionwiseFF(nn.Module):
self.d_inner = d_inner
self.dropout = dropout
#self.q_net = nn.Linear(d_model, d_model)
self.q_net = nn.Linear(d_model, d_model)
self.k_weight = nn.Parameter(torch.Tensor(n_head, d_inner, d_head))
self.k_bias = nn.Parameter(torch.Tensor(n_head, d_inner))
......@@ -129,8 +174,8 @@ class MultiHeadPositionwiseFF(nn.Module):
if self.pre_lnorm:
inp = self.layer_norm(inp)
# head_q = self.q_net(inp)
head_q = inp.view(inp.size(0), inp.size(1), self.n_head, self.d_head) # [.. x n_head x d_head]
head_q = self.q_net(inp)
head_q = head_q.view(inp.size(0), inp.size(1), self.n_head, self.d_head) # [.. x n_head x d_head]
attn_score = torch.einsum('ibnd,nhd->ibnh', (head_q, self.k_weight)) + self.k_bias # [.. x n_head x d_inner]
attn_score = F.relu(attn_score)
......@@ -148,6 +193,47 @@ class MultiHeadPositionwiseFF(nn.Module):
return output
class PositionwiseFF(nn.Module):
def __init__(self, d_model, d_inner, dropout, pre_lnorm=False):
super(PositionwiseFF, self).__init__()
self.d_model = d_model
self.d_inner = d_inner
self.dropout = dropout
self.CoreNet_1 = nn.Sequential(
nn.Linear(d_model, d_inner),
nn.ReLU(inplace=True)
)
self.CoreNet_2 = nn.Sequential(
nn.Dropout(dropout),
nn.Linear(d_inner, d_model),
nn.Dropout(dropout),
)
self.layer_norm = nn.LayerNorm(d_model)
self.pre_lnorm = pre_lnorm
def forward(self, inp):
if self.pre_lnorm:
##### layer normalization + positionwise feed-forward
relu_out = self.CoreNet_1(self.layer_norm(inp))
core_out = self.CoreNet_2(relu_out)
##### residual connection
output = core_out + inp
else:
##### positionwise feed-forward
relu_out = self.CoreNet_1(inp)
core_out = self.CoreNet_2(relu_out)
##### residual connection + layer normalization
output = self.layer_norm(inp + core_out)
return output
# return output, relu_out.detach()
class PositionwiseFF(nn.Module):
def __init__(self, d_model, d_inner, dropout, pre_lnorm=False):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment