Commit 823f9c2e authored by Jiezhong Qiu's avatar Jiezhong Qiu
Browse files

recover sparse ffn

parent 0f3e63eb
...@@ -7,7 +7,7 @@ import numpy as np ...@@ -7,7 +7,7 @@ import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
# import torch_sparse import torch_sparse
sys.path.append('utils') sys.path.append('utils')
from proj_adaptive_softmax import ProjectedAdaptiveLogSoftmax from proj_adaptive_softmax import ProjectedAdaptiveLogSoftmax
...@@ -114,6 +114,7 @@ class HierarchicalMoEPositionwiseFF(nn.Module): ...@@ -114,6 +114,7 @@ class HierarchicalMoEPositionwiseFF(nn.Module):
self.dropout_middle = nn.Dropout(dropout * ratio) self.dropout_middle = nn.Dropout(dropout * ratio)
self.dropout_final = nn.Dropout(dropout) self.dropout_final = nn.Dropout(dropout)
self.scale = 1 / (d_model ** 0.5)
self.reset_parameter() self.reset_parameter()
def reset_parameter(self): def reset_parameter(self):
...@@ -131,6 +132,8 @@ class HierarchicalMoEPositionwiseFF(nn.Module): ...@@ -131,6 +132,8 @@ class HierarchicalMoEPositionwiseFF(nn.Module):
block = self.block_net(inp) block = self.block_net(inp)
block_val, block_idx = torch.topk(block, k=self.top_block, dim=-1, largest=True, sorted=False) # [.. x top_k] block_val, block_idx = torch.topk(block, k=self.top_block, dim=-1, largest=True, sorted=False) # [.. x top_k]
block_val.mul_(self.scale)
gate = F.softmax(block_val, dim=-1) gate = F.softmax(block_val, dim=-1)
...@@ -154,51 +157,51 @@ class HierarchicalMoEPositionwiseFF(nn.Module): ...@@ -154,51 +157,51 @@ class HierarchicalMoEPositionwiseFF(nn.Module):
return output return output
# class SparsePositionwiseFF(nn.Module): class SparsePositionwiseFF(nn.Module):
# def __init__(self, d_model, d_inner, dropout, pre_lnorm=False): def __init__(self, d_model, d_inner, dropout, pre_lnorm=False):
# super(SparsePositionwiseFF, self).__init__() super(SparsePositionwiseFF, self).__init__()
# print("SparsePositionwiseFF") print("SparsePositionwiseFF")
# self.d_model = d_model self.d_model = d_model
# self.d_inner = d_inner self.d_inner = d_inner
# self.dropout = dropout self.dropout = dropout
# self.CoreNet_1 = nn.Sequential( self.CoreNet_1 = nn.Sequential(
# nn.Linear(d_model, d_inner), nn.Linear(d_model, d_inner),
# nn.ReLU(inplace=True), nn.ReLU(inplace=True),
# nn.Dropout(dropout) nn.Dropout(dropout)
# ) )
# self.W2 = nn.Parameter(torch.Tensor(d_inner, d_model)) self.W2 = nn.Parameter(torch.Tensor(d_inner, d_model))
# self.b2 = nn.Parameter(torch.Tensor(d_model)) self.b2 = nn.Parameter(torch.Tensor(d_model))
# self.layer_norm = nn.LayerNorm(d_model) self.layer_norm = nn.LayerNorm(d_model)
# self.pre_lnorm = pre_lnorm self.pre_lnorm = pre_lnorm
# self.dropout_final = nn.Dropout(dropout) self.dropout_final = nn.Dropout(dropout)
# self.reset_parameter() self.reset_parameter()
# def reset_parameter(self): def reset_parameter(self):
# temp_Linear = nn.Linear(self.d_inner, self.d_model) temp_Linear = nn.Linear(self.d_inner, self.d_model)
# self.W2.data = temp_Linear.weight.data.transpose(0, 1) self.W2.data = temp_Linear.weight.data.transpose(0, 1)
# self.b2.data = temp_Linear.bias.data self.b2.data = temp_Linear.bias.data
# def forward(self, inp): def forward(self, inp):
# residual = inp residual = inp
# if self.pre_lnorm: if self.pre_lnorm:
# inp = self.layer_norm(inp) inp = self.layer_norm(inp)
# relu_out = self.CoreNet_1(inp).view(-1, self.d_inner) relu_out = self.CoreNet_1(inp).view(-1, self.d_inner)
# sparse_relu_out = torch_sparse.SparseTensor.from_dense(relu_out) sparse_relu_out = torch_sparse.SparseTensor.from_dense(relu_out)
# core_out = torch_sparse.matmul(sparse_relu_out, self.W2) + self.b2 core_out = torch_sparse.matmul(sparse_relu_out, self.W2) + self.b2
# core_out = core_out.view(inp.size(0), inp.size(1), self.d_model) core_out = core_out.view(inp.size(0), inp.size(1), self.d_model)
# core_out = self.dropout_final(core_out) core_out = self.dropout_final(core_out)
# output = core_out + residual output = core_out + residual
# if not self.pre_lnorm: if not self.pre_lnorm:
# output = self.layer_norm(output) output = self.layer_norm(output)
# return output return output
class MultiHeadPositionwiseFF(nn.Module): class MultiHeadPositionwiseFF(nn.Module):
def __init__(self, d_model, d_inner, dropout, pre_lnorm=False, n_head=2): def __init__(self, d_model, d_inner, dropout, pre_lnorm=False, n_head=2):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment