"src/array/cuda/uvm/array_index_select_uvm.hip" did not exist on "548c85fff6b0a5b96f6064c86397e15477283f95"
Commit 1feaaf0c authored by Jiezhong Qiu's avatar Jiezhong Qiu
Browse files

customized topk

* when k=1, it reduces to torch.max, and not surprising that torch.max is
faster than torch.topk.
* however when k=2, it is even slower than torch.topk
parent fc279894
......@@ -83,12 +83,29 @@ class MoEPositionwiseFF(nn.Module):
return output
# return output, relu_out.detach()
def my_topk(x, k, inplace=True):
y = x if inplace else x.clone()
top1_val, top1_idx = torch.max(y, dim=-1)
top1_val = top1_val.unsqueeze(-1)
top1_idx = top1_idx.unsqueeze(-1)
if k == 1:
return top1_val, top1_idx
y.scatter_(-1, top1_idx, value=float('-inf'))
top2_val, top2_idx = torch.max(y, dim=-1)
top2_val = top2_val.unsqueeze(-1)
top2_idx = top2_idx.unsqueeze(-1)
top_val = torch.cat((top1_val, top2_val), dim=-1)
top_idx = torch.cat((top1_idx, top2_idx), dim=-1)
return top_val, top_idx
class HierarchicalMoEPositionwiseFF(nn.Module):
def __init__(self, d_model, d_inner, dropout, pre_lnorm=False, n_block=64, top_block=8):
def __init__(self, d_model, d_inner, dropout, pre_lnorm=False, n_block=64, top_block=2):
super(HierarchicalMoEPositionwiseFF, self).__init__()
print("HierarchicalMoEPositionwiseFF")
assert d_inner % n_block == 0
assert top_block in [1, 2]
self.top_block = top_block
self.n_block = n_block
......@@ -131,7 +148,9 @@ class HierarchicalMoEPositionwiseFF(nn.Module):
inp = self.layer_norm(inp)
block = self.block_net(inp)
block_val, block_idx = torch.topk(block, k=self.top_block, dim=-1, largest=True, sorted=False) # [.. x top_k]
block_val, block_idx = my_topk(block, k=self.top_block)
# block_val, block_idx = torch.topk(block, k=self.top_block, dim=-1, largest=True, sorted=False) # [.. x top_k]
# block_val.mul_(self.scale)
......@@ -719,7 +738,7 @@ class DecoderLayer(nn.Module):
self.dec_attn = MultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs)
# self.dec_attn = ExtendedMultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs)
self.pos_ff = PositionwiseFF(d_model, d_inner, dropout,
self.pos_ff = HierarchicalMoEPositionwiseFF(d_model, d_inner, dropout,
pre_lnorm=kwargs.get('pre_lnorm'))
def forward(self, dec_inp, dec_attn_mask=None, mems=None):
......@@ -739,7 +758,7 @@ class RelLearnableDecoderLayer(nn.Module):
self.dec_attn = RelLearnableMultiHeadAttn(n_head, d_model, d_head, dropout,
**kwargs)
self.pos_ff = PositionwiseFF(d_model, d_inner, dropout,
self.pos_ff = HierarchicalMoEPositionwiseFF(d_model, d_inner, dropout,
pre_lnorm=kwargs.get('pre_lnorm'))
def forward(self, dec_inp, r_emb, r_w_bias, r_bias, dec_attn_mask=None, mems=None):
......@@ -760,7 +779,7 @@ class RelPartialLearnableDecoderLayer(nn.Module):
self.dec_attn = RelPartialLearnableMultiHeadAttn(n_head, d_model,
d_head, dropout, **kwargs)
self.pos_ff = PositionwiseFF(d_model, d_inner, dropout,
self.pos_ff = HierarchicalMoEPositionwiseFF(d_model, d_inner, dropout,
pre_lnorm=kwargs.get('pre_lnorm'))
def forward(self, dec_inp, r, r_w_bias, r_r_bias, dec_attn_mask=None, mems=None):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment