"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "9ee3dd38626624e063a738b220d81ab6df271fdc"
Commit 1feaaf0c authored by Jiezhong Qiu's avatar Jiezhong Qiu
Browse files

customized topk

* when k=1, it reduces to torch.max, and not surprising that torch.max is
faster than torch.topk.
* however when k=2, it is even slower than torch.topk
parent fc279894
...@@ -83,12 +83,29 @@ class MoEPositionwiseFF(nn.Module): ...@@ -83,12 +83,29 @@ class MoEPositionwiseFF(nn.Module):
return output return output
# return output, relu_out.detach() # return output, relu_out.detach()
def my_topk(x, k, inplace=True):
y = x if inplace else x.clone()
top1_val, top1_idx = torch.max(y, dim=-1)
top1_val = top1_val.unsqueeze(-1)
top1_idx = top1_idx.unsqueeze(-1)
if k == 1:
return top1_val, top1_idx
y.scatter_(-1, top1_idx, value=float('-inf'))
top2_val, top2_idx = torch.max(y, dim=-1)
top2_val = top2_val.unsqueeze(-1)
top2_idx = top2_idx.unsqueeze(-1)
top_val = torch.cat((top1_val, top2_val), dim=-1)
top_idx = torch.cat((top1_idx, top2_idx), dim=-1)
return top_val, top_idx
class HierarchicalMoEPositionwiseFF(nn.Module): class HierarchicalMoEPositionwiseFF(nn.Module):
def __init__(self, d_model, d_inner, dropout, pre_lnorm=False, n_block=64, top_block=8): def __init__(self, d_model, d_inner, dropout, pre_lnorm=False, n_block=64, top_block=2):
super(HierarchicalMoEPositionwiseFF, self).__init__() super(HierarchicalMoEPositionwiseFF, self).__init__()
print("HierarchicalMoEPositionwiseFF") print("HierarchicalMoEPositionwiseFF")
assert d_inner % n_block == 0 assert d_inner % n_block == 0
assert top_block in [1, 2]
self.top_block = top_block self.top_block = top_block
self.n_block = n_block self.n_block = n_block
...@@ -131,7 +148,9 @@ class HierarchicalMoEPositionwiseFF(nn.Module): ...@@ -131,7 +148,9 @@ class HierarchicalMoEPositionwiseFF(nn.Module):
inp = self.layer_norm(inp) inp = self.layer_norm(inp)
block = self.block_net(inp) block = self.block_net(inp)
block_val, block_idx = torch.topk(block, k=self.top_block, dim=-1, largest=True, sorted=False) # [.. x top_k]
block_val, block_idx = my_topk(block, k=self.top_block)
# block_val, block_idx = torch.topk(block, k=self.top_block, dim=-1, largest=True, sorted=False) # [.. x top_k]
# block_val.mul_(self.scale) # block_val.mul_(self.scale)
...@@ -719,7 +738,7 @@ class DecoderLayer(nn.Module): ...@@ -719,7 +738,7 @@ class DecoderLayer(nn.Module):
self.dec_attn = MultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs) self.dec_attn = MultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs)
# self.dec_attn = ExtendedMultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs) # self.dec_attn = ExtendedMultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs)
self.pos_ff = PositionwiseFF(d_model, d_inner, dropout, self.pos_ff = HierarchicalMoEPositionwiseFF(d_model, d_inner, dropout,
pre_lnorm=kwargs.get('pre_lnorm')) pre_lnorm=kwargs.get('pre_lnorm'))
def forward(self, dec_inp, dec_attn_mask=None, mems=None): def forward(self, dec_inp, dec_attn_mask=None, mems=None):
...@@ -739,7 +758,7 @@ class RelLearnableDecoderLayer(nn.Module): ...@@ -739,7 +758,7 @@ class RelLearnableDecoderLayer(nn.Module):
self.dec_attn = RelLearnableMultiHeadAttn(n_head, d_model, d_head, dropout, self.dec_attn = RelLearnableMultiHeadAttn(n_head, d_model, d_head, dropout,
**kwargs) **kwargs)
self.pos_ff = PositionwiseFF(d_model, d_inner, dropout, self.pos_ff = HierarchicalMoEPositionwiseFF(d_model, d_inner, dropout,
pre_lnorm=kwargs.get('pre_lnorm')) pre_lnorm=kwargs.get('pre_lnorm'))
def forward(self, dec_inp, r_emb, r_w_bias, r_bias, dec_attn_mask=None, mems=None): def forward(self, dec_inp, r_emb, r_w_bias, r_bias, dec_attn_mask=None, mems=None):
...@@ -760,7 +779,7 @@ class RelPartialLearnableDecoderLayer(nn.Module): ...@@ -760,7 +779,7 @@ class RelPartialLearnableDecoderLayer(nn.Module):
self.dec_attn = RelPartialLearnableMultiHeadAttn(n_head, d_model, self.dec_attn = RelPartialLearnableMultiHeadAttn(n_head, d_model,
d_head, dropout, **kwargs) d_head, dropout, **kwargs)
self.pos_ff = PositionwiseFF(d_model, d_inner, dropout, self.pos_ff = HierarchicalMoEPositionwiseFF(d_model, d_inner, dropout,
pre_lnorm=kwargs.get('pre_lnorm')) pre_lnorm=kwargs.get('pre_lnorm'))
def forward(self, dec_inp, r, r_w_bias, r_r_bias, dec_attn_mask=None, mems=None): def forward(self, dec_inp, r, r_w_bias, r_r_bias, dec_attn_mask=None, mems=None):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment