Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
f2040d9f
Commit
f2040d9f
authored
Feb 05, 2021
by
Rick Ho
Browse files
pass pylint
parent
bf2fd0c0
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
160 additions
and
32 deletions
+160
-32
.pylintrc
.pylintrc
+6
-3
fmoe/distributed.py
fmoe/distributed.py
+22
-5
fmoe/functions.py
fmoe/functions.py
+5
-1
fmoe/layers.py
fmoe/layers.py
+73
-13
fmoe/megatron.py
fmoe/megatron.py
+39
-5
fmoe/utils.py
fmoe/utils.py
+15
-5
No files found.
.pylintrc
View file @
f2040d9f
...
@@ -138,7 +138,10 @@ disable=print-statement,
...
@@ -138,7 +138,10 @@ disable=print-statement,
xreadlines-attribute,
xreadlines-attribute,
deprecated-sys-function,
deprecated-sys-function,
exception-escape,
exception-escape,
comprehension-escape
comprehension-escape,
arguments-differ,
import-outside-toplevel,
signature-differs,
# Enable the message, report, category or checker with the given id(s). You can
# Enable the message, report, category or checker with the given id(s). You can
# either give multiple identifier separated by comma (,) or put this option
# either give multiple identifier separated by comma (,) or put this option
...
@@ -398,7 +401,7 @@ indent-after-paren=4
...
@@ -398,7 +401,7 @@ indent-after-paren=4
indent-string=' '
indent-string=' '
# Maximum number of characters on a single line.
# Maximum number of characters on a single line.
max-line-length=1
00
max-line-length=
8
1
# Maximum number of lines in a module.
# Maximum number of lines in a module.
max-module-lines=1000
max-module-lines=1000
...
@@ -553,7 +556,7 @@ preferred-modules=
...
@@ -553,7 +556,7 @@ preferred-modules=
max-args=12
max-args=12
# Maximum number of attributes for a class (see R0902).
# Maximum number of attributes for a class (see R0902).
max-attributes=
7
max-attributes=
32
# Maximum number of boolean expressions in an if statement (see R0916).
# Maximum number of boolean expressions in an if statement (see R0916).
max-bool-expr=5
max-bool-expr=5
...
...
fmoe/distributed.py
View file @
f2040d9f
r
'''
Supportive modules to conduct distributed training
'''
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
torch._utils
import
_flatten_dense_tensors
,
_unflatten_dense_tensors
from
torch._utils
import
_flatten_dense_tensors
,
_unflatten_dense_tensors
...
@@ -5,11 +8,24 @@ from .utils import get_torch_default_comm
...
@@ -5,11 +8,24 @@ from .utils import get_torch_default_comm
class
DistributedGroupedDataParallel
(
nn
.
Module
):
class
DistributedGroupedDataParallel
(
nn
.
Module
):
r
'''
A customized DDP module to support different all-reduce regions in the
model. The all-reduce region is defined as an attribution `dp_comm` in the
weight object.
The grads of the weights are identified to be reduced in different groups
according to the weigths' `dp_comm` attribute.
If it is set to `dp`, it will only be reduced across the data-parallel
groups, which means that in the model parallel group, they are not
synchronized.
If it is set to `world`, the gradients is synchronized across all workers,
regardless their model or data parallel group. This is extremely useful for
shared layers like the gate.
'''
def
__init__
(
self
,
module
,
mp_group
=
None
,
dp_group
=
None
,
world_group
=
None
,
def
__init__
(
self
,
module
,
mp_group
=
None
,
dp_group
=
None
,
world_group
=
None
,
auto_allreduce
=
False
):
auto_allreduce
=
False
):
assert
not
auto_allreduce
,
'Automatic all-reduce is not implemented yet'
assert
not
auto_allreduce
,
'Automatic all-reduce is not implemented yet'
super
(
DistributedGroupedDataParallel
,
self
).
__init__
()
super
().
__init__
()
self
.
module
=
module
self
.
module
=
module
self
.
comms
=
dict
()
self
.
comms
=
dict
()
...
@@ -39,10 +55,9 @@ class DistributedGroupedDataParallel(nn.Module):
...
@@ -39,10 +55,9 @@ class DistributedGroupedDataParallel(nn.Module):
groups
[
group_key
]
=
[
p
]
groups
[
group_key
]
=
[
p
]
else
:
else
:
groups
[
group_key
].
append
(
p
)
groups
[
group_key
].
append
(
p
)
for
dp_comm
,
dtype
in
groups
:
for
(
dp_comm
,
dtype
),
group
in
groups
.
items
()
:
if
dp_comm
not
in
self
.
comms
:
if
dp_comm
not
in
self
.
comms
:
continue
continue
group
=
groups
[
dp_comm
,
dtype
]
comm
=
self
.
comms
[
dp_comm
]
comm
=
self
.
comms
[
dp_comm
]
grads
=
[
p
.
grad
.
data
for
p
in
group
]
grads
=
[
p
.
grad
.
data
for
p
in
group
]
coalesced
=
_flatten_dense_tensors
(
grads
)
coalesced
=
_flatten_dense_tensors
(
grads
)
...
@@ -61,5 +76,7 @@ class DistributedGroupedDataParallel(nn.Module):
...
@@ -61,5 +76,7 @@ class DistributedGroupedDataParallel(nn.Module):
self
.
allreduce_params
=
allreduce_params
self
.
allreduce_params
=
allreduce_params
def
forward
(
self
,
*
args
,
**
kwargs
):
def
forward
(
self
,
*
args
,
**
kwargs
):
r
'''
Directly call the module's forward function.
'''
return
self
.
module
(
*
args
,
**
kwargs
)
return
self
.
module
(
*
args
,
**
kwargs
)
fmoe/functions.py
View file @
f2040d9f
...
@@ -40,7 +40,8 @@ def moe_prepare_forward(gate, num_expert, world_size, comm=None):
...
@@ -40,7 +40,8 @@ def moe_prepare_forward(gate, num_expert, world_size, comm=None):
)
)
else
:
else
:
global_expert_count
=
local_expert_count
global_expert_count
=
local_expert_count
fwd_expert_count
=
global_expert_count
.
view
(
world_size
,
num_expert
).
sum
(
dim
=
0
)
fwd_expert_count
=
global_expert_count
.
view
(
world_size
,
num_expert
).
sum
(
dim
=
0
)
fwd_batch_size
=
int
(
fwd_expert_count
.
sum
().
item
())
fwd_batch_size
=
int
(
fwd_expert_count
.
sum
().
item
())
return
(
return
(
pos
,
pos
,
...
@@ -175,6 +176,9 @@ class MOEGather(Function):
...
@@ -175,6 +176,9 @@ class MOEGather(Function):
class
AllGather
(
Function
):
class
AllGather
(
Function
):
r
'''
A wrapper for the All-Gather function to support auto-differentiation.
'''
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
inp
,
rank
,
world_size
,
group
):
def
forward
(
ctx
,
inp
,
rank
,
world_size
,
group
):
tensor_list
=
[
torch
.
empty_like
(
inp
)
for
_
in
range
(
world_size
)]
tensor_list
=
[
torch
.
empty_like
(
inp
)
for
_
in
range
(
world_size
)]
...
...
fmoe/layers.py
View file @
f2040d9f
from
.functions
import
*
r
'''
Layers that FMoE provides to users
'''
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
.functions
import
moe_prepare_forward
from
.functions
import
MOEScatter
,
MOEGather
,
MOELinear
from
.functions
import
AllGather
class
FMoELinear
(
nn
.
Module
):
class
FMoELinear
(
nn
.
Module
):
r
'''
A linear layer that contains multiple experts.
As multiple experts can be placed on the same worker, the computation can be
performed in parallel to increase the performance.
The FMoELinear module provides such function.
'''
def
__init__
(
self
,
num_expert
=
32
,
in_feat
=
1024
,
out_feat
=
1024
):
def
__init__
(
self
,
num_expert
=
32
,
in_feat
=
1024
,
out_feat
=
1024
):
super
(
FMoELinear
,
self
).
__init__
()
super
().
__init__
()
self
.
num_expert
=
num_expert
self
.
num_expert
=
num_expert
self
.
in_feat
=
in_feat
self
.
in_feat
=
in_feat
self
.
out_feat
=
out_feat
self
.
out_feat
=
out_feat
...
@@ -13,21 +26,40 @@ class FMoELinear(nn.Module):
...
@@ -13,21 +26,40 @@ class FMoELinear(nn.Module):
self
.
reset_parameters
()
self
.
reset_parameters
()
def
reset_parameters
(
self
):
def
reset_parameters
(
self
):
r
'''
Initialize the weight as linear layers
'''
for
i
in
range
(
self
.
num_expert
):
for
i
in
range
(
self
.
num_expert
):
linear
=
nn
.
Linear
(
in_features
=
self
.
in_feat
,
out_features
=
self
.
out_feat
)
linear
=
nn
.
Linear
(
in_features
=
self
.
in_feat
,
out_features
=
self
.
out_feat
)
self
.
weight
.
data
[
i
]
=
linear
.
weight
.
data
self
.
weight
.
data
[
i
]
=
linear
.
weight
.
data
def
forward
(
self
,
inp
,
fwd_expert_count
):
def
forward
(
self
,
inp
,
fwd_expert_count
):
r
'''
Call MOE function
'''
return
MOELinear
.
apply
(
inp
,
self
.
weight
,
fwd_expert_count
)
return
MOELinear
.
apply
(
inp
,
self
.
weight
,
fwd_expert_count
)
class
FMoENaiveGate
(
nn
.
Module
):
class
FMoENaiveGate
(
nn
.
Module
):
r
'''
A naive gate implementation that defines the standard behavior of the gate
which determines which experts the tokens are going to.
Both the indecies and the score, or confidence, are output to the parent
module.
The load-balance strategies are also designed to be implemented within the
`Gate` module.
'''
def
__init__
(
self
,
d_model
,
num_expert
,
world_size
,
top_k
=
2
):
def
__init__
(
self
,
d_model
,
num_expert
,
world_size
,
top_k
=
2
):
super
(
FMoENaiveGate
,
self
).
__init__
()
super
().
__init__
()
self
.
gate
=
nn
.
Linear
(
d_model
,
num_expert
*
world_size
)
self
.
gate
=
nn
.
Linear
(
d_model
,
num_expert
*
world_size
)
self
.
top_k
=
top_k
self
.
top_k
=
top_k
def
forward
(
self
,
inp
):
def
forward
(
self
,
inp
):
r
'''
The naive implementation simply calculates the top-k of a linear layer's
output.
'''
gate
=
self
.
gate
(
inp
)
gate
=
self
.
gate
(
inp
)
gate_top_k_val
,
gate_top_k_idx
=
torch
.
topk
(
gate_top_k_val
,
gate_top_k_idx
=
torch
.
topk
(
gate
,
k
=
self
.
top_k
,
dim
=-
1
,
largest
=
True
,
sorted
=
False
gate
,
k
=
self
.
top_k
,
dim
=-
1
,
largest
=
True
,
sorted
=
False
...
@@ -42,15 +74,25 @@ class FMoENaiveGate(nn.Module):
...
@@ -42,15 +74,25 @@ class FMoENaiveGate(nn.Module):
def
_fmoe_full_forward
(
inp
,
gate
,
linears
,
activation
,
num_expert
,
world_size
):
def
_fmoe_full_forward
(
inp
,
gate
,
linears
,
activation
,
num_expert
,
world_size
):
r
'''
A private function that performs the following steps to complete the MoE
computation.
* Count the number of tokens from each worker to each expert.
* Send the features to their target position so that input features to each
expert are contiguous in memory.
* Perform the MLP of the experts by applying MoELinear and the activation in
turns.
* Gather the output features of experts back, and reorder them as sentences.
Intermediate results like expert counts are hidden from users by this
function.
'''
(
(
pos
,
pos
,
local_expert_count
,
global_expert_count
,
fwd_expert_count
,
local_expert_count
,
fwd_batch_size
global_expert_count
,
fwd_expert_count
,
fwd_batch_size
,
)
=
moe_prepare_forward
(
gate
,
num_expert
,
world_size
)
)
=
moe_prepare_forward
(
gate
,
num_expert
,
world_size
)
x
=
MOEScatter
.
apply
(
x
=
MOEScatter
.
apply
(
inp
,
pos
,
local_expert_count
,
global_expert_count
,
fwd_batch_size
,
world_size
inp
,
pos
,
local_expert_count
,
global_expert_count
,
fwd_batch_size
,
world_size
)
)
for
i
,
l
in
enumerate
(
linears
):
for
i
,
l
in
enumerate
(
linears
):
if
i
:
if
i
:
...
@@ -63,6 +105,19 @@ def _fmoe_full_forward(inp, gate, linears, activation, num_expert, world_size):
...
@@ -63,6 +105,19 @@ def _fmoe_full_forward(inp, gate, linears, activation, num_expert, world_size):
class
FMoETransformerMLP
(
nn
.
Module
):
class
FMoETransformerMLP
(
nn
.
Module
):
r
'''
A complete MoE MLP module in a Transformer block.
* `num_expert` stands for the number of experts on **each** worker.
* `world_size` stands for the total number of workers that contains
different experts.
* `mp_group` can be a torch's communication group, indicating that model
parallel is applied across the group, which means that workers in the group
hold the same copy of the input feature, and demands the same copy of the
output. FMoE saves computation by slicing the input in the mp group and
performing all-gather after the MLP computation.
* `activation` is the activation function to be used in MLP in each expert.
* `top_k` stands for the number of experts each token is going to.
'''
def
__init__
(
def
__init__
(
self
,
self
,
num_expert
=
32
,
num_expert
=
32
,
...
@@ -72,9 +127,9 @@ class FMoETransformerMLP(nn.Module):
...
@@ -72,9 +127,9 @@ class FMoETransformerMLP(nn.Module):
mp_group
=
None
,
mp_group
=
None
,
activation
=
torch
.
nn
.
functional
.
gelu
,
activation
=
torch
.
nn
.
functional
.
gelu
,
top_k
=
2
,
top_k
=
2
,
pre_lnorm
=
False
,
pre_lnorm
=
False
):
):
super
(
FMoETransformerMLP
,
self
).
__init__
()
super
().
__init__
()
self
.
num_expert
=
num_expert
self
.
num_expert
=
num_expert
self
.
d_model
=
d_model
self
.
d_model
=
d_model
self
.
d_hidden
=
d_hidden
self
.
d_hidden
=
d_hidden
...
@@ -103,6 +158,11 @@ class FMoETransformerMLP(nn.Module):
...
@@ -103,6 +158,11 @@ class FMoETransformerMLP(nn.Module):
)
)
def
forward
(
self
,
inp
:
torch
.
Tensor
):
def
forward
(
self
,
inp
:
torch
.
Tensor
):
r
'''
The FMoETransformerMLP module automatically performs reshape and layer
normalization. The score of the selected gate given by the expert is
multiplied to the experts' output tensors as a weight.
'''
original_shape
=
inp
.
shape
original_shape
=
inp
.
shape
inp
=
inp
.
reshape
(
-
1
,
self
.
d_model
)
inp
=
inp
.
reshape
(
-
1
,
self
.
d_model
)
...
...
fmoe/megatron.py
View file @
f2040d9f
'''
The adaptor to seamlessly enable FastMoE in Megatron-LM v2.0 with at most two
lines of modification.
See `exapmles/megatron` for usage instructions.
'''
from
.layers
import
FMoETransformerMLP
from
.layers
import
FMoETransformerMLP
from
.distributed
import
DistributedGroupedDataParallel
from
.distributed
import
DistributedGroupedDataParallel
def
create_moe_mlp
(
args
,
group
):
def
_create_moe_mlp
(
args
,
group
):
assert
(
r
'''
args
.
seq_length
*
args
.
micro_batch_size
%
args
.
tensor_model_parallel_size
==
0
Make the FMoETransformerMLP layer that distributes experts across
communication group `group` to replace the original MLP layer in Megatron.
'''
assert
(
args
.
seq_length
*
args
.
micro_batch_size
%
args
.
tensor_model_parallel_size
==
0
),
"Batch size x sequence length should be multiple of mp size"
),
"Batch size x sequence length should be multiple of mp size"
if
not
args
.
distributed_experts
:
if
not
args
.
distributed_experts
:
world_size
=
1
world_size
=
1
...
@@ -23,6 +32,17 @@ def create_moe_mlp(args, group):
...
@@ -23,6 +32,17 @@ def create_moe_mlp(args, group):
def
fmoefy
(
model
,
num_experts
=
None
,
distributed_experts
=
True
):
def
fmoefy
(
model
,
num_experts
=
None
,
distributed_experts
=
True
):
r
'''
Replace MLP layers in a transformer-based model in Megatron by MoE.
* `model` should be a standard Megatron model that has
`model.language_model.transformer.layers` as transformer layers, which is an
array of transformer blocks that contain an `mlp` member.
* `distributed_expert` is set to True if different experts are located in
different workers. Otherwise, the experts on the workers are identical, and
they are trained in data-parallel mode. This can be useful when testing on
small models that do not require high training throughput or large parameter
capacity.
'''
from
megatron
import
get_args
from
megatron
import
get_args
from
megatron
import
mpu
from
megatron
import
mpu
args
=
get_args
()
args
=
get_args
()
...
@@ -37,24 +57,38 @@ def fmoefy(model, num_experts=None, distributed_experts=True):
...
@@ -37,24 +57,38 @@ def fmoefy(model, num_experts=None, distributed_experts=True):
args
.
distributed_experts
=
distributed_experts
args
.
distributed_experts
=
distributed_experts
for
l
in
model
.
language_model
.
transformer
.
layers
:
for
l
in
model
.
language_model
.
transformer
.
layers
:
l
.
mlp
=
create_moe_mlp
(
args
,
mpu
.
get_model_parallel_group
())
l
.
mlp
=
_
create_moe_mlp
(
args
,
mpu
.
get_model_parallel_group
())
return
model
return
model
class
DistributedDataParallel
(
DistributedGroupedDataParallel
):
class
DistributedDataParallel
(
DistributedGroupedDataParallel
):
r
'''
A wrapper that is used to replace the DDP module provided by Megatron, which
is adapted to enable the sophiscated parallel and reduction strategies in
Fast MoE.
'''
def
__init__
(
self
,
module
):
def
__init__
(
self
,
module
):
from
megatron
import
mpu
from
megatron
import
mpu
super
(
DistributedDataParallel
,
self
).
__init__
(
super
().
__init__
(
module
,
module
,
mp_group
=
mpu
.
get_model_parallel_group
(),
mp_group
=
mpu
.
get_model_parallel_group
(),
dp_group
=
mpu
.
get_data_parallel_group
()
dp_group
=
mpu
.
get_data_parallel_group
()
)
)
def
state_dict
(
self
,
*
args
,
**
kwargs
):
def
state_dict
(
self
,
*
args
,
**
kwargs
):
r
'''
Keep consitency with Megatron
'''
return
self
.
module
.
state_dict
(
*
args
,
**
kwargs
)
return
self
.
module
.
state_dict
(
*
args
,
**
kwargs
)
def
state_dict_for_save_checkpoint
(
self
,
*
args
,
**
kwargs
):
def
state_dict_for_save_checkpoint
(
self
,
*
args
,
**
kwargs
):
r
'''
Keep consitency with Megatron
'''
return
self
.
module
.
state_dict_for_save_checkpoint
(
*
args
,
**
kwargs
)
return
self
.
module
.
state_dict_for_save_checkpoint
(
*
args
,
**
kwargs
)
def
load_state_dict
(
self
,
*
args
,
**
kwargs
):
def
load_state_dict
(
self
,
*
args
,
**
kwargs
):
r
'''
Keep consitency with Megatron
'''
return
self
.
module
.
load_state_dict
(
*
args
,
**
kwargs
)
return
self
.
module
.
load_state_dict
(
*
args
,
**
kwargs
)
fmoe/utils.py
View file @
f2040d9f
r
'''
Utils to play with PyTorch.
'''
import
torch.distributed
as
dist
import
torch.distributed
as
dist
# pylint: disable=broad-except
# pylint: disable=protected-access
def
get_torch_default_comm
():
def
get_torch_default_comm
():
r
'''
The NCCL communicator is needed so that Fast MoE can perform customized
communication operators in the C code. However, it is not a publicly
available variable. Therefore, a hacking class of the `ProcessGroupNCCL`
in Fast MoE's C code takes the `_default_pg` and tries to dig the
communicator out from the object. As PyTorch's private interface varies from
time to time, different hacking techniques are tried one-by-one to be
compatible with various versions of PyTorch.
'''
try
:
try
:
comm
=
dist
.
distributed_c10d
.
_get_default_group
()
comm
=
dist
.
distributed_c10d
.
_get_default_group
()
return
comm
return
comm
except
Exception
as
e
:
except
Exception
as
_
:
print
(
'Error {}'
.
format
(
e
))
pass
pass
try
:
try
:
comm
=
dist
.
distributed_c10d
.
_default_pg
comm
=
dist
.
distributed_c10d
.
_default_pg
...
@@ -15,6 +28,3 @@ def get_torch_default_comm():
...
@@ -15,6 +28,3 @@ def get_torch_default_comm():
except
Exception
as
_
:
except
Exception
as
_
:
pass
pass
raise
RuntimeError
(
'Unsupported PyTorch version'
)
raise
RuntimeError
(
'Unsupported PyTorch version'
)
return
None
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment