Commit 19ee0ff2 authored by Jiezhong Qiu's avatar Jiezhong Qiu
Browse files

hmoe

parent 3bdbb28d
...@@ -83,6 +83,75 @@ class MoEPositionwiseFF(nn.Module): ...@@ -83,6 +83,75 @@ class MoEPositionwiseFF(nn.Module):
return output return output
# return output, relu_out.detach() # return output, relu_out.detach()
class HierarchicalMoEPositionwiseFF(nn.Module):
def __init__(self, d_model, d_inner, dropout, pre_lnorm=False, n_block=512, top_block=128):
super(HierarchicalMoEPositionwiseFF, self).__init__()
print("HierarchicalMoEPositionwiseFF")
assert self.d_inner % n_block == 0
self.top_block = top_block
self.n_block = n_block
d_block = d_inner // n_block
self.d_block = d_block
self.d_model = d_model
self.d_inner = d_inner
self.dropout = dropout
self.block_net = nn.Linear(n_block, d_inner)
self.W1 = nn.Parameter(torch.Tensor(n_block, d_block, d_model))
self.b1 = nn.Parameter(torch.Tensor(n_block, d_block))
self.W2 = nn.Parameter(torch.Tensor(n_block, d_block, d_model))
self.b2 = nn.Parameter(torch.Tensor(d_model))
self.layer_norm = nn.LayerNorm(d_model)
self.pre_lnorm = pre_lnorm
ratio = top_block / n_block
self.dropout_middle = nn.Dropout(dropout * ratio)
self.dropout_final = nn.Dropout(dropout)
self.reset_parameter()
def reset_parameter(self):
temp = nn.Linear(self.d_model, self.d_inner)
self.W1.data = temp.weight.data.view(self.n_block, self.d_block, self.d_model)
self.b1.data = temp.bias.data.view(self.n_block, self.d_block)
temp = nn.Linear(self.d_inner, self.d_model)
self.W2.data = temp.weight.data.transpose(0, 1).contiguous().view(self.n_block, self.d_block, self.d_model)
self.b2.data = temp.bias.data
def forward(self, inp):
residual = inp
if self.pre_lnorm:
inp = self.layer_norm(inp)
block = self.block_net(inp)
block_val, block_idx = torch.topk(block, k=self.top_block, dim=-1, largest=True, sorted=False) # [.. x top_k]
W1_block = self.W1[block_idx] # [.. x top_k x d_block x d_model]
b1_block = self.b1[block_idx] # [.. x top_k x d_block]
x = torch.einsum('ibd,ibnhd->ibnh', (inp, W1_block)) + b1_block
relu_out = F.relu(x)
relu_out = self.dropout_middle(relu_out)
W2_block = self.W2[block_idx] # [.. x top_k x d_model]
core_out = torch.einsum('ibnh,ibnhd->ibd', (x, W2_block)) + self.b2 # [.. x d_model]
core_out = self.dropout_final(core_out)
output = core_out + residual
if not self.pre_lnorm:
output = self.layer_norm(output)
return output
class SparsePositionwiseFF(nn.Module): class SparsePositionwiseFF(nn.Module):
def __init__(self, d_model, d_inner, dropout, pre_lnorm=False): def __init__(self, d_model, d_inner, dropout, pre_lnorm=False):
super(SparsePositionwiseFF, self).__init__() super(SparsePositionwiseFF, self).__init__()
...@@ -194,48 +263,6 @@ class MultiHeadPositionwiseFF(nn.Module): ...@@ -194,48 +263,6 @@ class MultiHeadPositionwiseFF(nn.Module):
return output return output
class PositionwiseFF(nn.Module):
def __init__(self, d_model, d_inner, dropout, pre_lnorm=False):
super(PositionwiseFF, self).__init__()
self.d_model = d_model
self.d_inner = d_inner
self.dropout = dropout
self.CoreNet_1 = nn.Sequential(
nn.Linear(d_model, d_inner),
nn.ReLU(inplace=True)
)
self.CoreNet_2 = nn.Sequential(
nn.Dropout(dropout),
nn.Linear(d_inner, d_model),
nn.Dropout(dropout),
)
self.layer_norm = nn.LayerNorm(d_model)
self.pre_lnorm = pre_lnorm
def forward(self, inp):
if self.pre_lnorm:
##### layer normalization + positionwise feed-forward
relu_out = self.CoreNet_1(self.layer_norm(inp))
core_out = self.CoreNet_2(relu_out)
##### residual connection
output = core_out + inp
else:
##### positionwise feed-forward
relu_out = self.CoreNet_1(inp)
core_out = self.CoreNet_2(relu_out)
##### residual connection + layer normalization
output = self.layer_norm(inp + core_out)
return output
# return output, relu_out.detach()
class PositionwiseFF(nn.Module): class PositionwiseFF(nn.Module):
def __init__(self, d_model, d_inner, dropout, pre_lnorm=False): def __init__(self, d_model, d_inner, dropout, pre_lnorm=False):
super(PositionwiseFF, self).__init__() super(PositionwiseFF, self).__init__()
...@@ -684,7 +711,7 @@ class DecoderLayer(nn.Module): ...@@ -684,7 +711,7 @@ class DecoderLayer(nn.Module):
self.dec_attn = MultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs) self.dec_attn = MultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs)
# self.dec_attn = ExtendedMultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs) # self.dec_attn = ExtendedMultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs)
self.pos_ff = MultiHeadPositionwiseFF(d_model, d_inner, dropout, self.pos_ff = SparsePositionwiseFF(d_model, d_inner, dropout,
pre_lnorm=kwargs.get('pre_lnorm')) pre_lnorm=kwargs.get('pre_lnorm'))
def forward(self, dec_inp, dec_attn_mask=None, mems=None): def forward(self, dec_inp, dec_attn_mask=None, mems=None):
...@@ -704,7 +731,7 @@ class RelLearnableDecoderLayer(nn.Module): ...@@ -704,7 +731,7 @@ class RelLearnableDecoderLayer(nn.Module):
self.dec_attn = RelLearnableMultiHeadAttn(n_head, d_model, d_head, dropout, self.dec_attn = RelLearnableMultiHeadAttn(n_head, d_model, d_head, dropout,
**kwargs) **kwargs)
self.pos_ff = MultiHeadPositionwiseFF(d_model, d_inner, dropout, self.pos_ff = SparsePositionwiseFF(d_model, d_inner, dropout,
pre_lnorm=kwargs.get('pre_lnorm')) pre_lnorm=kwargs.get('pre_lnorm'))
def forward(self, dec_inp, r_emb, r_w_bias, r_bias, dec_attn_mask=None, mems=None): def forward(self, dec_inp, r_emb, r_w_bias, r_bias, dec_attn_mask=None, mems=None):
...@@ -725,7 +752,7 @@ class RelPartialLearnableDecoderLayer(nn.Module): ...@@ -725,7 +752,7 @@ class RelPartialLearnableDecoderLayer(nn.Module):
self.dec_attn = RelPartialLearnableMultiHeadAttn(n_head, d_model, self.dec_attn = RelPartialLearnableMultiHeadAttn(n_head, d_model,
d_head, dropout, **kwargs) d_head, dropout, **kwargs)
self.pos_ff = MultiHeadPositionwiseFF(d_model, d_inner, dropout, self.pos_ff = SparsePositionwiseFF(d_model, d_inner, dropout,
pre_lnorm=kwargs.get('pre_lnorm')) pre_lnorm=kwargs.get('pre_lnorm'))
def forward(self, dec_inp, r, r_w_bias, r_r_bias, dec_attn_mask=None, mems=None): def forward(self, dec_inp, r, r_w_bias, r_r_bias, dec_attn_mask=None, mems=None):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment