Commit 823f9c2e authored by Jiezhong Qiu's avatar Jiezhong Qiu
Browse files

recover sparse ffn

parent 0f3e63eb
......@@ -7,7 +7,7 @@ import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
# import torch_sparse
import torch_sparse
sys.path.append('utils')
from proj_adaptive_softmax import ProjectedAdaptiveLogSoftmax
......@@ -114,6 +114,7 @@ class HierarchicalMoEPositionwiseFF(nn.Module):
self.dropout_middle = nn.Dropout(dropout * ratio)
self.dropout_final = nn.Dropout(dropout)
self.scale = 1 / (d_model ** 0.5)
self.reset_parameter()
def reset_parameter(self):
......@@ -131,6 +132,8 @@ class HierarchicalMoEPositionwiseFF(nn.Module):
block = self.block_net(inp)
block_val, block_idx = torch.topk(block, k=self.top_block, dim=-1, largest=True, sorted=False) # [.. x top_k]
block_val.mul_(self.scale)
gate = F.softmax(block_val, dim=-1)
......@@ -154,51 +157,51 @@ class HierarchicalMoEPositionwiseFF(nn.Module):
return output
# class SparsePositionwiseFF(nn.Module):
# def __init__(self, d_model, d_inner, dropout, pre_lnorm=False):
# super(SparsePositionwiseFF, self).__init__()
# print("SparsePositionwiseFF")
class SparsePositionwiseFF(nn.Module):
def __init__(self, d_model, d_inner, dropout, pre_lnorm=False):
super(SparsePositionwiseFF, self).__init__()
print("SparsePositionwiseFF")
# self.d_model = d_model
# self.d_inner = d_inner
# self.dropout = dropout
self.d_model = d_model
self.d_inner = d_inner
self.dropout = dropout
# self.CoreNet_1 = nn.Sequential(
# nn.Linear(d_model, d_inner),
# nn.ReLU(inplace=True),
# nn.Dropout(dropout)
# )
self.CoreNet_1 = nn.Sequential(
nn.Linear(d_model, d_inner),
nn.ReLU(inplace=True),
nn.Dropout(dropout)
)
# self.W2 = nn.Parameter(torch.Tensor(d_inner, d_model))
# self.b2 = nn.Parameter(torch.Tensor(d_model))
self.W2 = nn.Parameter(torch.Tensor(d_inner, d_model))
self.b2 = nn.Parameter(torch.Tensor(d_model))
# self.layer_norm = nn.LayerNorm(d_model)
# self.pre_lnorm = pre_lnorm
self.layer_norm = nn.LayerNorm(d_model)
self.pre_lnorm = pre_lnorm
# self.dropout_final = nn.Dropout(dropout)
# self.reset_parameter()
self.dropout_final = nn.Dropout(dropout)
self.reset_parameter()
# def reset_parameter(self):
# temp_Linear = nn.Linear(self.d_inner, self.d_model)
# self.W2.data = temp_Linear.weight.data.transpose(0, 1)
# self.b2.data = temp_Linear.bias.data
def reset_parameter(self):
temp_Linear = nn.Linear(self.d_inner, self.d_model)
self.W2.data = temp_Linear.weight.data.transpose(0, 1)
self.b2.data = temp_Linear.bias.data
# def forward(self, inp):
# residual = inp
# if self.pre_lnorm:
# inp = self.layer_norm(inp)
def forward(self, inp):
residual = inp
if self.pre_lnorm:
inp = self.layer_norm(inp)
# relu_out = self.CoreNet_1(inp).view(-1, self.d_inner)
# sparse_relu_out = torch_sparse.SparseTensor.from_dense(relu_out)
# core_out = torch_sparse.matmul(sparse_relu_out, self.W2) + self.b2
# core_out = core_out.view(inp.size(0), inp.size(1), self.d_model)
# core_out = self.dropout_final(core_out)
relu_out = self.CoreNet_1(inp).view(-1, self.d_inner)
sparse_relu_out = torch_sparse.SparseTensor.from_dense(relu_out)
core_out = torch_sparse.matmul(sparse_relu_out, self.W2) + self.b2
core_out = core_out.view(inp.size(0), inp.size(1), self.d_model)
core_out = self.dropout_final(core_out)
# output = core_out + residual
# if not self.pre_lnorm:
# output = self.layer_norm(output)
output = core_out + residual
if not self.pre_lnorm:
output = self.layer_norm(output)
# return output
return output
class MultiHeadPositionwiseFF(nn.Module):
def __init__(self, d_model, d_inner, dropout, pre_lnorm=False, n_head=2):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment