Commit fa023f32 authored by Jiezhong Qiu's avatar Jiezhong Qiu
Browse files

allow user to indicate number of experts and topk

parent 3bdfae96
......@@ -583,7 +583,8 @@ class MultiHeadAttn(nn.Module):
class RelMultiHeadAttn(nn.Module):
def __init__(self, n_head, d_model, d_head, dropout, dropatt=0,
tgt_len=None, ext_len=None, mem_len=None, pre_lnorm=False):
tgt_len=None, ext_len=None, mem_len=None, pre_lnorm=False,
moe_num_expert=64, moe_top_k=2):
super(RelMultiHeadAttn, self).__init__()
self.n_head = n_head
......@@ -819,10 +820,10 @@ class RelLearnableMultiHeadAttn(RelMultiHeadAttn):
from fmoe import FMoETransformerMLP
class CustomizedMoEPositionwiseFF(FMoETransformerMLP):
def __init__(self, d_model, d_inner, dropout, pre_lnorm=False):
def __init__(self, d_model, d_inner, dropout, pre_lnorm=False, moe_num_expert=64, moe_top_k=2):
def activation(x):
return self.dropout(F.relu(x))
super().__init__(num_expert=64, d_model=d_model, d_hidden=d_inner, topk=2,
super().__init__(num_expert=moe_num_expert, d_model=d_model, d_hidden=d_inner, top_k=moe_top_k,
pre_lnorm=pre_lnorm, activation=activation)
self.dropout = nn.Dropout(dropout)
self.bias = nn.Parameter(
......@@ -841,7 +842,9 @@ class DecoderLayer(nn.Module):
self.dec_attn = MultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs)
# self.dec_attn = ExtendedMultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs)
self.pos_ff = CustomizedMoEPositionwiseFF(d_model, d_inner, dropout,
pre_lnorm=kwargs.get('pre_lnorm'))
pre_lnorm=kwargs.get('pre_lnorm'),
moe_num_expert=kwargs.get('moe_num_expert'),
moe_top_k=kwargs.get('moe_top_k'))
def forward(self, dec_inp, dec_attn_mask=None, mems=None):
......@@ -861,7 +864,9 @@ class RelLearnableDecoderLayer(nn.Module):
self.dec_attn = RelLearnableMultiHeadAttn(n_head, d_model, d_head, dropout,
**kwargs)
self.pos_ff = CustomizedMoEPositionwiseFF(d_model, d_inner, dropout,
pre_lnorm=kwargs.get('pre_lnorm'))
pre_lnorm=kwargs.get('pre_lnorm'),
moe_num_expert=kwargs.get('moe_num_expert'),
moe_top_k=kwargs.get('moe_top_k'))
def forward(self, dec_inp, r_emb, r_w_bias, r_bias, dec_attn_mask=None, mems=None):
......@@ -882,7 +887,9 @@ class RelPartialLearnableDecoderLayer(nn.Module):
self.dec_attn = RelPartialLearnableMultiHeadAttn(n_head, d_model,
d_head, dropout, **kwargs)
self.pos_ff = CustomizedMoEPositionwiseFF(d_model, d_inner, dropout,
pre_lnorm=kwargs.get('pre_lnorm'))
pre_lnorm=kwargs.get('pre_lnorm'),
moe_num_expert=kwargs.get('moe_num_expert'),
moe_top_k=kwargs.get('moe_top_k'))
def forward(self, dec_inp, r, r_w_bias, r_r_bias, dec_attn_mask=None, mems=None):
......@@ -967,7 +974,7 @@ class MemTransformerLM(nn.Module):
tgt_len=None, ext_len=None, mem_len=None,
cutoffs=[], adapt_inp=False,
same_length=False, attn_type=0, clamp_len=-1,
sample_softmax=-1):
sample_softmax=-1, moe_num_expert=64, moe_top_k=2):
super(MemTransformerLM, self).__init__()
self.n_token = n_token
......@@ -998,7 +1005,8 @@ class MemTransformerLM(nn.Module):
RelPartialLearnableDecoderLayer(
n_head, d_model, d_head, d_inner, dropout,
tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len,
dropatt=dropatt, pre_lnorm=pre_lnorm)
dropatt=dropatt, pre_lnorm=pre_lnorm,
moe_num_expert=moe_num_expert, moe_top_k=moe_top_k)
)
elif attn_type == 1: # learnable embeddings
for i in range(n_layer):
......@@ -1006,14 +1014,16 @@ class MemTransformerLM(nn.Module):
RelLearnableDecoderLayer(
n_head, d_model, d_head, d_inner, dropout,
tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len,
dropatt=dropatt, pre_lnorm=pre_lnorm)
dropatt=dropatt, pre_lnorm=pre_lnorm,
moe_num_expert=moe_num_expert, moe_top_k=moe_top_k)
)
elif attn_type in [2, 3]: # absolute embeddings
for i in range(n_layer):
self.layers.append(
DecoderLayer(
n_head, d_model, d_head, d_inner, dropout,
dropatt=dropatt, pre_lnorm=pre_lnorm)
dropatt=dropatt, pre_lnorm=pre_lnorm,
moe_num_expert=moe_num_expert, moe_top_k=moe_top_k)
)
self.sample_softmax = sample_softmax
......
......@@ -167,8 +167,13 @@ parser.add_argument('--static-loss-scale', type=float, default=1,
parser.add_argument('--dynamic-loss-scale', action='store_true',
help='Use dynamic loss scaling. If supplied, this argument'
' supersedes --static-loss-scale.')
parser.add_argument('--moe-num-expert', type=int, default=64,
help='number of experts in MoE')
parser.add_argument('--moe-top_k', type=int, default=2,
help='top_k experts in hard gate of moe')
args = parser.parse_args()
args.tied = not args.not_tied
assert args.moe_num_expert >= args.moe_top_k, "must have moe-num-expert >= moe-top_k"
if args.d_embed < 0:
args.d_embed = args.d_model
......@@ -305,7 +310,8 @@ else:
tie_projs=tie_projs, pre_lnorm=args.pre_lnorm, tgt_len=args.tgt_len,
ext_len=args.ext_len, mem_len=args.mem_len, cutoffs=cutoffs,
same_length=args.same_length, attn_type=args.attn_type,
clamp_len=args.clamp_len, sample_softmax=args.sample_softmax)
clamp_len=args.clamp_len, sample_softmax=args.sample_softmax,
moe_num_expert=args.moe_num_expert, moe_top_k=args.moe_top_k)
model.apply(weights_init)
model.word_emb.apply(weights_init) # ensure embedding init is not overridden by out_layer in case of weight sharing
args.n_all_param = sum([p.nelement() for p in model.parameters()])
......@@ -571,7 +577,7 @@ def train():
# for i in range(len(avg_nnzs)):
# avg_nnzs[i].reset()
# act_hist[i] /= act_hist[i].sum()
# prob, index = torch.topk(act_hist[i], min(1024, act_hist[i].size(-1)))
# prob, index = torch.top_k(act_hist[i], min(1024, act_hist[i].size(-1)))
# log_str = '| layer {:2d} | top 64 prob {:3.2f} | top 128 prob {:3.2f} | top 256 prob {:3.2f} | top 512 prob {:3.2f} | top 1024 prob {:3.2f}'.format(
# i+1,
# prob[:64].sum().item(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment