Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
9e67148c
Commit
9e67148c
authored
Feb 23, 2021
by
Rick Ho
Browse files
optional layer-norm in transformer mlp module
parent
b97483a4
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
8 additions
and
3 deletions
+8
-3
fmoe/transformer.py
fmoe/transformer.py
+8
-3
No files found.
fmoe/transformer.py
View file @
9e67148c
...
@@ -47,6 +47,7 @@ class FMoETransformerMLP(FMoE):
...
@@ -47,6 +47,7 @@ class FMoETransformerMLP(FMoE):
activation
=
torch
.
nn
.
functional
.
gelu
,
activation
=
torch
.
nn
.
functional
.
gelu
,
gate
=
NaiveGate
,
gate
=
NaiveGate
,
top_k
=
2
,
top_k
=
2
,
do_lnorm
=
False
,
pre_lnorm
=
False
,
pre_lnorm
=
False
,
expert_dp_comm
=
'none'
expert_dp_comm
=
'none'
):
):
...
@@ -55,7 +56,11 @@ class FMoETransformerMLP(FMoE):
...
@@ -55,7 +56,11 @@ class FMoETransformerMLP(FMoE):
self
.
experts
=
_Expert
(
num_expert
,
d_model
,
d_hidden
,
activation
,
self
.
experts
=
_Expert
(
num_expert
,
d_model
,
d_hidden
,
activation
,
rank
=
self
.
mp_rank
)
rank
=
self
.
mp_rank
)
self
.
pre_lnorm
=
pre_lnorm
self
.
pre_lnorm
=
pre_lnorm
self
.
layer_norm
=
nn
.
LayerNorm
(
d_model
)
if
do_lnorm
:
self
.
layer_norm
=
nn
.
LayerNorm
(
d_model
)
self
.
pre_lnorm
=
pre_lnorm
else
:
self
.
pre_lnorm
=
None
self
.
mark_parallel_comm
(
expert_dp_comm
)
self
.
mark_parallel_comm
(
expert_dp_comm
)
def
forward
(
self
,
inp
:
torch
.
Tensor
):
def
forward
(
self
,
inp
:
torch
.
Tensor
):
...
@@ -65,9 +70,9 @@ class FMoETransformerMLP(FMoE):
...
@@ -65,9 +70,9 @@ class FMoETransformerMLP(FMoE):
'''
'''
original_shape
=
inp
.
shape
original_shape
=
inp
.
shape
inp
=
inp
.
reshape
(
-
1
,
self
.
d_model
)
inp
=
inp
.
reshape
(
-
1
,
self
.
d_model
)
if
self
.
pre_lnorm
:
if
self
.
pre_lnorm
is
not
None
and
self
.
pre_lnorm
:
inp
=
self
.
layer_norm
(
inp
)
inp
=
self
.
layer_norm
(
inp
)
output
=
super
().
forward
(
inp
)
+
inp
output
=
super
().
forward
(
inp
)
+
inp
if
not
self
.
pre_lnorm
:
if
self
.
pre_lnorm
is
not
None
and
not
self
.
pre_lnorm
:
output
=
self
.
layer_norm
(
output
)
output
=
self
.
layer_norm
(
output
)
return
output
.
reshape
(
original_shape
)
return
output
.
reshape
(
original_shape
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment