Commit b1dd8572 authored by Jiezhong Qiu's avatar Jiezhong Qiu
Browse files

multihead moe

parent 510ac924
...@@ -99,6 +99,91 @@ def my_topk(x, k, inplace=True): ...@@ -99,6 +99,91 @@ def my_topk(x, k, inplace=True):
top_idx = torch.cat((top1_idx, top2_idx), dim=-1) top_idx = torch.cat((top1_idx, top2_idx), dim=-1)
return top_val, top_idx return top_val, top_idx
class MultiHeadHierarchicalMoEPositionwiseFF(nn.Module):
def __init__(self, d_model, d_inner, dropout, pre_lnorm=False, n_block=16, top_block=2):
super(MultiHeadHierarchicalMoEPositionwiseFF, self).__init__()
print("MultiHeadHierarchicalMoEPositionwiseFF")
assert d_inner % n_block == 0
assert top_block in [1, 2]
self.top_block = top_block
self.n_block = n_block
d_block = d_inner // n_block
self.d_block = d_block
self.d_model = d_model
self.d_inner = d_inner
self.dropout = dropout
self.block_net_W = nn.Parameter(torch.Tensor(d_model, top_block, n_block))
self.block_net_b = nn.Parameter(torch.Tensor(top_block, n_block))
self.W1 = nn.Parameter(torch.Tensor(n_block, d_block, d_model))
self.b1 = nn.Parameter(torch.Tensor(n_block, d_block))
self.W2 = nn.Parameter(torch.Tensor(n_block, d_block, d_model))
self.b2 = nn.Parameter(torch.Tensor(d_model))
self.layer_norm = nn.LayerNorm(d_model)
self.pre_lnorm = pre_lnorm
ratio = top_block / n_block
self.dropout_middle = nn.Dropout(dropout * ratio)
self.dropout_final = nn.Dropout(dropout)
# self.scale = 1 / (d_model ** 0.5)
self.reset_parameter()
def reset_parameter(self):
temp = nn.Linear(self.d_model, self.d_inner)
self.W1.data = temp.weight.data.view(self.n_block, self.d_block, self.d_model)
self.b1.data = temp.bias.data.view(self.n_block, self.d_block)
temp = nn.Linear(self.d_inner, self.d_model)
self.W2.data = temp.weight.data.transpose(0, 1).contiguous().view(self.n_block, self.d_block, self.d_model)
self.b2.data = temp.bias.data
for i in range(self.top_block):
temp = nn.Linear(self.d_model, self.n_block)
self.block_net_W[:, i].data = temp.weight.data.transpose(0, 1).contiguous()
self.block_net_b[:, i].data = temp.bias.data
def forward(self, inp):
residual = inp
if self.pre_lnorm:
inp = self.layer_norm(inp)
block = torch.einsum("ibd,dan->iban", (inp, self.block_net_W)) + self.block_net_b # [.. x top_block x n_block ]
# block_val, block_idx = my_topk(block, k=1)
block_val, block_idx = torch.topk(block, k=1, dim=-1, largest=True, sorted=False) # [.. x top_k x 1]
block_val = block_val.squeeze(-1)
block_idx = block_idx.squeeze(-1)
gate = F.softmax(block_val, dim=-1)
W1_block = self.W1[block_idx] # [.. x top_k x d_block x d_model]
b1_block = self.b1[block_idx] # [.. x top_k x d_block]
x = torch.einsum('ibd,ibnhd->ibnh', (inp, W1_block)) + b1_block # [.. x top_k x d_block]
# x = x + block_val.unsqueeze(-1) # somehow like residual
x = x * gate.unsqueeze(-1)
relu_out = F.relu(x)
relu_out = self.dropout_middle(relu_out)
W2_block = self.W2[block_idx] # [.. x top_k x d_model]
core_out = torch.einsum('ibnh,ibnhd->ibd', (x, W2_block)) + self.b2 # [.. x d_model]
core_out = self.dropout_final(core_out)
output = core_out + residual
if not self.pre_lnorm:
output = self.layer_norm(output)
return output
class HierarchicalMoEPositionwiseFF(nn.Module): class HierarchicalMoEPositionwiseFF(nn.Module):
def __init__(self, d_model, d_inner, dropout, pre_lnorm=False, n_block=16, top_block=2): def __init__(self, d_model, d_inner, dropout, pre_lnorm=False, n_block=16, top_block=2):
super(HierarchicalMoEPositionwiseFF, self).__init__() super(HierarchicalMoEPositionwiseFF, self).__init__()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment