Commit fa023f32 authored by Jiezhong Qiu's avatar Jiezhong Qiu
Browse files

allow user to indicate number of experts and topk

parent 3bdfae96
...@@ -583,7 +583,8 @@ class MultiHeadAttn(nn.Module): ...@@ -583,7 +583,8 @@ class MultiHeadAttn(nn.Module):
class RelMultiHeadAttn(nn.Module): class RelMultiHeadAttn(nn.Module):
def __init__(self, n_head, d_model, d_head, dropout, dropatt=0, def __init__(self, n_head, d_model, d_head, dropout, dropatt=0,
tgt_len=None, ext_len=None, mem_len=None, pre_lnorm=False): tgt_len=None, ext_len=None, mem_len=None, pre_lnorm=False,
moe_num_expert=64, moe_top_k=2):
super(RelMultiHeadAttn, self).__init__() super(RelMultiHeadAttn, self).__init__()
self.n_head = n_head self.n_head = n_head
...@@ -819,10 +820,10 @@ class RelLearnableMultiHeadAttn(RelMultiHeadAttn): ...@@ -819,10 +820,10 @@ class RelLearnableMultiHeadAttn(RelMultiHeadAttn):
from fmoe import FMoETransformerMLP from fmoe import FMoETransformerMLP
class CustomizedMoEPositionwiseFF(FMoETransformerMLP): class CustomizedMoEPositionwiseFF(FMoETransformerMLP):
def __init__(self, d_model, d_inner, dropout, pre_lnorm=False): def __init__(self, d_model, d_inner, dropout, pre_lnorm=False, moe_num_expert=64, moe_top_k=2):
def activation(x): def activation(x):
return self.dropout(F.relu(x)) return self.dropout(F.relu(x))
super().__init__(num_expert=64, d_model=d_model, d_hidden=d_inner, topk=2, super().__init__(num_expert=moe_num_expert, d_model=d_model, d_hidden=d_inner, top_k=moe_top_k,
pre_lnorm=pre_lnorm, activation=activation) pre_lnorm=pre_lnorm, activation=activation)
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)
self.bias = nn.Parameter( self.bias = nn.Parameter(
...@@ -841,7 +842,9 @@ class DecoderLayer(nn.Module): ...@@ -841,7 +842,9 @@ class DecoderLayer(nn.Module):
self.dec_attn = MultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs) self.dec_attn = MultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs)
# self.dec_attn = ExtendedMultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs) # self.dec_attn = ExtendedMultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs)
self.pos_ff = CustomizedMoEPositionwiseFF(d_model, d_inner, dropout, self.pos_ff = CustomizedMoEPositionwiseFF(d_model, d_inner, dropout,
pre_lnorm=kwargs.get('pre_lnorm')) pre_lnorm=kwargs.get('pre_lnorm'),
moe_num_expert=kwargs.get('moe_num_expert'),
moe_top_k=kwargs.get('moe_top_k'))
def forward(self, dec_inp, dec_attn_mask=None, mems=None): def forward(self, dec_inp, dec_attn_mask=None, mems=None):
...@@ -861,7 +864,9 @@ class RelLearnableDecoderLayer(nn.Module): ...@@ -861,7 +864,9 @@ class RelLearnableDecoderLayer(nn.Module):
self.dec_attn = RelLearnableMultiHeadAttn(n_head, d_model, d_head, dropout, self.dec_attn = RelLearnableMultiHeadAttn(n_head, d_model, d_head, dropout,
**kwargs) **kwargs)
self.pos_ff = CustomizedMoEPositionwiseFF(d_model, d_inner, dropout, self.pos_ff = CustomizedMoEPositionwiseFF(d_model, d_inner, dropout,
pre_lnorm=kwargs.get('pre_lnorm')) pre_lnorm=kwargs.get('pre_lnorm'),
moe_num_expert=kwargs.get('moe_num_expert'),
moe_top_k=kwargs.get('moe_top_k'))
def forward(self, dec_inp, r_emb, r_w_bias, r_bias, dec_attn_mask=None, mems=None): def forward(self, dec_inp, r_emb, r_w_bias, r_bias, dec_attn_mask=None, mems=None):
...@@ -882,7 +887,9 @@ class RelPartialLearnableDecoderLayer(nn.Module): ...@@ -882,7 +887,9 @@ class RelPartialLearnableDecoderLayer(nn.Module):
self.dec_attn = RelPartialLearnableMultiHeadAttn(n_head, d_model, self.dec_attn = RelPartialLearnableMultiHeadAttn(n_head, d_model,
d_head, dropout, **kwargs) d_head, dropout, **kwargs)
self.pos_ff = CustomizedMoEPositionwiseFF(d_model, d_inner, dropout, self.pos_ff = CustomizedMoEPositionwiseFF(d_model, d_inner, dropout,
pre_lnorm=kwargs.get('pre_lnorm')) pre_lnorm=kwargs.get('pre_lnorm'),
moe_num_expert=kwargs.get('moe_num_expert'),
moe_top_k=kwargs.get('moe_top_k'))
def forward(self, dec_inp, r, r_w_bias, r_r_bias, dec_attn_mask=None, mems=None): def forward(self, dec_inp, r, r_w_bias, r_r_bias, dec_attn_mask=None, mems=None):
...@@ -967,7 +974,7 @@ class MemTransformerLM(nn.Module): ...@@ -967,7 +974,7 @@ class MemTransformerLM(nn.Module):
tgt_len=None, ext_len=None, mem_len=None, tgt_len=None, ext_len=None, mem_len=None,
cutoffs=[], adapt_inp=False, cutoffs=[], adapt_inp=False,
same_length=False, attn_type=0, clamp_len=-1, same_length=False, attn_type=0, clamp_len=-1,
sample_softmax=-1): sample_softmax=-1, moe_num_expert=64, moe_top_k=2):
super(MemTransformerLM, self).__init__() super(MemTransformerLM, self).__init__()
self.n_token = n_token self.n_token = n_token
...@@ -998,7 +1005,8 @@ class MemTransformerLM(nn.Module): ...@@ -998,7 +1005,8 @@ class MemTransformerLM(nn.Module):
RelPartialLearnableDecoderLayer( RelPartialLearnableDecoderLayer(
n_head, d_model, d_head, d_inner, dropout, n_head, d_model, d_head, d_inner, dropout,
tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len, tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len,
dropatt=dropatt, pre_lnorm=pre_lnorm) dropatt=dropatt, pre_lnorm=pre_lnorm,
moe_num_expert=moe_num_expert, moe_top_k=moe_top_k)
) )
elif attn_type == 1: # learnable embeddings elif attn_type == 1: # learnable embeddings
for i in range(n_layer): for i in range(n_layer):
...@@ -1006,14 +1014,16 @@ class MemTransformerLM(nn.Module): ...@@ -1006,14 +1014,16 @@ class MemTransformerLM(nn.Module):
RelLearnableDecoderLayer( RelLearnableDecoderLayer(
n_head, d_model, d_head, d_inner, dropout, n_head, d_model, d_head, d_inner, dropout,
tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len, tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len,
dropatt=dropatt, pre_lnorm=pre_lnorm) dropatt=dropatt, pre_lnorm=pre_lnorm,
moe_num_expert=moe_num_expert, moe_top_k=moe_top_k)
) )
elif attn_type in [2, 3]: # absolute embeddings elif attn_type in [2, 3]: # absolute embeddings
for i in range(n_layer): for i in range(n_layer):
self.layers.append( self.layers.append(
DecoderLayer( DecoderLayer(
n_head, d_model, d_head, d_inner, dropout, n_head, d_model, d_head, d_inner, dropout,
dropatt=dropatt, pre_lnorm=pre_lnorm) dropatt=dropatt, pre_lnorm=pre_lnorm,
moe_num_expert=moe_num_expert, moe_top_k=moe_top_k)
) )
self.sample_softmax = sample_softmax self.sample_softmax = sample_softmax
......
...@@ -167,8 +167,13 @@ parser.add_argument('--static-loss-scale', type=float, default=1, ...@@ -167,8 +167,13 @@ parser.add_argument('--static-loss-scale', type=float, default=1,
parser.add_argument('--dynamic-loss-scale', action='store_true', parser.add_argument('--dynamic-loss-scale', action='store_true',
help='Use dynamic loss scaling. If supplied, this argument' help='Use dynamic loss scaling. If supplied, this argument'
' supersedes --static-loss-scale.') ' supersedes --static-loss-scale.')
parser.add_argument('--moe-num-expert', type=int, default=64,
help='number of experts in MoE')
parser.add_argument('--moe-top_k', type=int, default=2,
help='top_k experts in hard gate of moe')
args = parser.parse_args() args = parser.parse_args()
args.tied = not args.not_tied args.tied = not args.not_tied
assert args.moe_num_expert >= args.moe_top_k, "must have moe-num-expert >= moe-top_k"
if args.d_embed < 0: if args.d_embed < 0:
args.d_embed = args.d_model args.d_embed = args.d_model
...@@ -305,7 +310,8 @@ else: ...@@ -305,7 +310,8 @@ else:
tie_projs=tie_projs, pre_lnorm=args.pre_lnorm, tgt_len=args.tgt_len, tie_projs=tie_projs, pre_lnorm=args.pre_lnorm, tgt_len=args.tgt_len,
ext_len=args.ext_len, mem_len=args.mem_len, cutoffs=cutoffs, ext_len=args.ext_len, mem_len=args.mem_len, cutoffs=cutoffs,
same_length=args.same_length, attn_type=args.attn_type, same_length=args.same_length, attn_type=args.attn_type,
clamp_len=args.clamp_len, sample_softmax=args.sample_softmax) clamp_len=args.clamp_len, sample_softmax=args.sample_softmax,
moe_num_expert=args.moe_num_expert, moe_top_k=args.moe_top_k)
model.apply(weights_init) model.apply(weights_init)
model.word_emb.apply(weights_init) # ensure embedding init is not overridden by out_layer in case of weight sharing model.word_emb.apply(weights_init) # ensure embedding init is not overridden by out_layer in case of weight sharing
args.n_all_param = sum([p.nelement() for p in model.parameters()]) args.n_all_param = sum([p.nelement() for p in model.parameters()])
...@@ -571,7 +577,7 @@ def train(): ...@@ -571,7 +577,7 @@ def train():
# for i in range(len(avg_nnzs)): # for i in range(len(avg_nnzs)):
# avg_nnzs[i].reset() # avg_nnzs[i].reset()
# act_hist[i] /= act_hist[i].sum() # act_hist[i] /= act_hist[i].sum()
# prob, index = torch.topk(act_hist[i], min(1024, act_hist[i].size(-1))) # prob, index = torch.top_k(act_hist[i], min(1024, act_hist[i].size(-1)))
# log_str = '| layer {:2d} | top 64 prob {:3.2f} | top 128 prob {:3.2f} | top 256 prob {:3.2f} | top 512 prob {:3.2f} | top 1024 prob {:3.2f}'.format( # log_str = '| layer {:2d} | top 64 prob {:3.2f} | top 128 prob {:3.2f} | top 256 prob {:3.2f} | top 512 prob {:3.2f} | top 1024 prob {:3.2f}'.format(
# i+1, # i+1,
# prob[:64].sum().item(), # prob[:64].sum().item(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment