Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
aff9d354
Unverified
Commit
aff9d354
authored
Mar 19, 2022
by
HELSON
Committed by
GitHub
Mar 19, 2022
Browse files
[MOE] polish moe_env (#467)
parent
bccbc158
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
96 additions
and
90 deletions
+96
-90
colossalai/core.py
colossalai/core.py
+1
-1
colossalai/engine/gradient_handler/_moe_gradient_handler.py
colossalai/engine/gradient_handler/_moe_gradient_handler.py
+3
-3
colossalai/nn/layer/moe/_operation.py
colossalai/nn/layer/moe/_operation.py
+3
-5
colossalai/nn/layer/moe/experts.py
colossalai/nn/layer/moe/experts.py
+45
-40
colossalai/nn/layer/moe/layers.py
colossalai/nn/layer/moe/layers.py
+30
-29
colossalai/nn/layer/moe/utils.py
colossalai/nn/layer/moe/utils.py
+6
-6
colossalai/nn/loss/loss_moe.py
colossalai/nn/loss/loss_moe.py
+5
-3
colossalai/utils/moe.py
colossalai/utils/moe.py
+3
-3
No files found.
colossalai/core.py
View file @
aff9d354
...
@@ -4,4 +4,4 @@
...
@@ -4,4 +4,4 @@
from
colossalai.context
import
ParallelContext
,
MoeContext
from
colossalai.context
import
ParallelContext
,
MoeContext
global_context
=
ParallelContext
.
get_instance
()
global_context
=
ParallelContext
.
get_instance
()
moe_context
=
MoeContext
.
get_instance
()
MOE_CONTEXT
=
MoeContext
.
get_instance
()
colossalai/engine/gradient_handler/_moe_gradient_handler.py
View file @
aff9d354
from
colossalai.core
import
global_context
as
gpc
,
moe_context
as
moe_env
from
colossalai.core
import
global_context
as
gpc
,
MOE_CONTEXT
from
colossalai.registry
import
GRADIENT_HANDLER
from
colossalai.registry
import
GRADIENT_HANDLER
from
colossalai.utils.moe
import
get_moe_epsize_param_dict
from
colossalai.utils.moe
import
get_moe_epsize_param_dict
from
._base_gradient_handler
import
BaseGradientHandler
from
._base_gradient_handler
import
BaseGradientHandler
...
@@ -30,5 +30,5 @@ class MoeGradientHandler(BaseGradientHandler):
...
@@ -30,5 +30,5 @@ class MoeGradientHandler(BaseGradientHandler):
bucket_allreduce
(
param_list
=
param_dict
[
1
],
group
=
gpc
.
get_group
(
ParallelMode
.
DATA
))
bucket_allreduce
(
param_list
=
param_dict
[
1
],
group
=
gpc
.
get_group
(
ParallelMode
.
DATA
))
for
ep_size
in
param_dict
:
for
ep_size
in
param_dict
:
if
ep_size
!=
1
and
ep_size
!=
moe_env
.
world_size
:
if
ep_size
!=
1
and
ep_size
!=
MOE_CONTEXT
.
world_size
:
bucket_allreduce
(
param_list
=
param_dict
[
ep_size
],
group
=
moe_env
.
information
[
ep_size
].
dp_group
)
bucket_allreduce
(
param_list
=
param_dict
[
ep_size
],
group
=
MOE_CONTEXT
.
information
[
ep_size
].
dp_group
)
colossalai/nn/layer/moe/_operation.py
View file @
aff9d354
...
@@ -4,11 +4,11 @@ from torch import Tensor
...
@@ -4,11 +4,11 @@ from torch import Tensor
from
typing
import
Any
,
Tuple
,
Optional
from
typing
import
Any
,
Tuple
,
Optional
from
torch.distributed
import
ProcessGroup
from
torch.distributed
import
ProcessGroup
U_CUDA_MODE
=
False
COL_MOE_KERNEL_FLAG
=
False
try
:
try
:
import
colossal_moe_cuda
import
colossal_moe_cuda
U_CUDA_MODE
=
True
COL_MOE_KERNEL_FLAG
=
True
except
ImportError
:
except
ImportError
:
print
(
"If you want to activate cuda mode for MoE, please install with cuda_ext!"
)
print
(
"If you want to activate cuda mode for MoE, please install with cuda_ext!"
)
...
@@ -17,7 +17,6 @@ class AllGather(torch.autograd.Function):
...
@@ -17,7 +17,6 @@ class AllGather(torch.autograd.Function):
@
staticmethod
@
staticmethod
def
forward
(
ctx
:
Any
,
inputs
:
Tensor
,
group
:
Optional
[
ProcessGroup
]
=
None
)
->
Tensor
:
def
forward
(
ctx
:
Any
,
inputs
:
Tensor
,
group
:
Optional
[
ProcessGroup
]
=
None
)
->
Tensor
:
if
ctx
is
not
None
:
if
ctx
is
not
None
:
ctx
.
comm_grp
=
group
ctx
.
comm_grp
=
group
...
@@ -40,7 +39,6 @@ class ReduceScatter(torch.autograd.Function):
...
@@ -40,7 +39,6 @@ class ReduceScatter(torch.autograd.Function):
@
staticmethod
@
staticmethod
def
forward
(
ctx
:
Any
,
inputs
:
Tensor
,
group
:
Optional
[
ProcessGroup
]
=
None
)
->
Tensor
:
def
forward
(
ctx
:
Any
,
inputs
:
Tensor
,
group
:
Optional
[
ProcessGroup
]
=
None
)
->
Tensor
:
if
ctx
is
not
None
:
if
ctx
is
not
None
:
ctx
.
comm_grp
=
group
ctx
.
comm_grp
=
group
...
@@ -149,7 +147,7 @@ class MoeCombine(torch.autograd.Function):
...
@@ -149,7 +147,7 @@ class MoeCombine(torch.autograd.Function):
def
moe_cumsum
(
inputs
:
Tensor
):
def
moe_cumsum
(
inputs
:
Tensor
):
dim0
=
inputs
.
size
(
0
)
dim0
=
inputs
.
size
(
0
)
flag
=
(
dim0
<=
1024
)
or
(
dim0
<=
2048
and
dim0
%
2
==
0
)
or
(
dim0
%
4
==
0
)
flag
=
(
dim0
<=
1024
)
or
(
dim0
<=
2048
and
dim0
%
2
==
0
)
or
(
dim0
%
4
==
0
)
if
flag
and
U_CUDA_MODE
:
if
flag
and
COL_MOE_KERNEL_FLAG
:
return
colossal_moe_cuda
.
cumsum_sub_one
(
inputs
)
return
colossal_moe_cuda
.
cumsum_sub_one
(
inputs
)
else
:
else
:
return
torch
.
cumsum
(
inputs
,
dim
=
0
)
-
1
return
torch
.
cumsum
(
inputs
,
dim
=
0
)
-
1
colossalai/nn/layer/moe/experts.py
View file @
aff9d354
...
@@ -2,18 +2,24 @@ import math
...
@@ -2,18 +2,24 @@ import math
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
colossalai.global_variables
import
moe_env
from
colossalai.context
import
ParallelMode
,
seed
from
colossalai.context
import
ParallelMode
,
seed
from
colossalai.utils
import
get_current_device
from
colossalai.utils
import
get_current_device
from
colossalai.core
import
MOE_CONTEXT
class
MoeExperts
(
nn
.
Module
):
class
MoeExperts
(
nn
.
Module
):
"""Basic class for experts in MoE. It stores what kind of communication expersts use
to exchange tokens, how many experts in a single GPU and parallel information such as
expert parallel size, data parallel size and their distributed communication groups.
"""
def
__init__
(
self
,
comm
:
str
):
def
__init__
(
self
,
comm
_name
:
str
,
num_experts
:
int
):
super
().
__init__
()
super
().
__init__
()
assert
comm
in
{
"all_to_all"
,
"all_gather"
},
\
assert
comm
_name
in
{
"all_to_all"
,
"all_gather"
},
\
"This kind of communication has not been implemented yet.
\n
Please use Experts build function."
"This kind of communication has not been implemented yet.
\n
Please use Experts build function."
self
.
comm
=
comm
self
.
comm_name
=
comm_name
# Get the configuration of experts' deployment and parallel information from moe contex
self
.
num_local_experts
,
self
.
dist_info
=
MOE_CONTEXT
.
get_info
(
num_experts
)
class
Experts
(
MoeExperts
):
class
Experts
(
MoeExperts
):
...
@@ -29,53 +35,48 @@ class Experts(MoeExperts):
...
@@ -29,53 +35,48 @@ class Experts(MoeExperts):
"""
"""
def
__init__
(
self
,
expert
,
num_experts
,
**
expert_args
):
def
__init__
(
self
,
expert
,
num_experts
,
**
expert_args
):
super
().
__init__
(
"all_to_all"
)
super
().
__init__
(
"all_to_all"
,
num_experts
)
assert
num_experts
%
moe_env
.
model_parallel_size
==
0
,
\
"The number of experts should be divied by moe model size"
num_local_experts
=
num_experts
//
moe_env
.
model_parallel_size
with
seed
(
ParallelMode
.
MOE_MODEL
):
# Use seed to make every expert different from others
self
.
experts
=
nn
.
ModuleList
([
expert
(
**
expert_args
)
for
_
in
range
(
num_local_experts
)])
with
seed
(
ParallelMode
.
TENSOR
):
self
.
experts
=
nn
.
ModuleList
([
expert
(
**
expert_args
)
for
_
in
range
(
self
.
num_local_experts
)])
# Attach parallel information for all parameters in Experts
for
exp
in
self
.
experts
:
for
exp
in
self
.
experts
:
for
param
in
exp
.
parameters
():
for
param
in
exp
.
parameters
():
param
.
__setattr__
(
'moe_param'
,
True
)
param
.
__setattr__
(
'moe_info'
,
self
.
dist_info
)
self
.
num_local_experts
=
num_local_experts
def
forward
(
self
,
inputs
):
def
forward
(
self
,
inputs
:
torch
.
Tensor
):
# Split inputs for each expert
expert_input
=
torch
.
chunk
(
inputs
,
self
.
num_local_experts
,
dim
=
1
)
expert_input
=
torch
.
chunk
(
inputs
,
self
.
num_local_experts
,
dim
=
1
)
expert_output
=
[]
expert_output
=
[]
# Get outputs from each expert
for
i
in
range
(
self
.
num_local_experts
):
for
i
in
range
(
self
.
num_local_experts
):
expert_output
.
append
(
self
.
experts
[
i
](
expert_input
[
i
]))
expert_output
.
append
(
self
.
experts
[
i
](
expert_input
[
i
]))
# Concatenate all outputs together
output
=
torch
.
cat
(
expert_output
,
dim
=
1
).
contiguous
()
output
=
torch
.
cat
(
expert_output
,
dim
=
1
).
contiguous
()
return
output
return
output
class
FFNExperts
(
MoeExperts
):
class
FFNExperts
(
MoeExperts
):
"""Use torch.bmm to speed up for multiple experts.
"""
def
__init__
(
self
,
num_experts
:
int
,
d_model
:
int
,
d_ff
:
int
,
activation
=
None
,
drop_rate
:
float
=
0
):
def
__init__
(
self
,
num_experts
:
int
,
d_model
:
int
,
d_ff
:
int
,
activation
=
None
,
drop_rate
:
float
=
0
):
super
().
__init__
(
"all_to_all"
)
super
().
__init__
(
"all_to_all"
,
num_experts
)
assert
num_experts
%
moe_env
.
model_parallel_size
==
0
,
\
self
.
w1
=
nn
.
Parameter
(
torch
.
empty
(
self
.
num_local_experts
,
d_model
,
d_ff
,
device
=
get_current_device
()))
"The number of experts should be
d
i
vi
ed by moe model size"
self
.
b1
=
nn
.
Parameter
(
torch
.
empty
(
self
.
num_local_experts
,
1
,
d_ff
,
d
e
vi
ce
=
get_current_device
()))
num_local_experts
=
num_experts
//
moe_env
.
model_parallel_size
self
.
w2
=
nn
.
Parameter
(
torch
.
empty
(
self
.
num_local_experts
,
d_ff
,
d_model
,
device
=
get_current_device
()))
self
.
b2
=
nn
.
Parameter
(
torch
.
empty
(
self
.
num_local_experts
,
1
,
d_model
,
device
=
get_current_device
()))
self
.
w1
=
nn
.
Parameter
(
torch
.
empty
(
num_local_experts
,
d_model
,
d_ff
,
device
=
get_current_device
()))
self
.
b1
=
nn
.
Parameter
(
torch
.
empty
(
num_local_experts
,
1
,
d_ff
,
device
=
get_current_device
()))
self
.
w2
=
nn
.
Parameter
(
torch
.
empty
(
num_local_experts
,
d_ff
,
d_model
,
device
=
get_current_device
()))
self
.
b2
=
nn
.
Parameter
(
torch
.
empty
(
num_local_experts
,
1
,
d_model
,
device
=
get_current_device
()))
s1
=
math
.
sqrt
(
0.1
/
d_model
)
s1
=
math
.
sqrt
(
0.1
/
d_model
)
s2
=
math
.
sqrt
(
0.1
/
d_ff
)
s2
=
math
.
sqrt
(
0.1
/
d_ff
)
with
seed
(
ParallelMode
.
MOE_MODEL
):
with
seed
(
ParallelMode
.
TENSOR
):
nn
.
init
.
trunc_normal_
(
self
.
w1
,
std
=
s1
)
nn
.
init
.
trunc_normal_
(
self
.
w1
,
std
=
s1
)
nn
.
init
.
trunc_normal_
(
self
.
b1
,
std
=
s1
)
nn
.
init
.
trunc_normal_
(
self
.
b1
,
std
=
s1
)
nn
.
init
.
trunc_normal_
(
self
.
w2
,
std
=
s2
)
nn
.
init
.
trunc_normal_
(
self
.
w2
,
std
=
s2
)
...
@@ -85,7 +86,7 @@ class FFNExperts(MoeExperts):
...
@@ -85,7 +86,7 @@ class FFNExperts(MoeExperts):
self
.
drop
=
nn
.
Dropout
(
p
=
drop_rate
)
self
.
drop
=
nn
.
Dropout
(
p
=
drop_rate
)
for
param
in
self
.
parameters
():
for
param
in
self
.
parameters
():
param
.
__setattr__
(
'moe_
param'
,
True
)
param
.
__setattr__
(
'moe_
info'
,
self
.
dist_info
)
def
forward
(
self
,
inputs
):
# inputs [g, el, c, h]
def
forward
(
self
,
inputs
):
# inputs [g, el, c, h]
...
@@ -99,9 +100,9 @@ class FFNExperts(MoeExperts):
...
@@ -99,9 +100,9 @@ class FFNExperts(MoeExperts):
out_ff
=
torch
.
baddbmm
(
self
.
b1
,
inputs
,
self
.
w1
)
out_ff
=
torch
.
baddbmm
(
self
.
b1
,
inputs
,
self
.
w1
)
out_act
=
self
.
act
(
out_ff
)
out_act
=
self
.
act
(
out_ff
)
with
seed
(
ParallelMode
.
TENSOR
):
with
seed
(
ParallelMode
.
TENSOR
):
inter
=
self
.
drop
(
out_act
)
out_
inter
=
self
.
drop
(
out_act
)
out_model
=
torch
.
baddbmm
(
self
.
b2
,
inter
,
self
.
w2
)
out_model
=
torch
.
baddbmm
(
self
.
b2
,
out_
inter
,
self
.
w2
)
with
seed
(
ParallelMode
.
TENSOR
):
with
seed
(
ParallelMode
.
TENSOR
):
outputs
=
self
.
drop
(
out_model
)
# outputs [el, gc, h]
outputs
=
self
.
drop
(
out_model
)
# outputs [el, gc, h]
...
@@ -111,14 +112,18 @@ class FFNExperts(MoeExperts):
...
@@ -111,14 +112,18 @@ class FFNExperts(MoeExperts):
class
TPExperts
(
MoeExperts
):
class
TPExperts
(
MoeExperts
):
"""Use tensor parallelism to split each expert evenly, which can deploy experts in
case that the number of experts can't be divied by maximum expert parallel size or
maximum expert parallel size can't be divied by the number of experts.
"""
def
__init__
(
self
,
num_experts
:
int
,
d_model
:
int
,
d_ff
:
int
,
activation
=
None
,
drop_rate
:
float
=
0
):
def
__init__
(
self
,
num_experts
:
int
,
d_model
:
int
,
d_ff
:
int
,
activation
=
None
,
drop_rate
:
float
=
0
):
super
().
__init__
(
"all_gather"
)
super
().
__init__
(
"all_gather"
,
MOE_CONTEXT
.
max_ep_size
)
assert
d_ff
%
moe_env
.
model_parallel
_size
==
0
,
\
assert
d_ff
%
MOE_CONTEXT
.
max_ep
_size
==
0
,
\
"d_ff should be divied by m
oe mod
el size"
"d_ff should be divied by m
aximum expert parall
el size"
p_ff
=
d_ff
//
moe_env
.
model_parallel
_size
p_ff
=
d_ff
//
MOE_CONTEXT
.
max_ep
_size
self
.
w1
=
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
d_model
,
p_ff
,
device
=
get_current_device
()))
self
.
w1
=
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
d_model
,
p_ff
,
device
=
get_current_device
()))
self
.
b1
=
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
1
,
p_ff
,
device
=
get_current_device
()))
self
.
b1
=
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
1
,
p_ff
,
device
=
get_current_device
()))
...
@@ -129,7 +134,7 @@ class TPExperts(MoeExperts):
...
@@ -129,7 +134,7 @@ class TPExperts(MoeExperts):
s1
=
math
.
sqrt
(
0.1
/
d_model
)
s1
=
math
.
sqrt
(
0.1
/
d_model
)
s2
=
math
.
sqrt
(
0.1
/
d_ff
)
s2
=
math
.
sqrt
(
0.1
/
d_ff
)
with
seed
(
ParallelMode
.
MOE_MODEL
):
with
seed
(
ParallelMode
.
TENSOR
):
nn
.
init
.
trunc_normal_
(
self
.
w1
,
std
=
s1
)
nn
.
init
.
trunc_normal_
(
self
.
w1
,
std
=
s1
)
nn
.
init
.
trunc_normal_
(
self
.
b1
,
std
=
s1
)
nn
.
init
.
trunc_normal_
(
self
.
b1
,
std
=
s1
)
nn
.
init
.
trunc_normal_
(
self
.
w2
,
std
=
s2
)
nn
.
init
.
trunc_normal_
(
self
.
w2
,
std
=
s2
)
...
@@ -139,9 +144,9 @@ class TPExperts(MoeExperts):
...
@@ -139,9 +144,9 @@ class TPExperts(MoeExperts):
self
.
act
=
nn
.
GELU
()
if
activation
is
None
else
activation
self
.
act
=
nn
.
GELU
()
if
activation
is
None
else
activation
self
.
drop
=
nn
.
Dropout
(
p
=
drop_rate
)
self
.
drop
=
nn
.
Dropout
(
p
=
drop_rate
)
self
.
w1
.
__setattr__
(
'moe_
param'
,
True
)
self
.
w1
.
__setattr__
(
'moe_
info'
,
self
.
dist_info
)
self
.
w2
.
__setattr__
(
'moe_
param'
,
True
)
self
.
w2
.
__setattr__
(
'moe_
info'
,
self
.
dist_info
)
self
.
b1
.
__setattr__
(
'moe_
param'
,
True
)
self
.
b1
.
__setattr__
(
'moe_
info'
,
self
.
dist_info
)
def
forward
(
self
,
inputs
):
# inputs [g, e, c, h]
def
forward
(
self
,
inputs
):
# inputs [g, e, c, h]
...
@@ -155,9 +160,9 @@ class TPExperts(MoeExperts):
...
@@ -155,9 +160,9 @@ class TPExperts(MoeExperts):
out_ff
=
torch
.
baddbmm
(
self
.
b1
,
inputs
,
self
.
w1
)
out_ff
=
torch
.
baddbmm
(
self
.
b1
,
inputs
,
self
.
w1
)
out_act
=
self
.
act
(
out_ff
)
out_act
=
self
.
act
(
out_ff
)
with
seed
(
ParallelMode
.
TENSOR
):
with
seed
(
ParallelMode
.
TENSOR
):
inter
=
self
.
drop
(
out_act
)
out_
inter
=
self
.
drop
(
out_act
)
out_model
=
torch
.
baddbmm
(
self
.
b2
,
inter
,
self
.
w2
)
out_model
=
torch
.
baddbmm
(
self
.
b2
,
out_
inter
,
self
.
w2
)
outputs
=
self
.
drop
(
out_model
)
# outputs [e, gc, h]
outputs
=
self
.
drop
(
out_model
)
# outputs [e, gc, h]
outputs
=
outputs
.
reshape
(
inshape
)
outputs
=
outputs
.
reshape
(
inshape
)
...
...
colossalai/nn/layer/moe/layers.py
View file @
aff9d354
...
@@ -4,14 +4,13 @@ import torch
...
@@ -4,14 +4,13 @@ import torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
colossalai.core
import
global_context
as
gpc
from
colossalai.core
import
MOE_CONTEXT
from
colossalai.global_variables
import
moe_env
from
colossalai.context
import
ParallelMode
from
colossalai.utils
import
get_current_device
from
colossalai.utils
import
get_current_device
from
._operation
import
U_CUDA_MODE
,
AllToAll
,
AllGather
,
ReduceScatter
,
MoeDispatch
,
MoeCombine
,
moe_cumsum
from
._operation
import
COL_MOE_KERNEL_FLAG
,
AllToAll
,
AllGather
,
ReduceScatter
,
MoeDispatch
,
MoeCombine
,
moe_cumsum
from
.experts
import
MoeExperts
from
.experts
import
MoeExperts
from
.utils
import
autocast_softmax
from
.utils
import
autocast_softmax
from
typing
import
Callable
from
typing
import
Callable
,
Optional
from
torch.distributed
import
ProcessGroup
class
Top1Router
(
nn
.
Module
):
class
Top1Router
(
nn
.
Module
):
...
@@ -19,8 +18,8 @@ class Top1Router(nn.Module):
...
@@ -19,8 +18,8 @@ class Top1Router(nn.Module):
for routing usage. More deailted function can be found in the paper about Switch Transformer
for routing usage. More deailted function can be found in the paper about Switch Transformer
of Google.
of Google.
:param capacity_factor_train: Capacity factor in routing
of
training
:param capacity_factor_train: Capacity factor in routing
during
training
:param capacity_factor_eval: Capacity factor in routing
of
evaluation
:param capacity_factor_eval: Capacity factor in routing
during
evaluation
:param min_capacity: The minimum number of the capacity of each expert
:param min_capacity: The minimum number of the capacity of each expert
:param select_policy: The policy about tokens selection
:param select_policy: The policy about tokens selection
:param noisy_func: Noisy function used in logits
:param noisy_func: Noisy function used in logits
...
@@ -66,7 +65,7 @@ class Top1Router(nn.Module):
...
@@ -66,7 +65,7 @@ class Top1Router(nn.Module):
assert
capacity
>
0
assert
capacity
>
0
return
capacity
return
capacity
def
forward
(
self
,
inputs
:
torch
.
Tensor
,
cuda_mode
:
bool
=
False
):
def
forward
(
self
,
inputs
:
torch
.
Tensor
,
use_kernel
:
bool
=
False
,
ep_group
:
Optional
[
ProcessGroup
]
=
None
):
if
self
.
noisy_func
is
not
None
and
self
.
training
:
if
self
.
noisy_func
is
not
None
and
self
.
training
:
inputs
=
self
.
noisy_func
(
inputs
)
inputs
=
self
.
noisy_func
(
inputs
)
...
@@ -82,10 +81,10 @@ class Top1Router(nn.Module):
...
@@ -82,10 +81,10 @@ class Top1Router(nn.Module):
me
=
torch
.
mean
(
logits
,
dim
=
0
)
me
=
torch
.
mean
(
logits
,
dim
=
0
)
ce
=
torch
.
mean
(
mask
.
float
(),
dim
=
0
)
ce
=
torch
.
mean
(
mask
.
float
(),
dim
=
0
)
l_aux
=
num_experts
*
torch
.
sum
(
me
*
ce
)
l_aux
=
num_experts
*
torch
.
sum
(
me
*
ce
)
moe_env
.
add_loss
(
l_aux
)
MOE_CONTEXT
.
add_loss
(
l_aux
)
elif
not
self
.
drop_tks
:
elif
not
self
.
drop_tks
:
max_num
=
torch
.
max
(
torch
.
sum
(
mask
,
dim
=
0
))
max_num
=
torch
.
max
(
torch
.
sum
(
mask
,
dim
=
0
))
dist
.
all_reduce
(
max_num
,
op
=
dist
.
ReduceOp
.
MAX
,
group
=
gpc
.
get_group
(
ParallelMode
.
MOE_MODEL
)
)
dist
.
all_reduce
(
max_num
,
op
=
dist
.
ReduceOp
.
MAX
,
group
=
ep_group
)
capacity
=
max_num
.
item
()
capacity
=
max_num
.
item
()
else
:
else
:
pass
pass
...
@@ -103,7 +102,7 @@ class Top1Router(nn.Module):
...
@@ -103,7 +102,7 @@ class Top1Router(nn.Module):
ranks
=
torch
.
sum
(
mask
*
ranks
,
dim
=-
1
)
ranks
=
torch
.
sum
(
mask
*
ranks
,
dim
=-
1
)
if
cuda_mode
:
if
use_kernel
:
mask
=
torch
.
sum
(
mask
,
dim
=-
1
)
mask
=
torch
.
sum
(
mask
,
dim
=-
1
)
mask
=
torch
.
stack
([
mask
],
dim
=
0
).
to
(
torch
.
int32
)
mask
=
torch
.
stack
([
mask
],
dim
=
0
).
to
(
torch
.
int32
)
dest_idx
=
torch
.
stack
([
top1_idx
*
capacity
+
ranks
],
dim
=
0
).
to
(
torch
.
int32
)
dest_idx
=
torch
.
stack
([
top1_idx
*
capacity
+
ranks
],
dim
=
0
).
to
(
torch
.
int32
)
...
@@ -120,8 +119,8 @@ class Top2Router(nn.Module):
...
@@ -120,8 +119,8 @@ class Top2Router(nn.Module):
"""Top2 router that returns the dispatch mask [s, e, c] and combine weight [s, e, c]
"""Top2 router that returns the dispatch mask [s, e, c] and combine weight [s, e, c]
for routing usage. More deailted function can be found in the paper about ViT-MoE.
for routing usage. More deailted function can be found in the paper about ViT-MoE.
:param capacity_factor_train: Capacity factor in routing
of
training
:param capacity_factor_train: Capacity factor in routing
during
training
:param capacity_factor_eval: Capacity factor in routing
of
evaluation
:param capacity_factor_eval: Capacity factor in routing
during
evaluation
:param min_capacity: The minimum number of the capacity of each expert
:param min_capacity: The minimum number of the capacity of each expert
:param noisy_func: Noisy function used in logits
:param noisy_func: Noisy function used in logits
:param drop_tks: Whether drops tokens in evaluation
:param drop_tks: Whether drops tokens in evaluation
...
@@ -157,7 +156,7 @@ class Top2Router(nn.Module):
...
@@ -157,7 +156,7 @@ class Top2Router(nn.Module):
assert
capacity
>
0
assert
capacity
>
0
return
capacity
return
capacity
def
forward
(
self
,
inputs
:
torch
.
Tensor
,
cuda_mode
:
bool
=
False
):
def
forward
(
self
,
inputs
:
torch
.
Tensor
,
use_kernel
:
bool
=
False
,
ep_group
:
Optional
[
ProcessGroup
]
=
None
):
# inputs: [s, h]
# inputs: [s, h]
if
self
.
noisy_func
is
not
None
and
self
.
training
:
if
self
.
noisy_func
is
not
None
and
self
.
training
:
inputs
=
self
.
noisy_func
(
inputs
)
inputs
=
self
.
noisy_func
(
inputs
)
...
@@ -177,10 +176,10 @@ class Top2Router(nn.Module):
...
@@ -177,10 +176,10 @@ class Top2Router(nn.Module):
me
=
torch
.
mean
(
logits
,
dim
=
0
)
me
=
torch
.
mean
(
logits
,
dim
=
0
)
ce
=
torch
.
mean
(
cmask
.
float
(),
dim
=
0
)
ce
=
torch
.
mean
(
cmask
.
float
(),
dim
=
0
)
l_aux
=
num_experts
*
torch
.
sum
(
me
*
ce
)
/
2.0
# div 2 to normalize it to 1
l_aux
=
num_experts
*
torch
.
sum
(
me
*
ce
)
/
2.0
# div 2 to normalize it to 1
moe_env
.
add_loss
(
l_aux
)
MOE_CONTEXT
.
add_loss
(
l_aux
)
elif
not
self
.
drop_tks
:
elif
not
self
.
drop_tks
:
max_num
=
torch
.
max
(
torch
.
sum
(
cmask
,
dim
=
0
))
max_num
=
torch
.
max
(
torch
.
sum
(
cmask
,
dim
=
0
))
dist
.
all_reduce
(
max_num
,
op
=
dist
.
ReduceOp
.
MAX
,
group
=
gpc
.
get_group
(
ParallelMode
.
MOE_MODEL
)
)
dist
.
all_reduce
(
max_num
,
op
=
dist
.
ReduceOp
.
MAX
,
group
=
ep_group
)
capacity
=
max_num
.
item
()
capacity
=
max_num
.
item
()
else
:
else
:
pass
pass
...
@@ -195,7 +194,7 @@ class Top2Router(nn.Module):
...
@@ -195,7 +194,7 @@ class Top2Router(nn.Module):
rank1
=
torch
.
sum
(
mask1
*
rank1
,
dim
=-
1
)
rank1
=
torch
.
sum
(
mask1
*
rank1
,
dim
=-
1
)
rank2
=
torch
.
sum
(
mask2
*
rank2
,
dim
=-
1
)
rank2
=
torch
.
sum
(
mask2
*
rank2
,
dim
=-
1
)
if
cuda_mode
:
if
use_kernel
:
mask1
=
torch
.
sum
(
mask1
,
dim
=-
1
)
mask1
=
torch
.
sum
(
mask1
,
dim
=-
1
)
mask2
=
torch
.
sum
(
mask2
,
dim
=-
1
)
mask2
=
torch
.
sum
(
mask2
,
dim
=-
1
)
...
@@ -241,34 +240,36 @@ class MoeLayer(nn.Module):
...
@@ -241,34 +240,36 @@ class MoeLayer(nn.Module):
self
.
gate
=
nn
.
Linear
(
dim_model
,
num_experts
,
bias
=
False
,
device
=
get_current_device
())
self
.
gate
=
nn
.
Linear
(
dim_model
,
num_experts
,
bias
=
False
,
device
=
get_current_device
())
self
.
router
=
router
self
.
router
=
router
self
.
experts
=
experts
self
.
experts
=
experts
self
.
cuda_mode
=
True
if
U_CUDA_MODE
and
moe_env
.
enable_cuda
else
False
self
.
use_kernel
=
True
if
COL_MOE_KERNEL_FLAG
and
MOE_CONTEXT
.
use_kernel_optim
else
False
self
.
ep_group
=
experts
.
dist_info
.
ep_group
self
.
ep_size
=
experts
.
dist_info
.
ep_size
self
.
num_local_experts
=
experts
.
num_local_experts
def
a2a_process
(
self
,
dispatch_data
:
torch
.
Tensor
):
def
a2a_process
(
self
,
dispatch_data
:
torch
.
Tensor
):
expert_input
=
AllToAll
.
apply
(
dispatch_data
,
ParallelMode
.
MOE_MODEL
)
expert_input
=
AllToAll
.
apply
(
dispatch_data
,
self
.
ep_group
)
input_shape
=
expert_input
.
shape
input_shape
=
expert_input
.
shape
expert_input
=
expert_input
.
reshape
(
moe_env
.
model_parallel_size
,
expert_input
=
expert_input
.
reshape
(
self
.
ep_size
,
self
.
num_local_experts
,
-
1
,
self
.
d_model
)
self
.
num_experts
//
moe_env
.
model_parallel_size
,
-
1
,
self
.
d_model
)
expert_output
=
self
.
experts
(
expert_input
)
expert_output
=
self
.
experts
(
expert_input
)
expert_output
=
expert_output
.
reshape
(
input_shape
)
expert_output
=
expert_output
.
reshape
(
input_shape
)
expert_output
=
AllToAll
.
apply
(
expert_output
,
ParallelMode
.
MOE_MODEL
)
expert_output
=
AllToAll
.
apply
(
expert_output
,
self
.
ep_group
)
return
expert_output
return
expert_output
def
tp_process
(
self
,
dispatch_data
:
torch
.
Tensor
):
def
tp_process
(
self
,
dispatch_data
:
torch
.
Tensor
):
expert_in
=
AllGather
.
apply
(
dispatch_data
,
ParallelMode
.
MOE_MODEL
)
expert_in
=
AllGather
.
apply
(
dispatch_data
,
self
.
ep_group
)
expert_out
=
self
.
experts
(
expert_in
)
expert_out
=
self
.
experts
(
expert_in
)
expert_out
=
ReduceScatter
.
apply
(
expert_out
,
ParallelMode
.
MOE_MODEL
)
expert_out
=
ReduceScatter
.
apply
(
expert_out
,
self
.
ep_group
)
return
expert_out
return
expert_out
def
forward
(
self
,
inputs
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
inputs
:
torch
.
Tensor
)
->
torch
.
Tensor
:
tokens
=
inputs
.
reshape
(
-
1
,
self
.
d_model
)
tokens
=
inputs
.
reshape
(
-
1
,
self
.
d_model
)
gate_output
=
self
.
gate
(
tokens
)
gate_output
=
self
.
gate
(
tokens
)
router_res
=
self
.
router
(
gate_output
,
self
.
cuda_mode
)
router_res
=
self
.
router
(
inputs
=
gate_output
,
use_kernel
=
self
.
use_kernel
,
ep_group
=
self
.
ep_group
)
if
self
.
cuda_mode
:
if
self
.
use_kernel
:
dispatch_data
=
MoeDispatch
.
apply
(
tokens
,
*
router_res
[
1
:])
dispatch_data
=
MoeDispatch
.
apply
(
tokens
,
*
router_res
[
1
:])
dispatch_data
=
dispatch_data
.
reshape
(
self
.
num_experts
,
-
1
,
self
.
d_model
)
dispatch_data
=
dispatch_data
.
reshape
(
self
.
num_experts
,
-
1
,
self
.
d_model
)
else
:
else
:
...
@@ -276,16 +277,16 @@ class MoeLayer(nn.Module):
...
@@ -276,16 +277,16 @@ class MoeLayer(nn.Module):
dispatch_data
=
torch
.
matmul
(
sec_mask_f
.
permute
(
1
,
2
,
0
),
tokens
)
dispatch_data
=
torch
.
matmul
(
sec_mask_f
.
permute
(
1
,
2
,
0
),
tokens
)
# dispatch_data [e, c, h]
# dispatch_data [e, c, h]
if
self
.
experts
.
comm
==
"all_to_all"
:
if
self
.
experts
.
comm
_name
==
"all_to_all"
:
expert_output
=
self
.
a2a_process
(
dispatch_data
)
expert_output
=
self
.
a2a_process
(
dispatch_data
)
elif
self
.
experts
.
comm
==
"all_gather"
:
elif
self
.
experts
.
comm
_name
==
"all_gather"
:
expert_output
=
self
.
tp_process
(
dispatch_data
)
expert_output
=
self
.
tp_process
(
dispatch_data
)
else
:
else
:
raise
NotImplementedError
(
"This kind of communication has not been implemented yet.
\n
Please use Experts "
raise
NotImplementedError
(
"This kind of communication has not been implemented yet.
\n
Please use Experts "
"build function."
)
"build function."
)
# expert_output [e, c, h]
# expert_output [e, c, h]
if
self
.
cuda_mode
:
if
self
.
use_kernel
:
expert_output
=
expert_output
.
reshape
(
-
1
,
self
.
d_model
)
expert_output
=
expert_output
.
reshape
(
-
1
,
self
.
d_model
)
ans
=
MoeCombine
.
apply
(
expert_output
,
*
router_res
)
ans
=
MoeCombine
.
apply
(
expert_output
,
*
router_res
)
else
:
else
:
...
...
colossalai/nn/layer/moe/utils.py
View file @
aff9d354
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
colossalai.utils
import
get_current_device
from
colossalai.utils
import
get_current_device
from
colossalai.
global_variables
import
moe_env
from
colossalai.
core
import
MOE_CONTEXT
from
.experts
import
FFNExperts
,
TPExperts
from
.experts
import
FFNExperts
,
TPExperts
...
@@ -36,7 +36,7 @@ class UniformNoiseGenerator:
...
@@ -36,7 +36,7 @@ class UniformNoiseGenerator:
:type eps: float
:type eps: float
"""
"""
def
__init__
(
self
,
eps
:
float
):
def
__init__
(
self
,
eps
:
float
=
1e-2
):
self
.
uniform
=
torch
.
distributions
.
uniform
.
Uniform
(
low
=
torch
.
tensor
(
1.0
-
eps
,
device
=
get_current_device
()),
self
.
uniform
=
torch
.
distributions
.
uniform
.
Uniform
(
low
=
torch
.
tensor
(
1.0
-
eps
,
device
=
get_current_device
()),
high
=
torch
.
tensor
(
1.0
+
eps
,
high
=
torch
.
tensor
(
1.0
+
eps
,
device
=
get_current_device
())).
rsample
device
=
get_current_device
())).
rsample
...
@@ -55,10 +55,10 @@ def autocast_softmax(inputs: torch.Tensor, dim: int):
...
@@ -55,10 +55,10 @@ def autocast_softmax(inputs: torch.Tensor, dim: int):
def
build_ffn_experts
(
num_experts
:
int
,
d_model
:
int
,
d_ff
:
int
,
activation
=
None
,
drop_rate
:
float
=
0
):
def
build_ffn_experts
(
num_experts
:
int
,
d_model
:
int
,
d_ff
:
int
,
activation
=
None
,
drop_rate
:
float
=
0
):
m
oe_m
p_size
=
moe_env
.
model_parallel
_size
m
e
p_size
=
MOE_CONTEXT
.
max_ep
_size
if
num_experts
%
m
oe_mp_size
==
0
:
if
num_experts
%
m
ep_size
==
0
or
mep_size
%
num_experts
==
0
:
return
FFNExperts
(
num_experts
,
d_model
,
d_ff
,
activation
,
drop_rate
)
return
FFNExperts
(
num_experts
,
d_model
,
d_ff
,
activation
,
drop_rate
)
elif
d_ff
%
m
oe_m
p_size
==
0
:
elif
d_ff
%
m
e
p_size
==
0
:
return
TPExperts
(
num_experts
,
d_model
,
d_ff
,
activation
,
drop_rate
)
return
TPExperts
(
num_experts
,
d_model
,
d_ff
,
activation
,
drop_rate
)
else
:
else
:
raise
NotImplementedError
(
f
"Can not build
{
num_experts
}
experts in
{
m
oe_m
p_size
}
GPUS."
)
raise
NotImplementedError
(
f
"Can not build
{
num_experts
}
experts in
{
m
e
p_size
}
GPUS."
)
colossalai/nn/loss/loss_moe.py
View file @
aff9d354
import
torch.nn
as
nn
import
torch.nn
as
nn
from
colossalai.registry
import
LOSSES
from
colossalai.registry
import
LOSSES
from
torch.nn.modules.loss
import
_Loss
from
torch.nn.modules.loss
import
_Loss
from
colossalai.
global_variables
import
moe_env
from
colossalai.
core
import
MOE_CONTEXT
@
LOSSES
.
register_module
@
LOSSES
.
register_module
...
@@ -14,6 +14,7 @@ class MoeCrossEntropyLoss(_Loss):
...
@@ -14,6 +14,7 @@ class MoeCrossEntropyLoss(_Loss):
:type aux_weight: float, optional
:type aux_weight: float, optional
"""
"""
def
__init__
(
self
,
aux_weight
:
float
=
0.01
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
aux_weight
:
float
=
0.01
,
*
args
,
**
kwargs
):
super
().
__init__
()
super
().
__init__
()
self
.
loss
=
nn
.
CrossEntropyLoss
(
*
args
,
**
kwargs
)
self
.
loss
=
nn
.
CrossEntropyLoss
(
*
args
,
**
kwargs
)
...
@@ -21,7 +22,7 @@ class MoeCrossEntropyLoss(_Loss):
...
@@ -21,7 +22,7 @@ class MoeCrossEntropyLoss(_Loss):
def
forward
(
self
,
*
args
):
def
forward
(
self
,
*
args
):
main_loss
=
self
.
loss
(
*
args
)
main_loss
=
self
.
loss
(
*
args
)
aux_loss
=
moe_env
.
get_loss
()
aux_loss
=
MOE_CONTEXT
.
get_loss
()
return
main_loss
+
self
.
aux_weight
*
aux_loss
return
main_loss
+
self
.
aux_weight
*
aux_loss
...
@@ -37,6 +38,7 @@ class MoeLoss(_Loss):
...
@@ -37,6 +38,7 @@ class MoeLoss(_Loss):
:type aux_weight: float
:type aux_weight: float
:type loss_fn: Callable
:type loss_fn: Callable
"""
"""
def
__init__
(
self
,
aux_weight
:
float
,
loss_fn
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
aux_weight
:
float
,
loss_fn
,
*
args
,
**
kwargs
):
super
().
__init__
()
super
().
__init__
()
self
.
loss_fn
=
loss_fn
(
*
args
,
**
kwargs
)
self
.
loss_fn
=
loss_fn
(
*
args
,
**
kwargs
)
...
@@ -44,5 +46,5 @@ class MoeLoss(_Loss):
...
@@ -44,5 +46,5 @@ class MoeLoss(_Loss):
def
forward
(
self
,
*
args
,
**
kwargs
):
def
forward
(
self
,
*
args
,
**
kwargs
):
main_loss
=
self
.
loss_fn
(
*
args
,
**
kwargs
)
main_loss
=
self
.
loss_fn
(
*
args
,
**
kwargs
)
aux_loss
=
moe_env
.
get_loss
()
aux_loss
=
MOE_CONTEXT
.
get_loss
()
return
main_loss
+
self
.
aux_weight
*
aux_loss
return
main_loss
+
self
.
aux_weight
*
aux_loss
colossalai/utils/moe.py
View file @
aff9d354
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
colossalai.core
import
global_context
as
gpc
,
moe_context
as
moe_env
from
colossalai.core
import
global_context
as
gpc
,
MOE_CONTEXT
from
colossalai.context
import
ParallelMode
from
colossalai.context
import
ParallelMode
from
.common
import
is_using_ddp
from
.common
import
is_using_ddp
from
typing
import
Dict
,
List
from
typing
import
Dict
,
List
...
@@ -45,7 +45,7 @@ def sync_moe_model_param(model: nn.Module):
...
@@ -45,7 +45,7 @@ def sync_moe_model_param(model: nn.Module):
for
ep_size
in
param_dict
:
for
ep_size
in
param_dict
:
# When ep_size = world_size, communication is not needed
# When ep_size = world_size, communication is not needed
if
ep_size
!=
1
and
ep_size
!=
moe_env
.
world_size
:
if
ep_size
!=
1
and
ep_size
!=
MOE_CONTEXT
.
world_size
:
src_rank
=
dist
.
get_rank
(
moe_env
.
information
[
ep_size
].
ep_group
)
src_rank
=
dist
.
get_rank
(
MOE_CONTEXT
.
information
[
ep_size
].
ep_group
)
for
param
in
param_dict
[
ep_size
]:
for
param
in
param_dict
[
ep_size
]:
dist
.
broadcast
(
param
,
src
=
src_rank
,
group
=
param
.
moe_info
.
dp_group
)
dist
.
broadcast
(
param
,
src
=
src_rank
,
group
=
param
.
moe_info
.
dp_group
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment