Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
ae10e942
Commit
ae10e942
authored
Oct 12, 2021
by
Rick Ho
Browse files
update calculation for megatron hhs
parent
07a5d8ac
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
13 additions
and
14 deletions
+13
-14
fmoe/megatron/layers.py
fmoe/megatron/layers.py
+5
-12
fmoe/megatron/patch.py
fmoe/megatron/patch.py
+8
-2
No files found.
fmoe/megatron/layers.py
View file @
ae10e942
...
@@ -74,11 +74,7 @@ class MegatronMLP(FMoETransformerMLP):
...
@@ -74,11 +74,7 @@ class MegatronMLP(FMoETransformerMLP):
communication group `group` to replace the original MLP layer in Megatron.
communication group `group` to replace the original MLP layer in Megatron.
"""
"""
def
__init__
(
self
,
args
,
layer_idx
):
def
__init__
(
self
,
args
,
layer_idx
,
gate
=
None
):
assert
(
args
.
seq_length
*
args
.
micro_batch_size
%
args
.
tensor_model_parallel_size
==
0
),
"Batch size x sequence length should be multiple of mp size"
if
not
args
.
distributed_experts
:
if
not
args
.
distributed_experts
:
world_size
=
1
world_size
=
1
moe_group
=
None
moe_group
=
None
...
@@ -87,7 +83,6 @@ class MegatronMLP(FMoETransformerMLP):
...
@@ -87,7 +83,6 @@ class MegatronMLP(FMoETransformerMLP):
from
megatron.mpu
import
get_data_parallel_group
from
megatron.mpu
import
get_data_parallel_group
moe_group
=
get_data_parallel_group
()
moe_group
=
get_data_parallel_group
()
gate
=
None
if
not
args
.
balance_strategy
or
args
.
balance_strategy
==
"naive"
:
if
not
args
.
balance_strategy
or
args
.
balance_strategy
==
"naive"
:
from
fmoe.gates
import
NaiveGate
from
fmoe.gates
import
NaiveGate
gate
=
NaiveGate
gate
=
NaiveGate
...
@@ -100,7 +95,7 @@ class MegatronMLP(FMoETransformerMLP):
...
@@ -100,7 +95,7 @@ class MegatronMLP(FMoETransformerMLP):
elif
args
.
balance_strategy
==
"switch"
:
elif
args
.
balance_strategy
==
"switch"
:
from
fmoe.gates
import
SwitchGate
from
fmoe.gates
import
SwitchGate
gate
=
SwitchGate
gate
=
SwitchGate
el
s
e
:
el
if
gate
is
Non
e
:
assert
False
,
"Undefined balance strategy {}"
%
(
args
.
balance_strategy
)
assert
False
,
"Undefined balance strategy {}"
%
(
args
.
balance_strategy
)
super
().
__init__
(
super
().
__init__
(
...
@@ -152,6 +147,7 @@ def fmoefy(
...
@@ -152,6 +147,7 @@ def fmoefy(
distributed_experts
=
True
,
distributed_experts
=
True
,
hidden_hidden_size
=
None
,
hidden_hidden_size
=
None
,
top_k
=
None
,
top_k
=
None
,
gate
=
None
,
):
):
r
"""
r
"""
Replace MLP layers in a transformer-based model in Megatron by MoE.
Replace MLP layers in a transformer-based model in Megatron by MoE.
...
@@ -186,13 +182,10 @@ def fmoefy(
...
@@ -186,13 +182,10 @@ def fmoefy(
elif
not
hasattr
(
args
,
"top_k"
):
elif
not
hasattr
(
args
,
"top_k"
):
args
.
top_k
=
2
args
.
top_k
=
2
if
hidden_hidden_size
is
not
None
:
args
.
hidden_hidden_size
=
hidden_hidden_size
args
.
hidden_hidden_size
=
hidden_hidden_size
elif
not
hasattr
(
args
,
"hidden_hidden_size"
):
args
.
hidden_hidden_size
=
args
.
hidden_size
*
4
//
args
.
tensor_model_parallel_size
for
idx
,
l
in
enumerate
(
model
.
language_model
.
transformer
.
layers
):
for
idx
,
l
in
enumerate
(
model
.
language_model
.
transformer
.
layers
):
l
.
mlp
=
MegatronMLP
(
args
,
idx
)
l
.
mlp
=
MegatronMLP
(
args
,
idx
,
gate
=
gate
)
# initialize gate hook
# initialize gate hook
num_layers
=
len
(
model
.
language_model
.
transformer
.
layers
)
num_layers
=
len
(
model
.
language_model
.
transformer
.
layers
)
...
...
fmoe/megatron/patch.py
View file @
ae10e942
...
@@ -46,17 +46,23 @@ def patch_forward_step(forward_step_func):
...
@@ -46,17 +46,23 @@ def patch_forward_step(forward_step_func):
return
forward_step_with_balance_loss
return
forward_step_with_balance_loss
def
patch_model_provider
(
model_provider
):
def
patch_model_provider
(
model_provider
,
gate
=
None
):
from
megatron
import
get_args
from
megatron
import
get_args
def
fmoefied_model_provider
():
def
fmoefied_model_provider
():
from
.layers
import
fmoefy
from
.layers
import
fmoefy
args
=
get_args
()
args
=
get_args
()
hhs
=
args
.
hidden_size
*
4
assert
hhs
%
args
.
top_k
==
0
hhs
=
hhs
//
args
.
top_k
assert
hhs
%
args
.
tensor_model_parallel_size
==
0
hhs
=
hhs
//
args
.
tensor_model_parallel_size
return
fmoefy
(
return
fmoefy
(
model_provider
(),
model_provider
(),
num_experts
=
args
.
num_experts
,
num_experts
=
args
.
num_experts
,
hidden_hidden_size
=
4
*
args
.
hidden_size
//
args
.
top_k
,
hidden_hidden_size
=
hhs
,
top_k
=
args
.
top_k
,
top_k
=
args
.
top_k
,
gate
=
gate
)
)
return
fmoefied_model_provider
return
fmoefied_model_provider
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment