Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
27c89b5a
Commit
27c89b5a
authored
Feb 22, 2021
by
Rick Ho
Browse files
customized dp-comm and hidden-hidden-size
parent
87dad9d5
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
13 additions
and
9 deletions
+13
-9
fmoe/layers.py
fmoe/layers.py
+2
-5
fmoe/megatron.py
fmoe/megatron.py
+8
-2
fmoe/transformer.py
fmoe/transformer.py
+3
-2
No files found.
fmoe/layers.py
View file @
27c89b5a
...
@@ -157,17 +157,14 @@ class FMoE(nn.Module):
...
@@ -157,17 +157,14 @@ class FMoE(nn.Module):
base_idx
+=
batch_size
base_idx
+=
batch_size
return
torch
.
cat
(
outputs
,
dim
=
0
)
return
torch
.
cat
(
outputs
,
dim
=
0
)
def
mark_parallel_comm
(
self
):
def
mark_parallel_comm
(
self
,
expert_dp_comm
=
'none'
):
r
'''
r
'''
Automatically mark the data parallel comms of the parameters within the
Automatically mark the data parallel comms of the parameters within the
module. This can be typically called at the end of the __init__ function
module. This can be typically called at the end of the __init__ function
in child classes.
in child classes.
'''
'''
if
self
.
experts
is
not
None
:
if
self
.
experts
is
not
None
:
if
self
.
world_size
>
self
.
mp_size
:
comm
=
expert_dp_comm
comm
=
'none'
else
:
comm
=
'dp'
if
isinstance
(
self
.
experts
,
list
):
if
isinstance
(
self
.
experts
,
list
):
for
e
in
self
.
experts
:
for
e
in
self
.
experts
:
mark_module_parallel_comm
(
e
,
comm
)
mark_module_parallel_comm
(
e
,
comm
)
...
...
fmoe/megatron.py
View file @
27c89b5a
...
@@ -24,7 +24,7 @@ class MegatronMLP(FMoETransformerMLP):
...
@@ -24,7 +24,7 @@ class MegatronMLP(FMoETransformerMLP):
else
:
else
:
world_size
=
args
.
world_size
world_size
=
args
.
world_size
super
().
__init__
(
args
.
num_experts
,
super
().
__init__
(
args
.
num_experts
,
d_model
=
args
.
hidden_size
,
d_hidden
=
args
.
hidden_size
*
4
,
d_model
=
args
.
hidden_size
,
d_hidden
=
args
.
hidden_
hidden_
size
,
world_size
=
world_size
,
mp_group
=
group
)
world_size
=
world_size
,
mp_group
=
group
)
self
.
bias
=
torch
.
nn
.
parameter
.
Parameter
(
self
.
bias
=
torch
.
nn
.
parameter
.
Parameter
(
torch
.
zeros
(
args
.
hidden_size
,
dtype
=
torch
.
float32
)
torch
.
zeros
(
args
.
hidden_size
,
dtype
=
torch
.
float32
)
...
@@ -34,7 +34,8 @@ class MegatronMLP(FMoETransformerMLP):
...
@@ -34,7 +34,8 @@ class MegatronMLP(FMoETransformerMLP):
return
super
().
forward
(
inp
),
self
.
bias
return
super
().
forward
(
inp
),
self
.
bias
def
fmoefy
(
model
,
num_experts
=
None
,
distributed_experts
=
True
):
def
fmoefy
(
model
,
num_experts
=
None
,
distributed_experts
=
True
,
hidden_hidden_size
=
None
):
r
'''
r
'''
Replace MLP layers in a transformer-based model in Megatron by MoE.
Replace MLP layers in a transformer-based model in Megatron by MoE.
* `model` should be a standard Megatron model that has
* `model` should be a standard Megatron model that has
...
@@ -57,6 +58,11 @@ def fmoefy(model, num_experts=None, distributed_experts=True):
...
@@ -57,6 +58,11 @@ def fmoefy(model, num_experts=None, distributed_experts=True):
'num_experts'
in
args
'num_experts'
in
args
),
'num_experts should be specified in arguments or fmoefy function'
),
'num_experts should be specified in arguments or fmoefy function'
if
hidden_hidden_size
is
not
None
:
args
.
hidden_hidden_size
=
hidden_hidden_size
elif
not
hasattr
(
args
,
'hidden_hidden_size'
):
args
.
hidden_hidden_size
=
args
.
hidden_size
*
4
# Set distributed_experts to None to use default setting in args
# Set distributed_experts to None to use default setting in args
if
distributed_experts
is
not
None
:
if
distributed_experts
is
not
None
:
args
.
distributed_experts
=
distributed_experts
args
.
distributed_experts
=
distributed_experts
...
...
fmoe/transformer.py
View file @
27c89b5a
...
@@ -47,7 +47,8 @@ class FMoETransformerMLP(FMoE):
...
@@ -47,7 +47,8 @@ class FMoETransformerMLP(FMoE):
activation
=
torch
.
nn
.
functional
.
gelu
,
activation
=
torch
.
nn
.
functional
.
gelu
,
gate
=
NaiveGate
,
gate
=
NaiveGate
,
top_k
=
2
,
top_k
=
2
,
pre_lnorm
=
False
pre_lnorm
=
False
,
expert_dp_comm
=
'none'
):
):
super
().
__init__
(
num_expert
=
num_expert
,
d_model
=
d_model
,
gate
=
gate
,
super
().
__init__
(
num_expert
=
num_expert
,
d_model
=
d_model
,
gate
=
gate
,
top_k
=
top_k
,
world_size
=
world_size
,
mp_group
=
mp_group
)
top_k
=
top_k
,
world_size
=
world_size
,
mp_group
=
mp_group
)
...
@@ -55,7 +56,7 @@ class FMoETransformerMLP(FMoE):
...
@@ -55,7 +56,7 @@ class FMoETransformerMLP(FMoE):
rank
=
self
.
mp_rank
)
rank
=
self
.
mp_rank
)
self
.
pre_lnorm
=
pre_lnorm
self
.
pre_lnorm
=
pre_lnorm
self
.
layer_norm
=
nn
.
LayerNorm
(
d_model
)
self
.
layer_norm
=
nn
.
LayerNorm
(
d_model
)
self
.
mark_parallel_comm
()
self
.
mark_parallel_comm
(
expert_dp_comm
)
def
forward
(
self
,
inp
:
torch
.
Tensor
):
def
forward
(
self
,
inp
:
torch
.
Tensor
):
r
'''
r
'''
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment