"vscode:/vscode.git/clone" did not exist on "c7e6d09068a88e752b43eed0f2c4e56ace6b7005"
Commit 39996fef authored by Jiezhong Qiu's avatar Jiezhong Qiu
Browse files

by default using 1 block

parent 6860c1bf
......@@ -100,7 +100,7 @@ def my_topk(x, k, inplace=True):
return top_val, top_idx
class HierarchicalMoEPositionwiseFF(nn.Module):
def __init__(self, d_model, d_inner, dropout, pre_lnorm=False, n_block=64, top_block=2):
def __init__(self, d_model, d_inner, dropout, pre_lnorm=False, n_block=32, top_block=1):
super(HierarchicalMoEPositionwiseFF, self).__init__()
print("HierarchicalMoEPositionwiseFF")
......@@ -115,7 +115,7 @@ class HierarchicalMoEPositionwiseFF(nn.Module):
self.d_inner = d_inner
self.dropout = dropout
self.block_net = nn.Linear(d_model, n_block)
self.block_net = nn.Linear(d_model, n_block, bias=False)
self.W1 = nn.Parameter(torch.Tensor(n_block, d_block, d_model))
self.b1 = nn.Parameter(torch.Tensor(n_block, d_block))
......@@ -152,11 +152,8 @@ class HierarchicalMoEPositionwiseFF(nn.Module):
block_val, block_idx = my_topk(block, k=self.top_block)
# block_val, block_idx = torch.topk(block, k=self.top_block, dim=-1, largest=True, sorted=False) # [.. x top_k]
# block_val.mul_(self.scale)
gate = F.softmax(block_val, dim=-1)
W1_block = self.W1[block_idx] # [.. x top_k x d_block x d_model]
b1_block = self.b1[block_idx] # [.. x top_k x d_block]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment