Commit 0a942e3f authored by Jiezhong Qiu's avatar Jiezhong Qiu
Browse files

allow user to indicate whether to use MoE

parent bdb64914
...@@ -143,7 +143,7 @@ class MultiHeadAttn(nn.Module): ...@@ -143,7 +143,7 @@ class MultiHeadAttn(nn.Module):
class RelMultiHeadAttn(nn.Module): class RelMultiHeadAttn(nn.Module):
def __init__(self, n_head, d_model, d_head, dropout, dropatt=0, def __init__(self, n_head, d_model, d_head, dropout, dropatt=0,
tgt_len=None, ext_len=None, mem_len=None, pre_lnorm=False, tgt_len=None, ext_len=None, mem_len=None, pre_lnorm=False,
moe_num_expert=64, moe_top_k=2): moe=False, moe_num_expert=64, moe_top_k=2):
super(RelMultiHeadAttn, self).__init__() super(RelMultiHeadAttn, self).__init__()
self.n_head = n_head self.n_head = n_head
...@@ -395,10 +395,14 @@ class DecoderLayer(nn.Module): ...@@ -395,10 +395,14 @@ class DecoderLayer(nn.Module):
super(DecoderLayer, self).__init__() super(DecoderLayer, self).__init__()
self.dec_attn = MultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs) self.dec_attn = MultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs)
self.pos_ff = CustomizedMoEPositionwiseFF(d_model, d_inner, dropout, if kwargs.get('moe') is False:
pre_lnorm=kwargs.get('pre_lnorm'), self.pos_ff = PositionwiseFF(d_model, d_inner, dropout,
moe_num_expert=kwargs.get('moe_num_expert'), pre_lnorm=kwargs.get('pre_lnorm'))
moe_top_k=kwargs.get('moe_top_k')) else:
self.pos_ff = CustomizedMoEPositionwiseFF(d_model, d_inner, dropout,
pre_lnorm=kwargs.get('pre_lnorm'),
moe_num_expert=kwargs.get('moe_num_expert'),
moe_top_k=kwargs.get('moe_top_k'))
def forward(self, dec_inp, dec_attn_mask=None, mems=None): def forward(self, dec_inp, dec_attn_mask=None, mems=None):
...@@ -415,10 +419,15 @@ class RelLearnableDecoderLayer(nn.Module): ...@@ -415,10 +419,15 @@ class RelLearnableDecoderLayer(nn.Module):
self.dec_attn = RelLearnableMultiHeadAttn(n_head, d_model, d_head, dropout, self.dec_attn = RelLearnableMultiHeadAttn(n_head, d_model, d_head, dropout,
**kwargs) **kwargs)
self.pos_ff = CustomizedMoEPositionwiseFF(d_model, d_inner, dropout,
pre_lnorm=kwargs.get('pre_lnorm'), if kwargs.get('moe') is False:
moe_num_expert=kwargs.get('moe_num_expert'), self.pos_ff = PositionwiseFF(d_model, d_inner, dropout,
moe_top_k=kwargs.get('moe_top_k')) pre_lnorm=kwargs.get('pre_lnorm'))
else:
self.pos_ff = CustomizedMoEPositionwiseFF(d_model, d_inner, dropout,
pre_lnorm=kwargs.get('pre_lnorm'),
moe_num_expert=kwargs.get('moe_num_expert'),
moe_top_k=kwargs.get('moe_top_k'))
def forward(self, dec_inp, r_emb, r_w_bias, r_bias, dec_attn_mask=None, mems=None): def forward(self, dec_inp, r_emb, r_w_bias, r_bias, dec_attn_mask=None, mems=None):
...@@ -436,10 +445,15 @@ class RelPartialLearnableDecoderLayer(nn.Module): ...@@ -436,10 +445,15 @@ class RelPartialLearnableDecoderLayer(nn.Module):
self.dec_attn = RelPartialLearnableMultiHeadAttn(n_head, d_model, self.dec_attn = RelPartialLearnableMultiHeadAttn(n_head, d_model,
d_head, dropout, **kwargs) d_head, dropout, **kwargs)
self.pos_ff = CustomizedMoEPositionwiseFF(d_model, d_inner, dropout,
pre_lnorm=kwargs.get('pre_lnorm'), if kwargs.get('moe') is False:
moe_num_expert=kwargs.get('moe_num_expert'), self.pos_ff = PositionwiseFF(d_model, d_inner, dropout,
moe_top_k=kwargs.get('moe_top_k')) pre_lnorm=kwargs.get('pre_lnorm'))
else:
self.pos_ff = CustomizedMoEPositionwiseFF(d_model, d_inner, dropout,
pre_lnorm=kwargs.get('pre_lnorm'),
moe_num_expert=kwargs.get('moe_num_expert'),
moe_top_k=kwargs.get('moe_top_k'))
def forward(self, dec_inp, r, r_w_bias, r_r_bias, dec_attn_mask=None, mems=None): def forward(self, dec_inp, r, r_w_bias, r_r_bias, dec_attn_mask=None, mems=None):
...@@ -521,7 +535,7 @@ class MemTransformerLM(nn.Module): ...@@ -521,7 +535,7 @@ class MemTransformerLM(nn.Module):
tgt_len=None, ext_len=None, mem_len=None, tgt_len=None, ext_len=None, mem_len=None,
cutoffs=[], adapt_inp=False, cutoffs=[], adapt_inp=False,
same_length=False, attn_type=0, clamp_len=-1, same_length=False, attn_type=0, clamp_len=-1,
sample_softmax=-1, moe_num_expert=64, moe_top_k=2): sample_softmax=-1, moe=False, moe_num_expert=64, moe_top_k=2):
super(MemTransformerLM, self).__init__() super(MemTransformerLM, self).__init__()
self.n_token = n_token self.n_token = n_token
...@@ -553,7 +567,7 @@ class MemTransformerLM(nn.Module): ...@@ -553,7 +567,7 @@ class MemTransformerLM(nn.Module):
n_head, d_model, d_head, d_inner, dropout, n_head, d_model, d_head, d_inner, dropout,
tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len, tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len,
dropatt=dropatt, pre_lnorm=pre_lnorm, dropatt=dropatt, pre_lnorm=pre_lnorm,
moe_num_expert=moe_num_expert, moe_top_k=moe_top_k) moe=moe, moe_num_expert=moe_num_expert, moe_top_k=moe_top_k)
) )
elif attn_type == 1: # learnable embeddings elif attn_type == 1: # learnable embeddings
for i in range(n_layer): for i in range(n_layer):
...@@ -562,7 +576,7 @@ class MemTransformerLM(nn.Module): ...@@ -562,7 +576,7 @@ class MemTransformerLM(nn.Module):
n_head, d_model, d_head, d_inner, dropout, n_head, d_model, d_head, d_inner, dropout,
tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len, tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len,
dropatt=dropatt, pre_lnorm=pre_lnorm, dropatt=dropatt, pre_lnorm=pre_lnorm,
moe_num_expert=moe_num_expert, moe_top_k=moe_top_k) moe=moe, moe_num_expert=moe_num_expert, moe_top_k=moe_top_k)
) )
elif attn_type in [2, 3]: # absolute embeddings elif attn_type in [2, 3]: # absolute embeddings
for i in range(n_layer): for i in range(n_layer):
...@@ -570,7 +584,7 @@ class MemTransformerLM(nn.Module): ...@@ -570,7 +584,7 @@ class MemTransformerLM(nn.Module):
DecoderLayer( DecoderLayer(
n_head, d_model, d_head, d_inner, dropout, n_head, d_model, d_head, d_inner, dropout,
dropatt=dropatt, pre_lnorm=pre_lnorm, dropatt=dropatt, pre_lnorm=pre_lnorm,
moe_num_expert=moe_num_expert, moe_top_k=moe_top_k) moe=moe, moe_num_expert=moe_num_expert, moe_top_k=moe_top_k)
) )
self.sample_softmax = sample_softmax self.sample_softmax = sample_softmax
......
...@@ -25,6 +25,7 @@ if [[ $1 == 'train' ]]; then ...@@ -25,6 +25,7 @@ if [[ $1 == 'train' ]]; then
--batch_size 22 \ --batch_size 22 \
--multi_gpu \ --multi_gpu \
--gpu0_bsz 4 \ --gpu0_bsz 4 \
--moe --moe-num-expert 64 --moe-top-k 2 \
${@:2} ${@:2}
elif [[ $1 == 'eval' ]]; then elif [[ $1 == 'eval' ]]; then
echo 'Run evaluation...' echo 'Run evaluation...'
......
...@@ -141,9 +141,11 @@ parser.add_argument('--static-loss-scale', type=float, default=1, ...@@ -141,9 +141,11 @@ parser.add_argument('--static-loss-scale', type=float, default=1,
parser.add_argument('--dynamic-loss-scale', action='store_true', parser.add_argument('--dynamic-loss-scale', action='store_true',
help='Use dynamic loss scaling. If supplied, this argument' help='Use dynamic loss scaling. If supplied, this argument'
' supersedes --static-loss-scale.') ' supersedes --static-loss-scale.')
parser.add_argument('--moe', action='store_true',
help='replace position-wise ffn with moe position-wise ffn')
parser.add_argument('--moe-num-expert', type=int, default=64, parser.add_argument('--moe-num-expert', type=int, default=64,
help='number of experts in MoE') help='number of experts in MoE')
parser.add_argument('--moe-top_k', type=int, default=2, parser.add_argument('--moe-top-k', type=int, default=2,
help='top_k experts in hard gate of moe') help='top_k experts in hard gate of moe')
args = parser.parse_args() args = parser.parse_args()
args.tied = not args.not_tied args.tied = not args.not_tied
...@@ -285,7 +287,7 @@ else: ...@@ -285,7 +287,7 @@ else:
ext_len=args.ext_len, mem_len=args.mem_len, cutoffs=cutoffs, ext_len=args.ext_len, mem_len=args.mem_len, cutoffs=cutoffs,
same_length=args.same_length, attn_type=args.attn_type, same_length=args.same_length, attn_type=args.attn_type,
clamp_len=args.clamp_len, sample_softmax=args.sample_softmax, clamp_len=args.clamp_len, sample_softmax=args.sample_softmax,
moe_num_expert=args.moe_num_expert, moe_top_k=args.moe_top_k) moe=args.moe, moe_num_expert=args.moe_num_expert, moe_top_k=args.moe_top_k)
model.apply(weights_init) model.apply(weights_init)
model.word_emb.apply(weights_init) # ensure embedding init is not overridden by out_layer in case of weight sharing model.word_emb.apply(weights_init) # ensure embedding init is not overridden by out_layer in case of weight sharing
args.n_all_param = sum([p.nelement() for p in model.parameters()]) args.n_all_param = sum([p.nelement() for p in model.parameters()])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment