Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
c9023d40
Unverified
Commit
c9023d40
authored
Mar 22, 2022
by
HELSON
Committed by
GitHub
Mar 22, 2022
Browse files
[MOE] support PR-MOE (#488)
parent
a9ecb4b2
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
133 additions
and
18 deletions
+133
-18
colossalai/nn/layer/moe/__init__.py
colossalai/nn/layer/moe/__init__.py
+2
-2
colossalai/nn/layer/moe/layers.py
colossalai/nn/layer/moe/layers.py
+100
-3
model_zoo/moe/models.py
model_zoo/moe/models.py
+31
-13
No files found.
colossalai/nn/layer/moe/__init__.py
View file @
c9023d40
from
.experts
import
Experts
,
FFNExperts
,
TPExperts
from
.layers
import
MoeLayer
,
Top1Router
,
Top2Router
from
.layers
import
MoeLayer
,
Top1Router
,
Top2Router
,
MoeModule
from
.utils
import
NormalNoiseGenerator
,
UniformNoiseGenerator
,
build_ffn_experts
__all__
=
[
'Experts'
,
'FFNExperts'
,
'TPExperts'
,
'Top1Router'
,
'Top2Router'
,
'MoeLayer'
,
'NormalNoiseGenerator'
,
'UniformNoiseGenerator'
,
'build_ffn_experts'
'UniformNoiseGenerator'
,
'build_ffn_experts'
,
'MoeModule'
]
colossalai/nn/layer/moe/layers.py
View file @
c9023d40
...
...
@@ -7,9 +7,9 @@ import torch.distributed as dist
from
colossalai.core
import
MOE_CONTEXT
from
colossalai.utils
import
get_current_device
from
._operation
import
COL_MOE_KERNEL_FLAG
,
AllToAll
,
AllGather
,
ReduceScatter
,
MoeDispatch
,
MoeCombine
,
moe_cumsum
from
.experts
import
MoeExperts
from
.utils
import
ForceFP32Parameter
from
typing
import
Callable
,
Optional
from
.experts
import
MoeExperts
,
Experts
from
.utils
import
ForceFP32Parameter
,
UniformNoiseGenerator
,
NormalNoiseGenerator
from
typing
import
Callable
,
Optional
,
Type
from
torch.distributed
import
ProcessGroup
...
...
@@ -315,3 +315,100 @@ class MoeLayer(nn.Module):
ans
=
ans
.
reshape
(
inputs
.
shape
)
return
ans
class
MoeModule
(
nn
.
Module
):
"""A class for users to create MoE modules in their models.
Args:
dim_model (int): Hidden dimension of training model
num_experts (int): The number experts
top_k (int, optional): The number of experts for dispatchment of each token
capacity_factor_train (float, optional): Capacity factor in routing during training
capacity_factor_eval (float, optional): Capacity factor in routing during evaluation
min_capacity (int, optional): The minimum number of the capacity of each expert
noisy_policy (str, optional): The policy of noisy function. Now we have 'Jitter' and 'Gaussian'.
'Jitter' can be found in Switch Transformer paper (https://arxiv.org/abs/2101.03961).
'Gaussian' can be found in ViT-MoE paper (https://arxiv.org/abs/2106.05974).
drop_tks (bool, optional): Whether drops tokens in evaluation
use_residual (bool, optional): Makes this MoE layer a Residual MoE.
More information can be found in Microsoft paper (https://arxiv.org/abs/2201.05596).
residual_instance (nn.Module, optional): The instance of residual module in Resiual MoE
expert_instance (MoeExperts, optional): The instance of experts module in MoeLayer
expert_cls (Type[nn.Module], optional): The class of each expert when no instance is given
expert_args (optional): The args of expert when no instance is given
"""
def
__init__
(
self
,
dim_model
:
int
,
num_experts
:
int
,
top_k
:
int
=
1
,
capacity_factor_train
:
float
=
1.25
,
capacity_factor_eval
:
float
=
2.0
,
min_capacity
:
int
=
4
,
noisy_policy
:
Optional
[
str
]
=
None
,
drop_tks
:
bool
=
True
,
use_residual
:
bool
=
False
,
residual_instance
:
Optional
[
nn
.
Module
]
=
None
,
expert_instance
:
Optional
[
MoeExperts
]
=
None
,
expert_cls
:
Optional
[
Type
[
nn
.
Module
]]
=
None
,
**
expert_args
):
super
().
__init__
()
noisy_func
=
None
if
noisy_policy
is
not
None
:
if
noisy_policy
==
'Jitter'
:
noisy_func
=
UniformNoiseGenerator
()
elif
noisy_policy
==
'Gaussian'
:
noisy_func
=
NormalNoiseGenerator
(
num_experts
)
else
:
raise
NotImplementedError
(
"Unsupported input noisy policy"
)
if
top_k
==
1
:
moe_router_cls
=
Top1Router
elif
top_k
==
2
:
moe_router_cls
=
Top2Router
else
:
raise
NotImplementedError
(
"top_k > 2 is not supported yet"
)
self
.
moe_router
=
moe_router_cls
(
capacity_factor_train
=
capacity_factor_train
,
capacity_factor_eval
=
capacity_factor_eval
,
min_capacity
=
min_capacity
,
noisy_func
=
noisy_func
,
drop_tks
=
drop_tks
)
self
.
use_residual
=
use_residual
if
use_residual
:
if
residual_instance
is
not
None
:
self
.
residual_module
=
residual_instance
else
:
assert
expert_cls
is
not
None
,
\
"Expert class can't be None when residual instance is not given"
self
.
residual_module
=
expert_cls
(
**
expert_args
)
self
.
residual_combine
=
nn
.
Linear
(
dim_model
,
2
,
device
=
get_current_device
())
if
expert_instance
is
not
None
:
self
.
experts
=
expert_instance
else
:
assert
expert_cls
is
not
None
,
\
"Expert class can't be None when experts instance is not given"
self
.
experts
=
Experts
(
expert_cls
,
num_experts
,
**
expert_args
)
self
.
moe_layer
=
MoeLayer
(
dim_model
=
dim_model
,
num_experts
=
num_experts
,
router
=
self
.
moe_router
,
experts
=
self
.
experts
)
def
forward
(
self
,
inputs
:
torch
.
Tensor
):
moe_output
=
self
.
moe_layer
(
inputs
)
if
self
.
use_residual
:
residual_output
=
self
.
residual_module
(
inputs
)
combine_coef
=
self
.
residual_combine
(
inputs
)
combine_coef
=
F
.
softmax
(
combine_coef
,
dim
=-
1
)
output
=
moe_output
*
combine_coef
[...,
0
:
1
]
+
residual_output
*
combine_coef
[...,
1
:]
else
:
output
=
moe_output
return
output
model_zoo/moe/models.py
View file @
c9023d40
...
...
@@ -4,11 +4,12 @@ import torch.nn as nn
from
colossalai.context
import
ParallelMode
from
colossalai.nn.layer
import
VanillaPatchEmbedding
,
VanillaClassifier
,
\
WrappedDropout
as
Dropout
,
WrappedDropPath
as
DropPath
from
colossalai.nn.layer.moe
import
build_ffn_experts
,
MoeLayer
,
Top2Router
,
NormalNoiseGenerator
from
colossalai.nn.layer.moe
import
build_ffn_experts
,
MoeLayer
,
Top2Router
,
NormalNoiseGenerator
,
MoeModule
from
.util
import
moe_sa_args
,
moe_mlp_args
from
..helper
import
TransformerLayer
from
colossalai.core
import
MOE_CONTEXT
from
colossalai.utils
import
get_current_device
from
typing
import
List
class
VanillaSelfAttention
(
nn
.
Module
):
...
...
@@ -146,7 +147,8 @@ class Widenet(nn.Module):
class
ViTMoE
(
nn
.
Module
):
def
__init__
(
self
,
num_experts
:
int
,
num_experts
:
int
or
List
[
int
],
use_residual
:
bool
=
False
,
capacity_factor_train
:
float
=
1.25
,
capacity_factor_eval
:
float
=
2.0
,
drop_tks
:
bool
=
True
,
...
...
@@ -164,29 +166,45 @@ class ViTMoE(nn.Module):
drop_path
:
float
=
0.
):
super
().
__init__
()
assert
depth
%
2
==
0
,
"The number of layers should be even right now"
if
isinstance
(
num_experts
,
list
):
assert
len
(
num_experts
)
==
depth
//
2
,
\
"The length of num_experts should equal to the number of MOE layers"
num_experts_list
=
num_experts
else
:
num_experts_list
=
[
num_experts
]
*
(
depth
//
2
)
embedding
=
VanillaPatchEmbedding
(
img_size
=
img_size
,
patch_size
=
patch_size
,
in_chans
=
in_chans
,
embed_size
=
d_model
)
embed_dropout
=
Dropout
(
p
=
drop_rate
,
mode
=
ParallelMode
.
TENSOR
)
noisy_func
=
NormalNoiseGenerator
(
num_experts
)
router
=
Top2Router
(
capacity_factor_train
=
capacity_factor_train
,
capacity_factor_eval
=
capacity_factor_eval
,
noisy_func
=
noisy_func
,
drop_tks
=
drop_tks
)
assert
depth
%
2
==
0
# stochastic depth decay rule
dpr
=
[
x
.
item
()
for
x
in
torch
.
linspace
(
0
,
drop_path
,
depth
)]
blocks
=
[]
for
i
in
range
(
depth
):
sa
=
VanillaSelfAttention
(
**
moe_sa_args
(
d_model
=
d_model
,
n_heads
=
num_heads
,
d_kv
=
d_kv
,
attention_drop
=
attention_drop
,
drop_rate
=
drop_rate
))
ffn
=
VanillaFFN
(
**
moe_mlp_args
(
d_model
=
d_model
,
d_ff
=
d_ff
,
drop_rate
=
drop_rate
))
if
i
%
2
==
0
else
\
MoeLayer
(
dim_model
=
d_model
,
num_experts
=
num_experts
,
router
=
router
,
experts
=
build_ffn_experts
(
num_experts
,
d_model
,
d_ff
,
drop_rate
=
drop_rate
))
if
i
%
2
==
0
:
ffn
=
VanillaFFN
(
**
moe_mlp_args
(
d_model
=
d_model
,
d_ff
=
d_ff
,
drop_rate
=
drop_rate
))
else
:
num_experts
=
num_experts_list
[
i
//
2
]
experts
=
build_ffn_experts
(
num_experts
,
d_model
,
d_ff
,
drop_rate
=
drop_rate
)
ffn
=
MoeModule
(
dim_model
=
d_model
,
num_experts
=
num_experts
,
top_k
=
1
if
use_residual
else
2
,
capacity_factor_train
=
capacity_factor_train
,
capacity_factor_eval
=
capacity_factor_eval
,
noisy_policy
=
'Jitter'
if
use_residual
else
'Gaussian'
,
drop_tks
=
drop_tks
,
use_residual
=
use_residual
,
expert_instance
=
experts
,
expert_cls
=
VanillaFFN
,
**
moe_mlp_args
(
d_model
=
d_model
,
d_ff
=
d_ff
,
drop_rate
=
drop_rate
))
layer
=
TransformerLayer
(
att
=
sa
,
ffn
=
ffn
,
norm1
=
nn
.
LayerNorm
(
d_model
,
eps
=
1e-6
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment