Commit 2fef9391 authored by Jiezhong Qiu's avatar Jiezhong Qiu
Browse files

update

parent 7eb40a4a
......@@ -144,8 +144,8 @@ class MultiHeadHierarchicalMoEPositionwiseFF(nn.Module):
self.b2.data = temp.bias.data
for i in range(self.top_block):
temp = nn.Linear(self.d_model, self.n_block)
self.block_net_W[:, i].data = temp.weight.data.transpose(0, 1).contiguous()
self.block_net_b[:, i].data = temp.bias.data
self.block_net_W.data[:, i] = temp.weight.data.transpose(0, 1).contiguous()
self.block_net_b.data[i] = temp.bias.data
def forward(self, inp):
residual = inp
......@@ -154,8 +154,8 @@ class MultiHeadHierarchicalMoEPositionwiseFF(nn.Module):
block = torch.einsum("ibd,dan->iban", (inp, self.block_net_W)) + self.block_net_b # [.. x top_block x n_block ]
# block_val, block_idx = my_topk(block, k=1)
block_val, block_idx = torch.topk(block, k=1, dim=-1, largest=True, sorted=False) # [.. x top_k x 1]
block_val, block_idx = my_topk(block, k=1, inplace=True)
# block_val, block_idx = torch.topk(block, k=1, dim=-1, largest=True, sorted=False) # [.. x top_k x 1]
block_val = block_val.squeeze(-1)
block_idx = block_idx.squeeze(-1)
......@@ -820,7 +820,7 @@ class DecoderLayer(nn.Module):
self.dec_attn = MultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs)
# self.dec_attn = ExtendedMultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs)
self.pos_ff = HierarchicalMoEPositionwiseFF(d_model, d_inner, dropout,
self.pos_ff = MultiHeadHierarchicalMoEPositionwiseFF(d_model, d_inner, dropout,
pre_lnorm=kwargs.get('pre_lnorm'))
def forward(self, dec_inp, dec_attn_mask=None, mems=None):
......@@ -840,7 +840,7 @@ class RelLearnableDecoderLayer(nn.Module):
self.dec_attn = RelLearnableMultiHeadAttn(n_head, d_model, d_head, dropout,
**kwargs)
self.pos_ff = HierarchicalMoEPositionwiseFF(d_model, d_inner, dropout,
self.pos_ff = MultiHeadHierarchicalMoEPositionwiseFF(d_model, d_inner, dropout,
pre_lnorm=kwargs.get('pre_lnorm'))
def forward(self, dec_inp, r_emb, r_w_bias, r_bias, dec_attn_mask=None, mems=None):
......@@ -861,7 +861,7 @@ class RelPartialLearnableDecoderLayer(nn.Module):
self.dec_attn = RelPartialLearnableMultiHeadAttn(n_head, d_model,
d_head, dropout, **kwargs)
self.pos_ff = HierarchicalMoEPositionwiseFF(d_model, d_inner, dropout,
self.pos_ff = MultiHeadHierarchicalMoEPositionwiseFF(d_model, d_inner, dropout,
pre_lnorm=kwargs.get('pre_lnorm'))
def forward(self, dec_inp, r, r_w_bias, r_r_bias, dec_attn_mask=None, mems=None):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment