Commit 0a942e3f authored by Jiezhong Qiu's avatar Jiezhong Qiu
Browse files

allow user to indicate whether to use MoE

parent bdb64914
......@@ -143,7 +143,7 @@ class MultiHeadAttn(nn.Module):
class RelMultiHeadAttn(nn.Module):
def __init__(self, n_head, d_model, d_head, dropout, dropatt=0,
tgt_len=None, ext_len=None, mem_len=None, pre_lnorm=False,
moe_num_expert=64, moe_top_k=2):
moe=False, moe_num_expert=64, moe_top_k=2):
super(RelMultiHeadAttn, self).__init__()
self.n_head = n_head
......@@ -395,10 +395,14 @@ class DecoderLayer(nn.Module):
super(DecoderLayer, self).__init__()
self.dec_attn = MultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs)
self.pos_ff = CustomizedMoEPositionwiseFF(d_model, d_inner, dropout,
pre_lnorm=kwargs.get('pre_lnorm'),
moe_num_expert=kwargs.get('moe_num_expert'),
moe_top_k=kwargs.get('moe_top_k'))
if kwargs.get('moe') is False:
self.pos_ff = PositionwiseFF(d_model, d_inner, dropout,
pre_lnorm=kwargs.get('pre_lnorm'))
else:
self.pos_ff = CustomizedMoEPositionwiseFF(d_model, d_inner, dropout,
pre_lnorm=kwargs.get('pre_lnorm'),
moe_num_expert=kwargs.get('moe_num_expert'),
moe_top_k=kwargs.get('moe_top_k'))
def forward(self, dec_inp, dec_attn_mask=None, mems=None):
......@@ -415,10 +419,15 @@ class RelLearnableDecoderLayer(nn.Module):
self.dec_attn = RelLearnableMultiHeadAttn(n_head, d_model, d_head, dropout,
**kwargs)
self.pos_ff = CustomizedMoEPositionwiseFF(d_model, d_inner, dropout,
pre_lnorm=kwargs.get('pre_lnorm'),
moe_num_expert=kwargs.get('moe_num_expert'),
moe_top_k=kwargs.get('moe_top_k'))
if kwargs.get('moe') is False:
self.pos_ff = PositionwiseFF(d_model, d_inner, dropout,
pre_lnorm=kwargs.get('pre_lnorm'))
else:
self.pos_ff = CustomizedMoEPositionwiseFF(d_model, d_inner, dropout,
pre_lnorm=kwargs.get('pre_lnorm'),
moe_num_expert=kwargs.get('moe_num_expert'),
moe_top_k=kwargs.get('moe_top_k'))
def forward(self, dec_inp, r_emb, r_w_bias, r_bias, dec_attn_mask=None, mems=None):
......@@ -436,10 +445,15 @@ class RelPartialLearnableDecoderLayer(nn.Module):
self.dec_attn = RelPartialLearnableMultiHeadAttn(n_head, d_model,
d_head, dropout, **kwargs)
self.pos_ff = CustomizedMoEPositionwiseFF(d_model, d_inner, dropout,
pre_lnorm=kwargs.get('pre_lnorm'),
moe_num_expert=kwargs.get('moe_num_expert'),
moe_top_k=kwargs.get('moe_top_k'))
if kwargs.get('moe') is False:
self.pos_ff = PositionwiseFF(d_model, d_inner, dropout,
pre_lnorm=kwargs.get('pre_lnorm'))
else:
self.pos_ff = CustomizedMoEPositionwiseFF(d_model, d_inner, dropout,
pre_lnorm=kwargs.get('pre_lnorm'),
moe_num_expert=kwargs.get('moe_num_expert'),
moe_top_k=kwargs.get('moe_top_k'))
def forward(self, dec_inp, r, r_w_bias, r_r_bias, dec_attn_mask=None, mems=None):
......@@ -521,7 +535,7 @@ class MemTransformerLM(nn.Module):
tgt_len=None, ext_len=None, mem_len=None,
cutoffs=[], adapt_inp=False,
same_length=False, attn_type=0, clamp_len=-1,
sample_softmax=-1, moe_num_expert=64, moe_top_k=2):
sample_softmax=-1, moe=False, moe_num_expert=64, moe_top_k=2):
super(MemTransformerLM, self).__init__()
self.n_token = n_token
......@@ -553,7 +567,7 @@ class MemTransformerLM(nn.Module):
n_head, d_model, d_head, d_inner, dropout,
tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len,
dropatt=dropatt, pre_lnorm=pre_lnorm,
moe_num_expert=moe_num_expert, moe_top_k=moe_top_k)
moe=moe, moe_num_expert=moe_num_expert, moe_top_k=moe_top_k)
)
elif attn_type == 1: # learnable embeddings
for i in range(n_layer):
......@@ -562,7 +576,7 @@ class MemTransformerLM(nn.Module):
n_head, d_model, d_head, d_inner, dropout,
tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len,
dropatt=dropatt, pre_lnorm=pre_lnorm,
moe_num_expert=moe_num_expert, moe_top_k=moe_top_k)
moe=moe, moe_num_expert=moe_num_expert, moe_top_k=moe_top_k)
)
elif attn_type in [2, 3]: # absolute embeddings
for i in range(n_layer):
......@@ -570,7 +584,7 @@ class MemTransformerLM(nn.Module):
DecoderLayer(
n_head, d_model, d_head, d_inner, dropout,
dropatt=dropatt, pre_lnorm=pre_lnorm,
moe_num_expert=moe_num_expert, moe_top_k=moe_top_k)
moe=moe, moe_num_expert=moe_num_expert, moe_top_k=moe_top_k)
)
self.sample_softmax = sample_softmax
......
......@@ -25,6 +25,7 @@ if [[ $1 == 'train' ]]; then
--batch_size 22 \
--multi_gpu \
--gpu0_bsz 4 \
--moe --moe-num-expert 64 --moe-top-k 2 \
${@:2}
elif [[ $1 == 'eval' ]]; then
echo 'Run evaluation...'
......
......@@ -141,9 +141,11 @@ parser.add_argument('--static-loss-scale', type=float, default=1,
parser.add_argument('--dynamic-loss-scale', action='store_true',
help='Use dynamic loss scaling. If supplied, this argument'
' supersedes --static-loss-scale.')
parser.add_argument('--moe', action='store_true',
help='replace position-wise ffn with moe position-wise ffn')
parser.add_argument('--moe-num-expert', type=int, default=64,
help='number of experts in MoE')
parser.add_argument('--moe-top_k', type=int, default=2,
parser.add_argument('--moe-top-k', type=int, default=2,
help='top_k experts in hard gate of moe')
args = parser.parse_args()
args.tied = not args.not_tied
......@@ -285,7 +287,7 @@ else:
ext_len=args.ext_len, mem_len=args.mem_len, cutoffs=cutoffs,
same_length=args.same_length, attn_type=args.attn_type,
clamp_len=args.clamp_len, sample_softmax=args.sample_softmax,
moe_num_expert=args.moe_num_expert, moe_top_k=args.moe_top_k)
moe=args.moe, moe_num_expert=args.moe_num_expert, moe_top_k=args.moe_top_k)
model.apply(weights_init)
model.word_emb.apply(weights_init) # ensure embedding init is not overridden by out_layer in case of weight sharing
args.n_all_param = sum([p.nelement() for p in model.parameters()])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment