Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
50a9aa94
Unverified
Commit
50a9aa94
authored
Jul 08, 2021
by
Rick Ho
Committed by
GitHub
Jul 08, 2021
Browse files
Merge pull request #59 from laekov/cope-with-pipeline
Use moe_group instead of world for MoE
parents
55f8ca7d
59913cca
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
62 additions
and
21 deletions
+62
-21
fmoe/distributed.py
fmoe/distributed.py
+5
-0
fmoe/functions.py
fmoe/functions.py
+4
-8
fmoe/layers.py
fmoe/layers.py
+10
-5
fmoe/megatron/balance.py
fmoe/megatron/balance.py
+5
-4
fmoe/megatron/distributed.py
fmoe/megatron/distributed.py
+13
-0
fmoe/megatron/layers.py
fmoe/megatron/layers.py
+23
-4
fmoe/transformer.py
fmoe/transformer.py
+2
-0
No files found.
fmoe/distributed.py
View file @
50a9aa94
...
@@ -27,6 +27,7 @@ class DistributedGroupedDataParallel(nn.Module):
...
@@ -27,6 +27,7 @@ class DistributedGroupedDataParallel(nn.Module):
module
,
module
,
mp_group
=
None
,
mp_group
=
None
,
dp_group
=
None
,
dp_group
=
None
,
moe_group
=
None
,
world_group
=
None
,
world_group
=
None
,
auto_allreduce
=
False
,
auto_allreduce
=
False
,
):
):
...
@@ -42,6 +43,10 @@ class DistributedGroupedDataParallel(nn.Module):
...
@@ -42,6 +43,10 @@ class DistributedGroupedDataParallel(nn.Module):
self
.
comms
[
"dp"
]
=
dp_group
self
.
comms
[
"dp"
]
=
dp_group
else
:
else
:
self
.
comms
[
"dp"
]
=
get_torch_default_comm
()
self
.
comms
[
"dp"
]
=
get_torch_default_comm
()
if
moe_group
is
not
None
:
self
.
comms
[
"moe"
]
=
moe_group
else
:
self
.
comms
[
"moe"
]
=
get_torch_default_comm
()
if
world_group
is
None
:
if
world_group
is
None
:
self
.
comms
[
"world"
]
=
get_torch_default_comm
()
self
.
comms
[
"world"
]
=
get_torch_default_comm
()
else
:
else
:
...
...
fmoe/functions.py
View file @
50a9aa94
...
@@ -10,13 +10,13 @@ import fmoe_cuda
...
@@ -10,13 +10,13 @@ import fmoe_cuda
from
.utils
import
get_torch_default_comm
from
.utils
import
get_torch_default_comm
def
_
ensure_
nccl
(
t
,
comm
=
None
):
def
ensure_
comm
(
t
,
comm
):
if
comm
is
None
:
if
comm
is
None
:
comm
=
get_torch_default_comm
()
comm
=
get_torch_default_comm
()
fmoe_cuda
.
ensure_nccl
(
comm
,
t
)
fmoe_cuda
.
ensure_nccl
(
comm
,
t
)
def
count_by_gate
(
gate
,
num_expert
,
world_size
,
require_pos
=
True
):
def
count_by_gate
(
gate
,
num_expert
,
world_size
,
comm
=
None
,
require_pos
=
True
):
with
torch
.
no_grad
():
with
torch
.
no_grad
():
local_expert_count
=
torch
.
zeros
(
local_expert_count
=
torch
.
zeros
(
num_expert
*
world_size
,
device
=
gate
.
device
,
dtype
=
torch
.
int32
num_expert
*
world_size
,
device
=
gate
.
device
,
dtype
=
torch
.
int32
...
@@ -25,7 +25,6 @@ def count_by_gate(gate, num_expert, world_size, require_pos=True):
...
@@ -25,7 +25,6 @@ def count_by_gate(gate, num_expert, world_size, require_pos=True):
local_expert_count
=
local_expert_count
.
long
()
local_expert_count
=
local_expert_count
.
long
()
if
world_size
>
1
:
if
world_size
>
1
:
_ensure_nccl
(
gate
)
global_expert_count
=
fmoe_cuda
.
expert_exchange
(
global_expert_count
=
fmoe_cuda
.
expert_exchange
(
local_expert_count
,
num_expert
,
world_size
local_expert_count
,
num_expert
,
world_size
)
)
...
@@ -41,7 +40,7 @@ def count_by_gate(gate, num_expert, world_size, require_pos=True):
...
@@ -41,7 +40,7 @@ def count_by_gate(gate, num_expert, world_size, require_pos=True):
return
pos
,
local_expert_count
,
global_expert_count
return
pos
,
local_expert_count
,
global_expert_count
def
prepare_forward
(
gate
,
num_expert
,
world_size
,
comm
=
None
):
def
prepare_forward
(
gate
,
num_expert
,
world_size
,
comm
):
r
"""
r
"""
Prepare necessary information from gate output for MoE computation.
Prepare necessary information from gate output for MoE computation.
...
@@ -52,11 +51,8 @@ def prepare_forward(gate, num_expert, world_size, comm=None):
...
@@ -52,11 +51,8 @@ def prepare_forward(gate, num_expert, world_size, comm=None):
world_size: number of workers that hold different experts.
world_size: number of workers that hold different experts.
comm: the communicator of all workers in the expert-parallel group.
comm: the communicator of all workers in the expert-parallel group.
"""
"""
if
world_size
>
1
:
_ensure_nccl
(
gate
,
comm
=
comm
)
pos
,
local_expert_count
,
global_expert_count
=
count_by_gate
(
gate
,
pos
,
local_expert_count
,
global_expert_count
=
count_by_gate
(
gate
,
num_expert
,
world_size
)
num_expert
,
world_size
,
comm
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
fwd_expert_count
=
global_expert_count
.
view
(
world_size
,
fwd_expert_count
=
global_expert_count
.
view
(
world_size
,
num_expert
).
sum
(
dim
=
0
)
num_expert
).
sum
(
dim
=
0
)
...
...
fmoe/layers.py
View file @
50a9aa94
...
@@ -5,7 +5,7 @@ import torch
...
@@ -5,7 +5,7 @@ import torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
math
import
math
from
.functions
import
prepare_forward
from
.functions
import
prepare_forward
,
ensure_comm
from
.functions
import
MOEScatter
,
MOEGather
,
MOELinear
from
.functions
import
MOEScatter
,
MOEGather
,
MOELinear
from
.functions
import
AllGather
,
Slice
from
.functions
import
AllGather
,
Slice
from
.gates
import
NaiveGate
from
.gates
import
NaiveGate
...
@@ -74,7 +74,8 @@ def mark_module_parallel_comm(module, comm):
...
@@ -74,7 +74,8 @@ def mark_module_parallel_comm(module, comm):
setattr
(
p
,
"dp_comm"
,
comm
)
setattr
(
p
,
"dp_comm"
,
comm
)
def
_fmoe_general_global_forward
(
inp
,
gate
,
expert_fn
,
num_expert
,
world_size
):
def
_fmoe_general_global_forward
(
inp
,
gate
,
expert_fn
,
num_expert
,
world_size
,
comm
=
None
):
r
"""
r
"""
A private function that performs the following steps to complete the MoE
A private function that performs the following steps to complete the MoE
computation.
computation.
...
@@ -92,7 +93,7 @@ def _fmoe_general_global_forward(inp, gate, expert_fn, num_expert, world_size):
...
@@ -92,7 +93,7 @@ def _fmoe_general_global_forward(inp, gate, expert_fn, num_expert, world_size):
global_expert_count
,
global_expert_count
,
fwd_expert_count
,
fwd_expert_count
,
fwd_batch_size
,
fwd_batch_size
,
)
=
prepare_forward
(
gate
,
num_expert
,
world_size
)
)
=
prepare_forward
(
gate
,
num_expert
,
world_size
,
comm
)
topk
=
1
topk
=
1
if
len
(
gate
.
shape
)
==
2
:
if
len
(
gate
.
shape
)
==
2
:
topk
=
gate
.
shape
[
1
]
topk
=
gate
.
shape
[
1
]
...
@@ -138,6 +139,7 @@ class FMoE(nn.Module):
...
@@ -138,6 +139,7 @@ class FMoE(nn.Module):
d_model
=
1024
,
d_model
=
1024
,
world_size
=
1
,
world_size
=
1
,
mp_group
=
None
,
mp_group
=
None
,
moe_group
=
None
,
top_k
=
2
,
top_k
=
2
,
gate
=
NaiveGate
,
gate
=
NaiveGate
,
expert
=
None
,
expert
=
None
,
...
@@ -171,6 +173,7 @@ class FMoE(nn.Module):
...
@@ -171,6 +173,7 @@ class FMoE(nn.Module):
self
.
gate_hook
=
gate_hook
self
.
gate_hook
=
gate_hook
self
.
mask
=
mask
self
.
mask
=
mask
self
.
mask_dict
=
mask_dict
self
.
mask_dict
=
mask_dict
self
.
moe_group
=
moe_group
def
expert_fn
(
self
,
inp
,
fwd_expert_count
):
def
expert_fn
(
self
,
inp
,
fwd_expert_count
):
r
"""
r
"""
...
@@ -201,7 +204,7 @@ class FMoE(nn.Module):
...
@@ -201,7 +204,7 @@ class FMoE(nn.Module):
mark_module_parallel_comm
(
e
,
comm
)
mark_module_parallel_comm
(
e
,
comm
)
else
:
else
:
mark_module_parallel_comm
(
self
.
experts
,
comm
)
mark_module_parallel_comm
(
self
.
experts
,
comm
)
mark_module_parallel_comm
(
self
.
gate
,
"
world
"
)
mark_module_parallel_comm
(
self
.
gate
,
"
moe
"
)
def
forward
(
self
,
inp
):
def
forward
(
self
,
inp
):
r
"""
r
"""
...
@@ -209,6 +212,8 @@ class FMoE(nn.Module):
...
@@ -209,6 +212,8 @@ class FMoE(nn.Module):
according to the gate. The score of the selected gate given by the
according to the gate. The score of the selected gate given by the
expert is multiplied to the experts' output tensors as a weight.
expert is multiplied to the experts' output tensors as a weight.
"""
"""
if
self
.
world_size
>
1
:
ensure_comm
(
inp
,
self
.
moe_group
)
if
self
.
mp_size
>
1
:
if
self
.
mp_size
>
1
:
inp
=
Slice
.
apply
(
inp
,
self
.
mp_rank
,
self
.
mp_size
,
self
.
mp_group
)
inp
=
Slice
.
apply
(
inp
,
self
.
mp_rank
,
self
.
mp_size
,
self
.
mp_group
)
...
@@ -224,7 +229,7 @@ class FMoE(nn.Module):
...
@@ -224,7 +229,7 @@ class FMoE(nn.Module):
fwd
=
_fmoe_general_global_forward
(
fwd
=
_fmoe_general_global_forward
(
inp
,
inp
,
gate_top_k_idx
,
gate_top_k_idx
,
self
.
expert_fn
,
self
.
num_expert
,
self
.
world_size
self
.
expert_fn
,
self
.
num_expert
,
self
.
world_size
,
self
.
moe_group
)
)
# recover deleted tensors
# recover deleted tensors
...
...
fmoe/megatron/balance.py
View file @
50a9aa94
...
@@ -5,6 +5,7 @@ import torch
...
@@ -5,6 +5,7 @@ import torch
from
fmoe.balance
import
reset_balance_profile
from
fmoe.balance
import
reset_balance_profile
from
fmoe.balance
import
update_balance_profile
from
fmoe.balance
import
update_balance_profile
from
fmoe.utils
import
get_torch_default_comm
from
fmoe.utils
import
get_torch_default_comm
from
.distributed
import
get_moe_group
balance_dict
=
{}
balance_dict
=
{}
...
@@ -101,11 +102,11 @@ def patch_forward_step(forward_step_func):
...
@@ -101,11 +102,11 @@ def patch_forward_step(forward_step_func):
torch
.
cat
(
loss_list
).
mean
()
*
args
.
balance_loss_weight
torch
.
cat
(
loss_list
).
mean
()
*
args
.
balance_loss_weight
)
)
# avarage across
world
group
# avarage across
moe
group
world
_group
=
get_
torch_default_comm
()
moe
_group
=
get_
moe_group
()
world_size
=
torch
.
distributed
.
get_world_size
(
group
=
world
_group
)
world_size
=
torch
.
distributed
.
get_world_size
(
group
=
moe
_group
)
averaged_bal_loss
=
bal_loss
.
clone
().
detach
()
averaged_bal_loss
=
bal_loss
.
clone
().
detach
()
torch
.
distributed
.
all_reduce
(
averaged_bal_loss
,
group
=
world
_group
)
torch
.
distributed
.
all_reduce
(
averaged_bal_loss
,
group
=
moe
_group
)
averaged_bal_loss
/=
world_size
averaged_bal_loss
/=
world_size
loss
+=
bal_loss
loss
+=
bal_loss
...
...
fmoe/megatron/distributed.py
View file @
50a9aa94
...
@@ -4,6 +4,18 @@ distributed support for Megatron
...
@@ -4,6 +4,18 @@ distributed support for Megatron
from
fmoe.distributed
import
DistributedGroupedDataParallel
from
fmoe.distributed
import
DistributedGroupedDataParallel
_moe_group
=
None
def
set_moe_group
(
moe_group
):
global
_moe_group
_moe_group
=
moe_group
def
get_moe_group
():
return
_moe_group
class
DistributedDataParallel
(
DistributedGroupedDataParallel
):
class
DistributedDataParallel
(
DistributedGroupedDataParallel
):
r
"""
r
"""
A wrapper that is used to replace the DDP module provided by Megatron, which
A wrapper that is used to replace the DDP module provided by Megatron, which
...
@@ -18,6 +30,7 @@ class DistributedDataParallel(DistributedGroupedDataParallel):
...
@@ -18,6 +30,7 @@ class DistributedDataParallel(DistributedGroupedDataParallel):
module
,
module
,
mp_group
=
mpu
.
get_model_parallel_group
(),
mp_group
=
mpu
.
get_model_parallel_group
(),
dp_group
=
mpu
.
get_data_parallel_group
(),
dp_group
=
mpu
.
get_data_parallel_group
(),
moe_group
=
_moe_group
)
)
def
state_dict
(
self
,
*
args
,
**
kwargs
):
def
state_dict
(
self
,
*
args
,
**
kwargs
):
...
...
fmoe/megatron/layers.py
View file @
50a9aa94
...
@@ -10,6 +10,7 @@ import torch.nn.functional as F
...
@@ -10,6 +10,7 @@ import torch.nn.functional as F
from
fmoe.transformer
import
FMoETransformerMLP
from
fmoe.transformer
import
FMoETransformerMLP
from
.balance
import
reset_gate_hook
from
.balance
import
reset_gate_hook
from
.balance
import
generate_megatron_gate_hook
from
.balance
import
generate_megatron_gate_hook
from
.distributed
import
set_moe_group
class
_FakeMegatronMLP
(
nn
.
Module
):
class
_FakeMegatronMLP
(
nn
.
Module
):
...
@@ -74,7 +75,7 @@ class MegatronMLP(FMoETransformerMLP):
...
@@ -74,7 +75,7 @@ class MegatronMLP(FMoETransformerMLP):
communication group `group` to replace the original MLP layer in Megatron.
communication group `group` to replace the original MLP layer in Megatron.
"""
"""
def
__init__
(
self
,
args
,
group
,
layer_idx
):
def
__init__
(
self
,
args
,
mp_group
,
moe_
group
,
layer_idx
):
assert
(
assert
(
args
.
seq_length
*
args
.
micro_batch_size
%
args
.
tensor_model_parallel_size
args
.
seq_length
*
args
.
micro_batch_size
%
args
.
tensor_model_parallel_size
==
0
==
0
...
@@ -82,7 +83,7 @@ class MegatronMLP(FMoETransformerMLP):
...
@@ -82,7 +83,7 @@ class MegatronMLP(FMoETransformerMLP):
if
not
args
.
distributed_experts
:
if
not
args
.
distributed_experts
:
world_size
=
1
world_size
=
1
else
:
else
:
world_size
=
args
.
world
_size
world_size
=
args
.
tensor_model_parallel_size
*
args
.
data_parallel
_size
gate
=
None
gate
=
None
if
not
args
.
balance_strategy
or
args
.
balance_strategy
==
"naive"
:
if
not
args
.
balance_strategy
or
args
.
balance_strategy
==
"naive"
:
from
fmoe.gates
import
NaiveGate
from
fmoe.gates
import
NaiveGate
...
@@ -102,13 +103,15 @@ class MegatronMLP(FMoETransformerMLP):
...
@@ -102,13 +103,15 @@ class MegatronMLP(FMoETransformerMLP):
gate
=
SwitchGate
gate
=
SwitchGate
else
:
else
:
assert
False
,
"Undefined balance strategy {}"
%
(
args
.
balance_strategy
)
assert
False
,
"Undefined balance strategy {}"
%
(
args
.
balance_strategy
)
super
().
__init__
(
super
().
__init__
(
args
.
num_experts
,
args
.
num_experts
,
top_k
=
args
.
top_k
,
top_k
=
args
.
top_k
,
d_model
=
args
.
hidden_size
,
d_model
=
args
.
hidden_size
,
d_hidden
=
args
.
hidden_hidden_size
,
d_hidden
=
args
.
hidden_hidden_size
,
world_size
=
world_size
,
world_size
=
world_size
,
mp_group
=
group
,
mp_group
=
mp_group
,
moe_group
=
moe_group
,
expert_dp_comm
=
"none"
if
args
.
distributed_experts
else
"dp"
,
expert_dp_comm
=
"none"
if
args
.
distributed_experts
else
"dp"
,
gate_hook
=
generate_megatron_gate_hook
(
gate_hook
=
generate_megatron_gate_hook
(
layer_idx
,
args
.
num_experts
*
world_size
layer_idx
,
args
.
num_experts
*
world_size
...
@@ -187,8 +190,24 @@ def fmoefy(
...
@@ -187,8 +190,24 @@ def fmoefy(
if
distributed_experts
is
not
None
:
if
distributed_experts
is
not
None
:
args
.
distributed_experts
=
distributed_experts
args
.
distributed_experts
=
distributed_experts
if
hasattr
(
mpu
,
'get_tensor_model_parallel_group'
):
mp_group
=
mpu
.
get_tensor_model_parallel_group
()
else
:
# For compatibility to older versions of Megatron-LM
mp_group
=
mpu
.
get_model_parallel_group
()
if
args
.
pipeline_model_parallel_size
==
1
:
moe_group
=
None
else
:
# Create a comm prependicular to pipeline group
stage_size
=
args
.
world_size
//
args
.
pipeline_model_parallel_size
for
i
in
range
(
0
,
args
.
world_size
,
stage_size
):
ranks
=
range
(
i
,
i
+
stage_size
)
group
=
torch
.
distributed
.
new_group
(
ranks
)
if
args
.
rank
in
ranks
:
moe_group
=
group
set_moe_group
(
moe_group
)
for
idx
,
l
in
enumerate
(
model
.
language_model
.
transformer
.
layers
):
for
idx
,
l
in
enumerate
(
model
.
language_model
.
transformer
.
layers
):
l
.
mlp
=
MegatronMLP
(
args
,
mp
u
.
get_model_parallel
_group
()
,
idx
)
l
.
mlp
=
MegatronMLP
(
args
,
mp
_group
,
moe
_group
,
idx
)
# initialize gate hook
# initialize gate hook
num_layers
=
len
(
model
.
language_model
.
transformer
.
layers
)
num_layers
=
len
(
model
.
language_model
.
transformer
.
layers
)
...
...
fmoe/transformer.py
View file @
50a9aa94
...
@@ -44,6 +44,7 @@ class FMoETransformerMLP(FMoE):
...
@@ -44,6 +44,7 @@ class FMoETransformerMLP(FMoE):
d_hidden
=
4096
,
d_hidden
=
4096
,
world_size
=
1
,
world_size
=
1
,
mp_group
=
None
,
mp_group
=
None
,
moe_group
=
None
,
activation
=
torch
.
nn
.
GELU
(),
activation
=
torch
.
nn
.
GELU
(),
gate
=
NaiveGate
,
gate
=
NaiveGate
,
top_k
=
2
,
top_k
=
2
,
...
@@ -59,6 +60,7 @@ class FMoETransformerMLP(FMoE):
...
@@ -59,6 +60,7 @@ class FMoETransformerMLP(FMoE):
top_k
=
top_k
,
top_k
=
top_k
,
world_size
=
world_size
,
world_size
=
world_size
,
mp_group
=
mp_group
,
mp_group
=
mp_group
,
moe_group
=
moe_group
,
gate_hook
=
gate_hook
,
gate_hook
=
gate_hook
,
mask
=
mask
,
mask
=
mask
,
mask_dict
=
mask_dict
mask_dict
=
mask_dict
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment