Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
5f8ba136
Commit
5f8ba136
authored
Oct 10, 2021
by
Rick Ho
Browse files
tide up megatron compatible layer
parent
33fc3aca
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
113 additions
and
114 deletions
+113
-114
fmoe/layers.py
fmoe/layers.py
+2
-2
fmoe/megatron/__init__.py
fmoe/megatron/__init__.py
+3
-2
fmoe/megatron/balance.py
fmoe/megatron/balance.py
+0
-61
fmoe/megatron/distributed.py
fmoe/megatron/distributed.py
+26
-14
fmoe/megatron/layers.py
fmoe/megatron/layers.py
+20
-35
fmoe/megatron/patch.py
fmoe/megatron/patch.py
+62
-0
No files found.
fmoe/layers.py
View file @
5f8ba136
...
...
@@ -107,8 +107,8 @@ class FMoE(nn.Module):
self
.
slice_size
=
1
self
.
slice_rank
=
0
else
:
self
.
slice_size
=
slice_group
.
size
()
self
.
slice_rank
=
slice_group
.
rank
()
self
.
slice_size
=
self
.
slice_group
.
size
()
self
.
slice_rank
=
self
.
slice_group
.
rank
()
self
.
top_k
=
top_k
if
type
(
expert
)
is
list
:
...
...
fmoe/megatron/__init__.py
View file @
5f8ba136
...
...
@@ -15,5 +15,6 @@ from .balance import reset_gate_hook
from
.balance
import
get_balance_profile
from
.balance
import
generate_megatron_gate_hook
from
.balance
import
add_balance_log
from
.balance
import
patch_forward_step
from
.balance
import
patch_model_provider
from
.patch
import
patch_forward_step
from
.patch
import
patch_model_provider
fmoe/megatron/balance.py
View file @
5f8ba136
...
...
@@ -5,7 +5,6 @@ import torch
from
fmoe.balance
import
reset_balance_profile
from
fmoe.balance
import
update_balance_profile
from
fmoe.utils
import
get_torch_default_comm
from
.distributed
import
get_moe_group
balance_dict
=
{}
...
...
@@ -71,63 +70,3 @@ def add_balance_log(model, writer, iteration):
balance_dict_tensor
[
idx
].
mean
().
item
(),
iteration
,
)
def
patch_forward_step
(
forward_step_func
):
r
"""
Patch model's forward_step_func to support balance loss
"""
from
megatron.mpu
import
is_pipeline_last_stage
from
megatron
import
get_args
if
not
get_args
().
balance_strategy
:
return
forward_step_func
def
forward_step_with_balance_loss
(
data_iterator
,
model
,
input_tensor
):
args
=
get_args
()
output
=
forward_step_func
(
data_iterator
,
model
,
input_tensor
)
if
not
is_pipeline_last_stage
()
or
not
args
.
balance_strategy
or
args
.
balance_strategy
==
'naive'
:
return
output
loss_name
=
args
.
balance_strategy
+
"_loss"
while
hasattr
(
model
,
'module'
):
model
=
model
.
module
loss_list
=
[
l
.
mlp
.
gate
.
get_loss
(
clear
=
False
).
view
(
1
)
for
l
in
model
.
language_model
.
transformer
.
layers
]
(
loss
,
state_dict
),
bal_loss
=
(
output
,
torch
.
cat
(
loss_list
).
mean
()
*
args
.
balance_loss_weight
)
# avarage across moe group
moe_group
=
get_moe_group
()
world_size
=
torch
.
distributed
.
get_world_size
(
group
=
moe_group
)
averaged_bal_loss
=
bal_loss
.
clone
().
detach
()
torch
.
distributed
.
all_reduce
(
averaged_bal_loss
,
group
=
moe_group
)
averaged_bal_loss
/=
world_size
loss
+=
bal_loss
state_dict
[
loss_name
]
=
averaged_bal_loss
return
loss
,
state_dict
return
forward_step_with_balance_loss
def
patch_model_provider
(
model_provider
):
from
megatron
import
get_args
def
fmoefied_model_provider
():
from
.layers
import
fmoefy
args
=
get_args
()
return
fmoefy
(
model_provider
(),
num_experts
=
args
.
num_experts
,
hidden_hidden_size
=
4
*
args
.
hidden_size
//
args
.
top_k
,
top_k
=
args
.
top_k
,
)
return
fmoefied_model_provider
fmoe/megatron/distributed.py
View file @
5f8ba136
r
"""
distributed support for Megatron
"""
import
torch
from
fmoe.distributed
import
DistributedGroupedDataParallel
_moe_group
=
None
_groups
=
None
def
_set_groups
(
**
kwargs
):
global
_groups
_groups
=
kwargs
def
set_moe_group
(
moe_group
):
global
_moe_group
_moe_group
=
moe_group
def
_init
():
from
megatron
import
get_args
from
megatron
import
mpu
args
=
get_args
()
# Create a comm prependicular to pipeline group as gate group
stage_size
=
args
.
world_size
//
args
.
pipeline_model_parallel_size
for
i
in
range
(
0
,
args
.
world_size
,
stage_size
):
ranks
=
range
(
i
,
i
+
stage_size
)
group
=
torch
.
distributed
.
new_group
(
ranks
)
if
args
.
rank
in
ranks
:
gate_group
=
group
def
get_moe_group
():
return
_moe_group
_set_groups
(
dp_group
=
mpu
.
get_data_parallel_group
(),
moe_group
=
mpu
.
get_data_parallel_group
(),
gate_group
=
gate_group
)
class
DistributedDataParallel
(
DistributedGroupedDataParallel
):
...
...
@@ -24,14 +41,9 @@ class DistributedDataParallel(DistributedGroupedDataParallel):
"""
def
__init__
(
self
,
module
):
from
megatron
import
mpu
super
().
__init__
(
module
,
mp_group
=
mpu
.
get_model_parallel_group
(),
dp_group
=
mpu
.
get_data_parallel_group
(),
moe_group
=
_moe_group
)
if
_groups
is
None
:
_init
()
super
().
__init__
(
module
,
**
_groups
)
def
state_dict
(
self
,
*
args
,
**
kwargs
):
r
"""
...
...
fmoe/megatron/layers.py
View file @
5f8ba136
...
...
@@ -10,7 +10,6 @@ import torch.nn.functional as F
from
fmoe.transformer
import
FMoETransformerMLP
from
.balance
import
reset_gate_hook
from
.balance
import
generate_megatron_gate_hook
from
.distributed
import
set_moe_group
class
_FakeMegatronMLP
(
nn
.
Module
):
...
...
@@ -75,31 +74,31 @@ class MegatronMLP(FMoETransformerMLP):
communication group `group` to replace the original MLP layer in Megatron.
"""
def
__init__
(
self
,
args
,
mp_group
,
moe_group
,
layer_idx
):
def
__init__
(
self
,
args
,
layer_idx
):
assert
(
args
.
seq_length
*
args
.
micro_batch_size
%
args
.
tensor_model_parallel_size
==
0
),
"Batch size x sequence length should be multiple of mp size"
if
not
args
.
distributed_experts
:
world_size
=
1
moe_group
=
None
else
:
world_size
=
args
.
tensor_model_parallel_size
*
args
.
data_parallel_size
world_size
=
args
.
data_parallel_size
from
megatron.mpu
import
get_data_parallel_group
moe_group
=
get_data_parallel_group
()
gate
=
None
if
not
args
.
balance_strategy
or
args
.
balance_strategy
==
"naive"
:
from
fmoe.gates
import
NaiveGate
gate
=
NaiveGate
elif
args
.
balance_strategy
==
"noisy"
:
from
fmoe.gates
import
NoisyGate
gate
=
NoisyGate
elif
args
.
balance_strategy
==
"gshard"
:
from
fmoe.gates
import
GShardGate
gate
=
GShardGate
elif
args
.
balance_strategy
==
"switch"
:
from
fmoe.gates
import
SwitchGate
gate
=
SwitchGate
else
:
assert
False
,
"Undefined balance strategy {}"
%
(
args
.
balance_strategy
)
...
...
@@ -110,7 +109,6 @@ class MegatronMLP(FMoETransformerMLP):
d_model
=
args
.
hidden_size
,
d_hidden
=
args
.
hidden_hidden_size
,
world_size
=
world_size
,
mp_group
=
mp_group
,
moe_group
=
moe_group
,
expert_dp_comm
=
"none"
if
args
.
distributed_experts
else
"dp"
,
gate_hook
=
generate_megatron_gate_hook
(
...
...
@@ -139,8 +137,11 @@ class MegatronMLP(FMoETransformerMLP):
_megatron_init_method
(
self
.
experts
.
h4toh
,
rng
,
std
)
def
forward
(
self
,
inp
):
from
megatron
import
mpu
x
=
super
().
forward
(
inp
)
x
=
mpu
.
reduce_from_tensor_model_parallel_region
(
x
)
return
(
super
().
forward
(
inp
)
,
x
,
torch
.
zeros
(
self
.
hidden_size
,
dtype
=
inp
.
dtype
,
device
=
inp
.
device
),
)
...
...
@@ -167,47 +168,31 @@ def fmoefy(
tensor_model_parall_comm x data_parallel_comm, which is not created.
"""
from
megatron
import
get_args
from
megatron
import
mpu
args
=
get_args
()
# Set distributed_experts to None to use default setting in args
if
distributed_experts
is
not
None
:
args
.
distributed_experts
=
distributed_experts
if
num_experts
is
not
None
:
args
.
num_experts
=
num_experts
assert
(
"num_experts"
in
args
),
"num_experts should be specified in arguments or fmoefy function"
if
hidden_hidden_size
is
not
None
:
args
.
hidden_hidden_size
=
hidden_hidden_size
elif
not
hasattr
(
args
,
"hidden_hidden_size"
):
args
.
hidden_hidden_size
=
args
.
hidden_size
*
4
if
top_k
is
not
None
:
args
.
top_k
=
top_k
elif
not
hasattr
(
args
,
"top_k"
):
args
.
top_k
=
2
# Set distributed_experts to None to use default setting in args
if
distributed_experts
is
not
None
:
args
.
distributed_experts
=
distributed_experts
if
hidden_hidden_size
is
not
None
:
args
.
hidden_hidden_size
=
hidden_hidden_size
elif
not
hasattr
(
args
,
"hidden_hidden_size"
):
args
.
hidden_hidden_size
=
args
.
hidden_size
*
4
//
args
.
tensor_model_parallel_size
if
hasattr
(
mpu
,
'get_tensor_model_parallel_group'
):
mp_group
=
mpu
.
get_tensor_model_parallel_group
()
else
:
# For compatibility to older versions of Megatron-LM
mp_group
=
mpu
.
get_model_parallel_group
()
if
args
.
pipeline_model_parallel_size
==
1
:
moe_group
=
None
else
:
# Create a comm prependicular to pipeline group
stage_size
=
args
.
world_size
//
args
.
pipeline_model_parallel_size
for
i
in
range
(
0
,
args
.
world_size
,
stage_size
):
ranks
=
range
(
i
,
i
+
stage_size
)
group
=
torch
.
distributed
.
new_group
(
ranks
)
if
args
.
rank
in
ranks
:
moe_group
=
group
set_moe_group
(
moe_group
)
for
idx
,
l
in
enumerate
(
model
.
language_model
.
transformer
.
layers
):
l
.
mlp
=
MegatronMLP
(
args
,
mp_group
,
moe_group
,
idx
)
l
.
mlp
=
MegatronMLP
(
args
,
idx
)
# initialize gate hook
num_layers
=
len
(
model
.
language_model
.
transformer
.
layers
)
...
...
fmoe/megatron/patch.py
0 → 100644
View file @
5f8ba136
r
"""
Patching some of Megatron-LM's functions to create an MoE model
"""
def
patch_forward_step
(
forward_step_func
):
r
"""
Patch model's forward_step_func to support balance loss
"""
from
megatron.mpu
import
is_pipeline_last_stage
from
megatron.mpu
import
get_tensor_model_parallel_group
from
megatron
import
get_args
if
not
get_args
().
balance_strategy
:
return
forward_step_func
def
forward_step_with_balance_loss
(
data_iterator
,
model
,
input_tensor
):
args
=
get_args
()
output
=
forward_step_func
(
data_iterator
,
model
,
input_tensor
)
if
not
is_pipeline_last_stage
()
or
not
args
.
balance_strategy
or
args
.
balance_strategy
==
'naive'
:
return
output
loss_name
=
args
.
balance_strategy
+
"_loss"
while
hasattr
(
model
,
'module'
):
model
=
model
.
module
loss_list
=
[
l
.
mlp
.
gate
.
get_loss
(
clear
=
False
).
view
(
1
)
for
l
in
model
.
language_model
.
transformer
.
layers
]
(
loss
,
state_dict
),
bal_loss
=
(
output
,
torch
.
cat
(
loss_list
).
mean
()
*
args
.
balance_loss_weight
)
# avarage across moe group
moe_group
=
get_tensor_model_parallel_group
()
world_size
=
torch
.
distributed
.
get_world_size
(
group
=
moe_group
)
averaged_bal_loss
=
bal_loss
.
clone
().
detach
()
torch
.
distributed
.
all_reduce
(
averaged_bal_loss
,
group
=
moe_group
)
averaged_bal_loss
/=
world_size
loss
+=
bal_loss
state_dict
[
loss_name
]
=
averaged_bal_loss
return
loss
,
state_dict
return
forward_step_with_balance_loss
def
patch_model_provider
(
model_provider
):
from
megatron
import
get_args
def
fmoefied_model_provider
():
from
.layers
import
fmoefy
args
=
get_args
()
return
fmoefy
(
model_provider
(),
num_experts
=
args
.
num_experts
,
hidden_hidden_size
=
4
*
args
.
hidden_size
//
args
.
top_k
,
top_k
=
args
.
top_k
,
)
return
fmoefied_model_provider
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment