Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
f7f22487
Unverified
Commit
f7f22487
authored
Sep 22, 2022
by
HELSON
Committed by
GitHub
Sep 22, 2022
Browse files
[moe] fix MoE bugs (#1628)
* remove forced FP32 modules * correct no_shard-contexts' positions
parent
38c68b5b
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
26 additions
and
33 deletions
+26
-33
colossalai/nn/layer/moe/experts.py
colossalai/nn/layer/moe/experts.py
+1
-1
colossalai/nn/layer/moe/layers.py
colossalai/nn/layer/moe/layers.py
+17
-14
colossalai/zero/init_ctx/init_context.py
colossalai/zero/init_ctx/init_context.py
+2
-1
tests/test_moe/test_kernel.py
tests/test_moe/test_kernel.py
+4
-3
tests/test_moe/test_moe_zero_init.py
tests/test_moe/test_moe_zero_init.py
+1
-7
tests/test_moe/test_moe_zero_optim.py
tests/test_moe/test_moe_zero_optim.py
+0
-6
tests/test_zero/common.py
tests/test_zero/common.py
+1
-1
No files found.
colossalai/nn/layer/moe/experts.py
View file @
f7f22487
...
...
@@ -24,6 +24,7 @@ class MoeExperts(nn.Module):
self
.
num_local_experts
,
self
.
dist_info
=
MOE_CONTEXT
.
get_info
(
num_experts
)
@
no_shard_zero_decrator
(
is_replicated
=
False
)
class
Experts
(
MoeExperts
):
"""A wrapper class to create experts. It will create E experts across the
moe model parallel group, where E is the number of experts. Every expert
...
...
@@ -35,7 +36,6 @@ class Experts(MoeExperts):
expert_args: Args used to initialize experts, the args could be found in corresponding expert class
"""
@
no_shard_zero_decrator
(
is_replicated
=
False
)
def
__init__
(
self
,
expert_cls
:
Type
[
nn
.
Module
],
num_experts
:
int
,
**
expert_args
):
super
().
__init__
(
"all_to_all"
,
num_experts
)
...
...
colossalai/nn/layer/moe/layers.py
View file @
f7f22487
...
...
@@ -228,6 +228,7 @@ class FP32LinearGate(nn.Module):
return
F
.
linear
(
x
,
self
.
weight
)
@
no_shard_zero_decrator
(
is_replicated
=
True
)
class
MoeLayer
(
nn
.
Module
):
"""A MoE layer, that puts its input tensor to its gate and uses the output logits
to router all tokens, is mainly used to exchange all tokens for every expert across
...
...
@@ -241,12 +242,11 @@ class MoeLayer(nn.Module):
experts (:class:`torch.nn.Module`): Instance of experts generated by Expert.
"""
@
no_shard_zero_decrator
(
is_replicated
=
True
)
def
__init__
(
self
,
dim_model
:
int
,
num_experts
:
int
,
router
:
nn
.
Module
,
experts
:
MoeExperts
):
super
().
__init__
()
self
.
d_model
=
dim_model
self
.
num_experts
=
num_experts
self
.
gate
=
FP32LinearGate
(
dim_model
,
num_experts
)
self
.
gate
_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
dim_model
)
)
self
.
router
=
router
self
.
experts
=
experts
self
.
use_kernel
=
True
if
COL_MOE_KERNEL_FLAG
and
MOE_CONTEXT
.
use_kernel_optim
else
False
...
...
@@ -254,16 +254,14 @@ class MoeLayer(nn.Module):
self
.
ep_size
=
experts
.
dist_info
.
ep_size
self
.
num_local_experts
=
experts
.
num_local_experts
nn
.
init
.
trunc_normal_
(
self
.
gate_weight
,
std
=
math
.
sqrt
(
0.1
/
dim_model
))
def
a2a_process
(
self
,
dispatch_data
:
torch
.
Tensor
):
expert_input
=
AllToAll
.
apply
(
dispatch_data
,
self
.
ep_group
)
input_shape
=
expert_input
.
shape
expert_input
=
expert_input
.
reshape
(
self
.
ep_size
,
self
.
num_local_experts
,
-
1
,
self
.
d_model
)
expert_output
=
self
.
experts
(
expert_input
)
expert_output
=
expert_output
.
reshape
(
input_shape
)
expert_output
=
AllToAll
.
apply
(
expert_output
,
self
.
ep_group
)
return
expert_output
...
...
@@ -274,16 +272,22 @@ class MoeLayer(nn.Module):
return
expert_out
def
forward
(
self
,
inputs
:
torch
.
Tensor
)
->
torch
.
Tensor
:
# reshape the input tokens
tokens
=
inputs
.
reshape
(
-
1
,
self
.
d_model
)
fp32_input
=
tokens
.
to
(
torch
.
float32
)
if
inputs
.
dtype
!=
torch
.
float32
else
tokens
gate_output
=
self
.
gate
(
fp32_input
)
router_res
=
self
.
router
(
inputs
=
gate_output
,
use_kernel
=
self
.
use_kernel
,
ep_group
=
self
.
ep_group
)
# the data type of the inputs in the gating should be fp32
fp32_input
=
tokens
.
to
(
torch
.
float
)
fp32_weight
=
self
.
gate_weight
.
to
(
torch
.
float
)
gate_output
=
F
.
linear
(
fp32_input
,
fp32_weight
)
# the result from the router
route_result_list
=
self
.
router
(
inputs
=
gate_output
,
use_kernel
=
self
.
use_kernel
,
ep_group
=
self
.
ep_group
)
if
self
.
use_kernel
:
dispatch_data
=
MoeDispatch
.
apply
(
tokens
,
*
route
r
_res
[
1
:])
dispatch_data
=
MoeDispatch
.
apply
(
tokens
,
*
route_res
ult_list
[
1
:])
dispatch_data
=
dispatch_data
.
reshape
(
self
.
num_experts
,
-
1
,
self
.
d_model
)
else
:
sec_mask_f
=
route
r
_res
[
1
].
type_as
(
inputs
)
sec_mask_f
=
route_res
ult_list
[
1
].
type_as
(
inputs
)
dispatch_data
=
torch
.
matmul
(
sec_mask_f
.
permute
(
1
,
2
,
0
),
tokens
)
# dispatch_data [e, c, h]
...
...
@@ -295,12 +299,11 @@ class MoeLayer(nn.Module):
raise
NotImplementedError
(
"This kind of communication has not been implemented yet.
\n
Please use Experts "
"build function."
)
# expert_output [e, c, h]
if
self
.
use_kernel
:
expert_output
=
expert_output
.
reshape
(
-
1
,
self
.
d_model
)
ans
=
MoeCombine
.
apply
(
expert_output
,
*
route
r
_res
)
ans
=
MoeCombine
.
apply
(
expert_output
,
*
route_res
ult_list
)
else
:
combine_weights
=
route
r
_res
[
0
].
type_as
(
inputs
)
combine_weights
=
route_res
ult_list
[
0
].
type_as
(
inputs
)
combine_weights
=
combine_weights
.
view
(
combine_weights
.
shape
[
0
],
-
1
)
expert_output
=
expert_output
.
view
(
-
1
,
expert_output
.
shape
[
-
1
])
ans
=
torch
.
matmul
(
combine_weights
,
expert_output
)
...
...
colossalai/zero/init_ctx/init_context.py
View file @
f7f22487
...
...
@@ -258,7 +258,8 @@ def no_shard_zero_decrator(is_replicated: bool = True):
def
_no_shard
(
*
args
,
**
kwargs
):
with
no_shard_zero_context
(
is_replicated
):
init_func
(
*
args
,
**
kwargs
)
ret
=
init_func
(
*
args
,
**
kwargs
)
return
ret
return
_no_shard
...
...
tests/test_moe/test_kernel.py
View file @
f7f22487
...
...
@@ -38,6 +38,7 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f
expert_factor
=
dict
(
in_features
=
hidden_size
,
out_features
=
hidden_size
,
device
=
get_current_device
())
expert
=
Experts
(
expert_module
,
NUM_EXPERTS
,
**
expert_factor
)
layer
=
MoeLayer
(
hidden_size
,
NUM_EXPERTS
,
router
(
capacity_factor_train
=
1.0
),
expert
)
layer
=
layer
.
to
(
get_current_device
())
if
data_type
==
torch
.
float16
:
layer
=
layer
.
half
()
...
...
@@ -50,11 +51,11 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f
# save all results
o_tk_grad
=
tokens
.
grad
.
data
.
clone
()
o_gt_grad
=
layer
.
gate
.
weight
.
grad
.
data
.
clone
()
o_gt_grad
=
layer
.
gate
_
weight
.
grad
.
data
.
clone
()
# reset all gradients
tokens
.
grad
.
zero_
()
layer
.
gate
.
weight
.
grad
.
zero_
()
layer
.
gate
_
weight
.
grad
.
zero_
()
layer
.
use_kernel
=
True
new_out
=
layer
(
tokens
)
# get ouputs through colossal kernel
...
...
@@ -67,7 +68,7 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f
new_out
.
backward
(
grad
)
# get new type gradient
n_tk_grad
=
tokens
.
grad
.
data
.
clone
()
n_gt_grad
=
layer
.
gate
.
weight
.
grad
.
data
.
clone
()
n_gt_grad
=
layer
.
gate
_
weight
.
grad
.
data
.
clone
()
if
data_type
==
torch
.
float32
:
check_equal
(
o_tk_grad
,
n_tk_grad
)
...
...
tests/test_moe/test_moe_zero_init.py
View file @
f7f22487
...
...
@@ -58,15 +58,9 @@ def run_moe_zero_init(init_device_type, shard_strategy_class):
for
name
,
param
in
model
.
named_parameters
():
assert
hasattr
(
param
,
'colo_attr'
)
# the weights in the gate should be fp32
if
'gate'
in
name
:
assert
param
.
colo_attr
.
sharded_data_tensor
.
dtype
==
torch
.
float32
else
:
assert
param
.
colo_attr
.
sharded_data_tensor
.
dtype
==
torch
.
half
# the parameters in moe experts and its gate should not be sharded
if
(
'experts'
in
name
)
or
(
'gate'
in
name
)
or
(
'residual_combine'
in
name
):
assert
not
param
.
colo_attr
.
sharded_data_tensor
.
is_sharded
assert
not
param
.
colo_attr
.
sharded_data_tensor
.
is_sharded
,
"`{}` parameter has problem"
.
format
(
name
)
else
:
assert
param
.
colo_attr
.
sharded_data_tensor
.
is_sharded
...
...
tests/test_moe/test_moe_zero_optim.py
View file @
f7f22487
...
...
@@ -94,12 +94,6 @@ def _run_test_sharded_optim_v2(cpu_offload,
apex_model
,
apex_optimizer
=
convert_to_apex_amp
(
model
,
optim
,
amp_config
)
apex_grad_handler
=
MoeGradientHandler
(
model
)
# Since MOE is not compatible with apex_amp now, we need to convert gate weight to fp32
for
(
n
,
p
),
zp
in
zip
(
apex_model
.
named_parameters
(),
zero_model
.
parameters
()):
if
'gate'
in
n
:
p
.
data
=
p
.
float
()
p
.
data
.
copy_
(
zp
.
colo_attr
.
data_payload
)
for
i
,
(
data
,
label
)
in
enumerate
(
train_dataloader
):
if
i
>
5
:
break
...
...
tests/test_zero/common.py
View file @
f7f22487
...
...
@@ -135,5 +135,5 @@ def check_sharded_model_params(model, zero_model, loose=False, reuse_fp16_shard=
else
:
zero_p
=
zero_p
.
colo_attr
.
data_payload
.
to
(
p
.
device
)
assert
p
.
dtype
==
zero_p
.
dtype
assert
p
.
dtype
==
zero_p
.
dtype
,
"Parameter `{}`:
\n
{} vs {}"
.
format
(
name
,
p
.
dtype
,
zero_p
.
dtype
)
assert
allclose
(
p
,
zero_p
,
loose
=
loose
),
f
'
{
p
}
vs
{
zero_p
}
'
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment