Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
1a1d68b0
Unverified
Commit
1a1d68b0
authored
Mar 31, 2023
by
HELSON
Committed by
GitHub
Mar 31, 2023
Browse files
[moe] add checkpoint for moe models (#3354)
* [moe] add checkpoint for moe models * [hotfix] fix bugs in unit test
parent
fee2af86
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
517 additions
and
384 deletions
+517
-384
colossalai/nn/layer/moe/__init__.py
colossalai/nn/layer/moe/__init__.py
+10
-9
colossalai/nn/layer/moe/checkpoint.py
colossalai/nn/layer/moe/checkpoint.py
+40
-0
colossalai/nn/layer/moe/experts.py
colossalai/nn/layer/moe/experts.py
+203
-172
colossalai/nn/layer/moe/layers.py
colossalai/nn/layer/moe/layers.py
+210
-203
tests/test_moe/test_moe_checkpoint.py
tests/test_moe/test_moe_checkpoint.py
+54
-0
No files found.
colossalai/nn/layer/moe/__init__.py
View file @
1a1d68b0
from
.checkpoint
import
load_moe_model
,
save_moe_model
from
.experts
import
Experts
,
FFNExperts
,
TPExperts
from
.layers
import
MoeLayer
,
MoeModule
from
.routers
import
MoeRouter
,
Top1Router
,
Top2Router
...
...
@@ -5,5 +6,5 @@ from .utils import NormalNoiseGenerator, UniformNoiseGenerator, build_ffn_expert
__all__
=
[
'Experts'
,
'FFNExperts'
,
'TPExperts'
,
'Top1Router'
,
'Top2Router'
,
'MoeLayer'
,
'NormalNoiseGenerator'
,
'UniformNoiseGenerator'
,
'build_ffn_experts'
,
'MoeModule'
,
'MoeRouter'
'UniformNoiseGenerator'
,
'build_ffn_experts'
,
'MoeModule'
,
'MoeRouter'
,
'save_moe_model'
,
'load_moe_model'
]
colossalai/nn/layer/moe/checkpoint.py
0 → 100644
View file @
1a1d68b0
import
torch
import
torch.distributed
as
dist
import
torch.nn
as
nn
from
.experts
import
MoeExperts
def
save_moe_model
(
model
:
nn
.
Module
,
save_path
:
str
):
state_dict
=
model
.
state_dict
()
if
dist
.
get_rank
()
==
0
:
torch
.
save
(
state_dict
,
save_path
)
dist
.
barrier
()
def
load_moe_model
(
model
:
nn
.
Module
,
load_path
:
str
):
state_dict
=
torch
.
load
(
load_path
)
for
prefix
,
module
in
model
.
named_modules
():
if
prefix
.
endswith
(
'.moe_layer.experts'
):
# this module should be an Experts instance
assert
isinstance
(
module
,
MoeExperts
)
ep_rank
=
dist
.
get_rank
(
module
.
dist_info
.
ep_group
)
num_local
=
module
.
num_local_experts
for
i
in
range
(
num_local
):
expert_id
=
ep_rank
*
num_local
+
i
for
name
,
_
in
module
.
experts
[
i
].
named_parameters
():
cur_key
=
f
'
{
prefix
}
.experts.
{
i
}
.
{
name
}
'
param_key
=
f
'
{
prefix
}
.experts.
{
expert_id
}
.
{
name
}
'
load_param
=
state_dict
[
param_key
]
state_dict
[
cur_key
]
=
load_param
for
name
,
_
in
module
.
experts
[
0
].
named_parameters
():
pop_pre
=
f
'
{
prefix
}
.experts.'
pop_suf
=
f
'.
{
name
}
'
for
i
in
range
(
num_local
,
module
.
num_total_experts
):
pop_key
=
f
'
{
pop_pre
}{
i
}{
pop_suf
}
'
state_dict
.
pop
(
pop_key
)
model
.
load_state_dict
(
state_dict
)
colossalai/nn/layer/moe/experts.py
View file @
1a1d68b0
import
math
from
copy
import
deepcopy
from
typing
import
Type
import
torch
import
torch.distributed
as
dist
import
torch.nn
as
nn
from
colossalai.context
import
ParallelMode
,
seed
from
colossalai.utils
import
get_current_device
from
colossalai.context.moe_context
import
MOE_CONTEXT
from
colossalai.utils
import
get_current_device
from
colossalai.zero.init_ctx
import
no_shard_zero_decrator
from
typing
import
Type
class
MoeExperts
(
nn
.
Module
):
...
...
@@ -20,6 +23,7 @@ class MoeExperts(nn.Module):
assert
comm_name
in
{
"all_to_all"
,
"all_gather"
},
\
"This kind of communication has not been implemented yet.
\n
Please use Experts build function."
self
.
comm_name
=
comm_name
self
.
num_total_experts
=
num_experts
# Get the configuration of experts' deployment and parallel information from moe contex
self
.
num_local_experts
,
self
.
dist_info
=
MOE_CONTEXT
.
get_info
(
num_experts
)
...
...
@@ -61,6 +65,33 @@ class Experts(MoeExperts):
output
=
torch
.
cat
(
expert_output
,
dim
=
1
).
contiguous
()
return
output
def
state_dict
(
self
,
destination
=
None
,
prefix
=
''
,
keep_vars
=
False
):
assert
keep_vars
==
False
,
"Only support keep_vars=False now"
dp_rank
=
dist
.
get_rank
(
self
.
dist_info
.
dp_group
)
ep_rank
=
dist
.
get_rank
(
self
.
dist_info
.
ep_group
)
submodule_dict
=
dict
()
example_submodule
=
None
for
name
,
subm
in
self
.
experts
.
named_modules
():
if
subm
is
self
.
experts
:
continue
module_number
=
self
.
num_local_experts
*
ep_rank
+
int
(
name
)
submodule_dict
[
module_number
]
=
subm
example_submodule
=
subm
if
dp_rank
==
0
:
local_prefix
=
prefix
+
'experts.'
buffer_module
=
deepcopy
(
example_submodule
)
for
i
in
range
(
self
.
num_total_experts
):
source_rank
=
i
//
self
.
num_local_experts
current_prefix
=
local_prefix
+
str
(
i
)
+
'.'
comm_module
=
submodule_dict
.
get
(
i
,
buffer_module
)
for
name
,
param
in
comm_module
.
named_parameters
():
dist
.
broadcast
(
param
.
data
,
src
=
source_rank
,
group
=
self
.
dist_info
.
ep_group
)
if
ep_rank
==
0
:
destination
[
current_prefix
+
name
]
=
param
.
data
.
cpu
()
dist
.
barrier
()
class
FFNExperts
(
MoeExperts
):
"""Use torch.bmm to speed up for multiple experts.
...
...
colossalai/nn/layer/moe/layers.py
View file @
1a1d68b0
import
math
from
typing
import
Optional
,
Tuple
,
Type
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
colossalai.context.moe_context
import
MOE_CONTEXT
from
colossalai.utils
import
get_current_device
from
colossalai.nn.layer.moe._operation
import
COL_MOE_KERNEL_FLAG
,
AllToAll
,
AllGather
,
\
ReduceScatter
,
MoeDispatch
,
MoeCombine
from
colossalai.nn.layer.moe.experts
import
MoeExperts
,
Experts
from
colossalai.nn.layer.moe.utils
import
UniformNoiseGenerator
,
NormalNoiseGenerator
from
colossalai.nn.layer.moe._operation
import
(
COL_MOE_KERNEL_FLAG
,
AllGather
,
AllToAll
,
MoeCombine
,
MoeDispatch
,
ReduceScatter
,
)
from
colossalai.nn.layer.moe.experts
import
Experts
,
MoeExperts
from
colossalai.nn.layer.moe.routers
import
MoeRouter
,
Top1Router
,
Top2Router
from
colossalai.nn.layer.moe.utils
import
NormalNoiseGenerator
,
UniformNoiseGenerator
from
colossalai.utils
import
get_current_device
from
colossalai.zero.init_ctx
import
no_shard_zero_context
,
no_shard_zero_decrator
from
typing
import
Optional
,
Type
,
Tuple
@
no_shard_zero_decrator
(
is_replicated
=
True
)
...
...
@@ -178,16 +185,16 @@ class MoeModule(nn.Module):
self
.
residual_combine
=
nn
.
Linear
(
dim_model
,
2
,
device
=
get_current_device
())
if
expert_instance
is
not
None
:
self
.
experts
=
expert_instance
my_
experts
=
expert_instance
else
:
assert
expert_cls
is
not
None
,
\
"Expert class can't be None when experts instance is not given"
self
.
experts
=
Experts
(
expert_cls
,
num_experts
,
**
expert_args
)
my_
experts
=
Experts
(
expert_cls
,
num_experts
,
**
expert_args
)
self
.
moe_layer
=
MoeLayer
(
dim_model
=
dim_model
,
num_experts
=
num_experts
,
router
=
self
.
moe_router
,
experts
=
self
.
experts
)
experts
=
my_
experts
)
def
forward
(
self
,
inputs
:
torch
.
Tensor
):
moe_output
,
l_aux
=
self
.
moe_layer
(
inputs
)
...
...
tests/test_moe/test_moe_checkpoint.py
0 → 100644
View file @
1a1d68b0
import
os
from
functools
import
partial
import
pytest
import
torch
import
torch.distributed
as
dist
import
torch.multiprocessing
as
mp
import
colossalai
from
colossalai.context
import
MOE_CONTEXT
from
colossalai.nn.layer.moe
import
load_moe_model
,
save_moe_model
from
colossalai.testing
import
parameterize
,
rerun_if_address_is_in_use
from
colossalai.utils
import
free_port
,
get_current_device
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
from
tests.test_moe.test_moe_zero_init
import
MoeModel
from
tests.test_tensor.common_utils
import
debug_print
from
tests.test_zero.common
import
CONFIG
def
exam_moe_checkpoint
():
with
ColoInitContext
(
device
=
get_current_device
()):
model
=
MoeModel
(
checkpoint
=
True
)
save_moe_model
(
model
,
'temp_path.pth'
)
with
ColoInitContext
(
device
=
get_current_device
()):
other_model
=
MoeModel
(
checkpoint
=
True
)
load_moe_model
(
other_model
,
'temp_path.pth'
)
state_0
=
model
.
state_dict
()
state_1
=
other_model
.
state_dict
()
for
k
,
v
in
state_0
.
items
():
u
=
state_1
.
get
(
k
)
assert
torch
.
equal
(
u
.
data
,
v
.
data
)
if
dist
.
get_rank
()
==
0
:
os
.
remove
(
'temp_path.pth'
)
def
_run_dist
(
rank
,
world_size
,
port
):
colossalai
.
launch
(
config
=
CONFIG
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
MOE_CONTEXT
.
setup
(
seed
=
42
)
exam_moe_checkpoint
()
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
2
,
4
])
@
rerun_if_address_is_in_use
()
def
test_moe_checkpoint
(
world_size
):
run_func
=
partial
(
_run_dist
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
if
__name__
==
'__main__'
:
test_moe_checkpoint
(
world_size
=
4
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment