Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
0a942e3f
Commit
0a942e3f
authored
Feb 25, 2021
by
Jiezhong Qiu
Browse files
allow user to indicate whether to use MoE
parent
bdb64914
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
36 additions
and
19 deletions
+36
-19
examples/transformer-xl/mem_transformer.py
examples/transformer-xl/mem_transformer.py
+31
-17
examples/transformer-xl/scripts/run_enwik8_base_gshard.sh
examples/transformer-xl/scripts/run_enwik8_base_gshard.sh
+1
-0
examples/transformer-xl/train.py
examples/transformer-xl/train.py
+4
-2
No files found.
examples/transformer-xl/mem_transformer.py
View file @
0a942e3f
...
@@ -143,7 +143,7 @@ class MultiHeadAttn(nn.Module):
...
@@ -143,7 +143,7 @@ class MultiHeadAttn(nn.Module):
class
RelMultiHeadAttn
(
nn
.
Module
):
class
RelMultiHeadAttn
(
nn
.
Module
):
def
__init__
(
self
,
n_head
,
d_model
,
d_head
,
dropout
,
dropatt
=
0
,
def
__init__
(
self
,
n_head
,
d_model
,
d_head
,
dropout
,
dropatt
=
0
,
tgt_len
=
None
,
ext_len
=
None
,
mem_len
=
None
,
pre_lnorm
=
False
,
tgt_len
=
None
,
ext_len
=
None
,
mem_len
=
None
,
pre_lnorm
=
False
,
moe_num_expert
=
64
,
moe_top_k
=
2
):
moe
=
False
,
moe_num_expert
=
64
,
moe_top_k
=
2
):
super
(
RelMultiHeadAttn
,
self
).
__init__
()
super
(
RelMultiHeadAttn
,
self
).
__init__
()
self
.
n_head
=
n_head
self
.
n_head
=
n_head
...
@@ -395,10 +395,14 @@ class DecoderLayer(nn.Module):
...
@@ -395,10 +395,14 @@ class DecoderLayer(nn.Module):
super
(
DecoderLayer
,
self
).
__init__
()
super
(
DecoderLayer
,
self
).
__init__
()
self
.
dec_attn
=
MultiHeadAttn
(
n_head
,
d_model
,
d_head
,
dropout
,
**
kwargs
)
self
.
dec_attn
=
MultiHeadAttn
(
n_head
,
d_model
,
d_head
,
dropout
,
**
kwargs
)
self
.
pos_ff
=
CustomizedMoEPositionwiseFF
(
d_model
,
d_inner
,
dropout
,
if
kwargs
.
get
(
'moe'
)
is
False
:
pre_lnorm
=
kwargs
.
get
(
'pre_lnorm'
),
self
.
pos_ff
=
PositionwiseFF
(
d_model
,
d_inner
,
dropout
,
moe_num_expert
=
kwargs
.
get
(
'moe_num_expert'
),
pre_lnorm
=
kwargs
.
get
(
'pre_lnorm'
))
moe_top_k
=
kwargs
.
get
(
'moe_top_k'
))
else
:
self
.
pos_ff
=
CustomizedMoEPositionwiseFF
(
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
kwargs
.
get
(
'pre_lnorm'
),
moe_num_expert
=
kwargs
.
get
(
'moe_num_expert'
),
moe_top_k
=
kwargs
.
get
(
'moe_top_k'
))
def
forward
(
self
,
dec_inp
,
dec_attn_mask
=
None
,
mems
=
None
):
def
forward
(
self
,
dec_inp
,
dec_attn_mask
=
None
,
mems
=
None
):
...
@@ -415,10 +419,15 @@ class RelLearnableDecoderLayer(nn.Module):
...
@@ -415,10 +419,15 @@ class RelLearnableDecoderLayer(nn.Module):
self
.
dec_attn
=
RelLearnableMultiHeadAttn
(
n_head
,
d_model
,
d_head
,
dropout
,
self
.
dec_attn
=
RelLearnableMultiHeadAttn
(
n_head
,
d_model
,
d_head
,
dropout
,
**
kwargs
)
**
kwargs
)
self
.
pos_ff
=
CustomizedMoEPositionwiseFF
(
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
kwargs
.
get
(
'pre_lnorm'
),
if
kwargs
.
get
(
'moe'
)
is
False
:
moe_num_expert
=
kwargs
.
get
(
'moe_num_expert'
),
self
.
pos_ff
=
PositionwiseFF
(
d_model
,
d_inner
,
dropout
,
moe_top_k
=
kwargs
.
get
(
'moe_top_k'
))
pre_lnorm
=
kwargs
.
get
(
'pre_lnorm'
))
else
:
self
.
pos_ff
=
CustomizedMoEPositionwiseFF
(
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
kwargs
.
get
(
'pre_lnorm'
),
moe_num_expert
=
kwargs
.
get
(
'moe_num_expert'
),
moe_top_k
=
kwargs
.
get
(
'moe_top_k'
))
def
forward
(
self
,
dec_inp
,
r_emb
,
r_w_bias
,
r_bias
,
dec_attn_mask
=
None
,
mems
=
None
):
def
forward
(
self
,
dec_inp
,
r_emb
,
r_w_bias
,
r_bias
,
dec_attn_mask
=
None
,
mems
=
None
):
...
@@ -436,10 +445,15 @@ class RelPartialLearnableDecoderLayer(nn.Module):
...
@@ -436,10 +445,15 @@ class RelPartialLearnableDecoderLayer(nn.Module):
self
.
dec_attn
=
RelPartialLearnableMultiHeadAttn
(
n_head
,
d_model
,
self
.
dec_attn
=
RelPartialLearnableMultiHeadAttn
(
n_head
,
d_model
,
d_head
,
dropout
,
**
kwargs
)
d_head
,
dropout
,
**
kwargs
)
self
.
pos_ff
=
CustomizedMoEPositionwiseFF
(
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
kwargs
.
get
(
'pre_lnorm'
),
if
kwargs
.
get
(
'moe'
)
is
False
:
moe_num_expert
=
kwargs
.
get
(
'moe_num_expert'
),
self
.
pos_ff
=
PositionwiseFF
(
d_model
,
d_inner
,
dropout
,
moe_top_k
=
kwargs
.
get
(
'moe_top_k'
))
pre_lnorm
=
kwargs
.
get
(
'pre_lnorm'
))
else
:
self
.
pos_ff
=
CustomizedMoEPositionwiseFF
(
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
kwargs
.
get
(
'pre_lnorm'
),
moe_num_expert
=
kwargs
.
get
(
'moe_num_expert'
),
moe_top_k
=
kwargs
.
get
(
'moe_top_k'
))
def
forward
(
self
,
dec_inp
,
r
,
r_w_bias
,
r_r_bias
,
dec_attn_mask
=
None
,
mems
=
None
):
def
forward
(
self
,
dec_inp
,
r
,
r_w_bias
,
r_r_bias
,
dec_attn_mask
=
None
,
mems
=
None
):
...
@@ -521,7 +535,7 @@ class MemTransformerLM(nn.Module):
...
@@ -521,7 +535,7 @@ class MemTransformerLM(nn.Module):
tgt_len
=
None
,
ext_len
=
None
,
mem_len
=
None
,
tgt_len
=
None
,
ext_len
=
None
,
mem_len
=
None
,
cutoffs
=
[],
adapt_inp
=
False
,
cutoffs
=
[],
adapt_inp
=
False
,
same_length
=
False
,
attn_type
=
0
,
clamp_len
=-
1
,
same_length
=
False
,
attn_type
=
0
,
clamp_len
=-
1
,
sample_softmax
=-
1
,
moe_num_expert
=
64
,
moe_top_k
=
2
):
sample_softmax
=-
1
,
moe
=
False
,
moe_num_expert
=
64
,
moe_top_k
=
2
):
super
(
MemTransformerLM
,
self
).
__init__
()
super
(
MemTransformerLM
,
self
).
__init__
()
self
.
n_token
=
n_token
self
.
n_token
=
n_token
...
@@ -553,7 +567,7 @@ class MemTransformerLM(nn.Module):
...
@@ -553,7 +567,7 @@ class MemTransformerLM(nn.Module):
n_head
,
d_model
,
d_head
,
d_inner
,
dropout
,
n_head
,
d_model
,
d_head
,
d_inner
,
dropout
,
tgt_len
=
tgt_len
,
ext_len
=
ext_len
,
mem_len
=
mem_len
,
tgt_len
=
tgt_len
,
ext_len
=
ext_len
,
mem_len
=
mem_len
,
dropatt
=
dropatt
,
pre_lnorm
=
pre_lnorm
,
dropatt
=
dropatt
,
pre_lnorm
=
pre_lnorm
,
moe_num_expert
=
moe_num_expert
,
moe_top_k
=
moe_top_k
)
moe
=
moe
,
moe_num_expert
=
moe_num_expert
,
moe_top_k
=
moe_top_k
)
)
)
elif
attn_type
==
1
:
# learnable embeddings
elif
attn_type
==
1
:
# learnable embeddings
for
i
in
range
(
n_layer
):
for
i
in
range
(
n_layer
):
...
@@ -562,7 +576,7 @@ class MemTransformerLM(nn.Module):
...
@@ -562,7 +576,7 @@ class MemTransformerLM(nn.Module):
n_head
,
d_model
,
d_head
,
d_inner
,
dropout
,
n_head
,
d_model
,
d_head
,
d_inner
,
dropout
,
tgt_len
=
tgt_len
,
ext_len
=
ext_len
,
mem_len
=
mem_len
,
tgt_len
=
tgt_len
,
ext_len
=
ext_len
,
mem_len
=
mem_len
,
dropatt
=
dropatt
,
pre_lnorm
=
pre_lnorm
,
dropatt
=
dropatt
,
pre_lnorm
=
pre_lnorm
,
moe_num_expert
=
moe_num_expert
,
moe_top_k
=
moe_top_k
)
moe
=
moe
,
moe_num_expert
=
moe_num_expert
,
moe_top_k
=
moe_top_k
)
)
)
elif
attn_type
in
[
2
,
3
]:
# absolute embeddings
elif
attn_type
in
[
2
,
3
]:
# absolute embeddings
for
i
in
range
(
n_layer
):
for
i
in
range
(
n_layer
):
...
@@ -570,7 +584,7 @@ class MemTransformerLM(nn.Module):
...
@@ -570,7 +584,7 @@ class MemTransformerLM(nn.Module):
DecoderLayer
(
DecoderLayer
(
n_head
,
d_model
,
d_head
,
d_inner
,
dropout
,
n_head
,
d_model
,
d_head
,
d_inner
,
dropout
,
dropatt
=
dropatt
,
pre_lnorm
=
pre_lnorm
,
dropatt
=
dropatt
,
pre_lnorm
=
pre_lnorm
,
moe_num_expert
=
moe_num_expert
,
moe_top_k
=
moe_top_k
)
moe
=
moe
,
moe_num_expert
=
moe_num_expert
,
moe_top_k
=
moe_top_k
)
)
)
self
.
sample_softmax
=
sample_softmax
self
.
sample_softmax
=
sample_softmax
...
...
examples/transformer-xl/scripts/run_enwik8_base_gshard.sh
View file @
0a942e3f
...
@@ -25,6 +25,7 @@ if [[ $1 == 'train' ]]; then
...
@@ -25,6 +25,7 @@ if [[ $1 == 'train' ]]; then
--batch_size
22
\
--batch_size
22
\
--multi_gpu
\
--multi_gpu
\
--gpu0_bsz
4
\
--gpu0_bsz
4
\
--moe
--moe-num-expert
64
--moe-top-k
2
\
${
@
:2
}
${
@
:2
}
elif
[[
$1
==
'eval'
]]
;
then
elif
[[
$1
==
'eval'
]]
;
then
echo
'Run evaluation...'
echo
'Run evaluation...'
...
...
examples/transformer-xl/train.py
View file @
0a942e3f
...
@@ -141,9 +141,11 @@ parser.add_argument('--static-loss-scale', type=float, default=1,
...
@@ -141,9 +141,11 @@ parser.add_argument('--static-loss-scale', type=float, default=1,
parser
.
add_argument
(
'--dynamic-loss-scale'
,
action
=
'store_true'
,
parser
.
add_argument
(
'--dynamic-loss-scale'
,
action
=
'store_true'
,
help
=
'Use dynamic loss scaling. If supplied, this argument'
help
=
'Use dynamic loss scaling. If supplied, this argument'
' supersedes --static-loss-scale.'
)
' supersedes --static-loss-scale.'
)
parser
.
add_argument
(
'--moe'
,
action
=
'store_true'
,
help
=
'replace position-wise ffn with moe position-wise ffn'
)
parser
.
add_argument
(
'--moe-num-expert'
,
type
=
int
,
default
=
64
,
parser
.
add_argument
(
'--moe-num-expert'
,
type
=
int
,
default
=
64
,
help
=
'number of experts in MoE'
)
help
=
'number of experts in MoE'
)
parser
.
add_argument
(
'--moe-top
_
k'
,
type
=
int
,
default
=
2
,
parser
.
add_argument
(
'--moe-top
-
k'
,
type
=
int
,
default
=
2
,
help
=
'top_k experts in hard gate of moe'
)
help
=
'top_k experts in hard gate of moe'
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
args
.
tied
=
not
args
.
not_tied
args
.
tied
=
not
args
.
not_tied
...
@@ -285,7 +287,7 @@ else:
...
@@ -285,7 +287,7 @@ else:
ext_len
=
args
.
ext_len
,
mem_len
=
args
.
mem_len
,
cutoffs
=
cutoffs
,
ext_len
=
args
.
ext_len
,
mem_len
=
args
.
mem_len
,
cutoffs
=
cutoffs
,
same_length
=
args
.
same_length
,
attn_type
=
args
.
attn_type
,
same_length
=
args
.
same_length
,
attn_type
=
args
.
attn_type
,
clamp_len
=
args
.
clamp_len
,
sample_softmax
=
args
.
sample_softmax
,
clamp_len
=
args
.
clamp_len
,
sample_softmax
=
args
.
sample_softmax
,
moe_num_expert
=
args
.
moe_num_expert
,
moe_top_k
=
args
.
moe_top_k
)
moe
=
args
.
moe
,
moe_num_expert
=
args
.
moe_num_expert
,
moe_top_k
=
args
.
moe_top_k
)
model
.
apply
(
weights_init
)
model
.
apply
(
weights_init
)
model
.
word_emb
.
apply
(
weights_init
)
# ensure embedding init is not overridden by out_layer in case of weight sharing
model
.
word_emb
.
apply
(
weights_init
)
# ensure embedding init is not overridden by out_layer in case of weight sharing
args
.
n_all_param
=
sum
([
p
.
nelement
()
for
p
in
model
.
parameters
()])
args
.
n_all_param
=
sum
([
p
.
nelement
()
for
p
in
model
.
parameters
()])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment