Commit 39996fef authored by Jiezhong Qiu's avatar Jiezhong Qiu
Browse files

by default using 1 block

parent 6860c1bf
...@@ -100,7 +100,7 @@ def my_topk(x, k, inplace=True): ...@@ -100,7 +100,7 @@ def my_topk(x, k, inplace=True):
return top_val, top_idx return top_val, top_idx
class HierarchicalMoEPositionwiseFF(nn.Module): class HierarchicalMoEPositionwiseFF(nn.Module):
def __init__(self, d_model, d_inner, dropout, pre_lnorm=False, n_block=64, top_block=2): def __init__(self, d_model, d_inner, dropout, pre_lnorm=False, n_block=32, top_block=1):
super(HierarchicalMoEPositionwiseFF, self).__init__() super(HierarchicalMoEPositionwiseFF, self).__init__()
print("HierarchicalMoEPositionwiseFF") print("HierarchicalMoEPositionwiseFF")
...@@ -115,7 +115,7 @@ class HierarchicalMoEPositionwiseFF(nn.Module): ...@@ -115,7 +115,7 @@ class HierarchicalMoEPositionwiseFF(nn.Module):
self.d_inner = d_inner self.d_inner = d_inner
self.dropout = dropout self.dropout = dropout
self.block_net = nn.Linear(d_model, n_block) self.block_net = nn.Linear(d_model, n_block, bias=False)
self.W1 = nn.Parameter(torch.Tensor(n_block, d_block, d_model)) self.W1 = nn.Parameter(torch.Tensor(n_block, d_block, d_model))
self.b1 = nn.Parameter(torch.Tensor(n_block, d_block)) self.b1 = nn.Parameter(torch.Tensor(n_block, d_block))
...@@ -152,11 +152,8 @@ class HierarchicalMoEPositionwiseFF(nn.Module): ...@@ -152,11 +152,8 @@ class HierarchicalMoEPositionwiseFF(nn.Module):
block_val, block_idx = my_topk(block, k=self.top_block) block_val, block_idx = my_topk(block, k=self.top_block)
# block_val, block_idx = torch.topk(block, k=self.top_block, dim=-1, largest=True, sorted=False) # [.. x top_k] # block_val, block_idx = torch.topk(block, k=self.top_block, dim=-1, largest=True, sorted=False) # [.. x top_k]
# block_val.mul_(self.scale)
gate = F.softmax(block_val, dim=-1) gate = F.softmax(block_val, dim=-1)
W1_block = self.W1[block_idx] # [.. x top_k x d_block x d_model] W1_block = self.W1[block_idx] # [.. x top_k x d_block x d_model]
b1_block = self.b1[block_idx] # [.. x top_k x d_block] b1_block = self.b1[block_idx] # [.. x top_k x d_block]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment