Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
0a942e3f
"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "214520c66aa0697205667c21aff30452c88893ef"
Commit
0a942e3f
authored
Feb 25, 2021
by
Jiezhong Qiu
Browse files
allow user to indicate whether to use MoE
parent
bdb64914
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
36 additions
and
19 deletions
+36
-19
examples/transformer-xl/mem_transformer.py
examples/transformer-xl/mem_transformer.py
+31
-17
examples/transformer-xl/scripts/run_enwik8_base_gshard.sh
examples/transformer-xl/scripts/run_enwik8_base_gshard.sh
+1
-0
examples/transformer-xl/train.py
examples/transformer-xl/train.py
+4
-2
No files found.
examples/transformer-xl/mem_transformer.py
View file @
0a942e3f
...
@@ -143,7 +143,7 @@ class MultiHeadAttn(nn.Module):
...
@@ -143,7 +143,7 @@ class MultiHeadAttn(nn.Module):
class
RelMultiHeadAttn
(
nn
.
Module
):
class
RelMultiHeadAttn
(
nn
.
Module
):
def
__init__
(
self
,
n_head
,
d_model
,
d_head
,
dropout
,
dropatt
=
0
,
def
__init__
(
self
,
n_head
,
d_model
,
d_head
,
dropout
,
dropatt
=
0
,
tgt_len
=
None
,
ext_len
=
None
,
mem_len
=
None
,
pre_lnorm
=
False
,
tgt_len
=
None
,
ext_len
=
None
,
mem_len
=
None
,
pre_lnorm
=
False
,
moe_num_expert
=
64
,
moe_top_k
=
2
):
moe
=
False
,
moe_num_expert
=
64
,
moe_top_k
=
2
):
super
(
RelMultiHeadAttn
,
self
).
__init__
()
super
(
RelMultiHeadAttn
,
self
).
__init__
()
self
.
n_head
=
n_head
self
.
n_head
=
n_head
...
@@ -395,10 +395,14 @@ class DecoderLayer(nn.Module):
...
@@ -395,10 +395,14 @@ class DecoderLayer(nn.Module):
super
(
DecoderLayer
,
self
).
__init__
()
super
(
DecoderLayer
,
self
).
__init__
()
self
.
dec_attn
=
MultiHeadAttn
(
n_head
,
d_model
,
d_head
,
dropout
,
**
kwargs
)
self
.
dec_attn
=
MultiHeadAttn
(
n_head
,
d_model
,
d_head
,
dropout
,
**
kwargs
)
self
.
pos_ff
=
CustomizedMoEPositionwiseFF
(
d_model
,
d_inner
,
dropout
,
if
kwargs
.
get
(
'moe'
)
is
False
:
pre_lnorm
=
kwargs
.
get
(
'pre_lnorm'
),
self
.
pos_ff
=
PositionwiseFF
(
d_model
,
d_inner
,
dropout
,
moe_num_expert
=
kwargs
.
get
(
'moe_num_expert'
),
pre_lnorm
=
kwargs
.
get
(
'pre_lnorm'
))
moe_top_k
=
kwargs
.
get
(
'moe_top_k'
))
else
:
self
.
pos_ff
=
CustomizedMoEPositionwiseFF
(
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
kwargs
.
get
(
'pre_lnorm'
),
moe_num_expert
=
kwargs
.
get
(
'moe_num_expert'
),
moe_top_k
=
kwargs
.
get
(
'moe_top_k'
))
def
forward
(
self
,
dec_inp
,
dec_attn_mask
=
None
,
mems
=
None
):
def
forward
(
self
,
dec_inp
,
dec_attn_mask
=
None
,
mems
=
None
):
...
@@ -415,10 +419,15 @@ class RelLearnableDecoderLayer(nn.Module):
...
@@ -415,10 +419,15 @@ class RelLearnableDecoderLayer(nn.Module):
self
.
dec_attn
=
RelLearnableMultiHeadAttn
(
n_head
,
d_model
,
d_head
,
dropout
,
self
.
dec_attn
=
RelLearnableMultiHeadAttn
(
n_head
,
d_model
,
d_head
,
dropout
,
**
kwargs
)
**
kwargs
)
self
.
pos_ff
=
CustomizedMoEPositionwiseFF
(
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
kwargs
.
get
(
'pre_lnorm'
),
if
kwargs
.
get
(
'moe'
)
is
False
:
moe_num_expert
=
kwargs
.
get
(
'moe_num_expert'
),
self
.
pos_ff
=
PositionwiseFF
(
d_model
,
d_inner
,
dropout
,
moe_top_k
=
kwargs
.
get
(
'moe_top_k'
))
pre_lnorm
=
kwargs
.
get
(
'pre_lnorm'
))
else
:
self
.
pos_ff
=
CustomizedMoEPositionwiseFF
(
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
kwargs
.
get
(
'pre_lnorm'
),
moe_num_expert
=
kwargs
.
get
(
'moe_num_expert'
),
moe_top_k
=
kwargs
.
get
(
'moe_top_k'
))
def
forward
(
self
,
dec_inp
,
r_emb
,
r_w_bias
,
r_bias
,
dec_attn_mask
=
None
,
mems
=
None
):
def
forward
(
self
,
dec_inp
,
r_emb
,
r_w_bias
,
r_bias
,
dec_attn_mask
=
None
,
mems
=
None
):
...
@@ -436,10 +445,15 @@ class RelPartialLearnableDecoderLayer(nn.Module):
...
@@ -436,10 +445,15 @@ class RelPartialLearnableDecoderLayer(nn.Module):
self
.
dec_attn
=
RelPartialLearnableMultiHeadAttn
(
n_head
,
d_model
,
self
.
dec_attn
=
RelPartialLearnableMultiHeadAttn
(
n_head
,
d_model
,
d_head
,
dropout
,
**
kwargs
)
d_head
,
dropout
,
**
kwargs
)
self
.
pos_ff
=
CustomizedMoEPositionwiseFF
(
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
kwargs
.
get
(
'pre_lnorm'
),
if
kwargs
.
get
(
'moe'
)
is
False
:
moe_num_expert
=
kwargs
.
get
(
'moe_num_expert'
),
self
.
pos_ff
=
PositionwiseFF
(
d_model
,
d_inner
,
dropout
,
moe_top_k
=
kwargs
.
get
(
'moe_top_k'
))
pre_lnorm
=
kwargs
.
get
(
'pre_lnorm'
))
else
:
self
.
pos_ff
=
CustomizedMoEPositionwiseFF
(
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
kwargs
.
get
(
'pre_lnorm'
),
moe_num_expert
=
kwargs
.
get
(
'moe_num_expert'
),
moe_top_k
=
kwargs
.
get
(
'moe_top_k'
))
def
forward
(
self
,
dec_inp
,
r
,
r_w_bias
,
r_r_bias
,
dec_attn_mask
=
None
,
mems
=
None
):
def
forward
(
self
,
dec_inp
,
r
,
r_w_bias
,
r_r_bias
,
dec_attn_mask
=
None
,
mems
=
None
):
...
@@ -521,7 +535,7 @@ class MemTransformerLM(nn.Module):
...
@@ -521,7 +535,7 @@ class MemTransformerLM(nn.Module):
tgt_len
=
None
,
ext_len
=
None
,
mem_len
=
None
,
tgt_len
=
None
,
ext_len
=
None
,
mem_len
=
None
,
cutoffs
=
[],
adapt_inp
=
False
,
cutoffs
=
[],
adapt_inp
=
False
,
same_length
=
False
,
attn_type
=
0
,
clamp_len
=-
1
,
same_length
=
False
,
attn_type
=
0
,
clamp_len
=-
1
,
sample_softmax
=-
1
,
moe_num_expert
=
64
,
moe_top_k
=
2
):
sample_softmax
=-
1
,
moe
=
False
,
moe_num_expert
=
64
,
moe_top_k
=
2
):
super
(
MemTransformerLM
,
self
).
__init__
()
super
(
MemTransformerLM
,
self
).
__init__
()
self
.
n_token
=
n_token
self
.
n_token
=
n_token
...
@@ -553,7 +567,7 @@ class MemTransformerLM(nn.Module):
...
@@ -553,7 +567,7 @@ class MemTransformerLM(nn.Module):
n_head
,
d_model
,
d_head
,
d_inner
,
dropout
,
n_head
,
d_model
,
d_head
,
d_inner
,
dropout
,
tgt_len
=
tgt_len
,
ext_len
=
ext_len
,
mem_len
=
mem_len
,
tgt_len
=
tgt_len
,
ext_len
=
ext_len
,
mem_len
=
mem_len
,
dropatt
=
dropatt
,
pre_lnorm
=
pre_lnorm
,
dropatt
=
dropatt
,
pre_lnorm
=
pre_lnorm
,
moe_num_expert
=
moe_num_expert
,
moe_top_k
=
moe_top_k
)
moe
=
moe
,
moe_num_expert
=
moe_num_expert
,
moe_top_k
=
moe_top_k
)
)
)
elif
attn_type
==
1
:
# learnable embeddings
elif
attn_type
==
1
:
# learnable embeddings
for
i
in
range
(
n_layer
):
for
i
in
range
(
n_layer
):
...
@@ -562,7 +576,7 @@ class MemTransformerLM(nn.Module):
...
@@ -562,7 +576,7 @@ class MemTransformerLM(nn.Module):
n_head
,
d_model
,
d_head
,
d_inner
,
dropout
,
n_head
,
d_model
,
d_head
,
d_inner
,
dropout
,
tgt_len
=
tgt_len
,
ext_len
=
ext_len
,
mem_len
=
mem_len
,
tgt_len
=
tgt_len
,
ext_len
=
ext_len
,
mem_len
=
mem_len
,
dropatt
=
dropatt
,
pre_lnorm
=
pre_lnorm
,
dropatt
=
dropatt
,
pre_lnorm
=
pre_lnorm
,
moe_num_expert
=
moe_num_expert
,
moe_top_k
=
moe_top_k
)
moe
=
moe
,
moe_num_expert
=
moe_num_expert
,
moe_top_k
=
moe_top_k
)
)
)
elif
attn_type
in
[
2
,
3
]:
# absolute embeddings
elif
attn_type
in
[
2
,
3
]:
# absolute embeddings
for
i
in
range
(
n_layer
):
for
i
in
range
(
n_layer
):
...
@@ -570,7 +584,7 @@ class MemTransformerLM(nn.Module):
...
@@ -570,7 +584,7 @@ class MemTransformerLM(nn.Module):
DecoderLayer
(
DecoderLayer
(
n_head
,
d_model
,
d_head
,
d_inner
,
dropout
,
n_head
,
d_model
,
d_head
,
d_inner
,
dropout
,
dropatt
=
dropatt
,
pre_lnorm
=
pre_lnorm
,
dropatt
=
dropatt
,
pre_lnorm
=
pre_lnorm
,
moe_num_expert
=
moe_num_expert
,
moe_top_k
=
moe_top_k
)
moe
=
moe
,
moe_num_expert
=
moe_num_expert
,
moe_top_k
=
moe_top_k
)
)
)
self
.
sample_softmax
=
sample_softmax
self
.
sample_softmax
=
sample_softmax
...
...
examples/transformer-xl/scripts/run_enwik8_base_gshard.sh
View file @
0a942e3f
...
@@ -25,6 +25,7 @@ if [[ $1 == 'train' ]]; then
...
@@ -25,6 +25,7 @@ if [[ $1 == 'train' ]]; then
--batch_size
22
\
--batch_size
22
\
--multi_gpu
\
--multi_gpu
\
--gpu0_bsz
4
\
--gpu0_bsz
4
\
--moe
--moe-num-expert
64
--moe-top-k
2
\
${
@
:2
}
${
@
:2
}
elif
[[
$1
==
'eval'
]]
;
then
elif
[[
$1
==
'eval'
]]
;
then
echo
'Run evaluation...'
echo
'Run evaluation...'
...
...
examples/transformer-xl/train.py
View file @
0a942e3f
...
@@ -141,9 +141,11 @@ parser.add_argument('--static-loss-scale', type=float, default=1,
...
@@ -141,9 +141,11 @@ parser.add_argument('--static-loss-scale', type=float, default=1,
parser
.
add_argument
(
'--dynamic-loss-scale'
,
action
=
'store_true'
,
parser
.
add_argument
(
'--dynamic-loss-scale'
,
action
=
'store_true'
,
help
=
'Use dynamic loss scaling. If supplied, this argument'
help
=
'Use dynamic loss scaling. If supplied, this argument'
' supersedes --static-loss-scale.'
)
' supersedes --static-loss-scale.'
)
parser
.
add_argument
(
'--moe'
,
action
=
'store_true'
,
help
=
'replace position-wise ffn with moe position-wise ffn'
)
parser
.
add_argument
(
'--moe-num-expert'
,
type
=
int
,
default
=
64
,
parser
.
add_argument
(
'--moe-num-expert'
,
type
=
int
,
default
=
64
,
help
=
'number of experts in MoE'
)
help
=
'number of experts in MoE'
)
parser
.
add_argument
(
'--moe-top
_
k'
,
type
=
int
,
default
=
2
,
parser
.
add_argument
(
'--moe-top
-
k'
,
type
=
int
,
default
=
2
,
help
=
'top_k experts in hard gate of moe'
)
help
=
'top_k experts in hard gate of moe'
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
args
.
tied
=
not
args
.
not_tied
args
.
tied
=
not
args
.
not_tied
...
@@ -285,7 +287,7 @@ else:
...
@@ -285,7 +287,7 @@ else:
ext_len
=
args
.
ext_len
,
mem_len
=
args
.
mem_len
,
cutoffs
=
cutoffs
,
ext_len
=
args
.
ext_len
,
mem_len
=
args
.
mem_len
,
cutoffs
=
cutoffs
,
same_length
=
args
.
same_length
,
attn_type
=
args
.
attn_type
,
same_length
=
args
.
same_length
,
attn_type
=
args
.
attn_type
,
clamp_len
=
args
.
clamp_len
,
sample_softmax
=
args
.
sample_softmax
,
clamp_len
=
args
.
clamp_len
,
sample_softmax
=
args
.
sample_softmax
,
moe_num_expert
=
args
.
moe_num_expert
,
moe_top_k
=
args
.
moe_top_k
)
moe
=
args
.
moe
,
moe_num_expert
=
args
.
moe_num_expert
,
moe_top_k
=
args
.
moe_top_k
)
model
.
apply
(
weights_init
)
model
.
apply
(
weights_init
)
model
.
word_emb
.
apply
(
weights_init
)
# ensure embedding init is not overridden by out_layer in case of weight sharing
model
.
word_emb
.
apply
(
weights_init
)
# ensure embedding init is not overridden by out_layer in case of weight sharing
args
.
n_all_param
=
sum
([
p
.
nelement
()
for
p
in
model
.
parameters
()])
args
.
n_all_param
=
sum
([
p
.
nelement
()
for
p
in
model
.
parameters
()])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment