You need to sign in or sign up before continuing.
Commit fa023f32 authored by Jiezhong Qiu's avatar Jiezhong Qiu
Browse files

allow user to indicate number of experts and topk

parent 3bdfae96
...@@ -583,7 +583,8 @@ class MultiHeadAttn(nn.Module): ...@@ -583,7 +583,8 @@ class MultiHeadAttn(nn.Module):
class RelMultiHeadAttn(nn.Module): class RelMultiHeadAttn(nn.Module):
def __init__(self, n_head, d_model, d_head, dropout, dropatt=0, def __init__(self, n_head, d_model, d_head, dropout, dropatt=0,
tgt_len=None, ext_len=None, mem_len=None, pre_lnorm=False): tgt_len=None, ext_len=None, mem_len=None, pre_lnorm=False,
moe_num_expert=64, moe_top_k=2):
super(RelMultiHeadAttn, self).__init__() super(RelMultiHeadAttn, self).__init__()
self.n_head = n_head self.n_head = n_head
...@@ -819,10 +820,10 @@ class RelLearnableMultiHeadAttn(RelMultiHeadAttn): ...@@ -819,10 +820,10 @@ class RelLearnableMultiHeadAttn(RelMultiHeadAttn):
from fmoe import FMoETransformerMLP from fmoe import FMoETransformerMLP
class CustomizedMoEPositionwiseFF(FMoETransformerMLP): class CustomizedMoEPositionwiseFF(FMoETransformerMLP):
def __init__(self, d_model, d_inner, dropout, pre_lnorm=False): def __init__(self, d_model, d_inner, dropout, pre_lnorm=False, moe_num_expert=64, moe_top_k=2):
def activation(x): def activation(x):
return self.dropout(F.relu(x)) return self.dropout(F.relu(x))
super().__init__(num_expert=64, d_model=d_model, d_hidden=d_inner, topk=2, super().__init__(num_expert=moe_num_expert, d_model=d_model, d_hidden=d_inner, top_k=moe_top_k,
pre_lnorm=pre_lnorm, activation=activation) pre_lnorm=pre_lnorm, activation=activation)
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)
self.bias = nn.Parameter( self.bias = nn.Parameter(
...@@ -841,7 +842,9 @@ class DecoderLayer(nn.Module): ...@@ -841,7 +842,9 @@ class DecoderLayer(nn.Module):
self.dec_attn = MultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs) self.dec_attn = MultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs)
# self.dec_attn = ExtendedMultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs) # self.dec_attn = ExtendedMultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs)
self.pos_ff = CustomizedMoEPositionwiseFF(d_model, d_inner, dropout, self.pos_ff = CustomizedMoEPositionwiseFF(d_model, d_inner, dropout,
pre_lnorm=kwargs.get('pre_lnorm')) pre_lnorm=kwargs.get('pre_lnorm'),
moe_num_expert=kwargs.get('moe_num_expert'),
moe_top_k=kwargs.get('moe_top_k'))
def forward(self, dec_inp, dec_attn_mask=None, mems=None): def forward(self, dec_inp, dec_attn_mask=None, mems=None):
...@@ -861,7 +864,9 @@ class RelLearnableDecoderLayer(nn.Module): ...@@ -861,7 +864,9 @@ class RelLearnableDecoderLayer(nn.Module):
self.dec_attn = RelLearnableMultiHeadAttn(n_head, d_model, d_head, dropout, self.dec_attn = RelLearnableMultiHeadAttn(n_head, d_model, d_head, dropout,
**kwargs) **kwargs)
self.pos_ff = CustomizedMoEPositionwiseFF(d_model, d_inner, dropout, self.pos_ff = CustomizedMoEPositionwiseFF(d_model, d_inner, dropout,
pre_lnorm=kwargs.get('pre_lnorm')) pre_lnorm=kwargs.get('pre_lnorm'),
moe_num_expert=kwargs.get('moe_num_expert'),
moe_top_k=kwargs.get('moe_top_k'))
def forward(self, dec_inp, r_emb, r_w_bias, r_bias, dec_attn_mask=None, mems=None): def forward(self, dec_inp, r_emb, r_w_bias, r_bias, dec_attn_mask=None, mems=None):
...@@ -882,7 +887,9 @@ class RelPartialLearnableDecoderLayer(nn.Module): ...@@ -882,7 +887,9 @@ class RelPartialLearnableDecoderLayer(nn.Module):
self.dec_attn = RelPartialLearnableMultiHeadAttn(n_head, d_model, self.dec_attn = RelPartialLearnableMultiHeadAttn(n_head, d_model,
d_head, dropout, **kwargs) d_head, dropout, **kwargs)
self.pos_ff = CustomizedMoEPositionwiseFF(d_model, d_inner, dropout, self.pos_ff = CustomizedMoEPositionwiseFF(d_model, d_inner, dropout,
pre_lnorm=kwargs.get('pre_lnorm')) pre_lnorm=kwargs.get('pre_lnorm'),
moe_num_expert=kwargs.get('moe_num_expert'),
moe_top_k=kwargs.get('moe_top_k'))
def forward(self, dec_inp, r, r_w_bias, r_r_bias, dec_attn_mask=None, mems=None): def forward(self, dec_inp, r, r_w_bias, r_r_bias, dec_attn_mask=None, mems=None):
...@@ -967,7 +974,7 @@ class MemTransformerLM(nn.Module): ...@@ -967,7 +974,7 @@ class MemTransformerLM(nn.Module):
tgt_len=None, ext_len=None, mem_len=None, tgt_len=None, ext_len=None, mem_len=None,
cutoffs=[], adapt_inp=False, cutoffs=[], adapt_inp=False,
same_length=False, attn_type=0, clamp_len=-1, same_length=False, attn_type=0, clamp_len=-1,
sample_softmax=-1): sample_softmax=-1, moe_num_expert=64, moe_top_k=2):
super(MemTransformerLM, self).__init__() super(MemTransformerLM, self).__init__()
self.n_token = n_token self.n_token = n_token
...@@ -998,7 +1005,8 @@ class MemTransformerLM(nn.Module): ...@@ -998,7 +1005,8 @@ class MemTransformerLM(nn.Module):
RelPartialLearnableDecoderLayer( RelPartialLearnableDecoderLayer(
n_head, d_model, d_head, d_inner, dropout, n_head, d_model, d_head, d_inner, dropout,
tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len, tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len,
dropatt=dropatt, pre_lnorm=pre_lnorm) dropatt=dropatt, pre_lnorm=pre_lnorm,
moe_num_expert=moe_num_expert, moe_top_k=moe_top_k)
) )
elif attn_type == 1: # learnable embeddings elif attn_type == 1: # learnable embeddings
for i in range(n_layer): for i in range(n_layer):
...@@ -1006,14 +1014,16 @@ class MemTransformerLM(nn.Module): ...@@ -1006,14 +1014,16 @@ class MemTransformerLM(nn.Module):
RelLearnableDecoderLayer( RelLearnableDecoderLayer(
n_head, d_model, d_head, d_inner, dropout, n_head, d_model, d_head, d_inner, dropout,
tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len, tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len,
dropatt=dropatt, pre_lnorm=pre_lnorm) dropatt=dropatt, pre_lnorm=pre_lnorm,
moe_num_expert=moe_num_expert, moe_top_k=moe_top_k)
) )
elif attn_type in [2, 3]: # absolute embeddings elif attn_type in [2, 3]: # absolute embeddings
for i in range(n_layer): for i in range(n_layer):
self.layers.append( self.layers.append(
DecoderLayer( DecoderLayer(
n_head, d_model, d_head, d_inner, dropout, n_head, d_model, d_head, d_inner, dropout,
dropatt=dropatt, pre_lnorm=pre_lnorm) dropatt=dropatt, pre_lnorm=pre_lnorm,
moe_num_expert=moe_num_expert, moe_top_k=moe_top_k)
) )
self.sample_softmax = sample_softmax self.sample_softmax = sample_softmax
......
...@@ -167,8 +167,13 @@ parser.add_argument('--static-loss-scale', type=float, default=1, ...@@ -167,8 +167,13 @@ parser.add_argument('--static-loss-scale', type=float, default=1,
parser.add_argument('--dynamic-loss-scale', action='store_true', parser.add_argument('--dynamic-loss-scale', action='store_true',
help='Use dynamic loss scaling. If supplied, this argument' help='Use dynamic loss scaling. If supplied, this argument'
' supersedes --static-loss-scale.') ' supersedes --static-loss-scale.')
parser.add_argument('--moe-num-expert', type=int, default=64,
help='number of experts in MoE')
parser.add_argument('--moe-top_k', type=int, default=2,
help='top_k experts in hard gate of moe')
args = parser.parse_args() args = parser.parse_args()
args.tied = not args.not_tied args.tied = not args.not_tied
assert args.moe_num_expert >= args.moe_top_k, "must have moe-num-expert >= moe-top_k"
if args.d_embed < 0: if args.d_embed < 0:
args.d_embed = args.d_model args.d_embed = args.d_model
...@@ -305,7 +310,8 @@ else: ...@@ -305,7 +310,8 @@ else:
tie_projs=tie_projs, pre_lnorm=args.pre_lnorm, tgt_len=args.tgt_len, tie_projs=tie_projs, pre_lnorm=args.pre_lnorm, tgt_len=args.tgt_len,
ext_len=args.ext_len, mem_len=args.mem_len, cutoffs=cutoffs, ext_len=args.ext_len, mem_len=args.mem_len, cutoffs=cutoffs,
same_length=args.same_length, attn_type=args.attn_type, same_length=args.same_length, attn_type=args.attn_type,
clamp_len=args.clamp_len, sample_softmax=args.sample_softmax) clamp_len=args.clamp_len, sample_softmax=args.sample_softmax,
moe_num_expert=args.moe_num_expert, moe_top_k=args.moe_top_k)
model.apply(weights_init) model.apply(weights_init)
model.word_emb.apply(weights_init) # ensure embedding init is not overridden by out_layer in case of weight sharing model.word_emb.apply(weights_init) # ensure embedding init is not overridden by out_layer in case of weight sharing
args.n_all_param = sum([p.nelement() for p in model.parameters()]) args.n_all_param = sum([p.nelement() for p in model.parameters()])
...@@ -571,7 +577,7 @@ def train(): ...@@ -571,7 +577,7 @@ def train():
# for i in range(len(avg_nnzs)): # for i in range(len(avg_nnzs)):
# avg_nnzs[i].reset() # avg_nnzs[i].reset()
# act_hist[i] /= act_hist[i].sum() # act_hist[i] /= act_hist[i].sum()
# prob, index = torch.topk(act_hist[i], min(1024, act_hist[i].size(-1))) # prob, index = torch.top_k(act_hist[i], min(1024, act_hist[i].size(-1)))
# log_str = '| layer {:2d} | top 64 prob {:3.2f} | top 128 prob {:3.2f} | top 256 prob {:3.2f} | top 512 prob {:3.2f} | top 1024 prob {:3.2f}'.format( # log_str = '| layer {:2d} | top 64 prob {:3.2f} | top 128 prob {:3.2f} | top 256 prob {:3.2f} | top 512 prob {:3.2f} | top 1024 prob {:3.2f}'.format(
# i+1, # i+1,
# prob[:64].sum().item(), # prob[:64].sum().item(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment