Commit 0f3e63eb authored by Jiezhong Qiu's avatar Jiezhong Qiu
Browse files

add gate

parent a43caff7
......@@ -131,12 +131,14 @@ class HierarchicalMoEPositionwiseFF(nn.Module):
block = self.block_net(inp)
block_val, block_idx = torch.topk(block, k=self.top_block, dim=-1, largest=True, sorted=False) # [.. x top_k]
gate = F.softmax(block_val, dim=-1)
W1_block = self.W1[block_idx] # [.. x top_k x d_block x d_model]
b1_block = self.b1[block_idx] # [.. x top_k x d_block]
x = torch.einsum('ibd,ibnhd->ibnh', (inp, W1_block)) + b1_block
x = x * gate.unsqueeze(-1)
relu_out = F.relu(x)
relu_out = self.dropout_middle(relu_out)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment