Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
d9dba929
Commit
d9dba929
authored
Oct 10, 2021
by
Rick Ho
Browse files
tide up fmoe python code
parent
d2392de2
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
136 additions
and
143 deletions
+136
-143
fmoe/__init__.py
fmoe/__init__.py
+2
-1
fmoe/distributed.py
fmoe/distributed.py
+7
-18
fmoe/functions.py
fmoe/functions.py
+0
-28
fmoe/layers.py
fmoe/layers.py
+29
-73
fmoe/linear.py
fmoe/linear.py
+92
-0
fmoe/transformer.py
fmoe/transformer.py
+6
-23
No files found.
fmoe/__init__.py
View file @
d9dba929
...
...
@@ -2,6 +2,7 @@ r"""
The fmoe package contains MoE Layers only.
"""
from
.layers
import
FMoELinear
,
FMoE
from
.layers
import
FMoE
from
.linear
import
FMoELinear
from
.transformer
import
FMoETransformerMLP
from
.distributed
import
DistributedGroupedDataParallel
fmoe/distributed.py
View file @
d9dba929
...
...
@@ -25,11 +25,8 @@ class DistributedGroupedDataParallel(nn.Module):
def
__init__
(
self
,
module
,
mp_group
=
None
,
dp_group
=
None
,
moe_group
=
None
,
world_group
=
None
,
auto_allreduce
=
False
,
**
kwargs
):
assert
not
auto_allreduce
,
"Automatic all-reduce is not implemented yet"
...
...
@@ -37,20 +34,12 @@ class DistributedGroupedDataParallel(nn.Module):
self
.
module
=
module
self
.
comms
=
dict
()
if
mp_group
is
not
None
:
self
.
comms
[
"mp"
]
=
mp_group
if
dp_group
is
not
None
:
self
.
comms
[
"dp"
]
=
dp_group
else
:
self
.
comms
[
"dp"
]
=
get_torch_default_comm
()
if
moe_group
is
not
None
:
self
.
comms
[
"moe"
]
=
moe_group
else
:
self
.
comms
[
"moe"
]
=
get_torch_default_comm
()
if
world_group
is
None
:
self
.
comms
[
"world"
]
=
get_torch_default_comm
()
else
:
self
.
comms
[
"world"
]
=
world_group
for
k
in
kwargs
:
if
k
.
endswith
(
'_group'
):
self
.
comms
[
k
[:
-
6
]]
=
kwargs
[
k
]
for
k
in
[
'dp'
,
'gate'
,
'moe'
,
'world'
]:
if
k
not
in
self
.
comms
:
self
.
comms
[
k
]
=
get_torch_default_comm
()
def
allreduce_params
(
no_scale
=
False
,
reduce_after
=
False
,
fp32_allreduce
=
False
):
...
...
fmoe/functions.py
View file @
d9dba929
...
...
@@ -132,34 +132,6 @@ class MOEScatter(Function):
grad_in
=
_local_gather
(
local_grad_in
,
pos
,
inp_batch_size
)
return
grad_in
,
None
,
None
,
None
,
None
,
None
class
MOELinear
(
Function
):
r
"""
Computes linear operators within one GPU on different experts simutaneously.
"""
@
staticmethod
def
forward
(
ctx
,
global_input_buf
,
fwd_expert_count
,
weight
,
bias
=
None
):
global_output_buf
=
fmoe_cuda
.
linear_forward
(
global_input_buf
,
fwd_expert_count
,
weight
,
bias
)
variables
=
(
global_input_buf
,
fwd_expert_count
,
weight
,
bias
)
ctx
.
save_for_backward
(
*
variables
)
return
global_output_buf
@
staticmethod
def
backward
(
ctx
,
grad_out
):
(
input_buf
,
fwd_expert_count
,
weight
,
bias
)
=
ctx
.
saved_tensors
grad_inp_buf
,
grad_weight
,
grad_bias
=
fmoe_cuda
.
linear_backward
(
grad_out
,
input_buf
,
fwd_expert_count
,
weight
,
bias
)
if
not
torch
.
is_tensor
(
bias
):
grad_bias
=
None
return
grad_inp_buf
,
None
,
grad_weight
,
grad_bias
class
MOEGather
(
Function
):
r
"""
Gather output samples from contiguous alone experts back to [batch x
...
...
fmoe/layers.py
View file @
d9dba929
r
"""
Layers that FMoE provides to us
er
s
FMoE core lay
er
"""
import
torch
import
torch.nn
as
nn
import
math
from
.functions
import
prepare_forward
,
ensure_comm
from
.functions
import
MOEScatter
,
MOEGather
,
MOELinear
from
.functions
import
MOEScatter
,
MOEGather
from
.functions
import
AllGather
,
Slice
from
.gates
import
NaiveGate
class
FMoELinear
(
nn
.
Module
):
r
"""
A linear layer that contains multiple experts.
As multiple experts can be placed on the same worker, the computation can be
performed in parallel to increase the performance.
The FMoELinear module provides such function.
"""
def
__init__
(
self
,
num_expert
:
int
,
in_feat
:
int
,
out_feat
:
int
,
bias
:
bool
=
True
,
rank
:
int
=
0
,
):
super
().
__init__
()
self
.
num_expert
=
num_expert
self
.
in_feat
=
in_feat
self
.
out_feat
=
out_feat
self
.
rank
=
rank
self
.
weight
=
nn
.
Parameter
(
torch
.
Tensor
(
num_expert
,
out_feat
,
in_feat
))
if
bias
:
self
.
bias
=
nn
.
Parameter
(
torch
.
zeros
(
num_expert
,
out_feat
))
else
:
self
.
register_parameter
(
"bias"
,
None
)
self
.
reset_parameters
()
def
forward
(
self
,
inp
,
fwd_expert_count
):
r
"""
Call MOE function
"""
x
=
MOELinear
.
apply
(
inp
,
fwd_expert_count
,
self
.
weight
,
self
.
bias
)
return
x
def
extra_repr
(
self
)
->
str
:
return
"num_expert={}, in_features={},
\
out_features={}, bias={}, rank={}"
.
format
(
self
.
num_expert
,
self
.
in_feat
,
self
.
out_feat
,
self
.
bias
is
not
None
,
self
.
rank
,
)
def
reset_parameters
(
self
):
# Approach is the same as in torch.nn.Linear
# https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/linear.py#L88
# bias is left to zero, similar as megatron
torch
.
nn
.
init
.
kaiming_uniform_
(
self
.
weight
,
a
=
math
.
sqrt
(
5
))
def
mark_module_parallel_comm
(
module
,
comm
):
r
"""
...
...
@@ -121,11 +67,12 @@ class FMoE(nn.Module):
* `num_expert` stands for the number of experts on **each** worker.
* `world_size` stands for the total number of workers that contains
different experts.
* `mp_group` can be a torch's communication group, indicating that model
parallel is applied across the group, which means that workers in the group
hold the same copy of the input feature, and demands the same copy of the
output. FMoE saves computation by slicing the input in the mp group and
performing all-gather after the MLP computation.
* `slice_group` can be a torch's communication group, indicating that
specific model parallel is applied across the group, and workers in the
group hold the same copy of input feature, and requires the same copy of
the output. For each worker, FMoE only computes the output of a certain
slice of the input batch, and will all-gather the outputs after
computation.
* `top_k` stands for the number of experts each token is going to.
* `gate` is a gate class which can found in `fmoe.gates`.
* `expert` can be specified as a module class, it is used to generate
...
...
@@ -137,7 +84,8 @@ class FMoE(nn.Module):
num_expert
=
32
,
d_model
=
1024
,
world_size
=
1
,
mp_group
=
None
,
mp_group
=
None
,
# being deprecated
slice_group
=
None
,
moe_group
=
None
,
top_k
=
2
,
gate
=
NaiveGate
,
...
...
@@ -150,13 +98,18 @@ class FMoE(nn.Module):
self
.
num_expert
=
num_expert
self
.
d_model
=
d_model
self
.
world_size
=
world_size
self
.
mp_group
=
mp_group
if
mp_group
is
None
:
self
.
mp_size
=
1
self
.
mp_rank
=
0
self
.
slice_group
=
slice_group
if
mp_group
is
not
None
:
print
(
'[Warning] mp_group is being deprecated'
)
self
.
slice_group
=
mp_group
if
self
.
slice_group
is
None
:
self
.
slice_size
=
1
self
.
slice_rank
=
0
else
:
self
.
mp_size
=
mp_group
.
size
()
self
.
mp_rank
=
mp_group
.
rank
()
self
.
slice_size
=
slice_group
.
size
()
self
.
slice_rank
=
slice_group
.
rank
()
self
.
top_k
=
top_k
if
type
(
expert
)
is
list
:
self
.
experts
=
nn
.
ModuleList
([
e
(
d_model
)
for
e
in
expert
])
...
...
@@ -168,6 +121,7 @@ class FMoE(nn.Module):
self
.
experts_fused
=
False
else
:
self
.
experts_fused
=
True
self
.
gate
=
gate
(
d_model
,
num_expert
,
world_size
,
top_k
)
self
.
gate_hook
=
gate_hook
self
.
mask
=
mask
...
...
@@ -203,7 +157,7 @@ class FMoE(nn.Module):
mark_module_parallel_comm
(
e
,
comm
)
else
:
mark_module_parallel_comm
(
self
.
experts
,
comm
)
mark_module_parallel_comm
(
self
.
gate
,
"
mo
e"
)
mark_module_parallel_comm
(
self
.
gate
,
"
gat
e"
)
def
forward
(
self
,
inp
):
r
"""
...
...
@@ -213,8 +167,9 @@ class FMoE(nn.Module):
"""
if
self
.
world_size
>
1
:
ensure_comm
(
inp
,
self
.
moe_group
)
if
self
.
mp_size
>
1
:
inp
=
Slice
.
apply
(
inp
,
self
.
mp_rank
,
self
.
mp_size
,
self
.
mp_group
)
if
self
.
slice_size
>
1
:
inp
=
Slice
.
apply
(
inp
,
self
.
slice_rank
,
self
.
slice_size
,
self
.
slice_group
)
gate_top_k_idx
,
gate_score
=
self
.
gate
(
inp
)
...
...
@@ -249,6 +204,7 @@ class FMoE(nn.Module):
gate_score
=
gate_score
.
view
(
x
.
shape
[
0
],
1
,
self
.
top_k
)
x
=
torch
.
bmm
(
gate_score
,
x
).
reshape
(
-
1
,
self
.
d_model
)
if
self
.
mp_size
>
1
:
x
=
AllGather
.
apply
(
x
,
self
.
mp_rank
,
self
.
mp_size
,
self
.
mp_group
)
if
self
.
slice_size
>
1
:
x
=
AllGather
.
apply
(
x
,
self
.
slice_rank
,
self
.
slice_size
,
self
.
slice_group
)
return
x
fmoe/linear.py
0 → 100644
View file @
d9dba929
r
"""
FMoE's parallel linear layer
"""
import
torch
import
torch.nn
as
nn
from
torch.autograd
import
Function
import
math
import
fmoe_cuda
class
MOELinear
(
Function
):
r
"""
Computes linear operators within one GPU on different experts simutaneously.
"""
@
staticmethod
def
forward
(
ctx
,
global_input_buf
,
fwd_expert_count
,
weight
,
bias
=
None
):
global_output_buf
=
fmoe_cuda
.
linear_forward
(
global_input_buf
,
fwd_expert_count
,
weight
,
bias
)
variables
=
(
global_input_buf
,
fwd_expert_count
,
weight
,
bias
)
ctx
.
save_for_backward
(
*
variables
)
return
global_output_buf
@
staticmethod
def
backward
(
ctx
,
grad_out
):
(
input_buf
,
fwd_expert_count
,
weight
,
bias
)
=
ctx
.
saved_tensors
grad_inp_buf
,
grad_weight
,
grad_bias
=
fmoe_cuda
.
linear_backward
(
grad_out
,
input_buf
,
fwd_expert_count
,
weight
,
bias
)
if
not
torch
.
is_tensor
(
bias
):
grad_bias
=
None
return
grad_inp_buf
,
None
,
grad_weight
,
grad_bias
class
FMoELinear
(
nn
.
Module
):
r
"""
A linear layer that contains multiple experts.
As multiple experts can be placed on the same worker, the computation can be
performed in parallel to increase the performance.
The FMoELinear module provides such function.
"""
def
__init__
(
self
,
num_expert
:
int
,
in_feat
:
int
,
out_feat
:
int
,
bias
:
bool
=
True
,
rank
:
int
=
0
,
):
super
().
__init__
()
self
.
num_expert
=
num_expert
self
.
in_feat
=
in_feat
self
.
out_feat
=
out_feat
self
.
rank
=
rank
self
.
weight
=
nn
.
Parameter
(
torch
.
Tensor
(
num_expert
,
out_feat
,
in_feat
))
if
bias
:
self
.
bias
=
nn
.
Parameter
(
torch
.
zeros
(
num_expert
,
out_feat
))
else
:
self
.
register_parameter
(
"bias"
,
None
)
self
.
reset_parameters
()
def
forward
(
self
,
inp
,
fwd_expert_count
):
r
"""
Call MOE function
"""
x
=
MOELinear
.
apply
(
inp
,
fwd_expert_count
,
self
.
weight
,
self
.
bias
)
return
x
def
extra_repr
(
self
)
->
str
:
return
"num_expert={}, in_features={},
\
out_features={}, bias={}, rank={}"
.
format
(
self
.
num_expert
,
self
.
in_feat
,
self
.
out_feat
,
self
.
bias
is
not
None
,
self
.
rank
,
)
def
reset_parameters
(
self
):
# Approach is the same as in torch.nn.Linear
# https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/linear.py#L88
# bias is left to zero, similar as megatron
torch
.
nn
.
init
.
kaiming_uniform_
(
self
.
weight
,
a
=
math
.
sqrt
(
5
))
fmoe/transformer.py
View file @
d9dba929
...
...
@@ -3,8 +3,8 @@ Adaption to act as the MLP layer using an MoE MLP layer in transformer.
"""
import
torch
import
torch.nn
as
nn
from
.
gate
s
import
NaiveGate
from
.l
ayers
import
FMoE
,
FMoELinear
from
.
layer
s
import
FMoE
from
.l
inear
import
FMoELinear
class
_Expert
(
nn
.
Module
):
...
...
@@ -42,31 +42,14 @@ class FMoETransformerMLP(FMoE):
num_expert
=
32
,
d_model
=
1024
,
d_hidden
=
4096
,
world_size
=
1
,
mp_group
=
None
,
moe_group
=
None
,
activation
=
torch
.
nn
.
GELU
(),
gate
=
NaiveGate
,
top_k
=
2
,
expert_dp_comm
=
"none"
,
gate_hook
=
None
,
mask
=
None
,
mask_dict
=
None
,
expert_rank
=
0
,
**
kwargs
):
super
().
__init__
(
num_expert
=
num_expert
,
d_model
=
d_model
,
gate
=
gate
,
top_k
=
top_k
,
world_size
=
world_size
,
mp_group
=
mp_group
,
moe_group
=
moe_group
,
gate_hook
=
gate_hook
,
mask
=
mask
,
mask_dict
=
mask_dict
)
super
().
__init__
(
num_expert
=
num_expert
,
d_model
=
d_model
,
**
kwargs
)
self
.
experts
=
_Expert
(
num_expert
,
d_model
,
d_hidden
,
activation
,
rank
=
self
.
mp
_rank
num_expert
,
d_model
,
d_hidden
,
activation
,
rank
=
expert
_rank
)
self
.
mark_parallel_comm
(
expert_dp_comm
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment