Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
0a942e3f
Commit
0a942e3f
authored
Feb 25, 2021
by
Jiezhong Qiu
Browse files
allow user to indicate whether to use MoE
parent
bdb64914
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
36 additions
and
19 deletions
+36
-19
examples/transformer-xl/mem_transformer.py
examples/transformer-xl/mem_transformer.py
+31
-17
examples/transformer-xl/scripts/run_enwik8_base_gshard.sh
examples/transformer-xl/scripts/run_enwik8_base_gshard.sh
+1
-0
examples/transformer-xl/train.py
examples/transformer-xl/train.py
+4
-2
No files found.
examples/transformer-xl/mem_transformer.py
View file @
0a942e3f
...
...
@@ -143,7 +143,7 @@ class MultiHeadAttn(nn.Module):
class
RelMultiHeadAttn
(
nn
.
Module
):
def
__init__
(
self
,
n_head
,
d_model
,
d_head
,
dropout
,
dropatt
=
0
,
tgt_len
=
None
,
ext_len
=
None
,
mem_len
=
None
,
pre_lnorm
=
False
,
moe_num_expert
=
64
,
moe_top_k
=
2
):
moe
=
False
,
moe_num_expert
=
64
,
moe_top_k
=
2
):
super
(
RelMultiHeadAttn
,
self
).
__init__
()
self
.
n_head
=
n_head
...
...
@@ -395,10 +395,14 @@ class DecoderLayer(nn.Module):
super
(
DecoderLayer
,
self
).
__init__
()
self
.
dec_attn
=
MultiHeadAttn
(
n_head
,
d_model
,
d_head
,
dropout
,
**
kwargs
)
self
.
pos_ff
=
CustomizedMoEPositionwiseFF
(
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
kwargs
.
get
(
'pre_lnorm'
),
moe_num_expert
=
kwargs
.
get
(
'moe_num_expert'
),
moe_top_k
=
kwargs
.
get
(
'moe_top_k'
))
if
kwargs
.
get
(
'moe'
)
is
False
:
self
.
pos_ff
=
PositionwiseFF
(
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
kwargs
.
get
(
'pre_lnorm'
))
else
:
self
.
pos_ff
=
CustomizedMoEPositionwiseFF
(
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
kwargs
.
get
(
'pre_lnorm'
),
moe_num_expert
=
kwargs
.
get
(
'moe_num_expert'
),
moe_top_k
=
kwargs
.
get
(
'moe_top_k'
))
def
forward
(
self
,
dec_inp
,
dec_attn_mask
=
None
,
mems
=
None
):
...
...
@@ -415,10 +419,15 @@ class RelLearnableDecoderLayer(nn.Module):
self
.
dec_attn
=
RelLearnableMultiHeadAttn
(
n_head
,
d_model
,
d_head
,
dropout
,
**
kwargs
)
self
.
pos_ff
=
CustomizedMoEPositionwiseFF
(
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
kwargs
.
get
(
'pre_lnorm'
),
moe_num_expert
=
kwargs
.
get
(
'moe_num_expert'
),
moe_top_k
=
kwargs
.
get
(
'moe_top_k'
))
if
kwargs
.
get
(
'moe'
)
is
False
:
self
.
pos_ff
=
PositionwiseFF
(
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
kwargs
.
get
(
'pre_lnorm'
))
else
:
self
.
pos_ff
=
CustomizedMoEPositionwiseFF
(
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
kwargs
.
get
(
'pre_lnorm'
),
moe_num_expert
=
kwargs
.
get
(
'moe_num_expert'
),
moe_top_k
=
kwargs
.
get
(
'moe_top_k'
))
def
forward
(
self
,
dec_inp
,
r_emb
,
r_w_bias
,
r_bias
,
dec_attn_mask
=
None
,
mems
=
None
):
...
...
@@ -436,10 +445,15 @@ class RelPartialLearnableDecoderLayer(nn.Module):
self
.
dec_attn
=
RelPartialLearnableMultiHeadAttn
(
n_head
,
d_model
,
d_head
,
dropout
,
**
kwargs
)
self
.
pos_ff
=
CustomizedMoEPositionwiseFF
(
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
kwargs
.
get
(
'pre_lnorm'
),
moe_num_expert
=
kwargs
.
get
(
'moe_num_expert'
),
moe_top_k
=
kwargs
.
get
(
'moe_top_k'
))
if
kwargs
.
get
(
'moe'
)
is
False
:
self
.
pos_ff
=
PositionwiseFF
(
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
kwargs
.
get
(
'pre_lnorm'
))
else
:
self
.
pos_ff
=
CustomizedMoEPositionwiseFF
(
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
kwargs
.
get
(
'pre_lnorm'
),
moe_num_expert
=
kwargs
.
get
(
'moe_num_expert'
),
moe_top_k
=
kwargs
.
get
(
'moe_top_k'
))
def
forward
(
self
,
dec_inp
,
r
,
r_w_bias
,
r_r_bias
,
dec_attn_mask
=
None
,
mems
=
None
):
...
...
@@ -521,7 +535,7 @@ class MemTransformerLM(nn.Module):
tgt_len
=
None
,
ext_len
=
None
,
mem_len
=
None
,
cutoffs
=
[],
adapt_inp
=
False
,
same_length
=
False
,
attn_type
=
0
,
clamp_len
=-
1
,
sample_softmax
=-
1
,
moe_num_expert
=
64
,
moe_top_k
=
2
):
sample_softmax
=-
1
,
moe
=
False
,
moe_num_expert
=
64
,
moe_top_k
=
2
):
super
(
MemTransformerLM
,
self
).
__init__
()
self
.
n_token
=
n_token
...
...
@@ -553,7 +567,7 @@ class MemTransformerLM(nn.Module):
n_head
,
d_model
,
d_head
,
d_inner
,
dropout
,
tgt_len
=
tgt_len
,
ext_len
=
ext_len
,
mem_len
=
mem_len
,
dropatt
=
dropatt
,
pre_lnorm
=
pre_lnorm
,
moe_num_expert
=
moe_num_expert
,
moe_top_k
=
moe_top_k
)
moe
=
moe
,
moe_num_expert
=
moe_num_expert
,
moe_top_k
=
moe_top_k
)
)
elif
attn_type
==
1
:
# learnable embeddings
for
i
in
range
(
n_layer
):
...
...
@@ -562,7 +576,7 @@ class MemTransformerLM(nn.Module):
n_head
,
d_model
,
d_head
,
d_inner
,
dropout
,
tgt_len
=
tgt_len
,
ext_len
=
ext_len
,
mem_len
=
mem_len
,
dropatt
=
dropatt
,
pre_lnorm
=
pre_lnorm
,
moe_num_expert
=
moe_num_expert
,
moe_top_k
=
moe_top_k
)
moe
=
moe
,
moe_num_expert
=
moe_num_expert
,
moe_top_k
=
moe_top_k
)
)
elif
attn_type
in
[
2
,
3
]:
# absolute embeddings
for
i
in
range
(
n_layer
):
...
...
@@ -570,7 +584,7 @@ class MemTransformerLM(nn.Module):
DecoderLayer
(
n_head
,
d_model
,
d_head
,
d_inner
,
dropout
,
dropatt
=
dropatt
,
pre_lnorm
=
pre_lnorm
,
moe_num_expert
=
moe_num_expert
,
moe_top_k
=
moe_top_k
)
moe
=
moe
,
moe_num_expert
=
moe_num_expert
,
moe_top_k
=
moe_top_k
)
)
self
.
sample_softmax
=
sample_softmax
...
...
examples/transformer-xl/scripts/run_enwik8_base_gshard.sh
View file @
0a942e3f
...
...
@@ -25,6 +25,7 @@ if [[ $1 == 'train' ]]; then
--batch_size
22
\
--multi_gpu
\
--gpu0_bsz
4
\
--moe
--moe-num-expert
64
--moe-top-k
2
\
${
@
:2
}
elif
[[
$1
==
'eval'
]]
;
then
echo
'Run evaluation...'
...
...
examples/transformer-xl/train.py
View file @
0a942e3f
...
...
@@ -141,9 +141,11 @@ parser.add_argument('--static-loss-scale', type=float, default=1,
parser
.
add_argument
(
'--dynamic-loss-scale'
,
action
=
'store_true'
,
help
=
'Use dynamic loss scaling. If supplied, this argument'
' supersedes --static-loss-scale.'
)
parser
.
add_argument
(
'--moe'
,
action
=
'store_true'
,
help
=
'replace position-wise ffn with moe position-wise ffn'
)
parser
.
add_argument
(
'--moe-num-expert'
,
type
=
int
,
default
=
64
,
help
=
'number of experts in MoE'
)
parser
.
add_argument
(
'--moe-top
_
k'
,
type
=
int
,
default
=
2
,
parser
.
add_argument
(
'--moe-top
-
k'
,
type
=
int
,
default
=
2
,
help
=
'top_k experts in hard gate of moe'
)
args
=
parser
.
parse_args
()
args
.
tied
=
not
args
.
not_tied
...
...
@@ -285,7 +287,7 @@ else:
ext_len
=
args
.
ext_len
,
mem_len
=
args
.
mem_len
,
cutoffs
=
cutoffs
,
same_length
=
args
.
same_length
,
attn_type
=
args
.
attn_type
,
clamp_len
=
args
.
clamp_len
,
sample_softmax
=
args
.
sample_softmax
,
moe_num_expert
=
args
.
moe_num_expert
,
moe_top_k
=
args
.
moe_top_k
)
moe
=
args
.
moe
,
moe_num_expert
=
args
.
moe_num_expert
,
moe_top_k
=
args
.
moe_top_k
)
model
.
apply
(
weights_init
)
model
.
word_emb
.
apply
(
weights_init
)
# ensure embedding init is not overridden by out_layer in case of weight sharing
args
.
n_all_param
=
sum
([
p
.
nelement
()
for
p
in
model
.
parameters
()])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment