Commit 51361a3f authored by Jiezhong Qiu's avatar Jiezhong Qiu
Browse files

residual in hmoe

parent 823f9c2e
...@@ -84,7 +84,7 @@ class MoEPositionwiseFF(nn.Module): ...@@ -84,7 +84,7 @@ class MoEPositionwiseFF(nn.Module):
# return output, relu_out.detach() # return output, relu_out.detach()
class HierarchicalMoEPositionwiseFF(nn.Module): class HierarchicalMoEPositionwiseFF(nn.Module):
def __init__(self, d_model, d_inner, dropout, pre_lnorm=False, n_block=512, top_block=128): def __init__(self, d_model, d_inner, dropout, pre_lnorm=False, n_block=64, top_block=8):
super(HierarchicalMoEPositionwiseFF, self).__init__() super(HierarchicalMoEPositionwiseFF, self).__init__()
print("HierarchicalMoEPositionwiseFF") print("HierarchicalMoEPositionwiseFF")
...@@ -133,14 +133,17 @@ class HierarchicalMoEPositionwiseFF(nn.Module): ...@@ -133,14 +133,17 @@ class HierarchicalMoEPositionwiseFF(nn.Module):
block = self.block_net(inp) block = self.block_net(inp)
block_val, block_idx = torch.topk(block, k=self.top_block, dim=-1, largest=True, sorted=False) # [.. x top_k] block_val, block_idx = torch.topk(block, k=self.top_block, dim=-1, largest=True, sorted=False) # [.. x top_k]
block_val.mul_(self.scale)
# block_val.mul_(self.scale)
gate = F.softmax(block_val, dim=-1) gate = F.softmax(block_val, dim=-1)
W1_block = self.W1[block_idx] # [.. x top_k x d_block x d_model] W1_block = self.W1[block_idx] # [.. x top_k x d_block x d_model]
b1_block = self.b1[block_idx] # [.. x top_k x d_block] b1_block = self.b1[block_idx] # [.. x top_k x d_block]
x = torch.einsum('ibd,ibnhd->ibnh', (inp, W1_block)) + b1_block x = torch.einsum('ibd,ibnhd->ibnh', (inp, W1_block)) + b1_block # [.. x top_k x d_block]
x = x + block_val.unsqueeze(-1) # somehow like residual
x = x * gate.unsqueeze(-1) x = x * gate.unsqueeze(-1)
relu_out = F.relu(x) relu_out = F.relu(x)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment