Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
fa023f32
Commit
fa023f32
authored
Feb 24, 2021
by
Jiezhong Qiu
Browse files
allow user to indicate number of experts and topk
parent
3bdfae96
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
28 additions
and
12 deletions
+28
-12
examples/transformer-xl/mem_transformer.py
examples/transformer-xl/mem_transformer.py
+20
-10
examples/transformer-xl/train.py
examples/transformer-xl/train.py
+8
-2
No files found.
examples/transformer-xl/mem_transformer.py
View file @
fa023f32
...
@@ -583,7 +583,8 @@ class MultiHeadAttn(nn.Module):
...
@@ -583,7 +583,8 @@ class MultiHeadAttn(nn.Module):
class
RelMultiHeadAttn
(
nn
.
Module
):
class
RelMultiHeadAttn
(
nn
.
Module
):
def
__init__
(
self
,
n_head
,
d_model
,
d_head
,
dropout
,
dropatt
=
0
,
def
__init__
(
self
,
n_head
,
d_model
,
d_head
,
dropout
,
dropatt
=
0
,
tgt_len
=
None
,
ext_len
=
None
,
mem_len
=
None
,
pre_lnorm
=
False
):
tgt_len
=
None
,
ext_len
=
None
,
mem_len
=
None
,
pre_lnorm
=
False
,
moe_num_expert
=
64
,
moe_top_k
=
2
):
super
(
RelMultiHeadAttn
,
self
).
__init__
()
super
(
RelMultiHeadAttn
,
self
).
__init__
()
self
.
n_head
=
n_head
self
.
n_head
=
n_head
...
@@ -819,10 +820,10 @@ class RelLearnableMultiHeadAttn(RelMultiHeadAttn):
...
@@ -819,10 +820,10 @@ class RelLearnableMultiHeadAttn(RelMultiHeadAttn):
from
fmoe
import
FMoETransformerMLP
from
fmoe
import
FMoETransformerMLP
class
CustomizedMoEPositionwiseFF
(
FMoETransformerMLP
):
class
CustomizedMoEPositionwiseFF
(
FMoETransformerMLP
):
def
__init__
(
self
,
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
False
):
def
__init__
(
self
,
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
False
,
moe_num_expert
=
64
,
moe_top_k
=
2
):
def
activation
(
x
):
def
activation
(
x
):
return
self
.
dropout
(
F
.
relu
(
x
))
return
self
.
dropout
(
F
.
relu
(
x
))
super
().
__init__
(
num_expert
=
64
,
d_model
=
d_model
,
d_hidden
=
d_inner
,
topk
=
2
,
super
().
__init__
(
num_expert
=
moe_num_expert
,
d_model
=
d_model
,
d_hidden
=
d_inner
,
top
_
k
=
moe_top_k
,
pre_lnorm
=
pre_lnorm
,
activation
=
activation
)
pre_lnorm
=
pre_lnorm
,
activation
=
activation
)
self
.
dropout
=
nn
.
Dropout
(
dropout
)
self
.
dropout
=
nn
.
Dropout
(
dropout
)
self
.
bias
=
nn
.
Parameter
(
self
.
bias
=
nn
.
Parameter
(
...
@@ -841,7 +842,9 @@ class DecoderLayer(nn.Module):
...
@@ -841,7 +842,9 @@ class DecoderLayer(nn.Module):
self
.
dec_attn
=
MultiHeadAttn
(
n_head
,
d_model
,
d_head
,
dropout
,
**
kwargs
)
self
.
dec_attn
=
MultiHeadAttn
(
n_head
,
d_model
,
d_head
,
dropout
,
**
kwargs
)
# self.dec_attn = ExtendedMultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs)
# self.dec_attn = ExtendedMultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs)
self
.
pos_ff
=
CustomizedMoEPositionwiseFF
(
d_model
,
d_inner
,
dropout
,
self
.
pos_ff
=
CustomizedMoEPositionwiseFF
(
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
kwargs
.
get
(
'pre_lnorm'
))
pre_lnorm
=
kwargs
.
get
(
'pre_lnorm'
),
moe_num_expert
=
kwargs
.
get
(
'moe_num_expert'
),
moe_top_k
=
kwargs
.
get
(
'moe_top_k'
))
def
forward
(
self
,
dec_inp
,
dec_attn_mask
=
None
,
mems
=
None
):
def
forward
(
self
,
dec_inp
,
dec_attn_mask
=
None
,
mems
=
None
):
...
@@ -861,7 +864,9 @@ class RelLearnableDecoderLayer(nn.Module):
...
@@ -861,7 +864,9 @@ class RelLearnableDecoderLayer(nn.Module):
self
.
dec_attn
=
RelLearnableMultiHeadAttn
(
n_head
,
d_model
,
d_head
,
dropout
,
self
.
dec_attn
=
RelLearnableMultiHeadAttn
(
n_head
,
d_model
,
d_head
,
dropout
,
**
kwargs
)
**
kwargs
)
self
.
pos_ff
=
CustomizedMoEPositionwiseFF
(
d_model
,
d_inner
,
dropout
,
self
.
pos_ff
=
CustomizedMoEPositionwiseFF
(
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
kwargs
.
get
(
'pre_lnorm'
))
pre_lnorm
=
kwargs
.
get
(
'pre_lnorm'
),
moe_num_expert
=
kwargs
.
get
(
'moe_num_expert'
),
moe_top_k
=
kwargs
.
get
(
'moe_top_k'
))
def
forward
(
self
,
dec_inp
,
r_emb
,
r_w_bias
,
r_bias
,
dec_attn_mask
=
None
,
mems
=
None
):
def
forward
(
self
,
dec_inp
,
r_emb
,
r_w_bias
,
r_bias
,
dec_attn_mask
=
None
,
mems
=
None
):
...
@@ -882,7 +887,9 @@ class RelPartialLearnableDecoderLayer(nn.Module):
...
@@ -882,7 +887,9 @@ class RelPartialLearnableDecoderLayer(nn.Module):
self
.
dec_attn
=
RelPartialLearnableMultiHeadAttn
(
n_head
,
d_model
,
self
.
dec_attn
=
RelPartialLearnableMultiHeadAttn
(
n_head
,
d_model
,
d_head
,
dropout
,
**
kwargs
)
d_head
,
dropout
,
**
kwargs
)
self
.
pos_ff
=
CustomizedMoEPositionwiseFF
(
d_model
,
d_inner
,
dropout
,
self
.
pos_ff
=
CustomizedMoEPositionwiseFF
(
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
kwargs
.
get
(
'pre_lnorm'
))
pre_lnorm
=
kwargs
.
get
(
'pre_lnorm'
),
moe_num_expert
=
kwargs
.
get
(
'moe_num_expert'
),
moe_top_k
=
kwargs
.
get
(
'moe_top_k'
))
def
forward
(
self
,
dec_inp
,
r
,
r_w_bias
,
r_r_bias
,
dec_attn_mask
=
None
,
mems
=
None
):
def
forward
(
self
,
dec_inp
,
r
,
r_w_bias
,
r_r_bias
,
dec_attn_mask
=
None
,
mems
=
None
):
...
@@ -967,7 +974,7 @@ class MemTransformerLM(nn.Module):
...
@@ -967,7 +974,7 @@ class MemTransformerLM(nn.Module):
tgt_len
=
None
,
ext_len
=
None
,
mem_len
=
None
,
tgt_len
=
None
,
ext_len
=
None
,
mem_len
=
None
,
cutoffs
=
[],
adapt_inp
=
False
,
cutoffs
=
[],
adapt_inp
=
False
,
same_length
=
False
,
attn_type
=
0
,
clamp_len
=-
1
,
same_length
=
False
,
attn_type
=
0
,
clamp_len
=-
1
,
sample_softmax
=-
1
):
sample_softmax
=-
1
,
moe_num_expert
=
64
,
moe_top_k
=
2
):
super
(
MemTransformerLM
,
self
).
__init__
()
super
(
MemTransformerLM
,
self
).
__init__
()
self
.
n_token
=
n_token
self
.
n_token
=
n_token
...
@@ -998,7 +1005,8 @@ class MemTransformerLM(nn.Module):
...
@@ -998,7 +1005,8 @@ class MemTransformerLM(nn.Module):
RelPartialLearnableDecoderLayer
(
RelPartialLearnableDecoderLayer
(
n_head
,
d_model
,
d_head
,
d_inner
,
dropout
,
n_head
,
d_model
,
d_head
,
d_inner
,
dropout
,
tgt_len
=
tgt_len
,
ext_len
=
ext_len
,
mem_len
=
mem_len
,
tgt_len
=
tgt_len
,
ext_len
=
ext_len
,
mem_len
=
mem_len
,
dropatt
=
dropatt
,
pre_lnorm
=
pre_lnorm
)
dropatt
=
dropatt
,
pre_lnorm
=
pre_lnorm
,
moe_num_expert
=
moe_num_expert
,
moe_top_k
=
moe_top_k
)
)
)
elif
attn_type
==
1
:
# learnable embeddings
elif
attn_type
==
1
:
# learnable embeddings
for
i
in
range
(
n_layer
):
for
i
in
range
(
n_layer
):
...
@@ -1006,14 +1014,16 @@ class MemTransformerLM(nn.Module):
...
@@ -1006,14 +1014,16 @@ class MemTransformerLM(nn.Module):
RelLearnableDecoderLayer
(
RelLearnableDecoderLayer
(
n_head
,
d_model
,
d_head
,
d_inner
,
dropout
,
n_head
,
d_model
,
d_head
,
d_inner
,
dropout
,
tgt_len
=
tgt_len
,
ext_len
=
ext_len
,
mem_len
=
mem_len
,
tgt_len
=
tgt_len
,
ext_len
=
ext_len
,
mem_len
=
mem_len
,
dropatt
=
dropatt
,
pre_lnorm
=
pre_lnorm
)
dropatt
=
dropatt
,
pre_lnorm
=
pre_lnorm
,
moe_num_expert
=
moe_num_expert
,
moe_top_k
=
moe_top_k
)
)
)
elif
attn_type
in
[
2
,
3
]:
# absolute embeddings
elif
attn_type
in
[
2
,
3
]:
# absolute embeddings
for
i
in
range
(
n_layer
):
for
i
in
range
(
n_layer
):
self
.
layers
.
append
(
self
.
layers
.
append
(
DecoderLayer
(
DecoderLayer
(
n_head
,
d_model
,
d_head
,
d_inner
,
dropout
,
n_head
,
d_model
,
d_head
,
d_inner
,
dropout
,
dropatt
=
dropatt
,
pre_lnorm
=
pre_lnorm
)
dropatt
=
dropatt
,
pre_lnorm
=
pre_lnorm
,
moe_num_expert
=
moe_num_expert
,
moe_top_k
=
moe_top_k
)
)
)
self
.
sample_softmax
=
sample_softmax
self
.
sample_softmax
=
sample_softmax
...
...
examples/transformer-xl/train.py
View file @
fa023f32
...
@@ -167,8 +167,13 @@ parser.add_argument('--static-loss-scale', type=float, default=1,
...
@@ -167,8 +167,13 @@ parser.add_argument('--static-loss-scale', type=float, default=1,
parser
.
add_argument
(
'--dynamic-loss-scale'
,
action
=
'store_true'
,
parser
.
add_argument
(
'--dynamic-loss-scale'
,
action
=
'store_true'
,
help
=
'Use dynamic loss scaling. If supplied, this argument'
help
=
'Use dynamic loss scaling. If supplied, this argument'
' supersedes --static-loss-scale.'
)
' supersedes --static-loss-scale.'
)
parser
.
add_argument
(
'--moe-num-expert'
,
type
=
int
,
default
=
64
,
help
=
'number of experts in MoE'
)
parser
.
add_argument
(
'--moe-top_k'
,
type
=
int
,
default
=
2
,
help
=
'top_k experts in hard gate of moe'
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
args
.
tied
=
not
args
.
not_tied
args
.
tied
=
not
args
.
not_tied
assert
args
.
moe_num_expert
>=
args
.
moe_top_k
,
"must have moe-num-expert >= moe-top_k"
if
args
.
d_embed
<
0
:
if
args
.
d_embed
<
0
:
args
.
d_embed
=
args
.
d_model
args
.
d_embed
=
args
.
d_model
...
@@ -305,7 +310,8 @@ else:
...
@@ -305,7 +310,8 @@ else:
tie_projs
=
tie_projs
,
pre_lnorm
=
args
.
pre_lnorm
,
tgt_len
=
args
.
tgt_len
,
tie_projs
=
tie_projs
,
pre_lnorm
=
args
.
pre_lnorm
,
tgt_len
=
args
.
tgt_len
,
ext_len
=
args
.
ext_len
,
mem_len
=
args
.
mem_len
,
cutoffs
=
cutoffs
,
ext_len
=
args
.
ext_len
,
mem_len
=
args
.
mem_len
,
cutoffs
=
cutoffs
,
same_length
=
args
.
same_length
,
attn_type
=
args
.
attn_type
,
same_length
=
args
.
same_length
,
attn_type
=
args
.
attn_type
,
clamp_len
=
args
.
clamp_len
,
sample_softmax
=
args
.
sample_softmax
)
clamp_len
=
args
.
clamp_len
,
sample_softmax
=
args
.
sample_softmax
,
moe_num_expert
=
args
.
moe_num_expert
,
moe_top_k
=
args
.
moe_top_k
)
model
.
apply
(
weights_init
)
model
.
apply
(
weights_init
)
model
.
word_emb
.
apply
(
weights_init
)
# ensure embedding init is not overridden by out_layer in case of weight sharing
model
.
word_emb
.
apply
(
weights_init
)
# ensure embedding init is not overridden by out_layer in case of weight sharing
args
.
n_all_param
=
sum
([
p
.
nelement
()
for
p
in
model
.
parameters
()])
args
.
n_all_param
=
sum
([
p
.
nelement
()
for
p
in
model
.
parameters
()])
...
@@ -571,7 +577,7 @@ def train():
...
@@ -571,7 +577,7 @@ def train():
# for i in range(len(avg_nnzs)):
# for i in range(len(avg_nnzs)):
# avg_nnzs[i].reset()
# avg_nnzs[i].reset()
# act_hist[i] /= act_hist[i].sum()
# act_hist[i] /= act_hist[i].sum()
# prob, index = torch.topk(act_hist[i], min(1024, act_hist[i].size(-1)))
# prob, index = torch.top
_
k(act_hist[i], min(1024, act_hist[i].size(-1)))
# log_str = '| layer {:2d} | top 64 prob {:3.2f} | top 128 prob {:3.2f} | top 256 prob {:3.2f} | top 512 prob {:3.2f} | top 1024 prob {:3.2f}'.format(
# log_str = '| layer {:2d} | top 64 prob {:3.2f} | top 128 prob {:3.2f} | top 256 prob {:3.2f} | top 512 prob {:3.2f} | top 1024 prob {:3.2f}'.format(
# i+1,
# i+1,
# prob[:64].sum().item(),
# prob[:64].sum().item(),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment