Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
a088022e
Unverified
Commit
a088022e
authored
Sep 23, 2022
by
HELSON
Committed by
GitHub
Sep 23, 2022
Browse files
[moe] fix moe bugs (#1633)
parent
702dbc52
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
287 additions
and
249 deletions
+287
-249
colossalai/nn/layer/moe/__init__.py
colossalai/nn/layer/moe/__init__.py
+3
-2
colossalai/nn/layer/moe/layers.py
colossalai/nn/layer/moe/layers.py
+16
-229
colossalai/nn/layer/moe/routers.py
colossalai/nn/layer/moe/routers.py
+226
-0
tests/test_moe/test_grad_handler.py
tests/test_moe/test_grad_handler.py
+4
-3
tests/test_moe/test_kernel.py
tests/test_moe/test_kernel.py
+2
-2
tests/test_moe/test_moe_zero_init.py
tests/test_moe/test_moe_zero_init.py
+29
-10
tests/test_moe/test_moe_zero_model.py
tests/test_moe/test_moe_zero_model.py
+4
-2
tests/test_moe/test_moe_zero_optim.py
tests/test_moe/test_moe_zero_optim.py
+3
-1
No files found.
colossalai/nn/layer/moe/__init__.py
View file @
a088022e
from
.experts
import
Experts
,
FFNExperts
,
TPExperts
from
.layers
import
MoeLayer
,
Top1Router
,
Top2Router
,
MoeModule
from
.layers
import
MoeLayer
,
MoeModule
from
.routers
import
MoeRouter
,
Top1Router
,
Top2Router
from
.utils
import
NormalNoiseGenerator
,
UniformNoiseGenerator
,
build_ffn_experts
__all__
=
[
'Experts'
,
'FFNExperts'
,
'TPExperts'
,
'Top1Router'
,
'Top2Router'
,
'MoeLayer'
,
'NormalNoiseGenerator'
,
'UniformNoiseGenerator'
,
'build_ffn_experts'
,
'MoeModule'
'UniformNoiseGenerator'
,
'build_ffn_experts'
,
'MoeModule'
,
'MoeRouter'
]
colossalai/nn/layer/moe/layers.py
View file @
a088022e
import
functools
import
math
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.distributed
as
dist
from
colossalai.context.moe_context
import
MOE_CONTEXT
from
colossalai.utils
import
get_current_device
from
._operation
import
COL_MOE_KERNEL_FLAG
,
AllToAll
,
AllGather
,
ReduceScatter
,
MoeDispatch
,
MoeCombine
,
moe_cumsum
from
.experts
import
MoeExperts
,
Experts
from
.utils
import
ForceFP32Parameter
,
UniformNoiseGenerator
,
NormalNoiseGenerator
,
autocast_softmax
from
colossalai.nn.layer.moe._operation
import
COL_MOE_KERNEL_FLAG
,
AllToAll
,
AllGather
,
\
ReduceScatter
,
MoeDispatch
,
MoeCombine
from
colossalai.nn.layer.moe.experts
import
MoeExperts
,
Experts
from
colossalai.nn.layer.moe.utils
import
UniformNoiseGenerator
,
NormalNoiseGenerator
from
colossalai.nn.layer.moe.routers
import
MoeRouter
,
Top1Router
,
Top2Router
from
colossalai.zero.init_ctx
import
no_shard_zero_context
,
no_shard_zero_decrator
from
typing
import
Callable
,
Optional
,
Type
from
torch.distributed
import
ProcessGroup
class
Top1Router
(
nn
.
Module
):
"""Top1 router that returns the dispatch mask [s, e, c] and combine weight [s, e, c]
for routing usage. More deailted function can be found in the paper about Switch Transformer
of Google.
Args:
capacity_factor_train (float, optional): Capacity factor in routing of training.
capacity_factor_eval (float, optional): Capacity factor in routing of evaluation.
min_capacity (int, optional): The minimum number of the capacity of each expert.
select_policy (str, optional): The policy about tokens selection.
noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits.
drop_tks (bool, optional): Whether drops tokens in evaluation
"""
def
__init__
(
self
,
capacity_factor_train
:
float
=
1.25
,
capacity_factor_eval
:
float
=
2.0
,
min_capacity
:
int
=
4
,
select_policy
:
str
=
"first"
,
noisy_func
:
Callable
=
None
,
drop_tks
:
bool
=
True
):
super
().
__init__
()
self
.
capacity_factor_train
=
capacity_factor_train
self
.
capacity_factor_eval
=
capacity_factor_eval
self
.
min_capacity
=
min_capacity
self
.
select_policy
=
select_policy
self
.
noisy_func
=
noisy_func
self
.
drop_tks
=
drop_tks
assert
select_policy
in
{
"first"
,
"random"
}
if
select_policy
==
"random"
:
self
.
uniform
=
torch
.
distributions
.
uniform
.
Uniform
(
low
=
torch
.
tensor
(
0.0
,
device
=
get_current_device
()),
high
=
torch
.
tensor
(
1.0
,
device
=
get_current_device
())).
rsample
def
get_capacity
(
self
,
logits_shape
,
):
capacity_factor
=
self
.
capacity_factor_train
if
self
.
training
else
self
.
capacity_factor_eval
capacity
=
math
.
floor
(
capacity_factor
*
logits_shape
[
-
2
]
/
logits_shape
[
-
1
])
capacity
+=
capacity
%
2
capacity
=
max
(
capacity
,
self
.
min_capacity
)
assert
capacity
>
0
return
capacity
def
forward
(
self
,
inputs
:
torch
.
Tensor
,
use_kernel
:
bool
=
False
,
ep_group
:
Optional
[
ProcessGroup
]
=
None
):
if
self
.
noisy_func
is
not
None
and
self
.
training
:
inputs
=
self
.
noisy_func
(
inputs
)
logits
=
autocast_softmax
(
inputs
,
dim
=-
1
)
num_experts
=
logits
.
size
(
-
1
)
capacity
=
self
.
get_capacity
(
logits
.
shape
)
top1_idx
=
torch
.
argmax
(
inputs
,
dim
=-
1
)
mask
=
F
.
one_hot
(
top1_idx
,
num_classes
=
num_experts
).
to
(
torch
.
int32
)
if
self
.
training
:
me
=
torch
.
mean
(
logits
,
dim
=
0
)
ce
=
torch
.
mean
(
mask
.
float
(),
dim
=
0
)
l_aux
=
num_experts
*
torch
.
sum
(
me
*
ce
)
MOE_CONTEXT
.
add_loss
(
l_aux
)
elif
not
self
.
drop_tks
:
max_num
=
torch
.
max
(
torch
.
sum
(
mask
,
dim
=
0
))
dist
.
all_reduce
(
max_num
,
op
=
dist
.
ReduceOp
.
MAX
,
group
=
ep_group
)
capacity
=
max_num
.
item
()
else
:
pass
if
self
.
select_policy
==
"random"
:
rand_mask
=
mask
*
self
.
uniform
(
mask
.
shape
)
_
,
dispatch_idx
=
torch
.
topk
(
rand_mask
,
k
=
capacity
,
dim
=
0
)
mask
=
mask
*
torch
.
zeros_like
(
mask
).
scatter_
(
0
,
dispatch_idx
,
1
)
ranks
=
moe_cumsum
(
mask
)
elif
self
.
select_policy
==
"first"
:
ranks
=
moe_cumsum
(
mask
)
mask
=
mask
*
torch
.
lt
(
ranks
,
capacity
)
else
:
raise
NotImplementedError
(
"Not support such select policy yet."
)
ranks
=
torch
.
sum
(
mask
*
ranks
,
dim
=-
1
)
if
use_kernel
:
mask
=
torch
.
sum
(
mask
,
dim
=-
1
)
mask
=
torch
.
stack
([
mask
],
dim
=
0
).
to
(
torch
.
int32
)
dest_idx
=
torch
.
stack
([
top1_idx
*
capacity
+
ranks
],
dim
=
0
).
to
(
torch
.
int32
)
return
logits
,
mask
,
dest_idx
,
num_experts
*
capacity
else
:
ranks
=
F
.
one_hot
(
ranks
,
num_classes
=
capacity
)
weight
=
mask
*
logits
.
type_as
(
inputs
)
combine_weights
=
weight
.
unsqueeze
(
2
)
*
ranks
.
unsqueeze
(
1
)
sec_mask
=
combine_weights
.
bool
()
return
combine_weights
,
sec_mask
class
Top2Router
(
nn
.
Module
):
"""Top2 router that returns the dispatch mask [s, e, c] and combine weight [s, e, c]
for routing usage. More deailted function can be found in the paper about ViT-MoE.
Args:
capacity_factor_train (float, optional): Capacity factor in routing of training.
capacity_factor_eval (float, optional): Capacity factor in routing of evaluation.
min_capacity (int, optional): The minimum number of the capacity of each expert
noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits.
drop_tks (bool, optional): Whether drops tokens in evaluation.
"""
def
__init__
(
self
,
capacity_factor_train
:
float
=
1.25
,
capacity_factor_eval
:
float
=
2.0
,
min_capacity
:
int
=
4
,
noisy_func
:
Callable
=
None
,
drop_tks
:
bool
=
True
):
super
().
__init__
()
self
.
capacity_factor_train
=
capacity_factor_train
self
.
capacity_factor_eval
=
capacity_factor_eval
self
.
min_capacity
=
min_capacity
self
.
noisy_func
=
noisy_func
self
.
drop_tks
=
drop_tks
def
get_capacity
(
self
,
logits_shape
,
):
capacity_factor
=
self
.
capacity_factor_train
if
self
.
training
else
self
.
capacity_factor_eval
capacity
=
math
.
floor
(
capacity_factor
*
logits_shape
[
-
2
]
/
logits_shape
[
-
1
])
capacity
+=
capacity
%
2
capacity
=
max
(
capacity
,
self
.
min_capacity
)
assert
capacity
>
0
return
capacity
def
forward
(
self
,
inputs
:
torch
.
Tensor
,
use_kernel
:
bool
=
False
,
ep_group
:
Optional
[
ProcessGroup
]
=
None
):
# inputs: [s, h]
if
self
.
noisy_func
is
not
None
and
self
.
training
:
inputs
=
self
.
noisy_func
(
inputs
)
logits
=
autocast_softmax
(
inputs
,
dim
=-
1
)
# logits: [s, e]
num_experts
=
logits
.
size
(
-
1
)
capacity
=
self
.
get_capacity
(
logits
.
shape
)
top1_idx
=
torch
.
argmax
(
logits
,
dim
=-
1
)
mask1
=
F
.
one_hot
(
top1_idx
,
num_classes
=
num_experts
).
to
(
torch
.
int32
)
logits_except1
=
logits
.
masked_fill
(
mask1
.
bool
(),
float
(
"-inf"
))
top2_idx
=
torch
.
argmax
(
logits_except1
,
dim
=-
1
)
mask2
=
F
.
one_hot
(
top2_idx
,
num_classes
=
num_experts
).
to
(
torch
.
int32
)
cmask
=
(
mask1
+
mask2
)
# loss: [s, e]
if
self
.
training
:
me
=
torch
.
mean
(
logits
,
dim
=
0
)
ce
=
torch
.
mean
(
cmask
.
float
(),
dim
=
0
)
l_aux
=
num_experts
*
torch
.
sum
(
me
*
ce
)
/
2.0
# div 2 to normalize it to 1
MOE_CONTEXT
.
add_loss
(
l_aux
)
elif
not
self
.
drop_tks
:
max_num
=
torch
.
max
(
torch
.
sum
(
cmask
,
dim
=
0
))
dist
.
all_reduce
(
max_num
,
op
=
dist
.
ReduceOp
.
MAX
,
group
=
ep_group
)
capacity
=
max_num
.
item
()
else
:
pass
rank1
=
moe_cumsum
(
mask1
)
# rank1: [s, e]
rank2
=
moe_cumsum
(
mask2
)
rank2
+=
torch
.
sum
(
mask1
,
dim
=-
2
,
keepdim
=
True
)
mask1
*=
torch
.
lt
(
rank1
,
capacity
)
mask2
*=
torch
.
lt
(
rank2
,
capacity
)
rank1
=
torch
.
sum
(
mask1
*
rank1
,
dim
=-
1
)
rank2
=
torch
.
sum
(
mask2
*
rank2
,
dim
=-
1
)
if
use_kernel
:
mask1
=
torch
.
sum
(
mask1
,
dim
=-
1
)
mask2
=
torch
.
sum
(
mask2
,
dim
=-
1
)
mask
=
torch
.
stack
([
mask1
,
mask2
],
dim
=
0
).
to
(
torch
.
int32
)
dest_idx
=
torch
.
stack
([
top1_idx
*
capacity
+
rank1
,
top2_idx
*
capacity
+
rank2
],
dim
=
0
).
to
(
torch
.
int32
)
return
logits
,
mask
,
dest_idx
,
num_experts
*
capacity
else
:
weight1
=
mask1
*
logits
.
type_as
(
inputs
)
weight2
=
mask2
*
logits
.
type_as
(
inputs
)
rank1_sc
=
F
.
one_hot
(
rank1
,
num_classes
=
capacity
)
rank2_sc
=
F
.
one_hot
(
rank2
,
num_classes
=
capacity
)
cb_weight1
=
weight1
.
unsqueeze
(
2
)
*
rank1_sc
.
unsqueeze
(
1
)
cb_weight2
=
weight2
.
unsqueeze
(
2
)
*
rank2_sc
.
unsqueeze
(
1
)
cb_weight
=
cb_weight1
+
cb_weight2
sec_mask
=
cb_weight
.
bool
()
return
cb_weight
,
sec_mask
class
FP32LinearGate
(
nn
.
Module
):
"""Gate module used in MOE layer. Just a linear function without bias.
But it should be kept as fp32 forever.
Args:
d_model (int): Hidden dimension of training model
num_experts (int): The number experts
Attributes:
weight (ForceFP32Parameter): The weight of linear gate
"""
def
__init__
(
self
,
d_model
:
int
,
num_experts
:
int
,
scale
:
float
=
0.1
):
super
().
__init__
()
self
.
weight
=
ForceFP32Parameter
(
torch
.
empty
(
num_experts
,
d_model
,
device
=
get_current_device
()))
nn
.
init
.
trunc_normal_
(
self
.
weight
,
std
=
math
.
sqrt
(
scale
/
d_model
))
def
forward
(
self
,
x
:
torch
.
Tensor
):
return
F
.
linear
(
x
,
self
.
weight
)
from
typing
import
Optional
,
Type
,
Tuple
@
no_shard_zero_decrator
(
is_replicated
=
True
)
...
...
@@ -238,17 +24,17 @@ class MoeLayer(nn.Module):
Args:
dim_model (int): Dimension of model.
num_experts (int): The number of experts.
router (
:class:`torch.nn.Module`
): Instance of router used in routing.
experts (
:class:`torch.nn.Module`
): Instance of experts generated by Expert.
router (
MoeRouter
): Instance of router used in routing.
experts (
MoeExperts
): Instance of experts generated by Expert.
"""
def
__init__
(
self
,
dim_model
:
int
,
num_experts
:
int
,
router
:
nn
.
Module
,
experts
:
MoeExperts
):
def
__init__
(
self
,
dim_model
:
int
,
num_experts
:
int
,
router
:
MoeRouter
,
experts
:
MoeExperts
):
super
().
__init__
()
self
.
d_model
=
dim_model
self
.
num_experts
=
num_experts
self
.
gate_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
dim_model
))
self
.
router
=
router
self
.
experts
=
experts
self
.
router
:
MoeRouter
=
router
self
.
experts
:
MoeExperts
=
experts
self
.
use_kernel
=
True
if
COL_MOE_KERNEL_FLAG
and
MOE_CONTEXT
.
use_kernel_optim
else
False
self
.
ep_group
=
experts
.
dist_info
.
ep_group
self
.
ep_size
=
experts
.
dist_info
.
ep_size
...
...
@@ -271,7 +57,7 @@ class MoeLayer(nn.Module):
expert_out
=
ReduceScatter
.
apply
(
expert_out
,
self
.
ep_group
)
return
expert_out
def
forward
(
self
,
inputs
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
inputs
:
torch
.
Tensor
)
->
Tuple
:
# reshape the input tokens
tokens
=
inputs
.
reshape
(
-
1
,
self
.
d_model
)
...
...
@@ -309,7 +95,8 @@ class MoeLayer(nn.Module):
ans
=
torch
.
matmul
(
combine_weights
,
expert_output
)
ans
=
ans
.
reshape
(
inputs
.
shape
)
return
ans
l_aux
=
self
.
router
.
pop_routing_loss
()
return
ans
,
l_aux
class
MoeModule
(
nn
.
Module
):
...
...
@@ -403,7 +190,7 @@ class MoeModule(nn.Module):
experts
=
self
.
experts
)
def
forward
(
self
,
inputs
:
torch
.
Tensor
):
moe_output
=
self
.
moe_layer
(
inputs
)
moe_output
,
l_aux
=
self
.
moe_layer
(
inputs
)
if
self
.
use_residual
:
residual_output
=
self
.
residual_module
(
inputs
)
...
...
@@ -413,4 +200,4 @@ class MoeModule(nn.Module):
else
:
output
=
moe_output
return
output
return
output
,
l_aux
colossalai/nn/layer/moe/routers.py
0 → 100644
View file @
a088022e
import
math
from
abc
import
ABC
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.distributed
as
dist
from
colossalai.utils
import
get_current_device
from
colossalai.context
import
MOE_CONTEXT
from
colossalai.nn.layer.moe._operation
import
moe_cumsum
from
typing
import
Callable
,
Optional
from
torch.distributed
import
ProcessGroup
class
MoeRouter
(
nn
.
Module
,
ABC
):
"""Base class for all MoE routers.
Args:
k_value (int): The value of top_k.
capacity_factor_train (float): Capacity factor in routing of training.
capacity_factor_eval (float): Capacity factor in routing of evaluation.
min_capacity (int): The minimum number of the capacity of each expert.
noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits.
drop_tks (bool, optional): Whether drops tokens in evaluation
"""
def
__init__
(
self
,
k_value
:
int
,
capacity_factor_train
:
float
,
capacity_factor_eval
:
float
,
min_capacity
:
int
,
noisy_func
:
Callable
=
None
,
drop_tks
:
bool
=
True
):
super
().
__init__
()
self
.
k_value
=
k_value
self
.
capacity_factor_train
=
capacity_factor_train
self
.
capacity_factor_eval
=
capacity_factor_eval
self
.
min_capacity
=
min_capacity
self
.
noisy_func
=
noisy_func
self
.
drop_tks
=
drop_tks
self
.
_routing_loss
=
None
def
get_capacity
(
self
,
logits_shape
):
capacity_factor
=
self
.
capacity_factor_train
if
self
.
training
else
self
.
capacity_factor_eval
capacity
=
math
.
floor
(
self
.
k_value
*
capacity_factor
*
logits_shape
[
-
2
]
/
logits_shape
[
-
1
])
capacity
+=
capacity
%
2
capacity
=
max
(
capacity
,
self
.
min_capacity
)
assert
capacity
>
0
return
capacity
def
set_routing_loss
(
self
,
aux_loss
:
torch
.
Tensor
)
->
None
:
assert
self
.
_routing_loss
is
None
self
.
_routing_loss
=
aux_loss
def
pop_routing_loss
(
self
)
->
torch
.
Tensor
:
assert
self
.
_routing_loss
is
not
None
reservation
=
self
.
_routing_loss
self
.
_routing_loss
=
None
return
reservation
class
Top1Router
(
MoeRouter
):
"""Top1 router that returns the dispatch mask [s, e, c] and combine weight [s, e, c]
for routing usage. More deailted function can be found in the paper about Switch Transformer
of Google.
Args:
capacity_factor_train (float, optional): Capacity factor in routing of training.
capacity_factor_eval (float, optional): Capacity factor in routing of evaluation.
min_capacity (int, optional): The minimum number of the capacity of each expert.
select_policy (str, optional): The policy about tokens selection.
noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits.
drop_tks (bool, optional): Whether drops tokens in evaluation
"""
def
__init__
(
self
,
capacity_factor_train
:
float
=
1.25
,
capacity_factor_eval
:
float
=
2.0
,
min_capacity
:
int
=
4
,
select_policy
:
str
=
"first"
,
noisy_func
:
Callable
=
None
,
drop_tks
:
bool
=
True
):
super
().
__init__
(
k_value
=
1
,
capacity_factor_train
=
capacity_factor_train
,
capacity_factor_eval
=
capacity_factor_eval
,
min_capacity
=
min_capacity
,
noisy_func
=
noisy_func
,
drop_tks
=
drop_tks
)
self
.
select_policy
=
select_policy
assert
select_policy
in
{
"first"
,
"random"
}
if
select_policy
==
"random"
:
self
.
uniform
=
torch
.
distributions
.
uniform
.
Uniform
(
low
=
torch
.
tensor
(
0.0
,
device
=
get_current_device
()),
high
=
torch
.
tensor
(
1.0
,
device
=
get_current_device
())).
rsample
def
forward
(
self
,
inputs
:
torch
.
Tensor
,
use_kernel
:
bool
=
False
,
ep_group
:
Optional
[
ProcessGroup
]
=
None
):
if
self
.
noisy_func
is
not
None
and
self
.
training
:
inputs
=
self
.
noisy_func
(
inputs
)
assert
inputs
.
dtype
==
torch
.
float
logits
=
F
.
softmax
(
inputs
,
dim
=-
1
)
num_experts
=
logits
.
size
(
-
1
)
capacity
=
self
.
get_capacity
(
logits
.
shape
)
top1_idx
=
torch
.
argmax
(
inputs
,
dim
=-
1
)
mask
=
F
.
one_hot
(
top1_idx
,
num_classes
=
num_experts
).
to
(
torch
.
int32
)
# caculate the auxiliary loss
me
=
torch
.
mean
(
logits
,
dim
=
0
)
ce
=
torch
.
mean
(
mask
.
float
(),
dim
=
0
)
l_aux
=
num_experts
*
torch
.
sum
(
me
*
ce
)
self
.
set_routing_loss
(
l_aux
)
if
not
self
.
training
and
not
self
.
drop_tks
:
max_num
=
torch
.
max
(
torch
.
sum
(
mask
,
dim
=
0
))
dist
.
all_reduce
(
max_num
,
op
=
dist
.
ReduceOp
.
MAX
,
group
=
ep_group
)
capacity
=
max_num
.
item
()
if
self
.
select_policy
==
"random"
:
rand_mask
=
mask
*
self
.
uniform
(
mask
.
shape
)
_
,
dispatch_idx
=
torch
.
topk
(
rand_mask
,
k
=
capacity
,
dim
=
0
)
mask
=
mask
*
torch
.
zeros_like
(
mask
).
scatter_
(
0
,
dispatch_idx
,
1
)
ranks
=
moe_cumsum
(
mask
)
elif
self
.
select_policy
==
"first"
:
ranks
=
moe_cumsum
(
mask
)
mask
=
mask
*
torch
.
lt
(
ranks
,
capacity
)
else
:
raise
NotImplementedError
(
"Not support such select policy yet."
)
ranks
=
torch
.
sum
(
mask
*
ranks
,
dim
=-
1
)
if
use_kernel
:
mask
=
torch
.
sum
(
mask
,
dim
=-
1
)
mask
=
torch
.
stack
([
mask
],
dim
=
0
).
to
(
torch
.
int32
)
dest_idx
=
torch
.
stack
([
top1_idx
*
capacity
+
ranks
],
dim
=
0
).
to
(
torch
.
int32
)
return
logits
,
mask
,
dest_idx
,
num_experts
*
capacity
else
:
ranks
=
F
.
one_hot
(
ranks
,
num_classes
=
capacity
)
weight
=
mask
*
logits
.
type_as
(
inputs
)
combine_weights
=
weight
.
unsqueeze
(
2
)
*
ranks
.
unsqueeze
(
1
)
sec_mask
=
combine_weights
.
bool
()
return
combine_weights
,
sec_mask
class
Top2Router
(
MoeRouter
):
"""Top2 router that returns the dispatch mask [s, e, c] and combine weight [s, e, c]
for routing usage. More deailted function can be found in the paper about ViT-MoE.
Args:
capacity_factor_train (float, optional): Capacity factor in routing of training.
capacity_factor_eval (float, optional): Capacity factor in routing of evaluation.
min_capacity (int, optional): The minimum number of the capacity of each expert
noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits.
drop_tks (bool, optional): Whether drops tokens in evaluation.
"""
def
__init__
(
self
,
capacity_factor_train
:
float
=
1.25
,
capacity_factor_eval
:
float
=
2.0
,
min_capacity
:
int
=
4
,
noisy_func
:
Callable
=
None
,
drop_tks
:
bool
=
True
):
super
().
__init__
(
k_value
=
2
,
capacity_factor_train
=
capacity_factor_train
,
capacity_factor_eval
=
capacity_factor_eval
,
min_capacity
=
min_capacity
,
noisy_func
=
noisy_func
,
drop_tks
=
drop_tks
)
def
forward
(
self
,
inputs
:
torch
.
Tensor
,
use_kernel
:
bool
=
False
,
ep_group
:
Optional
[
ProcessGroup
]
=
None
):
# inputs: [s, h]
if
self
.
noisy_func
is
not
None
and
self
.
training
:
inputs
=
self
.
noisy_func
(
inputs
)
assert
inputs
.
dtype
==
torch
.
float
logits
=
F
.
softmax
(
inputs
,
dim
=-
1
)
# logits: [s, e]
num_experts
=
logits
.
size
(
-
1
)
capacity
=
self
.
get_capacity
(
logits
.
shape
)
top1_idx
=
torch
.
argmax
(
logits
,
dim
=-
1
)
mask1
=
F
.
one_hot
(
top1_idx
,
num_classes
=
num_experts
).
to
(
torch
.
int32
)
logits_except1
=
logits
.
masked_fill
(
mask1
.
bool
(),
float
(
"-inf"
))
top2_idx
=
torch
.
argmax
(
logits_except1
,
dim
=-
1
)
mask2
=
F
.
one_hot
(
top2_idx
,
num_classes
=
num_experts
).
to
(
torch
.
int32
)
cmask
=
(
mask1
+
mask2
)
# loss: [s, e]
# caculate the auxiliary loss
me
=
torch
.
mean
(
logits
,
dim
=
0
)
ce
=
torch
.
mean
(
cmask
.
float
(),
dim
=
0
)
l_aux
=
num_experts
*
torch
.
sum
(
me
*
ce
)
/
2.0
# div 2 to normalize it to 1
self
.
set_routing_loss
(
l_aux
)
if
not
self
.
training
and
not
self
.
drop_tks
:
max_num
=
torch
.
max
(
torch
.
sum
(
cmask
,
dim
=
0
))
dist
.
all_reduce
(
max_num
,
op
=
dist
.
ReduceOp
.
MAX
,
group
=
ep_group
)
capacity
=
max_num
.
item
()
rank1
=
moe_cumsum
(
mask1
)
# rank1: [s, e]
rank2
=
moe_cumsum
(
mask2
)
rank2
+=
torch
.
sum
(
mask1
,
dim
=-
2
,
keepdim
=
True
)
mask1
*=
torch
.
lt
(
rank1
,
capacity
)
mask2
*=
torch
.
lt
(
rank2
,
capacity
)
rank1
=
torch
.
sum
(
mask1
*
rank1
,
dim
=-
1
)
rank2
=
torch
.
sum
(
mask2
*
rank2
,
dim
=-
1
)
if
use_kernel
:
mask1
=
torch
.
sum
(
mask1
,
dim
=-
1
)
mask2
=
torch
.
sum
(
mask2
,
dim
=-
1
)
mask
=
torch
.
stack
([
mask1
,
mask2
],
dim
=
0
).
to
(
torch
.
int32
)
dest_idx
=
torch
.
stack
([
top1_idx
*
capacity
+
rank1
,
top2_idx
*
capacity
+
rank2
],
dim
=
0
).
to
(
torch
.
int32
)
return
logits
,
mask
,
dest_idx
,
num_experts
*
capacity
else
:
weight1
=
mask1
*
logits
.
type_as
(
inputs
)
weight2
=
mask2
*
logits
.
type_as
(
inputs
)
rank1_sc
=
F
.
one_hot
(
rank1
,
num_classes
=
capacity
)
rank2_sc
=
F
.
one_hot
(
rank2
,
num_classes
=
capacity
)
cb_weight1
=
weight1
.
unsqueeze
(
2
)
*
rank1_sc
.
unsqueeze
(
1
)
cb_weight2
=
weight2
.
unsqueeze
(
2
)
*
rank2_sc
.
unsqueeze
(
1
)
cb_weight
=
cb_weight1
+
cb_weight2
sec_mask
=
cb_weight
.
bool
()
return
cb_weight
,
sec_mask
tests/test_moe/test_grad_handler.py
View file @
a088022e
...
...
@@ -32,7 +32,7 @@ def run_test(rank, world_size, port):
moe_layer
=
MoeLayer
(
DIM
,
num_experts
,
router
,
exp
)
layer_list
.
append
(
moe_layer
)
model
=
nn
.
Sequential
(
*
layer_list
)
model
=
nn
.
ModuleList
(
layer_list
)
model
=
model
.
to
(
get_current_device
())
sync_moe_model_param
(
model
)
...
...
@@ -49,8 +49,9 @@ def run_test(rank, world_size, port):
grad
=
torch
.
randn_like
(
data
)
MOE_CONTEXT
.
reset_loss
()
outputs
=
model
(
data
)
outputs
.
backward
(
grad
)
for
layer
in
layer_list
:
data
,
_
=
layer
(
data
)
data
.
backward
(
grad
)
grad_handler
.
handle_gradient
()
assert_equal_in_group
(
layer_list
[
0
].
experts
.
experts
[
0
].
weight
.
grad
,
dist_dict
[
1
].
dp_group
)
...
...
tests/test_moe/test_kernel.py
View file @
a088022e
...
...
@@ -44,7 +44,7 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f
# use matrix multiplication instead of COL_MOE_KERNL in MOE dispatch and combine
layer
.
use_kernel
=
False
old_out
=
layer
(
tokens
)
old_out
,
_
=
layer
(
tokens
)
ech
=
old_out
.
shape
grad
=
torch
.
randn
(
ech
,
device
=
get_current_device
())
old_out
.
backward
(
grad
)
# get gradient
...
...
@@ -58,7 +58,7 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f
layer
.
gate_weight
.
grad
.
zero_
()
layer
.
use_kernel
=
True
new_out
=
layer
(
tokens
)
# get ouputs through colossal kernel
new_out
,
_
=
layer
(
tokens
)
# get ouputs through colossal kernel
if
data_type
==
torch
.
float32
:
check_equal
(
old_out
,
new_out
)
...
...
tests/test_moe/test_moe_zero_init.py
View file @
a088022e
...
...
@@ -19,20 +19,39 @@ from colossalai.utils import get_current_device
from
tests.test_zero.common
import
CONFIG
class
MoeModel
(
Checkpoint
Module
):
class
MoeModel
(
nn
.
Module
):
def
__init__
(
self
,
checkpoint
:
bool
=
False
):
class
TestSubModule
(
CheckpointModule
):
def
__init__
(
self
):
super
().
__init__
(
checkpoint
)
self
.
proj1
=
nn
.
Linear
(
4
,
16
)
expert_cls
=
nn
.
Linear
expert_args_dict
=
dict
(
in_features
=
16
,
out_features
=
16
)
self
.
moe
=
MoeModule
(
dim_model
=
16
,
num_experts
=
8
,
use_residual
=
True
,
expert_cls
=
expert_cls
,
**
expert_args_dict
)
self
.
proj2
=
nn
.
Linear
(
16
,
4
)
self
.
moe
=
MoeModule
(
dim_model
=
16
,
num_experts
=
8
,
use_residual
=
True
,
expert_cls
=
expert_cls
,
**
expert_args_dict
)
self
.
proj
=
nn
.
Linear
(
16
,
4
)
def
_forward
(
self
,
x
):
x
,
y
=
self
.
moe
(
x
)
x
=
self
.
proj
(
x
)
return
x
,
y
super
().
__init__
()
self
.
test_embed
=
nn
.
Linear
(
4
,
16
)
self
.
test_transform
=
TestSubModule
()
def
forward
(
self
,
x
):
x
=
self
.
proj1
(
x
)
x
=
self
.
moe
(
x
)
x
=
self
.
proj2
(
x
)
MOE_CONTEXT
.
reset_loss
()
x
=
self
.
test_embed
(
x
)
x
,
y
=
self
.
test_transform
(
x
)
MOE_CONTEXT
.
add_loss
(
y
)
return
x
...
...
tests/test_moe/test_moe_zero_model.py
View file @
a088022e
...
...
@@ -4,6 +4,8 @@ import colossalai
import
pytest
import
torch
import
torch.multiprocessing
as
mp
from
colossalai.nn
import
MoeLoss
from
colossalai.testing
import
parameterize
,
rerun_if_address_is_in_use
from
colossalai.utils
import
free_port
from
colossalai.zero.init_ctx
import
ZeroInitContext
...
...
@@ -26,7 +28,8 @@ def run_model_test(enable_autocast, shard_strategy_class):
shard_strategy
=
shard_strategy_class
()
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
'no_leaf_module'
)
_
,
train_dataloader
,
_
,
_
,
criterion
=
get_components_func
()
_
,
train_dataloader
,
_
,
optimizer_class
,
_
=
get_components_func
()
criterion
=
MoeLoss
(
aux_weight
=
0.01
,
loss_fn
=
torch
.
nn
.
CrossEntropyLoss
)
with
ZeroInitContext
(
target_device
=
torch
.
device
(
'cuda'
,
torch
.
cuda
.
current_device
()),
shard_strategy
=
shard_strategy
,
...
...
@@ -59,7 +62,6 @@ def run_model_test(enable_autocast, shard_strategy_class):
def
run_dist
(
rank
,
world_size
,
port
):
colossalai
.
launch
(
config
=
CONFIG
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
MOE_CONTEXT
.
setup
(
seed
=
42
)
MOE_CONTEXT
.
reset_loss
()
run_model_test
()
...
...
tests/test_moe/test_moe_zero_optim.py
View file @
a088022e
...
...
@@ -5,6 +5,7 @@ import pytest
import
torch
import
torch.multiprocessing
as
mp
from
colossalai.amp
import
convert_to_apex_amp
from
colossalai.nn
import
MoeLoss
from
colossalai.nn.optimizer
import
CPUAdam
from
colossalai.testing
import
parameterize
,
rerun_if_address_is_in_use
from
colossalai.utils
import
free_port
...
...
@@ -60,7 +61,8 @@ def _run_test_sharded_optim_v2(cpu_offload,
return
MOE_CONTEXT
.
reset_loss
()
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
'no_leaf_module'
)
_
,
train_dataloader
,
_
,
optimizer_class
,
criterion
=
get_components_func
()
_
,
train_dataloader
,
_
,
optimizer_class
,
_
=
get_components_func
()
criterion
=
MoeLoss
(
aux_weight
=
0.01
,
loss_fn
=
torch
.
nn
.
CrossEntropyLoss
)
with
ZeroInitContext
(
target_device
=
torch
.
device
(
'cpu'
)
if
cpu_offload
else
get_current_device
(),
shard_strategy
=
shard_strategy
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment