Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
65c0f380
Unverified
Commit
65c0f380
authored
Mar 21, 2022
by
Jiarui Fang
Committed by
GitHub
Mar 21, 2022
Browse files
[format] polish name format for MOE (#481)
parent
8d3250d7
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
45 additions
and
32 deletions
+45
-32
colossalai/context/moe_context.py
colossalai/context/moe_context.py
+23
-12
colossalai/engine/gradient_handler/_moe_gradient_handler.py
colossalai/engine/gradient_handler/_moe_gradient_handler.py
+2
-1
colossalai/nn/layer/moe/experts.py
colossalai/nn/layer/moe/experts.py
+3
-2
colossalai/utils/moe.py
colossalai/utils/moe.py
+1
-1
tests/test_moe/test_grad_handler.py
tests/test_moe/test_grad_handler.py
+1
-1
tests/test_moe/test_moe_group.py
tests/test_moe/test_moe_group.py
+15
-15
No files found.
colossalai/context/moe_context.py
View file @
65c0f380
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
.parallel_mode
import
ParallelMode
from
.parallel_mode
import
ParallelMode
from
typing
import
Tuple
def
_check_sanity
():
def
_check_sanity
():
...
@@ -10,7 +11,7 @@ def _check_sanity():
...
@@ -10,7 +11,7 @@ def _check_sanity():
"pipeline parallel at present."
)
"pipeline parallel at present."
)
class
MoeInfo
:
class
Moe
Parallel
Info
:
"""Moe parallelism information, storing parallel sizes and groups.
"""Moe parallelism information, storing parallel sizes and groups.
"""
"""
...
@@ -78,11 +79,11 @@ class MoeContext:
...
@@ -78,11 +79,11 @@ class MoeContext:
self
.
use_kernel_optim
=
True
self
.
use_kernel_optim
=
True
self
.
has_setup
=
False
self
.
has_setup
=
False
self
.
_info_dict
=
dict
()
self
.
_
parallel_
info_dict
=
dict
()
@
property
@
property
def
information
(
self
):
def
parallel_info_dict
(
self
):
return
self
.
_info_dict
return
self
.
_
parallel_
info_dict
@
property
@
property
def
is_initialized
(
self
):
def
is_initialized
(
self
):
...
@@ -110,17 +111,27 @@ class MoeContext:
...
@@ -110,17 +111,27 @@ class MoeContext:
moe_set_seed
(
seed
)
moe_set_seed
(
seed
)
self
.
has_setup
=
True
self
.
has_setup
=
True
def
get_info
(
self
,
num_experts
:
int
):
def
get_info
(
self
,
num_experts
:
int
)
->
Tuple
[
int
,
MoeParallelInfo
]:
"""Automatically deploys experts and returns parallel infomation about
"""Calculate the Data Parallel Group and Expert Parallel Group.
distributed communication groups.
Parameters
----------
num_experts : int
The number experts
Returns
-------
int, MoeParallelInfo
number of local experts, the MoeParallelInfo of the current ep_size
"""
"""
gt_flag
=
num_experts
%
self
.
max_ep_size
==
0
# check whether num_experts is greater
gt_flag
=
num_experts
%
self
.
max_ep_size
==
0
# check whether num_experts is greater
lt_flag
=
self
.
max_ep_size
%
num_experts
==
0
# check whether num_experts is less
lt_flag
=
self
.
max_ep_size
%
num_experts
==
0
# check whether num_experts is less
assert
gt_flag
or
lt_flag
,
"Automatic experts placement do not support such situation right now."
assert
gt_flag
or
lt_flag
,
"Automatic experts placement dose not not support expert number"
\
" is not a multiple of ep size or vice versa."
# If the number of experts is greater than maximum expert parallel size,
# If the number of experts is greater than maximum expert parallel
size. a.k.a ep_
size,
# there are multiple experts in each GPU and each GPU has different experts
# there are multiple experts in each GPU and each GPU has different experts
# So it's data parallel size is 1
# So it's data parallel size is 1
# Otherwise, there is only one expert in each GPU
# Otherwise, there is only one expert in each GPU
...
@@ -133,10 +144,10 @@ class MoeContext:
...
@@ -133,10 +144,10 @@ class MoeContext:
# Don't forget to multiply minimum data parallel size
# Don't forget to multiply minimum data parallel size
dp_size
*=
self
.
min_dp_size
dp_size
*=
self
.
min_dp_size
if
not
(
ep_size
in
self
.
information
):
if
not
(
ep_size
in
self
.
parallel_info_dict
):
self
.
information
[
ep_size
]
=
MoeInfo
(
ep_size
,
dp_size
)
self
.
parallel_info_dict
[
ep_size
]
=
Moe
Parallel
Info
(
ep_size
,
dp_size
)
return
num_local_experts
,
self
.
information
[
ep_size
]
return
num_local_experts
,
self
.
parallel_info_dict
[
ep_size
]
def
set_kernel_not_use
(
self
):
def
set_kernel_not_use
(
self
):
self
.
use_kernel_optim
=
False
self
.
use_kernel_optim
=
False
...
...
colossalai/engine/gradient_handler/_moe_gradient_handler.py
View file @
65c0f380
...
@@ -31,4 +31,5 @@ class MoeGradientHandler(BaseGradientHandler):
...
@@ -31,4 +31,5 @@ class MoeGradientHandler(BaseGradientHandler):
for
ep_size
in
param_dict
:
for
ep_size
in
param_dict
:
if
ep_size
!=
1
and
ep_size
!=
MOE_CONTEXT
.
world_size
:
if
ep_size
!=
1
and
ep_size
!=
MOE_CONTEXT
.
world_size
:
bucket_allreduce
(
param_list
=
param_dict
[
ep_size
],
group
=
MOE_CONTEXT
.
information
[
ep_size
].
dp_group
)
bucket_allreduce
(
param_list
=
param_dict
[
ep_size
],
group
=
MOE_CONTEXT
.
parallel_info_dict
[
ep_size
].
dp_group
)
colossalai/nn/layer/moe/experts.py
View file @
65c0f380
...
@@ -5,6 +5,7 @@ import torch.nn as nn
...
@@ -5,6 +5,7 @@ import torch.nn as nn
from
colossalai.context
import
ParallelMode
,
seed
from
colossalai.context
import
ParallelMode
,
seed
from
colossalai.utils
import
get_current_device
from
colossalai.utils
import
get_current_device
from
colossalai.core
import
MOE_CONTEXT
from
colossalai.core
import
MOE_CONTEXT
from
typing
import
Type
class
MoeExperts
(
nn
.
Module
):
class
MoeExperts
(
nn
.
Module
):
...
@@ -34,12 +35,12 @@ class Experts(MoeExperts):
...
@@ -34,12 +35,12 @@ class Experts(MoeExperts):
:type num_experts: int
:type num_experts: int
"""
"""
def
__init__
(
self
,
expert
,
num_experts
,
**
expert_args
):
def
__init__
(
self
,
expert
_cls
:
Type
[
nn
.
Module
]
,
num_experts
:
int
,
**
expert_args
):
super
().
__init__
(
"all_to_all"
,
num_experts
)
super
().
__init__
(
"all_to_all"
,
num_experts
)
# Use seed to make every expert different from others
# Use seed to make every expert different from others
with
seed
(
ParallelMode
.
TENSOR
):
with
seed
(
ParallelMode
.
TENSOR
):
self
.
experts
=
nn
.
ModuleList
([
expert
(
**
expert_args
)
for
_
in
range
(
self
.
num_local_experts
)])
self
.
experts
=
nn
.
ModuleList
([
expert
_cls
(
**
expert_args
)
for
_
in
range
(
self
.
num_local_experts
)])
# Attach parallel information for all parameters in Experts
# Attach parallel information for all parameters in Experts
for
exp
in
self
.
experts
:
for
exp
in
self
.
experts
:
...
...
colossalai/utils/moe.py
View file @
65c0f380
...
@@ -46,6 +46,6 @@ def sync_moe_model_param(model: nn.Module):
...
@@ -46,6 +46,6 @@ def sync_moe_model_param(model: nn.Module):
for
ep_size
in
param_dict
:
for
ep_size
in
param_dict
:
# When ep_size = world_size, communication is not needed
# When ep_size = world_size, communication is not needed
if
ep_size
!=
1
and
ep_size
!=
MOE_CONTEXT
.
world_size
:
if
ep_size
!=
1
and
ep_size
!=
MOE_CONTEXT
.
world_size
:
src_rank
=
dist
.
get_rank
(
MOE_CONTEXT
.
information
[
ep_size
].
ep_group
)
src_rank
=
dist
.
get_rank
(
MOE_CONTEXT
.
parallel_info_dict
[
ep_size
].
ep_group
)
for
param
in
param_dict
[
ep_size
]:
for
param
in
param_dict
[
ep_size
]:
dist
.
broadcast
(
param
,
src
=
src_rank
,
group
=
param
.
moe_info
.
dp_group
)
dist
.
broadcast
(
param
,
src
=
src_rank
,
group
=
param
.
moe_info
.
dp_group
)
tests/test_moe/test_grad_handler.py
View file @
65c0f380
...
@@ -36,7 +36,7 @@ def run_test(rank, world_size, port):
...
@@ -36,7 +36,7 @@ def run_test(rank, world_size, port):
model
=
model
.
to
(
get_current_device
())
model
=
model
.
to
(
get_current_device
())
sync_moe_model_param
(
model
)
sync_moe_model_param
(
model
)
dist_dict
=
MOE_CONTEXT
.
information
dist_dict
=
MOE_CONTEXT
.
parallel_info_dict
assert_equal_in_group
(
layer_list
[
0
].
experts
.
experts
[
0
].
weight
.
data
,
dist_dict
[
1
].
dp_group
)
assert_equal_in_group
(
layer_list
[
0
].
experts
.
experts
[
0
].
weight
.
data
,
dist_dict
[
1
].
dp_group
)
assert_equal_in_group
(
layer_list
[
1
].
experts
.
experts
[
0
].
weight
.
data
,
dist_dict
[
2
].
dp_group
)
assert_equal_in_group
(
layer_list
[
1
].
experts
.
experts
[
0
].
weight
.
data
,
dist_dict
[
2
].
dp_group
)
# MoE model synchronization passed
# MoE model synchronization passed
...
...
tests/test_moe/test_moe_group.py
View file @
65c0f380
from
functools
import
partial
from
functools
import
partial
import
pytest
import
pytest
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.multiprocessing
as
mp
import
torch.multiprocessing
as
mp
import
torch.distributed
as
dist
import
torch.distributed
as
dist
...
@@ -16,7 +15,8 @@ D_FF = 8
...
@@ -16,7 +15,8 @@ D_FF = 8
CONFIG
=
dict
()
CONFIG
=
dict
()
def
run_test
(
rank
,
world_size
,
port
):
def
run_test
(
rank
,
port
):
world_size
=
4
colossalai
.
launch
(
config
=
CONFIG
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
colossalai
.
launch
(
config
=
CONFIG
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
expert_module
=
nn
.
Linear
expert_module
=
nn
.
Linear
expert_factor
=
dict
(
in_features
=
D_MODEL
,
out_features
=
D_FF
,
device
=
get_current_device
())
expert_factor
=
dict
(
in_features
=
D_MODEL
,
out_features
=
D_FF
,
device
=
get_current_device
())
...
@@ -33,36 +33,36 @@ def run_test(rank, world_size, port):
...
@@ -33,36 +33,36 @@ def run_test(rank, world_size, port):
assert
exp3
.
num_local_experts
==
2
assert
exp3
.
num_local_experts
==
2
# experts deployment passed
# experts deployment passed
dist
_dict
=
MOE_CONTEXT
.
information
parallel_info
_dict
=
MOE_CONTEXT
.
parallel_info_dict
rank
=
dist
.
get_rank
()
rank
=
dist
.
get_rank
()
assert
len
(
dist
_dict
)
==
3
assert
len
(
parallel_info
_dict
)
==
3
assert
dist
.
get_rank
(
dist
_dict
[
4
].
ep_group
)
==
rank
assert
dist
.
get_rank
(
parallel_info
_dict
[
4
].
ep_group
)
==
rank
assert
dist
.
get_rank
(
dist
_dict
[
2
].
ep_group
)
==
rank
%
2
assert
dist
.
get_rank
(
parallel_info
_dict
[
2
].
ep_group
)
==
rank
%
2
assert
dist
.
get_rank
(
dist
_dict
[
1
].
ep_group
)
==
0
assert
dist
.
get_rank
(
parallel_info
_dict
[
1
].
ep_group
)
==
0
assert
dist
.
get_rank
(
dist
_dict
[
4
].
dp_group
)
==
0
assert
dist
.
get_rank
(
parallel_info
_dict
[
4
].
dp_group
)
==
0
assert
dist
.
get_rank
(
dist
_dict
[
2
].
dp_group
)
==
rank
//
2
assert
dist
.
get_rank
(
parallel_info
_dict
[
2
].
dp_group
)
==
rank
//
2
assert
dist
.
get_rank
(
dist
_dict
[
1
].
dp_group
)
==
rank
assert
dist
.
get_rank
(
parallel_info
_dict
[
1
].
dp_group
)
==
rank
# group creation passed
# group creation passed
model
=
nn
.
ModuleList
([
exp0
,
exp1
,
exp2
,
exp3
])
model
=
nn
.
ModuleList
([
exp0
,
exp1
,
exp2
,
exp3
])
model
=
model
.
to
(
get_current_device
())
model
=
model
.
to
(
get_current_device
())
sync_moe_model_param
(
model
)
sync_moe_model_param
(
model
)
assert_equal_in_group
(
exp0
.
experts
[
0
].
weight
.
data
,
dist
_dict
[
1
].
dp_group
)
assert_equal_in_group
(
exp0
.
experts
[
0
].
weight
.
data
,
parallel_info
_dict
[
1
].
dp_group
)
assert_equal_in_group
(
exp0
.
experts
[
0
].
bias
.
data
,
dist
_dict
[
1
].
dp_group
)
assert_equal_in_group
(
exp0
.
experts
[
0
].
bias
.
data
,
parallel_info
_dict
[
1
].
dp_group
)
# MOE experts layout success when ep_size = 1
# MOE experts layout success when ep_size = 1
assert_equal_in_group
(
exp1
.
experts
[
0
].
weight
.
data
,
dist
_dict
[
2
].
dp_group
)
assert_equal_in_group
(
exp1
.
experts
[
0
].
weight
.
data
,
parallel_info
_dict
[
2
].
dp_group
)
assert_equal_in_group
(
exp1
.
experts
[
0
].
bias
.
data
,
dist
_dict
[
2
].
dp_group
)
assert_equal_in_group
(
exp1
.
experts
[
0
].
bias
.
data
,
parallel_info
_dict
[
2
].
dp_group
)
# MOE experts layout success when ep_size = 2
# MOE experts layout success when ep_size = 2
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
dist
def
test_moe_initialization
():
def
test_moe_initialization
():
world_size
=
4
world_size
=
4
run_func
=
partial
(
run_test
,
world_size
=
world_size
,
port
=
free_port
())
run_func
=
partial
(
run_test
,
port
=
free_port
())
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment