Commit 2fef9391 authored by Jiezhong Qiu's avatar Jiezhong Qiu
Browse files

update

parent 7eb40a4a
...@@ -144,8 +144,8 @@ class MultiHeadHierarchicalMoEPositionwiseFF(nn.Module): ...@@ -144,8 +144,8 @@ class MultiHeadHierarchicalMoEPositionwiseFF(nn.Module):
self.b2.data = temp.bias.data self.b2.data = temp.bias.data
for i in range(self.top_block): for i in range(self.top_block):
temp = nn.Linear(self.d_model, self.n_block) temp = nn.Linear(self.d_model, self.n_block)
self.block_net_W[:, i].data = temp.weight.data.transpose(0, 1).contiguous() self.block_net_W.data[:, i] = temp.weight.data.transpose(0, 1).contiguous()
self.block_net_b[:, i].data = temp.bias.data self.block_net_b.data[i] = temp.bias.data
def forward(self, inp): def forward(self, inp):
residual = inp residual = inp
...@@ -154,8 +154,8 @@ class MultiHeadHierarchicalMoEPositionwiseFF(nn.Module): ...@@ -154,8 +154,8 @@ class MultiHeadHierarchicalMoEPositionwiseFF(nn.Module):
block = torch.einsum("ibd,dan->iban", (inp, self.block_net_W)) + self.block_net_b # [.. x top_block x n_block ] block = torch.einsum("ibd,dan->iban", (inp, self.block_net_W)) + self.block_net_b # [.. x top_block x n_block ]
# block_val, block_idx = my_topk(block, k=1) block_val, block_idx = my_topk(block, k=1, inplace=True)
block_val, block_idx = torch.topk(block, k=1, dim=-1, largest=True, sorted=False) # [.. x top_k x 1] # block_val, block_idx = torch.topk(block, k=1, dim=-1, largest=True, sorted=False) # [.. x top_k x 1]
block_val = block_val.squeeze(-1) block_val = block_val.squeeze(-1)
block_idx = block_idx.squeeze(-1) block_idx = block_idx.squeeze(-1)
...@@ -820,7 +820,7 @@ class DecoderLayer(nn.Module): ...@@ -820,7 +820,7 @@ class DecoderLayer(nn.Module):
self.dec_attn = MultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs) self.dec_attn = MultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs)
# self.dec_attn = ExtendedMultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs) # self.dec_attn = ExtendedMultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs)
self.pos_ff = HierarchicalMoEPositionwiseFF(d_model, d_inner, dropout, self.pos_ff = MultiHeadHierarchicalMoEPositionwiseFF(d_model, d_inner, dropout,
pre_lnorm=kwargs.get('pre_lnorm')) pre_lnorm=kwargs.get('pre_lnorm'))
def forward(self, dec_inp, dec_attn_mask=None, mems=None): def forward(self, dec_inp, dec_attn_mask=None, mems=None):
...@@ -840,7 +840,7 @@ class RelLearnableDecoderLayer(nn.Module): ...@@ -840,7 +840,7 @@ class RelLearnableDecoderLayer(nn.Module):
self.dec_attn = RelLearnableMultiHeadAttn(n_head, d_model, d_head, dropout, self.dec_attn = RelLearnableMultiHeadAttn(n_head, d_model, d_head, dropout,
**kwargs) **kwargs)
self.pos_ff = HierarchicalMoEPositionwiseFF(d_model, d_inner, dropout, self.pos_ff = MultiHeadHierarchicalMoEPositionwiseFF(d_model, d_inner, dropout,
pre_lnorm=kwargs.get('pre_lnorm')) pre_lnorm=kwargs.get('pre_lnorm'))
def forward(self, dec_inp, r_emb, r_w_bias, r_bias, dec_attn_mask=None, mems=None): def forward(self, dec_inp, r_emb, r_w_bias, r_bias, dec_attn_mask=None, mems=None):
...@@ -861,7 +861,7 @@ class RelPartialLearnableDecoderLayer(nn.Module): ...@@ -861,7 +861,7 @@ class RelPartialLearnableDecoderLayer(nn.Module):
self.dec_attn = RelPartialLearnableMultiHeadAttn(n_head, d_model, self.dec_attn = RelPartialLearnableMultiHeadAttn(n_head, d_model,
d_head, dropout, **kwargs) d_head, dropout, **kwargs)
self.pos_ff = HierarchicalMoEPositionwiseFF(d_model, d_inner, dropout, self.pos_ff = MultiHeadHierarchicalMoEPositionwiseFF(d_model, d_inner, dropout,
pre_lnorm=kwargs.get('pre_lnorm')) pre_lnorm=kwargs.get('pre_lnorm'))
def forward(self, dec_inp, r, r_w_bias, r_r_bias, dec_attn_mask=None, mems=None): def forward(self, dec_inp, r, r_w_bias, r_r_bias, dec_attn_mask=None, mems=None):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment