Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
1a1d68b0
Unverified
Commit
1a1d68b0
authored
Mar 31, 2023
by
HELSON
Committed by
GitHub
Mar 31, 2023
Browse files
[moe] add checkpoint for moe models (#3354)
* [moe] add checkpoint for moe models * [hotfix] fix bugs in unit test
parent
fee2af86
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
517 additions
and
384 deletions
+517
-384
colossalai/nn/layer/moe/__init__.py
colossalai/nn/layer/moe/__init__.py
+10
-9
colossalai/nn/layer/moe/checkpoint.py
colossalai/nn/layer/moe/checkpoint.py
+40
-0
colossalai/nn/layer/moe/experts.py
colossalai/nn/layer/moe/experts.py
+203
-172
colossalai/nn/layer/moe/layers.py
colossalai/nn/layer/moe/layers.py
+210
-203
tests/test_moe/test_moe_checkpoint.py
tests/test_moe/test_moe_checkpoint.py
+54
-0
No files found.
colossalai/nn/layer/moe/__init__.py
View file @
1a1d68b0
from
.experts
import
Experts
,
FFNExperts
,
TPExperts
from
.checkpoint
import
load_moe_model
,
save_moe_model
from
.layers
import
MoeLayer
,
MoeModule
from
.experts
import
Experts
,
FFNExperts
,
TPExperts
from
.routers
import
MoeRouter
,
Top1Router
,
Top2Router
from
.layers
import
MoeLayer
,
MoeModule
from
.utils
import
NormalNoiseGenerator
,
UniformNoiseGenerator
,
build_ffn_experts
from
.routers
import
MoeRouter
,
Top1Router
,
Top2Router
from
.utils
import
NormalNoiseGenerator
,
UniformNoiseGenerator
,
build_ffn_experts
__all__
=
[
'Experts'
,
'FFNExperts'
,
'TPExperts'
,
'Top1Router'
,
'Top2Router'
,
'MoeLayer'
,
'NormalNoiseGenerator'
,
__all__
=
[
'UniformNoiseGenerator'
,
'build_ffn_experts'
,
'MoeModule'
,
'MoeRouter'
'Experts'
,
'FFNExperts'
,
'TPExperts'
,
'Top1Router'
,
'Top2Router'
,
'MoeLayer'
,
'NormalNoiseGenerator'
,
]
'UniformNoiseGenerator'
,
'build_ffn_experts'
,
'MoeModule'
,
'MoeRouter'
,
'save_moe_model'
,
'load_moe_model'
]
colossalai/nn/layer/moe/checkpoint.py
0 → 100644
View file @
1a1d68b0
import
torch
import
torch.distributed
as
dist
import
torch.nn
as
nn
from
.experts
import
MoeExperts
def
save_moe_model
(
model
:
nn
.
Module
,
save_path
:
str
):
state_dict
=
model
.
state_dict
()
if
dist
.
get_rank
()
==
0
:
torch
.
save
(
state_dict
,
save_path
)
dist
.
barrier
()
def
load_moe_model
(
model
:
nn
.
Module
,
load_path
:
str
):
state_dict
=
torch
.
load
(
load_path
)
for
prefix
,
module
in
model
.
named_modules
():
if
prefix
.
endswith
(
'.moe_layer.experts'
):
# this module should be an Experts instance
assert
isinstance
(
module
,
MoeExperts
)
ep_rank
=
dist
.
get_rank
(
module
.
dist_info
.
ep_group
)
num_local
=
module
.
num_local_experts
for
i
in
range
(
num_local
):
expert_id
=
ep_rank
*
num_local
+
i
for
name
,
_
in
module
.
experts
[
i
].
named_parameters
():
cur_key
=
f
'
{
prefix
}
.experts.
{
i
}
.
{
name
}
'
param_key
=
f
'
{
prefix
}
.experts.
{
expert_id
}
.
{
name
}
'
load_param
=
state_dict
[
param_key
]
state_dict
[
cur_key
]
=
load_param
for
name
,
_
in
module
.
experts
[
0
].
named_parameters
():
pop_pre
=
f
'
{
prefix
}
.experts.'
pop_suf
=
f
'.
{
name
}
'
for
i
in
range
(
num_local
,
module
.
num_total_experts
):
pop_key
=
f
'
{
pop_pre
}{
i
}{
pop_suf
}
'
state_dict
.
pop
(
pop_key
)
model
.
load_state_dict
(
state_dict
)
colossalai/nn/layer/moe/experts.py
View file @
1a1d68b0
import
math
import
math
from
copy
import
deepcopy
import
torch
from
typing
import
Type
import
torch.nn
as
nn
from
colossalai.context
import
ParallelMode
,
seed
import
torch
from
colossalai.utils
import
get_current_device
import
torch.distributed
as
dist
from
colossalai.context.moe_context
import
MOE_CONTEXT
import
torch.nn
as
nn
from
colossalai.zero.init_ctx
import
no_shard_zero_decrator
from
typing
import
Type
from
colossalai.context
import
ParallelMode
,
seed
from
colossalai.context.moe_context
import
MOE_CONTEXT
from
colossalai.utils
import
get_current_device
class
MoeExperts
(
nn
.
Module
):
from
colossalai.zero.init_ctx
import
no_shard_zero_decrator
"""Basic class for experts in MoE. It stores what kind of communication expersts use
to exchange tokens, how many experts in a single GPU and parallel information such as
expert parallel size, data parallel size and their distributed communication groups.
class
MoeExperts
(
nn
.
Module
):
"""
"""Basic class for experts in MoE. It stores what kind of communication expersts use
to exchange tokens, how many experts in a single GPU and parallel information such as
def
__init__
(
self
,
comm_name
:
str
,
num_experts
:
int
):
expert parallel size, data parallel size and their distributed communication groups.
super
().
__init__
()
"""
assert
comm_name
in
{
"all_to_all"
,
"all_gather"
},
\
"This kind of communication has not been implemented yet.
\n
Please use Experts build function."
def
__init__
(
self
,
comm_name
:
str
,
num_experts
:
int
):
self
.
comm_name
=
comm_name
super
().
__init__
()
# Get the configuration of experts' deployment and parallel information from moe contex
assert
comm_name
in
{
"all_to_all"
,
"all_gather"
},
\
self
.
num_local_experts
,
self
.
dist_info
=
MOE_CONTEXT
.
get_info
(
num_experts
)
"This kind of communication has not been implemented yet.
\n
Please use Experts build function."
self
.
comm_name
=
comm_name
self
.
num_total_experts
=
num_experts
@
no_shard_zero_decrator
(
is_replicated
=
False
)
# Get the configuration of experts' deployment and parallel information from moe contex
class
Experts
(
MoeExperts
):
self
.
num_local_experts
,
self
.
dist_info
=
MOE_CONTEXT
.
get_info
(
num_experts
)
"""A wrapper class to create experts. It will create E experts across the
moe model parallel group, where E is the number of experts. Every expert
is a instence of the class, 'expert' in initialization parameters.
@
no_shard_zero_decrator
(
is_replicated
=
False
)
class
Experts
(
MoeExperts
):
Args:
"""A wrapper class to create experts. It will create E experts across the
expert_cls (:class:`torch.nn.Module`): The class of all experts
moe model parallel group, where E is the number of experts. Every expert
num_experts (int): The number of experts
is a instence of the class, 'expert' in initialization parameters.
expert_args: Args used to initialize experts, the args could be found in corresponding expert class
"""
Args:
expert_cls (:class:`torch.nn.Module`): The class of all experts
def
__init__
(
self
,
expert_cls
:
Type
[
nn
.
Module
],
num_experts
:
int
,
**
expert_args
):
num_experts (int): The number of experts
super
().
__init__
(
"all_to_all"
,
num_experts
)
expert_args: Args used to initialize experts, the args could be found in corresponding expert class
"""
# Use seed to make every expert different from others
with
seed
(
ParallelMode
.
TENSOR
):
def
__init__
(
self
,
expert_cls
:
Type
[
nn
.
Module
],
num_experts
:
int
,
**
expert_args
):
self
.
experts
=
nn
.
ModuleList
([
expert_cls
(
**
expert_args
)
for
_
in
range
(
self
.
num_local_experts
)])
super
().
__init__
(
"all_to_all"
,
num_experts
)
# Attach parallel information for all parameters in Experts
# Use seed to make every expert different from others
for
exp
in
self
.
experts
:
with
seed
(
ParallelMode
.
TENSOR
):
for
param
in
exp
.
parameters
():
self
.
experts
=
nn
.
ModuleList
([
expert_cls
(
**
expert_args
)
for
_
in
range
(
self
.
num_local_experts
)])
param
.
__setattr__
(
'moe_info'
,
self
.
dist_info
)
# Attach parallel information for all parameters in Experts
def
forward
(
self
,
inputs
:
torch
.
Tensor
):
for
exp
in
self
.
experts
:
# Split inputs for each expert
for
param
in
exp
.
parameters
():
expert_input
=
torch
.
chunk
(
inputs
,
self
.
num_local_experts
,
dim
=
1
)
param
.
__setattr__
(
'moe_info'
,
self
.
dist_info
)
expert_output
=
[]
def
forward
(
self
,
inputs
:
torch
.
Tensor
):
# Get outputs from each expert
# Split inputs for each expert
for
i
in
range
(
self
.
num_local_experts
):
expert_input
=
torch
.
chunk
(
inputs
,
self
.
num_local_experts
,
dim
=
1
)
expert_output
.
append
(
self
.
experts
[
i
](
expert_input
[
i
]))
expert_output
=
[]
# Concatenate all outputs together
# Get outputs from each expert
output
=
torch
.
cat
(
expert_output
,
dim
=
1
).
contiguous
()
for
i
in
range
(
self
.
num_local_experts
):
return
output
expert_output
.
append
(
self
.
experts
[
i
](
expert_input
[
i
]))
# Concatenate all outputs together
class
FFNExperts
(
MoeExperts
):
output
=
torch
.
cat
(
expert_output
,
dim
=
1
).
contiguous
()
"""Use torch.bmm to speed up for multiple experts.
return
output
"""
def
state_dict
(
self
,
destination
=
None
,
prefix
=
''
,
keep_vars
=
False
):
def
__init__
(
self
,
num_experts
:
int
,
d_model
:
int
,
d_ff
:
int
,
activation
=
None
,
drop_rate
:
float
=
0
):
assert
keep_vars
==
False
,
"Only support keep_vars=False now"
super
().
__init__
(
"all_to_all"
,
num_experts
)
dp_rank
=
dist
.
get_rank
(
self
.
dist_info
.
dp_group
)
ep_rank
=
dist
.
get_rank
(
self
.
dist_info
.
ep_group
)
self
.
w1
=
nn
.
Parameter
(
torch
.
empty
(
self
.
num_local_experts
,
d_model
,
d_ff
,
device
=
get_current_device
()))
submodule_dict
=
dict
()
self
.
b1
=
nn
.
Parameter
(
torch
.
empty
(
self
.
num_local_experts
,
1
,
d_ff
,
device
=
get_current_device
()))
example_submodule
=
None
for
name
,
subm
in
self
.
experts
.
named_modules
():
self
.
w2
=
nn
.
Parameter
(
torch
.
empty
(
self
.
num_local_experts
,
d_ff
,
d_model
,
device
=
get_current_device
()))
if
subm
is
self
.
experts
:
self
.
b2
=
nn
.
Parameter
(
torch
.
empty
(
self
.
num_local_experts
,
1
,
d_model
,
device
=
get_current_device
()))
continue
module_number
=
self
.
num_local_experts
*
ep_rank
+
int
(
name
)
s1
=
math
.
sqrt
(
0.1
/
d_model
)
submodule_dict
[
module_number
]
=
subm
s2
=
math
.
sqrt
(
0.1
/
d_ff
)
example_submodule
=
subm
with
seed
(
ParallelMode
.
TENSOR
):
if
dp_rank
==
0
:
nn
.
init
.
trunc_normal_
(
self
.
w1
,
std
=
s1
)
local_prefix
=
prefix
+
'experts.'
nn
.
init
.
trunc_normal_
(
self
.
b1
,
std
=
s1
)
buffer_module
=
deepcopy
(
example_submodule
)
nn
.
init
.
trunc_normal_
(
self
.
w2
,
std
=
s2
)
for
i
in
range
(
self
.
num_total_experts
):
nn
.
init
.
trunc_normal_
(
self
.
b2
,
std
=
s2
)
source_rank
=
i
//
self
.
num_local_experts
current_prefix
=
local_prefix
+
str
(
i
)
+
'.'
self
.
act
=
nn
.
GELU
()
if
activation
is
None
else
activation
comm_module
=
submodule_dict
.
get
(
i
,
buffer_module
)
self
.
drop
=
nn
.
Dropout
(
p
=
drop_rate
)
for
name
,
param
in
comm_module
.
named_parameters
():
dist
.
broadcast
(
param
.
data
,
src
=
source_rank
,
group
=
self
.
dist_info
.
ep_group
)
for
param
in
self
.
parameters
():
if
ep_rank
==
0
:
param
.
__setattr__
(
'moe_info'
,
self
.
dist_info
)
destination
[
current_prefix
+
name
]
=
param
.
data
.
cpu
()
def
forward
(
self
,
inputs
):
# inputs [g, el, c, h]
dist
.
barrier
()
el
=
inputs
.
size
(
1
)
h
=
inputs
.
size
(
-
1
)
class
FFNExperts
(
MoeExperts
):
"""Use torch.bmm to speed up for multiple experts.
inputs
=
inputs
.
transpose
(
0
,
1
)
"""
inshape
=
inputs
.
shape
inputs
=
inputs
.
reshape
(
el
,
-
1
,
h
)
def
__init__
(
self
,
num_experts
:
int
,
d_model
:
int
,
d_ff
:
int
,
activation
=
None
,
drop_rate
:
float
=
0
):
super
().
__init__
(
"all_to_all"
,
num_experts
)
out_ff
=
torch
.
baddbmm
(
self
.
b1
,
inputs
,
self
.
w1
)
out_act
=
self
.
act
(
out_ff
)
self
.
w1
=
nn
.
Parameter
(
torch
.
empty
(
self
.
num_local_experts
,
d_model
,
d_ff
,
device
=
get_current_device
()))
with
seed
(
ParallelMode
.
TENSOR
):
self
.
b1
=
nn
.
Parameter
(
torch
.
empty
(
self
.
num_local_experts
,
1
,
d_ff
,
device
=
get_current_device
()))
out_inter
=
self
.
drop
(
out_act
)
self
.
w2
=
nn
.
Parameter
(
torch
.
empty
(
self
.
num_local_experts
,
d_ff
,
d_model
,
device
=
get_current_device
()))
out_model
=
torch
.
baddbmm
(
self
.
b2
,
out_inter
,
self
.
w2
)
self
.
b2
=
nn
.
Parameter
(
torch
.
empty
(
self
.
num_local_experts
,
1
,
d_model
,
device
=
get_current_device
()))
with
seed
(
ParallelMode
.
TENSOR
):
outputs
=
self
.
drop
(
out_model
)
# outputs [el, gc, h]
s1
=
math
.
sqrt
(
0.1
/
d_model
)
s2
=
math
.
sqrt
(
0.1
/
d_ff
)
outputs
=
outputs
.
reshape
(
inshape
)
outputs
=
outputs
.
transpose
(
0
,
1
).
contiguous
()
with
seed
(
ParallelMode
.
TENSOR
):
return
outputs
nn
.
init
.
trunc_normal_
(
self
.
w1
,
std
=
s1
)
nn
.
init
.
trunc_normal_
(
self
.
b1
,
std
=
s1
)
nn
.
init
.
trunc_normal_
(
self
.
w2
,
std
=
s2
)
class
TPExperts
(
MoeExperts
):
nn
.
init
.
trunc_normal_
(
self
.
b2
,
std
=
s2
)
"""Use tensor parallelism to split each expert evenly, which can deploy experts in
case that the number of experts can't be divied by maximum expert parallel size or
self
.
act
=
nn
.
GELU
()
if
activation
is
None
else
activation
maximum expert parallel size can't be divied by the number of experts.
self
.
drop
=
nn
.
Dropout
(
p
=
drop_rate
)
"""
for
param
in
self
.
parameters
():
def
__init__
(
self
,
num_experts
:
int
,
d_model
:
int
,
d_ff
:
int
,
activation
=
None
,
drop_rate
:
float
=
0
):
param
.
__setattr__
(
'moe_info'
,
self
.
dist_info
)
super
().
__init__
(
"all_gather"
,
MOE_CONTEXT
.
max_ep_size
)
def
forward
(
self
,
inputs
):
# inputs [g, el, c, h]
assert
d_ff
%
MOE_CONTEXT
.
max_ep_size
==
0
,
\
"d_ff should be divied by maximum expert parallel size"
el
=
inputs
.
size
(
1
)
h
=
inputs
.
size
(
-
1
)
p_ff
=
d_ff
//
MOE_CONTEXT
.
max_ep_size
inputs
=
inputs
.
transpose
(
0
,
1
)
self
.
w1
=
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
d_model
,
p_ff
,
device
=
get_current_device
()))
inshape
=
inputs
.
shape
self
.
b1
=
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
1
,
p_ff
,
device
=
get_current_device
()))
inputs
=
inputs
.
reshape
(
el
,
-
1
,
h
)
self
.
w2
=
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
p_ff
,
d_model
,
device
=
get_current_device
()))
out_ff
=
torch
.
baddbmm
(
self
.
b1
,
inputs
,
self
.
w1
)
self
.
b2
=
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
1
,
d_model
,
device
=
get_current_device
()))
out_act
=
self
.
act
(
out_ff
)
with
seed
(
ParallelMode
.
TENSOR
):
s1
=
math
.
sqrt
(
0.1
/
d_model
)
out_inter
=
self
.
drop
(
out_act
)
s2
=
math
.
sqrt
(
0.1
/
d_ff
)
out_model
=
torch
.
baddbmm
(
self
.
b2
,
out_inter
,
self
.
w2
)
with
seed
(
ParallelMode
.
TENSOR
):
with
seed
(
ParallelMode
.
TENSOR
):
nn
.
init
.
trunc_normal_
(
self
.
w1
,
std
=
s1
)
outputs
=
self
.
drop
(
out_model
)
# outputs [el, gc, h]
nn
.
init
.
trunc_normal_
(
self
.
b1
,
std
=
s1
)
nn
.
init
.
trunc_normal_
(
self
.
w2
,
std
=
s2
)
outputs
=
outputs
.
reshape
(
inshape
)
outputs
=
outputs
.
transpose
(
0
,
1
).
contiguous
()
nn
.
init
.
trunc_normal_
(
self
.
b2
,
std
=
s2
)
return
outputs
self
.
act
=
nn
.
GELU
()
if
activation
is
None
else
activation
self
.
drop
=
nn
.
Dropout
(
p
=
drop_rate
)
class
TPExperts
(
MoeExperts
):
"""Use tensor parallelism to split each expert evenly, which can deploy experts in
self
.
w1
.
__setattr__
(
'moe_info'
,
self
.
dist_info
)
case that the number of experts can't be divied by maximum expert parallel size or
self
.
w2
.
__setattr__
(
'moe_info'
,
self
.
dist_info
)
maximum expert parallel size can't be divied by the number of experts.
self
.
b1
.
__setattr__
(
'moe_info'
,
self
.
dist_info
)
"""
def
forward
(
self
,
inputs
):
# inputs [g, e, c, h]
def
__init__
(
self
,
num_experts
:
int
,
d_model
:
int
,
d_ff
:
int
,
activation
=
None
,
drop_rate
:
float
=
0
):
super
().
__init__
(
"all_gather"
,
MOE_CONTEXT
.
max_ep_size
)
e
=
inputs
.
size
(
1
)
h
=
inputs
.
size
(
-
1
)
assert
d_ff
%
MOE_CONTEXT
.
max_ep_size
==
0
,
\
"d_ff should be divied by maximum expert parallel size"
inputs
=
inputs
.
transpose
(
0
,
1
)
inshape
=
inputs
.
shape
p_ff
=
d_ff
//
MOE_CONTEXT
.
max_ep_size
inputs
=
inputs
.
reshape
(
e
,
-
1
,
h
)
self
.
w1
=
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
d_model
,
p_ff
,
device
=
get_current_device
()))
out_ff
=
torch
.
baddbmm
(
self
.
b1
,
inputs
,
self
.
w1
)
self
.
b1
=
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
1
,
p_ff
,
device
=
get_current_device
()))
out_act
=
self
.
act
(
out_ff
)
with
seed
(
ParallelMode
.
TENSOR
):
self
.
w2
=
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
p_ff
,
d_model
,
device
=
get_current_device
()))
out_inter
=
self
.
drop
(
out_act
)
self
.
b2
=
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
1
,
d_model
,
device
=
get_current_device
()))
out_model
=
torch
.
baddbmm
(
self
.
b2
,
out_inter
,
self
.
w2
)
s1
=
math
.
sqrt
(
0.1
/
d_model
)
outputs
=
self
.
drop
(
out_model
)
# outputs [e, gc, h]
s2
=
math
.
sqrt
(
0.1
/
d_ff
)
outputs
=
outputs
.
reshape
(
inshape
)
with
seed
(
ParallelMode
.
TENSOR
):
outputs
=
outputs
.
transpose
(
0
,
1
).
contiguous
()
nn
.
init
.
trunc_normal_
(
self
.
w1
,
std
=
s1
)
return
outputs
# outputs [g, e, c, h]
nn
.
init
.
trunc_normal_
(
self
.
b1
,
std
=
s1
)
nn
.
init
.
trunc_normal_
(
self
.
w2
,
std
=
s2
)
nn
.
init
.
trunc_normal_
(
self
.
b2
,
std
=
s2
)
self
.
act
=
nn
.
GELU
()
if
activation
is
None
else
activation
self
.
drop
=
nn
.
Dropout
(
p
=
drop_rate
)
self
.
w1
.
__setattr__
(
'moe_info'
,
self
.
dist_info
)
self
.
w2
.
__setattr__
(
'moe_info'
,
self
.
dist_info
)
self
.
b1
.
__setattr__
(
'moe_info'
,
self
.
dist_info
)
def
forward
(
self
,
inputs
):
# inputs [g, e, c, h]
e
=
inputs
.
size
(
1
)
h
=
inputs
.
size
(
-
1
)
inputs
=
inputs
.
transpose
(
0
,
1
)
inshape
=
inputs
.
shape
inputs
=
inputs
.
reshape
(
e
,
-
1
,
h
)
out_ff
=
torch
.
baddbmm
(
self
.
b1
,
inputs
,
self
.
w1
)
out_act
=
self
.
act
(
out_ff
)
with
seed
(
ParallelMode
.
TENSOR
):
out_inter
=
self
.
drop
(
out_act
)
out_model
=
torch
.
baddbmm
(
self
.
b2
,
out_inter
,
self
.
w2
)
outputs
=
self
.
drop
(
out_model
)
# outputs [e, gc, h]
outputs
=
outputs
.
reshape
(
inshape
)
outputs
=
outputs
.
transpose
(
0
,
1
).
contiguous
()
return
outputs
# outputs [g, e, c, h]
colossalai/nn/layer/moe/layers.py
View file @
1a1d68b0
import
math
import
math
from
typing
import
Optional
,
Tuple
,
Type
import
torch
import
torch.nn
as
nn
import
torch
import
torch.nn.functional
as
F
import
torch.nn
as
nn
from
colossalai.context.moe_context
import
MOE_CONTEXT
import
torch.nn.functional
as
F
from
colossalai.utils
import
get_current_device
from
colossalai.nn.layer.moe._operation
import
COL_MOE_KERNEL_FLAG
,
AllToAll
,
AllGather
,
\
from
colossalai.context.moe_context
import
MOE_CONTEXT
ReduceScatter
,
MoeDispatch
,
MoeCombine
from
colossalai.nn.layer.moe._operation
import
(
from
colossalai.nn.layer.moe.experts
import
MoeExperts
,
Experts
COL_MOE_KERNEL_FLAG
,
from
colossalai.nn.layer.moe.utils
import
UniformNoiseGenerator
,
NormalNoiseGenerator
AllGather
,
from
colossalai.nn.layer.moe.routers
import
MoeRouter
,
Top1Router
,
Top2Router
AllToAll
,
from
colossalai.zero.init_ctx
import
no_shard_zero_context
,
no_shard_zero_decrator
MoeCombine
,
from
typing
import
Optional
,
Type
,
Tuple
MoeDispatch
,
ReduceScatter
,
)
@
no_shard_zero_decrator
(
is_replicated
=
True
)
from
colossalai.nn.layer.moe.experts
import
Experts
,
MoeExperts
class
MoeLayer
(
nn
.
Module
):
from
colossalai.nn.layer.moe.routers
import
MoeRouter
,
Top1Router
,
Top2Router
"""A MoE layer, that puts its input tensor to its gate and uses the output logits
from
colossalai.nn.layer.moe.utils
import
NormalNoiseGenerator
,
UniformNoiseGenerator
to router all tokens, is mainly used to exchange all tokens for every expert across
from
colossalai.utils
import
get_current_device
the moe tensor group by all to all comunication. Then it will get the output of all
from
colossalai.zero.init_ctx
import
no_shard_zero_context
,
no_shard_zero_decrator
experts and exchange the output. At last returns the output of the moe system.
Args:
@
no_shard_zero_decrator
(
is_replicated
=
True
)
dim_model (int): Dimension of model.
class
MoeLayer
(
nn
.
Module
):
num_experts (int): The number of experts.
"""A MoE layer, that puts its input tensor to its gate and uses the output logits
router (MoeRouter): Instance of router used in routing.
to router all tokens, is mainly used to exchange all tokens for every expert across
experts (MoeExperts): Instance of experts generated by Expert.
the moe tensor group by all to all comunication. Then it will get the output of all
"""
experts and exchange the output. At last returns the output of the moe system.
def
__init__
(
self
,
dim_model
:
int
,
num_experts
:
int
,
router
:
MoeRouter
,
experts
:
MoeExperts
):
Args:
super
().
__init__
()
dim_model (int): Dimension of model.
self
.
d_model
=
dim_model
num_experts (int): The number of experts.
self
.
num_experts
=
num_experts
router (MoeRouter): Instance of router used in routing.
self
.
gate_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
dim_model
))
experts (MoeExperts): Instance of experts generated by Expert.
self
.
router
:
MoeRouter
=
router
"""
self
.
experts
:
MoeExperts
=
experts
self
.
use_kernel
=
True
if
COL_MOE_KERNEL_FLAG
and
MOE_CONTEXT
.
use_kernel_optim
else
False
def
__init__
(
self
,
dim_model
:
int
,
num_experts
:
int
,
router
:
MoeRouter
,
experts
:
MoeExperts
):
self
.
ep_group
=
experts
.
dist_info
.
ep_group
super
().
__init__
()
self
.
ep_size
=
experts
.
dist_info
.
ep_size
self
.
d_model
=
dim_model
self
.
num_local_experts
=
experts
.
num_local_experts
self
.
num_experts
=
num_experts
self
.
gate_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
dim_model
))
nn
.
init
.
trunc_normal_
(
self
.
gate_weight
,
std
=
math
.
sqrt
(
0.1
/
dim_model
))
self
.
router
:
MoeRouter
=
router
self
.
experts
:
MoeExperts
=
experts
def
a2a_process
(
self
,
dispatch_data
:
torch
.
Tensor
):
self
.
use_kernel
=
True
if
COL_MOE_KERNEL_FLAG
and
MOE_CONTEXT
.
use_kernel_optim
else
False
expert_input
=
AllToAll
.
apply
(
dispatch_data
,
self
.
ep_group
)
self
.
ep_group
=
experts
.
dist_info
.
ep_group
input_shape
=
expert_input
.
shape
self
.
ep_size
=
experts
.
dist_info
.
ep_size
expert_input
=
expert_input
.
reshape
(
self
.
ep_size
,
self
.
num_local_experts
,
-
1
,
self
.
d_model
)
self
.
num_local_experts
=
experts
.
num_local_experts
expert_output
=
self
.
experts
(
expert_input
)
expert_output
=
expert_output
.
reshape
(
input_shape
)
nn
.
init
.
trunc_normal_
(
self
.
gate_weight
,
std
=
math
.
sqrt
(
0.1
/
dim_model
))
expert_output
=
AllToAll
.
apply
(
expert_output
,
self
.
ep_group
)
return
expert_output
def
a2a_process
(
self
,
dispatch_data
:
torch
.
Tensor
):
expert_input
=
AllToAll
.
apply
(
dispatch_data
,
self
.
ep_group
)
def
tp_process
(
self
,
dispatch_data
:
torch
.
Tensor
):
input_shape
=
expert_input
.
shape
expert_in
=
AllGather
.
apply
(
dispatch_data
,
self
.
ep_group
)
expert_input
=
expert_input
.
reshape
(
self
.
ep_size
,
self
.
num_local_experts
,
-
1
,
self
.
d_model
)
expert_out
=
self
.
experts
(
expert_in
)
expert_output
=
self
.
experts
(
expert_input
)
expert_out
=
ReduceScatter
.
apply
(
expert_out
,
self
.
ep_group
)
expert_output
=
expert_output
.
reshape
(
input_shape
)
return
expert_out
expert_output
=
AllToAll
.
apply
(
expert_output
,
self
.
ep_group
)
return
expert_output
def
forward
(
self
,
inputs
:
torch
.
Tensor
)
->
Tuple
:
# reshape the input tokens
def
tp_process
(
self
,
dispatch_data
:
torch
.
Tensor
):
tokens
=
inputs
.
reshape
(
-
1
,
self
.
d_model
)
expert_in
=
AllGather
.
apply
(
dispatch_data
,
self
.
ep_group
)
expert_out
=
self
.
experts
(
expert_in
)
# the data type of the inputs in the gating should be fp32
expert_out
=
ReduceScatter
.
apply
(
expert_out
,
self
.
ep_group
)
fp32_input
=
tokens
.
to
(
torch
.
float
)
return
expert_out
fp32_weight
=
self
.
gate_weight
.
to
(
torch
.
float
)
gate_output
=
F
.
linear
(
fp32_input
,
fp32_weight
)
def
forward
(
self
,
inputs
:
torch
.
Tensor
)
->
Tuple
:
# reshape the input tokens
# the result from the router
tokens
=
inputs
.
reshape
(
-
1
,
self
.
d_model
)
route_result_list
=
self
.
router
(
inputs
=
gate_output
,
use_kernel
=
self
.
use_kernel
,
ep_group
=
self
.
ep_group
)
# the data type of the inputs in the gating should be fp32
if
self
.
use_kernel
:
fp32_input
=
tokens
.
to
(
torch
.
float
)
dispatch_data
=
MoeDispatch
.
apply
(
tokens
,
*
route_result_list
[
1
:])
fp32_weight
=
self
.
gate_weight
.
to
(
torch
.
float
)
dispatch_data
=
dispatch_data
.
reshape
(
self
.
num_experts
,
-
1
,
self
.
d_model
)
gate_output
=
F
.
linear
(
fp32_input
,
fp32_weight
)
else
:
sec_mask_f
=
route_result_list
[
1
].
type_as
(
inputs
)
# the result from the router
dispatch_data
=
torch
.
matmul
(
sec_mask_f
.
permute
(
1
,
2
,
0
),
tokens
)
route_result_list
=
self
.
router
(
inputs
=
gate_output
,
use_kernel
=
self
.
use_kernel
,
ep_group
=
self
.
ep_group
)
# dispatch_data [e, c, h]
if
self
.
use_kernel
:
if
self
.
experts
.
comm_name
==
"all_to_all"
:
dispatch_data
=
MoeDispatch
.
apply
(
tokens
,
*
route_result_list
[
1
:])
expert_output
=
self
.
a2a_process
(
dispatch_data
)
dispatch_data
=
dispatch_data
.
reshape
(
self
.
num_experts
,
-
1
,
self
.
d_model
)
elif
self
.
experts
.
comm_name
==
"all_gather"
:
else
:
expert_output
=
self
.
tp_process
(
dispatch_data
)
sec_mask_f
=
route_result_list
[
1
].
type_as
(
inputs
)
else
:
dispatch_data
=
torch
.
matmul
(
sec_mask_f
.
permute
(
1
,
2
,
0
),
tokens
)
raise
NotImplementedError
(
"This kind of communication has not been implemented yet.
\n
Please use Experts "
"build function."
)
# dispatch_data [e, c, h]
# expert_output [e, c, h]
if
self
.
experts
.
comm_name
==
"all_to_all"
:
if
self
.
use_kernel
:
expert_output
=
self
.
a2a_process
(
dispatch_data
)
expert_output
=
expert_output
.
reshape
(
-
1
,
self
.
d_model
)
elif
self
.
experts
.
comm_name
==
"all_gather"
:
ans
=
MoeCombine
.
apply
(
expert_output
,
*
route_result_list
)
expert_output
=
self
.
tp_process
(
dispatch_data
)
else
:
else
:
combine_weights
=
route_result_list
[
0
].
type_as
(
inputs
)
raise
NotImplementedError
(
"This kind of communication has not been implemented yet.
\n
Please use Experts "
combine_weights
=
combine_weights
.
view
(
combine_weights
.
shape
[
0
],
-
1
)
"build function."
)
expert_output
=
expert_output
.
view
(
-
1
,
expert_output
.
shape
[
-
1
])
# expert_output [e, c, h]
ans
=
torch
.
matmul
(
combine_weights
,
expert_output
)
if
self
.
use_kernel
:
expert_output
=
expert_output
.
reshape
(
-
1
,
self
.
d_model
)
ans
=
ans
.
reshape
(
inputs
.
shape
)
ans
=
MoeCombine
.
apply
(
expert_output
,
*
route_result_list
)
l_aux
=
self
.
router
.
pop_routing_loss
()
else
:
return
ans
,
l_aux
combine_weights
=
route_result_list
[
0
].
type_as
(
inputs
)
combine_weights
=
combine_weights
.
view
(
combine_weights
.
shape
[
0
],
-
1
)
expert_output
=
expert_output
.
view
(
-
1
,
expert_output
.
shape
[
-
1
])
class
MoeModule
(
nn
.
Module
):
ans
=
torch
.
matmul
(
combine_weights
,
expert_output
)
"""A class for users to create MoE modules in their models.
ans
=
ans
.
reshape
(
inputs
.
shape
)
Args:
l_aux
=
self
.
router
.
pop_routing_loss
()
dim_model (int): Hidden dimension of training model
return
ans
,
l_aux
num_experts (int): The number experts
top_k (int, optional): The number of experts for dispatchment of each token
capacity_factor_train (float, optional): Capacity factor in routing during training
class
MoeModule
(
nn
.
Module
):
capacity_factor_eval (float, optional): Capacity factor in routing during evaluation
"""A class for users to create MoE modules in their models.
min_capacity (int, optional): The minimum number of the capacity of each expert
noisy_policy (str, optional): The policy of noisy function. Now we have 'Jitter' and 'Gaussian'.
Args:
'Jitter' can be found in `Switch Transformer paper`_.
dim_model (int): Hidden dimension of training model
'Gaussian' can be found in `ViT-MoE paper`_.
num_experts (int): The number experts
drop_tks (bool, optional): Whether drops tokens in evaluation
top_k (int, optional): The number of experts for dispatchment of each token
use_residual (bool, optional): Makes this MoE layer a Residual MoE.
capacity_factor_train (float, optional): Capacity factor in routing during training
More information can be found in `Microsoft paper`_.
capacity_factor_eval (float, optional): Capacity factor in routing during evaluation
residual_instance (nn.Module, optional): The instance of residual module in Resiual MoE
min_capacity (int, optional): The minimum number of the capacity of each expert
expert_instance (MoeExperts, optional): The instance of experts module in MoeLayer
noisy_policy (str, optional): The policy of noisy function. Now we have 'Jitter' and 'Gaussian'.
expert_cls (Type[nn.Module], optional): The class of each expert when no instance is given
'Jitter' can be found in `Switch Transformer paper`_.
expert_args (optional): The args of expert when no instance is given
'Gaussian' can be found in `ViT-MoE paper`_.
drop_tks (bool, optional): Whether drops tokens in evaluation
.. _Switch Transformer paper:
use_residual (bool, optional): Makes this MoE layer a Residual MoE.
https://arxiv.org/abs/2101.03961
More information can be found in `Microsoft paper`_.
.. _ViT-MoE paper:
residual_instance (nn.Module, optional): The instance of residual module in Resiual MoE
https://arxiv.org/abs/2106.05974
expert_instance (MoeExperts, optional): The instance of experts module in MoeLayer
.. _Microsoft paper:
expert_cls (Type[nn.Module], optional): The class of each expert when no instance is given
https://arxiv.org/abs/2201.05596
expert_args (optional): The args of expert when no instance is given
"""
.. _Switch Transformer paper:
def
__init__
(
self
,
https://arxiv.org/abs/2101.03961
dim_model
:
int
,
.. _ViT-MoE paper:
num_experts
:
int
,
https://arxiv.org/abs/2106.05974
top_k
:
int
=
1
,
.. _Microsoft paper:
capacity_factor_train
:
float
=
1.25
,
https://arxiv.org/abs/2201.05596
capacity_factor_eval
:
float
=
2.0
,
"""
min_capacity
:
int
=
4
,
noisy_policy
:
Optional
[
str
]
=
None
,
def
__init__
(
self
,
drop_tks
:
bool
=
True
,
dim_model
:
int
,
use_residual
:
bool
=
False
,
num_experts
:
int
,
residual_instance
:
Optional
[
nn
.
Module
]
=
None
,
top_k
:
int
=
1
,
expert_instance
:
Optional
[
MoeExperts
]
=
None
,
capacity_factor_train
:
float
=
1.25
,
expert_cls
:
Optional
[
Type
[
nn
.
Module
]]
=
None
,
capacity_factor_eval
:
float
=
2.0
,
**
expert_args
):
min_capacity
:
int
=
4
,
super
().
__init__
()
noisy_policy
:
Optional
[
str
]
=
None
,
drop_tks
:
bool
=
True
,
noisy_func
=
None
use_residual
:
bool
=
False
,
if
noisy_policy
is
not
None
:
residual_instance
:
Optional
[
nn
.
Module
]
=
None
,
if
noisy_policy
==
'Jitter'
:
expert_instance
:
Optional
[
MoeExperts
]
=
None
,
noisy_func
=
UniformNoiseGenerator
()
expert_cls
:
Optional
[
Type
[
nn
.
Module
]]
=
None
,
elif
noisy_policy
==
'Gaussian'
:
**
expert_args
):
noisy_func
=
NormalNoiseGenerator
(
num_experts
)
super
().
__init__
()
else
:
raise
NotImplementedError
(
"Unsupported input noisy policy"
)
noisy_func
=
None
if
noisy_policy
is
not
None
:
if
top_k
==
1
:
if
noisy_policy
==
'Jitter'
:
moe_router_cls
=
Top1Router
noisy_func
=
UniformNoiseGenerator
()
elif
top_k
==
2
:
elif
noisy_policy
==
'Gaussian'
:
moe_router_cls
=
Top2Router
noisy_func
=
NormalNoiseGenerator
(
num_experts
)
else
:
else
:
raise
NotImplementedError
(
"top_k > 2 is not supported yet"
)
raise
NotImplementedError
(
"Unsupported input noisy policy"
)
self
.
moe_router
=
moe_router_cls
(
capacity_factor_train
=
capacity_factor_train
,
if
top_k
==
1
:
capacity_factor_eval
=
capacity_factor_eval
,
moe_router_cls
=
Top1Router
min_capacity
=
min_capacity
,
elif
top_k
==
2
:
noisy_func
=
noisy_func
,
moe_router_cls
=
Top2Router
drop_tks
=
drop_tks
)
else
:
self
.
use_residual
=
use_residual
raise
NotImplementedError
(
"top_k > 2 is not supported yet"
)
if
use_residual
:
if
residual_instance
is
not
None
:
self
.
moe_router
=
moe_router_cls
(
capacity_factor_train
=
capacity_factor_train
,
self
.
residual_module
=
residual_instance
capacity_factor_eval
=
capacity_factor_eval
,
else
:
min_capacity
=
min_capacity
,
assert
expert_cls
is
not
None
,
\
noisy_func
=
noisy_func
,
"Expert class can't be None when residual instance is not given"
drop_tks
=
drop_tks
)
self
.
residual_module
=
expert_cls
(
**
expert_args
)
self
.
use_residual
=
use_residual
if
use_residual
:
with
no_shard_zero_context
():
if
residual_instance
is
not
None
:
self
.
residual_combine
=
nn
.
Linear
(
dim_model
,
2
,
device
=
get_current_device
())
self
.
residual_module
=
residual_instance
else
:
if
expert_instance
is
not
None
:
assert
expert_cls
is
not
None
,
\
self
.
experts
=
expert_instance
"Expert class can't be None when residual instance is not given"
else
:
self
.
residual_module
=
expert_cls
(
**
expert_args
)
assert
expert_cls
is
not
None
,
\
"Expert class can't be None when experts instance is not given"
with
no_shard_zero_context
():
self
.
experts
=
Experts
(
expert_cls
,
num_experts
,
**
expert_args
)
self
.
residual_combine
=
nn
.
Linear
(
dim_model
,
2
,
device
=
get_current_device
())
self
.
moe_layer
=
MoeLayer
(
dim_model
=
dim_model
,
if
expert_instance
is
not
None
:
num_experts
=
num_experts
,
my_experts
=
expert_instance
router
=
self
.
moe_router
,
else
:
experts
=
self
.
experts
)
assert
expert_cls
is
not
None
,
\
"Expert class can't be None when experts instance is not given"
def
forward
(
self
,
inputs
:
torch
.
Tensor
):
my_experts
=
Experts
(
expert_cls
,
num_experts
,
**
expert_args
)
moe_output
,
l_aux
=
self
.
moe_layer
(
inputs
)
self
.
moe_layer
=
MoeLayer
(
dim_model
=
dim_model
,
if
self
.
use_residual
:
num_experts
=
num_experts
,
residual_output
=
self
.
residual_module
(
inputs
)
router
=
self
.
moe_router
,
combine_coef
=
self
.
residual_combine
(
inputs
)
experts
=
my_experts
)
combine_coef
=
F
.
softmax
(
combine_coef
,
dim
=-
1
)
output
=
moe_output
*
combine_coef
[...,
0
:
1
]
+
residual_output
*
combine_coef
[...,
1
:]
def
forward
(
self
,
inputs
:
torch
.
Tensor
):
else
:
moe_output
,
l_aux
=
self
.
moe_layer
(
inputs
)
output
=
moe_output
if
self
.
use_residual
:
return
output
,
l_aux
residual_output
=
self
.
residual_module
(
inputs
)
combine_coef
=
self
.
residual_combine
(
inputs
)
combine_coef
=
F
.
softmax
(
combine_coef
,
dim
=-
1
)
output
=
moe_output
*
combine_coef
[...,
0
:
1
]
+
residual_output
*
combine_coef
[...,
1
:]
else
:
output
=
moe_output
return
output
,
l_aux
tests/test_moe/test_moe_checkpoint.py
0 → 100644
View file @
1a1d68b0
import
os
from
functools
import
partial
import
pytest
import
torch
import
torch.distributed
as
dist
import
torch.multiprocessing
as
mp
import
colossalai
from
colossalai.context
import
MOE_CONTEXT
from
colossalai.nn.layer.moe
import
load_moe_model
,
save_moe_model
from
colossalai.testing
import
parameterize
,
rerun_if_address_is_in_use
from
colossalai.utils
import
free_port
,
get_current_device
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
from
tests.test_moe.test_moe_zero_init
import
MoeModel
from
tests.test_tensor.common_utils
import
debug_print
from
tests.test_zero.common
import
CONFIG
def
exam_moe_checkpoint
():
with
ColoInitContext
(
device
=
get_current_device
()):
model
=
MoeModel
(
checkpoint
=
True
)
save_moe_model
(
model
,
'temp_path.pth'
)
with
ColoInitContext
(
device
=
get_current_device
()):
other_model
=
MoeModel
(
checkpoint
=
True
)
load_moe_model
(
other_model
,
'temp_path.pth'
)
state_0
=
model
.
state_dict
()
state_1
=
other_model
.
state_dict
()
for
k
,
v
in
state_0
.
items
():
u
=
state_1
.
get
(
k
)
assert
torch
.
equal
(
u
.
data
,
v
.
data
)
if
dist
.
get_rank
()
==
0
:
os
.
remove
(
'temp_path.pth'
)
def
_run_dist
(
rank
,
world_size
,
port
):
colossalai
.
launch
(
config
=
CONFIG
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
MOE_CONTEXT
.
setup
(
seed
=
42
)
exam_moe_checkpoint
()
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
2
,
4
])
@
rerun_if_address_is_in_use
()
def
test_moe_checkpoint
(
world_size
):
run_func
=
partial
(
_run_dist
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
if
__name__
==
'__main__'
:
test_moe_checkpoint
(
world_size
=
4
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment