Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
57bdfe88
Commit
57bdfe88
authored
Oct 26, 2021
by
Rick Ho
Browse files
fix megatron adapter for swipe
parent
8a56481b
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
22 additions
and
6 deletions
+22
-6
fmoe/gates/__init__.py
fmoe/gates/__init__.py
+2
-0
fmoe/gates/base_gate.py
fmoe/gates/base_gate.py
+4
-0
fmoe/megatron/balance.py
fmoe/megatron/balance.py
+6
-3
fmoe/megatron/layers.py
fmoe/megatron/layers.py
+3
-0
fmoe/megatron/patch.py
fmoe/megatron/patch.py
+7
-3
No files found.
fmoe/gates/__init__.py
View file @
57bdfe88
...
@@ -7,3 +7,5 @@ from .noisy_gate import NoisyGate
...
@@ -7,3 +7,5 @@ from .noisy_gate import NoisyGate
from
.gshard_gate
import
GShardGate
from
.gshard_gate
import
GShardGate
from
.switch_gate
import
SwitchGate
from
.switch_gate
import
SwitchGate
from
.swipe_gate
import
SwipeGate
fmoe/gates/base_gate.py
View file @
57bdfe88
...
@@ -23,3 +23,7 @@ class BaseGate(nn.Module):
...
@@ -23,3 +23,7 @@ class BaseGate(nn.Module):
if
clear
:
if
clear
:
self
.
loss
=
None
self
.
loss
=
None
return
loss
return
loss
@
property
def
has_loss
(
self
):
return
self
.
loss
is
not
None
fmoe/megatron/balance.py
View file @
57bdfe88
...
@@ -51,9 +51,12 @@ def add_balance_log(model, writer, iteration):
...
@@ -51,9 +51,12 @@ def add_balance_log(model, writer, iteration):
while
hasattr
(
model
,
'module'
):
while
hasattr
(
model
,
'module'
):
model
=
model
.
module
model
=
model
.
module
balance_dict_tensor
=
torch
.
vstack
(
losses
=
[
l
.
mlp
.
gate
.
get_loss
(
clear
=
True
)
[
l
.
mlp
.
gate
.
get_loss
(
clear
=
True
)
for
l
in
model
.
language_model
.
transformer
.
layers
]
for
l
in
model
.
language_model
.
transformer
.
layers
).
detach
()
if
l
.
mlp
.
gate
.
has_loss
]
if
len
(
losses
)
==
0
:
return
balance_dict_tensor
=
torch
.
vstack
(
losses
).
detach
()
world_group
=
get_torch_default_comm
()
world_group
=
get_torch_default_comm
()
world_size
=
torch
.
distributed
.
get_world_size
(
group
=
world_group
)
world_size
=
torch
.
distributed
.
get_world_size
(
group
=
world_group
)
torch
.
distributed
.
all_reduce
(
balance_dict_tensor
,
group
=
world_group
)
torch
.
distributed
.
all_reduce
(
balance_dict_tensor
,
group
=
world_group
)
...
...
fmoe/megatron/layers.py
View file @
57bdfe88
...
@@ -95,6 +95,9 @@ class MegatronMLP(FMoETransformerMLP):
...
@@ -95,6 +95,9 @@ class MegatronMLP(FMoETransformerMLP):
elif
args
.
balance_strategy
==
"switch"
:
elif
args
.
balance_strategy
==
"switch"
:
from
fmoe.gates
import
SwitchGate
from
fmoe.gates
import
SwitchGate
gate
=
SwitchGate
gate
=
SwitchGate
elif
args
.
balance_strategy
==
"swipe"
:
from
fmoe.gates
import
SwipeGate
gate
=
SwipeGate
elif
gate
is
None
:
elif
gate
is
None
:
assert
False
,
"Undefined balance strategy {}"
%
(
args
.
balance_strategy
)
assert
False
,
"Undefined balance strategy {}"
%
(
args
.
balance_strategy
)
...
...
fmoe/megatron/patch.py
View file @
57bdfe88
...
@@ -20,15 +20,19 @@ def patch_forward_step(forward_step_func):
...
@@ -20,15 +20,19 @@ def patch_forward_step(forward_step_func):
args
=
get_args
()
args
=
get_args
()
output
=
forward_step_func
(
data_iterator
,
model
,
input_tensor
)
output
=
forward_step_func
(
data_iterator
,
model
,
input_tensor
)
if
not
is_pipeline_last_stage
()
or
not
args
.
balance_strategy
or
args
.
balance_strategy
==
'naive'
:
if
not
is_pipeline_last_stage
()
or
not
args
.
balance_strategy
:
return
output
return
output
loss_name
=
args
.
balance_strategy
+
"_loss"
while
hasattr
(
model
,
'module'
):
while
hasattr
(
model
,
'module'
):
model
=
model
.
module
model
=
model
.
module
loss_list
=
[
l
.
mlp
.
gate
.
get_loss
(
clear
=
False
).
view
(
1
)
loss_list
=
[
l
.
mlp
.
gate
.
get_loss
(
clear
=
False
).
view
(
1
)
for
l
in
model
.
language_model
.
transformer
.
layers
]
for
l
in
model
.
language_model
.
transformer
.
layers
if
l
.
mlp
.
gate
.
has_loss
]
if
len
(
loss_list
)
==
0
:
return
output
loss_name
=
args
.
balance_strategy
+
"_loss"
(
loss
,
state_dict
),
bal_loss
=
(
(
loss
,
state_dict
),
bal_loss
=
(
output
,
output
,
torch
.
cat
(
loss_list
).
mean
()
*
args
.
balance_loss_weight
torch
.
cat
(
loss_list
).
mean
()
*
args
.
balance_loss_weight
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment