Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
b3380ec2
Commit
b3380ec2
authored
Feb 07, 2021
by
Rick Ho
Browse files
support arbitrary module as expert
parent
8328c794
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
162 additions
and
96 deletions
+162
-96
fmoe/__init__.py
fmoe/__init__.py
+2
-1
fmoe/layers.py
fmoe/layers.py
+67
-75
fmoe/megatron.py
fmoe/megatron.py
+22
-20
fmoe/transformer.py
fmoe/transformer.py
+71
-0
No files found.
fmoe/__init__.py
View file @
b3380ec2
...
...
@@ -2,5 +2,6 @@ r"""
The fmoe package contains MoE Layers only.
"""
from
.layers
import
FMoELinear
,
FMoETransformerMLP
from
.layers
import
FMoELinear
,
FMoE
from
.transformer
import
FMoETransformerMLP
from
.distributed
import
DistributedGroupedDataParallel
fmoe/layers.py
View file @
b3380ec2
...
...
@@ -41,15 +41,23 @@ class FMoELinear(nn.Module):
return
MOELinear
.
apply
(
inp
,
self
.
weight
,
fwd_expert_count
)
def
_fmoe_full_forward
(
inp
,
gate
,
linears
,
activation
,
num_expert
,
world_size
):
def
mark_module_parallel_comm
(
module
,
comm
):
r
'''
Mark all parameters in `module` as doing data parallel in `comm`, where
`comm` may be one of `'world', 'dp', 'none'`.
'''
for
p
in
module
.
parameters
():
setattr
(
p
,
'dp_comm'
,
comm
)
def
_fmoe_general_global_forward
(
inp
,
gate
,
expert_fn
,
num_expert
,
world_size
):
r
'''
A private function that performs the following steps to complete the MoE
computation.
* Count the number of tokens from each worker to each expert.
* Send the features to their target position so that input features to each
expert are contiguous in memory.
* Perform the MLP of the experts by applying MoELinear and the activation in
turns.
* Perform the forward computation of the experts using `expert_fn`
* Gather the output features of experts back, and reorder them as sentences.
Intermediate results like expert counts are hidden from users by this
function.
...
...
@@ -62,19 +70,18 @@ def _fmoe_full_forward(inp, gate, linears, activation, num_expert, world_size):
inp
,
pos
,
local_expert_count
,
global_expert_count
,
fwd_batch_size
,
world_size
)
for
i
,
l
in
enumerate
(
linears
):
if
i
:
x
=
activation
(
x
)
x
=
l
(
x
,
fwd_expert_count
)
x
=
expert_fn
(
x
,
fwd_expert_count
)
x
=
MOEGather
.
apply
(
x
,
pos
,
local_expert_count
,
global_expert_count
,
inp
.
shape
[
0
],
world_size
)
return
x
class
FMoETransformerMLP
(
nn
.
Module
):
class
FMoE
(
nn
.
Module
):
r
'''
A complete MoE MLP module in a Transformer block.
A general moe implementation that supports an arbitrary module as the expert
Either `expert` or `expert_fn` is required.
* `num_expert` stands for the number of experts on **each** worker.
* `world_size` stands for the total number of workers that contains
different experts.
...
...
@@ -83,25 +90,19 @@ class FMoETransformerMLP(nn.Module):
hold the same copy of the input feature, and demands the same copy of the
output. FMoE saves computation by slicing the input in the mp group and
performing all-gather after the MLP computation.
* `activation` is the activation function to be used in MLP in each expert.
* `top_k` stands for the number of experts each token is going to.
* `gate` is a gate class which can found in `fmoe.gates`.
* `expert` can be specified as a module class, it is used to generate
`num_expert` expert modules.
* `expert_fn` is specified as a callable object or a function, it will be
called during forward, giving the input tensor (contiguous) and the array of
the number of input feature to each expert as input.
'''
def
__init__
(
self
,
num_expert
=
32
,
d_model
=
1024
,
d_hidden
=
4096
,
world_size
=
1
,
mp_group
=
None
,
activation
=
torch
.
nn
.
functional
.
gelu
,
gate
=
NaiveGate
,
top_k
=
2
,
pre_lnorm
=
False
):
def
__init__
(
self
,
num_expert
=
32
,
d_model
=
1024
,
world_size
=
1
,
mp_group
=
None
,
top_k
=
2
,
gate
=
NaiveGate
,
expert
=
None
,
expert_fn
=
None
):
super
().
__init__
()
self
.
num_expert
=
num_expert
self
.
d_model
=
d_model
self
.
d_hidden
=
d_hidden
self
.
world_size
=
world_size
self
.
mp_group
=
mp_group
if
mp_group
is
None
:
...
...
@@ -110,37 +111,46 @@ class FMoETransformerMLP(nn.Module):
else
:
self
.
mp_size
=
mp_group
.
size
()
self
.
mp_rank
=
mp_group
.
rank
()
self
.
activation
=
activation
self
.
pre_lnorm
=
pre_lnorm
self
.
top_k
=
top_k
self
.
htoh4
=
FMoELinear
(
num_expert
,
d_model
,
d_hidden
)
self
.
h4toh
=
FMoELinear
(
num_expert
,
d_hidden
,
d_model
)
if
self
.
world_size
>
self
.
mp_size
:
for
p
in
self
.
htoh4
.
parameters
():
setattr
(
p
,
'dp_comm'
,
'none'
)
for
p
in
self
.
h4toh
.
parameters
():
setattr
(
p
,
'dp_comm'
,
'none'
)
self
.
gate
=
gate
(
d_model
,
num_expert
,
world_size
,
top_k
)
for
p
in
self
.
gate
.
parameters
():
setattr
(
p
,
'dp_comm'
,
'world'
)
self
.
layer_norm
=
nn
.
LayerNorm
(
d_model
)
self
.
bias
=
torch
.
nn
.
parameter
.
Parameter
(
torch
.
zeros
(
d_model
,
dtype
=
torch
.
float32
)
)
def
forward
(
self
,
inp
:
torch
.
Tensor
):
if
expert_fn
is
None
:
assert
expert
is
not
None
,
'Either expert or expert_fn should be set'
self
.
experts
=
[
expert
(
d_model
)
for
_
in
range
(
num_expert
)]
def
expert_fn
(
self
,
inp
,
fwd_expert_count
):
outputs
=
[]
base_idx
=
0
for
i
in
range
(
self
.
num_expert
):
batch_size
=
fwd_expert_count
[
i
].
item
()
inp_slice
=
inp
[
base_idx
:
base_idx
+
batch_size
]
outputs
.
append
(
self
.
experts
[
i
](
inp_slice
))
base_idx
+=
batch_size
return
torch
.
cat
(
outputs
,
dim
=
0
)
self
.
expert_fn
=
expert_fn
def
mark_parallel_comm
(
self
):
r
'''
The FMoETransformerMLP module automatically performs reshape and layer
normalization. The score of the selected gate given by the expert is
multiplied to the experts' output tensors as a weight.
Automatically mark the data parallel comms of the parameters within the
module. This can be typically called at the end of the __init__ function
in child classes.
'''
if
self
.
experts
is
not
None
:
if
self
.
world_size
>
self
.
mp_size
:
comm
=
'none'
else
:
comm
=
'dp'
if
isinstance
(
self
.
experts
,
list
):
for
e
in
self
.
experts
:
mark_module_parallel_comm
(
e
,
comm
)
else
:
mark_module_parallel_comm
(
self
.
experts
,
comm
)
mark_module_parallel_comm
(
self
.
gate
,
'world'
)
def
forward
(
self
,
inp
):
r
'''
The FMoE module first computes gate output, and then conduct MoE forward
according to the gate. The score of the selected gate given by the
expert is multiplied to the experts' output tensors as a weight.
'''
original_shape
=
inp
.
shape
inp
=
inp
.
reshape
(
-
1
,
self
.
d_model
)
if
self
.
mp_size
>
1
:
B
:
int
=
inp
.
shape
[
0
]
local_batch_size
=
B
//
self
.
mp_size
...
...
@@ -148,35 +158,17 @@ class FMoETransformerMLP(nn.Module):
batch_end
=
min
(
batch_start
+
local_batch_size
,
B
)
inp
=
inp
[
batch_start
:
batch_end
]
residual
=
inp
if
self
.
pre_lnorm
:
inp
=
self
.
layer_norm
(
inp
)
gate_top_k_idx
,
gate_score
=
self
.
gate
(
inp
)
# to: (BxLxtop_k) x d_model
inp
=
inp
.
repeat_interleave
(
repeats
=
self
.
top_k
,
dim
=
0
)
x
=
_fmoe_full_forward
(
inp
,
gate_top_k_idx
,
[
self
.
htoh4
,
self
.
h4toh
],
self
.
activation
,
self
.
num_expert
,
self
.
world_size
,
)
x
=
_fmoe_general_global_forward
(
inp
,
gate_top_k_idx
,
self
.
expert_fn
,
self
.
num_expert
,
self
.
world_size
)
# to: (BxL) x top_k x d_model
core_out
=
x
.
view
(
-
1
,
self
.
top_k
,
self
.
d_model
)
# to: (BxL) x 1 x d_model
core_out
=
torch
.
bmm
(
gate_score
,
core_out
)
output
=
core_out
.
reshape
(
residual
.
shape
)
+
residual
if
not
self
.
pre_lnorm
:
output
=
self
.
layer_norm
(
output
)
x
=
x
.
view
(
-
1
,
self
.
top_k
,
self
.
d_model
)
# to: (BxL) x d_model
x
=
torch
.
bmm
(
gate_score
,
x
).
reshape
(
-
1
,
self
.
d_model
)
if
self
.
mp_size
>
1
:
output
=
AllGather
.
apply
(
output
,
x
=
AllGather
.
apply
(
x
,
self
.
mp_rank
,
self
.
mp_size
,
self
.
mp_group
)
return
output
.
reshape
(
original_shape
),
self
.
bias
return
x
fmoe/megatron.py
View file @
b3380ec2
...
...
@@ -3,33 +3,35 @@ The adaptor to seamlessly enable FastMoE in Megatron-LM v2.0 with at most two
lines of modification.
See `exapmles/megatron` for usage instructions.
'''
from
.layers
import
FMoETransformerMLP
import
torch
from
.transformer
import
FMoETransformerMLP
from
.distributed
import
DistributedGroupedDataParallel
from
.utils
import
get_torch_default_comm
def
_create_moe_mlp
(
args
,
group
):
class
MegatronMLP
(
FMoETransformerMLP
):
r
'''
Make the FMoETransformerMLP layer that distributes experts across
communication group `group` to replace the original MLP layer in Megatron.
'''
assert
(
args
.
seq_length
*
args
.
micro_batch_size
%
args
.
tensor_model_parallel_size
==
0
),
"Batch size x sequence length should be multiple of mp size"
if
not
args
.
distributed_experts
:
world_size
=
1
else
:
world_size
=
args
.
world_size
fmoe
=
FMoETransformerMLP
(
args
.
num_experts
,
d_model
=
args
.
hidden_size
,
d_hidden
=
args
.
hidden_size
*
4
,
world_size
=
world_size
,
mp_group
=
group
)
for
p
in
fmoe
.
gate
.
parameters
():
setattr
(
p
,
'shared'
,
True
)
return
fmoe
def
__init__
(
self
,
args
,
group
):
assert
(
args
.
seq_length
*
args
.
micro_batch_size
%
args
.
tensor_model_parallel_size
==
0
),
"Batch size x sequence length should be multiple of mp size"
if
not
args
.
distributed_experts
:
world_size
=
1
else
:
world_size
=
args
.
world_size
super
().
__init__
(
args
.
num_experts
,
d_model
=
args
.
hidden_size
,
d_hidden
=
args
.
hidden_size
*
4
,
world_size
=
world_size
,
mp_group
=
group
)
self
.
bias
=
torch
.
nn
.
parameter
.
Parameter
(
torch
.
zeros
(
args
.
hidden_size
,
dtype
=
torch
.
float32
)
)
def
forward
(
self
,
inp
):
return
super
().
forward
(
inp
),
self
.
bias
def
fmoefy
(
model
,
num_experts
=
None
,
distributed_experts
=
True
):
...
...
@@ -60,7 +62,7 @@ def fmoefy(model, num_experts=None, distributed_experts=True):
args
.
distributed_experts
=
distributed_experts
for
l
in
model
.
language_model
.
transformer
.
layers
:
l
.
mlp
=
_create_moe_mlp
(
args
,
get_torch_default_comm
())
l
.
mlp
=
MegatronMLP
(
args
,
get_torch_default_comm
())
return
model
...
...
fmoe/transformer.py
0 → 100644
View file @
b3380ec2
r
'''
Adaption to act as the MLP layer using an MoE MLP layer in transformer.
'''
import
torch
import
torch.nn
as
nn
from
.gates
import
NaiveGate
from
.layers
import
FMoE
,
FMoELinear
class
_Expert
(
nn
.
Module
):
r
'''
An expert using 2 FMoELinear modules to speed up the computation of experts
within one worker.
'''
def
__init__
(
self
,
num_expert
,
d_model
,
d_hidden
,
activation
):
super
().
__init__
()
self
.
htoh4
=
FMoELinear
(
num_expert
,
d_model
,
d_hidden
)
self
.
h4toh
=
FMoELinear
(
num_expert
,
d_hidden
,
d_model
)
self
.
activation
=
activation
def
forward
(
self
,
inp
,
fwd_expert_count
):
r
'''
First expand input to 4h (the hidden size is variable, but is called h4
for convenience). Then perform activation. Finally shirink back to h.
'''
x
=
self
.
htoh4
(
inp
,
fwd_expert_count
)
x
=
self
.
activation
(
x
)
x
=
self
.
h4toh
(
x
,
fwd_expert_count
)
return
x
class
FMoETransformerMLP
(
FMoE
):
r
'''
A complete MoE MLP module in a Transformer block.
* `activation` is the activation function to be used in MLP in each expert.
* `d_hidden` is the dimension of the MLP layer.
'''
def
__init__
(
self
,
num_expert
=
32
,
d_model
=
1024
,
d_hidden
=
4096
,
world_size
=
1
,
mp_group
=
None
,
activation
=
torch
.
nn
.
functional
.
gelu
,
gate
=
NaiveGate
,
top_k
=
2
,
pre_lnorm
=
False
):
def
expert_fn
(
inp
,
gate
):
return
self
.
experts
(
inp
,
gate
)
super
().
__init__
(
num_expert
=
num_expert
,
d_model
=
d_model
,
gate
=
gate
,
world_size
=
world_size
,
mp_group
=
mp_group
,
expert_fn
=
expert_fn
)
self
.
experts
=
_Expert
(
num_expert
,
d_model
,
d_hidden
,
activation
)
self
.
pre_lnorm
=
pre_lnorm
self
.
layer_norm
=
nn
.
LayerNorm
(
d_model
)
self
.
mark_parallel_comm
()
def
forward
(
self
,
inp
:
torch
.
Tensor
):
r
'''
This module wraps up the FMoE module with reshape, residual and layer
normalization.
'''
original_shape
=
inp
.
shape
inp
=
inp
.
reshape
(
-
1
,
self
.
d_model
)
if
self
.
pre_lnorm
:
inp
=
self
.
layer_norm
(
inp
)
output
=
super
().
forward
(
inp
)
+
inp
if
not
self
.
pre_lnorm
:
output
=
self
.
layer_norm
(
output
)
return
output
.
reshape
(
original_shape
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment