Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
7d8e0338
Commit
7d8e0338
authored
Dec 14, 2023
by
Xuanlei Zhao
Committed by
ver217
Feb 07, 2024
Browse files
[moe] init mixtral impl
parent
c53ddda8
Changes
28
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
302 additions
and
198 deletions
+302
-198
colossalai/moe/experts.py
colossalai/moe/experts.py
+5
-1
colossalai/moe/layers.py
colossalai/moe/layers.py
+40
-32
colossalai/moe/routers.py
colossalai/moe/routers.py
+26
-7
colossalai/moe/utils.py
colossalai/moe/utils.py
+2
-0
colossalai/zero/low_level/low_level_optim.py
colossalai/zero/low_level/low_level_optim.py
+46
-16
tests/test_moe/moe_utils.py
tests/test_moe/moe_utils.py
+80
-3
tests/test_moe/test_moe_zero_fwd_bwd.py
tests/test_moe/test_moe_zero_fwd_bwd.py
+48
-75
tests/test_moe/test_moe_zero_optim.py
tests/test_moe/test_moe_zero_optim.py
+55
-64
No files found.
colossalai/moe/experts.py
View file @
7d8e0338
...
...
@@ -67,7 +67,11 @@ class MLPExperts(nn.Module):
self
.
ep_size
=
1
if
gated
:
self
.
wi_gate
=
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
hidden_size
,
intermediate_size
*
2
))
self
.
wi_gate
=
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
hidden_size
,
intermediate_size
*
2
if
activation
==
"swiglu"
else
intermediate_size
)
)
self
.
wi_up
=
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
hidden_size
,
intermediate_size
))
else
:
self
.
wi
=
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
hidden_size
,
intermediate_size
))
...
...
colossalai/moe/layers.py
View file @
7d8e0338
...
...
@@ -51,6 +51,8 @@ class SparseMLP(nn.Module):
hidden_size
:
int
,
intermediate_size
:
int
,
router_top_k
:
int
=
1
,
router_loss
:
bool
=
True
,
router_norm
:
bool
=
False
,
router_capacity_factor_train
:
float
=
1.25
,
router_capacity_factor_eval
:
float
=
2.0
,
router_min_capacity
:
int
=
4
,
...
...
@@ -65,15 +67,19 @@ class SparseMLP(nn.Module):
enable_kernel
:
bool
=
False
,
enable_comm_overlap
:
bool
=
False
,
enable_hierarchical_comm
:
bool
=
False
,
return_gate_logits
:
bool
=
False
,
):
super
().
__init__
()
self
.
hidden_size
=
hidden_size
self
.
intermediate_size
=
intermediate_size
self
.
num_experts
=
num_experts
self
.
gated
=
mlp_gated
self
.
return_gate_logits
=
return_gate_logits
self
.
enable_kernel
=
enable_kernel
self
.
enable_comm_overlap
=
enable_comm_overlap
self
.
expert_parallel
=
MOE_MANAGER
.
get_parallel
()
self
.
router_loss
=
router_loss
self
.
router_norm
=
router_norm
# moe router
noisy_func
=
get_noise_generator
(
router_noisy_policy
,
num_experts
)
...
...
@@ -150,9 +156,8 @@ class SparseMLP(nn.Module):
tokens
=
inputs
.
reshape
(
-
1
,
self
.
hidden_size
)
# the data type of the inputs in the gating should be fp32
fp32_input
=
tokens
.
to
(
torch
.
float
)
fp32_weight
=
self
.
gate_weight
.
to
(
torch
.
float
)
gate_output
=
F
.
linear
(
fp32_input
,
fp32_weight
)
gate_logits
=
F
.
linear
(
tokens
,
self
.
gate_weight
)
gate_output
=
gate_logits
.
to
(
torch
.
float
)
# update expert load
if
self
.
enable_load_balance
==
True
:
...
...
@@ -165,7 +170,12 @@ class SparseMLP(nn.Module):
# the result from the router
used_capacity
,
*
route_result_list
=
self
.
router
(
inputs
=
gate_output
,
use_kernel
=
self
.
enable_kernel
,
ep_group
=
self
.
ep_group
)
inputs
=
gate_output
,
use_kernel
=
self
.
enable_kernel
,
ep_group
=
self
.
ep_group
,
use_loss
=
self
.
router_loss
,
use_norm
=
self
.
router_norm
,
)
# dispatch_data: (num_experts, capacity, hidden_size)
if
self
.
enable_kernel
:
...
...
@@ -177,22 +187,15 @@ class SparseMLP(nn.Module):
# expert_output: (num_groups, num_experts, capacity, hidden_size)
if
self
.
expert_parallel
==
"EP"
:
expert_output
=
self
.
_ep_process
(
dispatch_data
,
used_capacity
,
overlap
=
self
.
enable_comm_overlap
)
expert_output
=
self
.
_ep_process
(
dispatch_data
,
used_capacity
,
overlap
=
self
.
enable_comm_overlap
)
elif
self
.
expert_parallel
==
"TP"
:
expert_output
=
self
.
_tp_process
(
dispatch_data
,
used_capacity
,
overlap
=
self
.
enable_comm_overlap
)
expert_output
=
self
.
_tp_process
(
dispatch_data
,
used_capacity
,
overlap
=
self
.
enable_comm_overlap
)
elif
self
.
expert_parallel
is
None
:
expert_output
=
self
.
_local_process
(
dispatch_data
)
else
:
raise
NotImplementedError
(
"This kind of communication has not been implemented yet.
\n
"
"Please use Experts build function."
)
raise
NotImplementedError
(
"This kind of communication has not been implemented yet.
\n
"
"Please use Experts build function."
)
if
self
.
enable_kernel
:
expert_output
=
expert_output
.
reshape
(
-
1
,
self
.
hidden_size
)
...
...
@@ -204,7 +207,11 @@ class SparseMLP(nn.Module):
ans
=
torch
.
matmul
(
combine_weights
,
expert_output
)
ans
=
ans
.
reshape
(
inputs
.
shape
)
return
ans
if
self
.
return_gate_logits
:
return
ans
,
gate_logits
else
:
return
ans
def
_local_process
(
self
,
expert_in
:
torch
.
Tensor
)
->
torch
.
Tensor
:
expert_in
=
expert_in
.
unsqueeze
(
0
)
...
...
@@ -212,10 +219,7 @@ class SparseMLP(nn.Module):
return
expert_out
def
_ep_process
(
self
,
dispatch_data
:
torch
.
Tensor
,
used_capacity
:
torch
.
Tensor
,
overlap
:
bool
=
False
self
,
dispatch_data
:
torch
.
Tensor
,
used_capacity
:
torch
.
Tensor
,
overlap
:
bool
=
False
)
->
torch
.
Tensor
:
"""
Expert Parallel
...
...
@@ -228,10 +232,14 @@ class SparseMLP(nn.Module):
"""
if
not
overlap
or
dist
.
get_world_size
(
self
.
ep_group
)
==
1
:
if
self
.
ep_hierarchical_group
is
not
None
:
expert_input
=
HierarchicalAllToAll
.
apply
(
dispatch_data
,
self
.
ep_hierarchical_group
,
self
.
ep_intra_src_rank
)
expert_input
=
HierarchicalAllToAll
.
apply
(
dispatch_data
,
self
.
ep_hierarchical_group
,
self
.
ep_intra_src_rank
)
expert_input
=
expert_input
.
reshape
(
self
.
ep_size
,
self
.
num_local_experts
,
-
1
,
self
.
hidden_size
)
expert_output
=
self
.
experts
(
expert_input
)
expert_output
=
HierarchicalAllToAll
.
apply
(
expert_output
,
self
.
ep_hierarchical_group
,
self
.
ep_intra_src_rank
)
expert_output
=
HierarchicalAllToAll
.
apply
(
expert_output
,
self
.
ep_hierarchical_group
,
self
.
ep_intra_src_rank
)
return
expert_output
else
:
expert_input
=
AllToAll
.
apply
(
dispatch_data
,
self
.
ep_group
,
False
)[
0
]
...
...
@@ -249,7 +257,7 @@ class SparseMLP(nn.Module):
NUM_CHUNK
=
4
NUM_STAGES
=
4
assert
(
dispatch_data
.
shape
[
1
]
%
NUM_CHUNK
==
0
)
,
"arbitrary chunk num is not supported yet"
assert
dispatch_data
.
shape
[
1
]
%
NUM_CHUNK
==
0
,
"arbitrary chunk num is not supported yet"
chunk_size
=
dispatch_data
.
shape
[
1
]
//
NUM_CHUNK
input_shape
=
(
self
.
ep_size
,
self
.
num_local_experts
,
-
1
,
self
.
hidden_size
)
dispatch_data
=
dispatch_data
.
reshape
(
*
input_shape
)
...
...
@@ -262,13 +270,15 @@ class SparseMLP(nn.Module):
for
i
in
range
(
NUM_CHUNK
+
NUM_STAGES
-
1
):
if
expert_out
is
not
None
:
expert_out
.
handle
.
wait
()
output
[:,
:,
offset
:
offset
+
chunk_size
,
:]
=
expert_out
.
data
output
[:,
:,
offset
:
offset
+
chunk_size
,
:]
=
expert_out
.
data
offset
+=
chunk_size
expert_out
=
None
# all2all last output
if
_expert_out
is
not
None
:
expert_out
=
Capsule
(
*
AllToAll
.
apply
(
_expert_out
.
data
,
self
.
ep_group
,
True
),)
expert_out
=
Capsule
(
*
AllToAll
.
apply
(
_expert_out
.
data
,
self
.
ep_group
,
True
),
)
_expert_out
=
None
# all2all next input
...
...
@@ -288,10 +298,7 @@ class SparseMLP(nn.Module):
return
output
def
_tp_process
(
self
,
dispatch_data
:
torch
.
Tensor
,
used_capacity
:
torch
.
Tensor
,
overlap
:
bool
=
False
self
,
dispatch_data
:
torch
.
Tensor
,
used_capacity
:
torch
.
Tensor
,
overlap
:
bool
=
False
)
->
torch
.
Tensor
:
"""
without overlap:
...
...
@@ -326,8 +333,9 @@ class SparseMLP(nn.Module):
NUM_CHUNK
=
4
NUM_STAGES
=
4
assert
dispatch_data
.
shape
[
0
]
%
NUM_CHUNK
==
0
,
\
"arbitrary chunk num is not supported yet, please use chunk num that can divide num_experts"
assert
(
dispatch_data
.
shape
[
0
]
%
NUM_CHUNK
==
0
),
"arbitrary chunk num is not supported yet, please use chunk num that can divide num_experts"
chunk_size
=
dispatch_data
.
shape
[
0
]
//
NUM_CHUNK
chunk_data
=
torch
.
split
(
dispatch_data
,
chunk_size
,
dim
=
0
)
output
=
torch
.
empty_like
(
dispatch_data
)
...
...
colossalai/moe/routers.py
View file @
7d8e0338
...
...
@@ -150,7 +150,14 @@ class Top1Router(MoeRouter):
high
=
torch
.
tensor
(
1.0
,
device
=
get_accelerator
().
get_current_device
()),
).
rsample
def
forward
(
self
,
inputs
:
torch
.
Tensor
,
use_kernel
:
bool
=
False
,
ep_group
:
Optional
[
ProcessGroup
]
=
None
)
->
Tuple
:
def
forward
(
self
,
inputs
:
torch
.
Tensor
,
use_kernel
:
bool
=
False
,
ep_group
:
Optional
[
ProcessGroup
]
=
None
,
use_loss
:
bool
=
False
,
use_norm
:
bool
=
False
,
)
->
Tuple
:
"""
Args:
inputs (torch.Tensor): The input tensor of shape (batch_size * seq_len, num_experts).
...
...
@@ -207,7 +214,7 @@ class Top1Router(MoeRouter):
weight
=
mask
*
probs
.
type_as
(
inputs
)
combine_weights
=
weight
.
unsqueeze
(
2
)
*
ranks
.
unsqueeze
(
1
)
sec_mask
=
combine_weights
.
bool
()
return
used_capacity
,
combine_weights
,
sec_mask
return
used_capacity
,
combine_weights
,
sec_mask
,
probs
class
Top2Router
(
MoeRouter
):
...
...
@@ -240,7 +247,14 @@ class Top2Router(MoeRouter):
drop_tks
=
drop_tks
,
)
def
forward
(
self
,
inputs
:
torch
.
Tensor
,
use_kernel
:
bool
=
False
,
ep_group
:
Optional
[
ProcessGroup
]
=
None
)
->
Tuple
:
def
forward
(
self
,
inputs
:
torch
.
Tensor
,
use_kernel
:
bool
=
False
,
ep_group
:
Optional
[
ProcessGroup
]
=
None
,
use_norm
:
bool
=
False
,
use_loss
:
bool
=
True
,
)
->
Tuple
:
"""
Args:
inputs (torch.Tensor): The input tensor of shape (batch_size * seq_len, num_experts).
...
...
@@ -257,6 +271,10 @@ class Top2Router(MoeRouter):
assert
inputs
.
dtype
==
torch
.
float
probs
=
F
.
softmax
(
inputs
,
dim
=-
1
)
if
use_norm
:
routing_weights
,
_
=
torch
.
topk
(
probs
,
2
,
dim
=-
1
)
probs
=
probs
/
routing_weights
.
sum
(
dim
=-
1
,
keepdim
=
True
)
num_experts
=
probs
.
size
(
-
1
)
capacity
=
self
.
get_capacity
(
inputs
.
shape
)
...
...
@@ -270,10 +288,11 @@ class Top2Router(MoeRouter):
cmask
=
cmask
.
float
()
/
2.0
# div 2 to normalize it to 1
# calculate loss
expert_indices
=
torch
.
stack
([
top1_idx
,
top2_idx
],
dim
=-
1
)
self
.
set_aux_loss
(
probs
,
expert_indices
,
num_experts
)
self
.
set_z_loss
(
inputs
)
self
.
pop_router_loss
()
if
use_loss
:
expert_indices
=
torch
.
stack
([
top1_idx
,
top2_idx
],
dim
=-
1
)
self
.
set_aux_loss
(
probs
,
expert_indices
,
num_experts
)
self
.
set_z_loss
(
inputs
)
self
.
pop_router_loss
()
if
not
self
.
training
and
not
self
.
drop_tks
and
ep_group
is
not
None
:
max_num
=
torch
.
max
(
torch
.
sum
(
cmask
,
dim
=
0
))
...
...
colossalai/moe/utils.py
View file @
7d8e0338
...
...
@@ -83,6 +83,8 @@ def get_activation(act: str) -> Callable:
return
torch
.
nn
.
GELU
()
elif
act
==
"swiglu"
:
return
SwiGLU
elif
act
==
"silu"
:
return
torch
.
nn
.
SiLU
()
else
:
raise
NotImplementedError
(
"Unsupported activation function"
)
...
...
colossalai/zero/low_level/low_level_optim.py
View file @
7d8e0338
...
...
@@ -141,7 +141,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
# because they have different parallel strategy
# so we need to store them separately in param_groups
# instead of working_groups
moe_params
=
list
()
self
.
working_
moe_params
=
list
()
# iterate over the param group in the optimizer
# partition these param groups for data parallel training
...
...
@@ -153,7 +153,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
if
self
.
moe_extra_dp_pg
is
None
:
# skip moe param
if
is_moe_tensor
(
param
):
moe_params
.
append
(
param
)
self
.
working_
moe_params
.
append
(
param
)
continue
group_params
.
append
(
param
)
...
...
@@ -168,13 +168,23 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
# managed by this data parallel rank
param_group
[
"params"
]
=
master_param_current_rank
# if there are moe params, store in additional group in optim
if
len
(
moe_params
)
>
0
:
# if there are moe params, store in addtional group in optim
if
len
(
self
.
working_moe_params
)
>
0
:
self
.
_sync_master_param
=
False
param_group
=
dict
()
# create fp32 master param
for
key
,
value
in
self
.
optim
.
param_groups
[
0
].
items
():
if
key
!=
"params"
:
param_group
[
key
]
=
value
param_group
[
"params"
]
=
moe_params
self
.
master_moe_params
=
[]
for
param
in
self
.
working_moe_params
:
self
.
master_moe_params
.
append
(
param
.
clone
().
to
(
torch
.
float32
).
detach
())
# create mapping from master to working for optimizer io
self
.
moe_master_to_working_map
=
{}
for
master_moe_param
,
working_moe_param
in
zip
(
self
.
master_moe_params
,
self
.
working_moe_params
):
self
.
moe_master_to_working_map
[
id
(
master_moe_param
)]
=
working_moe_param
# add to optim
param_group
[
"params"
]
=
self
.
master_moe_params
self
.
optim
.
param_groups
.
append
(
param_group
)
# initialize communication stream for
...
...
@@ -593,24 +603,40 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
# update the params in the optimizer
self
.
optim
.
param_groups
[
group_id
][
"params"
]
=
real_master_params
[
group_id
]
# update param for moe ep
# move grad to master param and compute norm
if
len
(
self
.
working_moe_params
)
>
0
:
moe_grads
=
[]
for
master_moe_param
,
working_moe_param
in
zip
(
self
.
master_moe_params
,
self
.
working_moe_params
):
if
master_moe_param
.
grad
is
not
None
:
raise
RuntimeError
(
"Moe param should not have grad here"
)
grad
=
working_moe_param
.
grad
# no need to copy fp32 grad if master_weights is False
if
self
.
_master_weights
:
grad
=
grad
.
to
(
master_moe_param
.
dtype
).
to
(
master_moe_param
.
device
)
master_moe_param
.
grad
=
grad
working_moe_param
.
grad
=
None
moe_grads
.
append
(
grad
)
grad_partition_groups
.
append
(
grad
)
norm_group
=
self
.
_compute_grad_norm
(
gradients
=
moe_grads
)
norm_groups
.
append
(
norm_group
)
self
.
optim
.
param_groups
[
-
1
][
"params"
]
=
self
.
master_moe_params
del
moe_grads
# unscale and clip grads
global_norm
=
calculate_global_norm_from_list
(
norm_list
=
norm_groups
)
self
.
_unscale_and_clip_grads
(
grad_partition_groups
,
global_norm
)
# TODO: we should store master param for ep
if
len
(
self
.
param_groups
)
>
len
(
self
.
_working_param_groups
):
for
param
in
self
.
param_groups
[
-
1
][
"params"
]:
param
.
data
=
param
.
data
.
to
(
torch
.
float32
)
param
.
grad
=
param
.
grad
.
to
(
torch
.
float32
)
# update the parameters
self
.
optim
.
step
()
# release the moe gradm
if
len
(
self
.
param_groups
)
>
len
(
self
.
_working_param_groups
):
for
param
in
self
.
param_groups
[
-
1
][
"params"
]:
param
.
grad
=
None
param
.
data
=
param
.
data
.
to
(
self
.
_dtype
)
# release moe grad
if
len
(
self
.
working_moe_params
)
>
0
:
for
master_moe_param
,
working_moe_param
in
zip
(
self
.
master_moe_params
,
self
.
working_moe_params
):
master_moe_param
.
grad
=
None
working_moe_param
.
data
=
(
master_moe_param
.
data
.
to
(
working_moe_param
.
device
).
to
(
working_moe_param
.
dtype
).
detach
()
)
# release the grad
grad_partition_groups
=
[]
...
...
@@ -640,6 +666,10 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
working_param
.
data
.
copy_
(
flatten
(
all_splited_param
)[:
working_param
.
numel
()].
reshape_as
(
working_param
))
self
.
optim
.
param_groups
[
group_id
][
"params"
]
=
self
.
_master_param_groups_of_current_rank
[
group_id
]
def
sync_moe_master_param
(
self
):
for
master_moe_param
,
working_moe_param
in
zip
(
self
.
master_moe_params
,
self
.
working_moe_params
):
master_moe_param
.
data
=
working_moe_param
.
data
.
clone
().
to
(
torch
.
float32
).
detach
()
def
_compute_grad_norm
(
self
,
gradients
:
List
[
Tensor
],
norm_type
:
int
=
2
)
->
float
:
r
"""
Compute and return the gradient norm for gradient clipping.
...
...
tests/test_moe/moe_utils.py
View file @
7d8e0338
import
torch
import
torch.distributed
as
dist
import
torch.nn
as
nn
from
torch.testing
import
assert_close
from
colossalai.booster.plugin.low_level_zero_plugin
import
LowLevelZeroModel
from
colossalai.legacy.engine.gradient_handler._base_gradient_handler
import
BaseGradientHandler
from
colossalai.legacy.engine.gradient_handler.utils
import
bucket_allreduce
from
colossalai.legacy.registry
import
GRADIENT_HANDLER
from
colossalai.moe
import
SparseMLP
from
colossalai.moe.manager
import
MOE_MANAGER
from
colossalai.moe.utils
import
get_moe_epsize_param_dict
from
colossalai.tensor.moe_tensor.api
import
get_ep_group
,
get_ep_size
def
delete_moe_info
(
model
):
for
_
,
param
in
model
.
named_parameters
():
if
hasattr
(
param
,
"moe_info"
):
delattr
(
param
,
"moe_info"
)
class
MoeModel
(
nn
.
Module
):
...
...
@@ -85,6 +94,74 @@ def assert_not_equal_in_group(tensor, process_group=None):
for
i
in
range
(
world_size
-
1
):
a
=
tensor_list
[
i
]
b
=
tensor_list
[
i
+
1
]
assert
not
torch
.
allclose
(
a
,
b
),
\
(
f
"expected tensors on rank
{
i
}
and
{
i
+
1
}
not to be equal "
f
"but they are,
{
a
}
vs
{
b
}
"
)
assert
not
torch
.
allclose
(
a
,
b
),
(
f
"expected tensors on rank
{
i
}
and
{
i
+
1
}
not to be equal "
f
"but they are,
{
a
}
vs
{
b
}
"
)
def
run_fwd_bwd
(
model
,
data
,
label
,
criterion
,
optimizer
,
enable_autocast
=
False
):
model
.
train
()
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
enable_autocast
):
if
criterion
:
y
=
model
(
data
)
loss
=
criterion
(
y
,
label
)
else
:
loss
=
model
(
data
,
label
)
loss
=
loss
.
float
()
if
isinstance
(
model
,
LowLevelZeroModel
):
optimizer
.
backward
(
loss
)
else
:
loss
.
backward
()
return
y
def
sync_local_from_ep
(
local_model
:
SparseMLP
,
ep_model
:
SparseMLP
,
assert_grad_flag
:
bool
=
False
)
->
None
:
"""Sync the parameters of tp model from ep model
Args:
local_model (MoeModule)
ep_model (MoeModule)
"""
for
(
local_name
,
local_param
),
(
ep_name
,
ep_param
)
in
zip
(
local_model
.
named_parameters
(),
ep_model
.
named_parameters
()
):
assert
local_name
in
ep_name
,
print
(
f
"
{
local_name
}
!=
{
ep_name
}
"
)
if
"experts"
not
in
local_name
:
if
assert_grad_flag
:
assert
torch
.
allclose
(
local_param
,
ep_param
),
f
"local_param:
{
local_param
}
, ep_param:
{
ep_param
}
"
assert
torch
.
allclose
(
local_param
.
grad
,
ep_param
.
grad
)
else
:
local_param
.
data
.
copy_
(
ep_param
.
data
)
continue
# gather param from ep model
param_list
=
[
torch
.
zeros_like
(
ep_param
)
for
_
in
range
(
get_ep_size
(
ep_param
))]
dist
.
all_gather
(
param_list
,
ep_param
,
group
=
get_ep_group
(
ep_param
))
all_param
=
torch
.
cat
(
param_list
,
dim
=
0
)
if
assert_grad_flag
:
grad_list
=
[
torch
.
zeros_like
(
ep_param
)
for
_
in
range
(
get_ep_size
(
ep_param
))]
dist
.
all_gather
(
grad_list
,
ep_param
.
grad
,
group
=
get_ep_group
(
ep_param
))
all_grad
=
torch
.
cat
(
grad_list
,
dim
=
0
)
if
assert_grad_flag
:
assert
torch
.
allclose
(
local_param
,
all_param
)
assert
torch
.
allclose
(
local_param
.
grad
,
all_grad
)
else
:
local_param
.
data
.
copy_
(
all_param
.
data
)
def
loose_close
(
a
,
b
,
dtype
:
torch
.
dtype
=
torch
.
float32
):
rtol
=
None
atol
=
None
if
dtype
is
torch
.
float16
:
rtol
=
5e-2
atol
=
5e-4
elif
dtype
is
torch
.
bfloat16
:
rtol
=
4e-3
atol
=
4e-3
a
=
a
.
detach
().
to
(
dtype
)
b
=
b
.
detach
().
to
(
dtype
).
to
(
a
.
device
)
assert_close
(
a
,
b
,
rtol
=
rtol
,
atol
=
atol
)
tests/test_moe/test_moe_zero_fwd_bwd.py
View file @
7d8e0338
...
...
@@ -4,102 +4,75 @@ import torch
import
colossalai
from
colossalai.booster
import
Booster
from
colossalai.booster.plugin
import
LowLevelZeroPlugin
from
colossalai.booster.plugin.low_level_zero_plugin
import
LowLevelZeroModel
from
colossalai.moe.manager
import
MOE_MANAGER
from
colossalai.testing
import
rerun_if_address_is_in_use
,
spawn
from
colossalai.testing.random
import
seed_all
from
tests.test_moe.moe_utils
import
Moe
GradientHandler
,
MoeModel
from
tests.test_moe.moe_utils
import
Moe
Model
,
delete_moe_info
,
run_fwd_bwd
,
sync_local_from_ep
def
split_ddp_grad
(
grad
,
world_size
):
with
torch
.
no_grad
():
grad
=
grad
.
clone
().
detach
().
flatten
()
padding_size
=
(
world_size
-
grad
.
numel
()
%
world_size
)
%
world_size
if
padding_size
>
0
:
grad
=
torch
.
nn
.
functional
.
pad
(
grad
,
[
0
,
padding_size
])
splited_grad
=
grad
.
split
(
grad
.
numel
()
//
world_size
)
return
splited_grad
def
run_fwd_bwd
(
model
,
data
,
label
,
criterion
,
optimizer
,
enable_autocast
=
False
):
model
.
train
()
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
enable_autocast
):
if
criterion
:
y
=
model
(
data
)
loss
=
criterion
(
y
,
label
)
else
:
loss
=
model
(
data
,
label
)
loss
=
loss
.
float
()
if
isinstance
(
model
,
LowLevelZeroModel
):
optimizer
.
backward
(
loss
)
else
:
loss
.
backward
()
return
y
def
run_zero_test
(
local_rank
,
world_size
,
stage
=
1
):
def
run_zero_test
(
local_rank
,
stage
=
1
):
criterion
=
torch
.
nn
.
CrossEntropyLoss
()
zero_model
=
MoeModel
()
optimizer
=
torch
.
optim
.
Adam
(
zero_model
.
parameters
())
plugin
=
LowLevelZeroPlugin
(
stage
=
stage
,
precision
=
"fp32"
)
booster
=
Booster
(
plugin
=
plugin
)
zero_model
,
optimizer
,
_
,
_
,
_
=
booster
.
boost
(
zero_model
,
optimizer
)
torch_model
=
MoeModel
()
for
zero_param
,
torch_param
in
zip
(
zero_model
.
parameters
(),
torch_model
.
parameters
()):
torch_param
.
data
.
copy_
(
zero_param
.
data
)
torch_model
=
torch_model
.
cuda
()
grad_handler
=
MoeGradientHandler
(
torch_model
)
# assert zero model
for
(
torch_name
,
torch_param
),
(
zero_name
,
zero_param
)
in
zip
(
torch_model
.
named_parameters
(),
zero_model
.
module
.
named_parameters
()
):
assert
zero_name
==
torch_name
assert
torch
.
allclose
(
zero_param
.
data
,
torch_param
.
data
)
data
=
torch
.
randn
(
16
,
4
).
cuda
()
MOE_MANAGER
.
__init__
()
MOE_MANAGER
.
setup
(
parallel
=
"EP"
)
moe_model
=
MoeModel
().
bfloat16
()
moe_optimizer
=
torch
.
optim
.
Adam
(
moe_model
.
parameters
())
moe_plugin
=
LowLevelZeroPlugin
(
stage
=
stage
,
precision
=
"bf16"
)
moe_booster
=
Booster
(
plugin
=
moe_plugin
)
moe_model
,
moe_optimizer
,
_
,
_
,
_
=
moe_booster
.
boost
(
moe_model
,
moe_optimizer
)
MOE_MANAGER
.
__init__
()
MOE_MANAGER
.
setup
(
parallel
=
None
)
zero_model
=
MoeModel
().
bfloat16
()
delete_moe_info
(
zero_model
)
zero_optimizer
=
torch
.
optim
.
Adam
(
zero_model
.
parameters
())
zero_plugin
=
LowLevelZeroPlugin
(
stage
=
stage
,
precision
=
"bf16"
)
zero_booster
=
Booster
(
plugin
=
zero_plugin
)
zero_model
,
zero_optimizer
,
_
,
_
,
_
=
zero_booster
.
boost
(
zero_model
,
zero_optimizer
)
sync_local_from_ep
(
zero_model
,
moe_model
)
data
=
torch
.
randn
(
16
,
4
).
bfloat16
().
cuda
()
label
=
torch
.
randint
(
0
,
4
,
(
16
,)).
cuda
()
torch_out
=
run_fwd_bwd
(
torch_model
,
data
,
label
,
criterion
,
None
)
zero_out
=
run_fwd_bwd
(
zero_model
,
data
,
label
,
criterion
,
optimizer
)
assert
torch
.
allclose
(
torch_out
,
zero_out
)
grad_handler
.
handle_gradient
()
zero_out
=
run_fwd_bwd
(
zero_model
,
data
,
label
,
criterion
,
zero_optimizer
)
moe_out
=
run_fwd_bwd
(
moe_model
,
data
,
label
,
criterion
,
moe_optimizer
)
assert
torch
.
allclose
(
zero_out
,
moe_out
)
for
(
zero
_name
,
zero
_param
),
(
torch
_name
,
torch
_param
)
in
zip
(
zero
_model
.
module
.
named_parameters
(),
torch
_model
.
named_parameters
()
for
(
moe
_name
,
moe
_param
),
(
zero
_name
,
zero
_param
)
in
zip
(
moe
_model
.
module
.
named_parameters
(),
zero
_model
.
module
.
named_parameters
()
):
assert
zero_name
==
torch_name
zero_grad_list
=
optimizer
.
_grad_store
.
get_partitioned_gradients_by_param_id
(
0
,
id
(
zero_param
))
if
hasattr
(
zero_param
,
"moe_info"
):
assert
len
(
zero_grad_list
)
==
0
assert
torch
.
allclose
(
zero_param
.
grad
,
torch_param
.
grad
)
assert
moe_name
==
zero_name
moe_grad_list
=
moe_optimizer
.
_grad_store
.
get_partitioned_gradients_by_param_id
(
0
,
id
(
moe_param
))
zero_grad_list
=
zero_optimizer
.
_grad_store
.
get_partitioned_gradients_by_param_id
(
0
,
id
(
zero_param
))
if
hasattr
(
moe_param
,
"moe_info"
):
assert
len
(
moe_grad_list
)
==
0
if
stage
==
1
:
zero_grad
=
zero_grad_list
[
local_rank
].
view
(
moe_param
.
grad
.
shape
)
else
:
zero_grad
=
zero_grad_list
[
0
].
view
(
moe_param
.
grad
.
shape
)
assert
torch
.
allclose
(
moe_param
.
grad
,
zero_grad
,
atol
=
1e-5
),
f
"zero grad:
\n
{
moe_param
.
grad
}
\n
torch grad:
\n
{
zero_grad
}
\n
max diff:
{
(
moe_param
.
grad
-
zero_grad
).
abs
().
max
()
}
, mean diff:
{
(
moe_param
.
grad
-
zero_grad
).
abs
().
mean
()
}
"
else
:
assert
len
(
zero_grad_list
)
>
0
torch_grad_list
=
split_ddp_grad
(
torch_param
.
grad
,
world_size
)
if
stage
==
2
:
torch_grad_list
=
torch_grad_list
[
local_rank
:
local_rank
+
1
]
assert
len
(
zero_grad_list
)
==
len
(
torch_grad_list
)
for
zero_grad
,
torch_grad
in
zip
(
zero_grad_list
,
torch_grad_list
):
assert
torch
.
allclose
(
zero_grad
,
torch_grad
)
assert
len
(
moe_grad_list
)
>
0
assert
len
(
moe_grad_list
)
==
len
(
zero_grad_list
)
for
moe_grad
,
zero_grad
in
zip
(
moe_grad_list
,
zero_grad_list
):
assert
torch
.
allclose
(
moe_grad
,
zero_grad
)
def
run_dist
(
rank
,
world_size
,
port
):
def
run_dist
(
rank
,
world_size
,
port
,
stage
):
colossalai
.
launch
(
config
=
dict
(),
rank
=
rank
,
world_size
=
world_size
,
host
=
"localhost"
,
port
=
port
,
backend
=
"nccl"
)
MOE_MANAGER
.
setup
(
parallel
=
"EP"
)
seed_all
(
42
+
rank
)
run_zero_test
(
rank
,
world_size
,
stage
=
1
)
run_zero_test
(
rank
,
world_size
,
stage
=
2
)
run_zero_test
(
rank
,
stage
=
stage
)
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
2
])
@
pytest
.
mark
.
parametrize
(
"stage"
,
[
1
,
2
])
@
rerun_if_address_is_in_use
()
def
test_moe_zero_model
(
world_size
):
spawn
(
run_dist
,
world_size
)
def
test_moe_zero_model
(
world_size
,
stage
):
spawn
(
run_dist
,
world_size
,
stage
=
stage
)
if
__name__
==
"__main__"
:
test_moe_zero_model
(
world_size
=
2
)
test_moe_zero_model
(
world_size
=
2
,
stage
=
1
)
tests/test_moe/test_moe_zero_optim.py
View file @
7d8e0338
...
...
@@ -4,89 +4,80 @@ import torch
import
colossalai
from
colossalai.booster
import
Booster
from
colossalai.booster.plugin
import
LowLevelZeroPlugin
from
colossalai.booster.plugin.low_level_zero_plugin
import
LowLevelZeroModel
from
colossalai.moe.manager
import
MOE_MANAGER
from
colossalai.tensor.moe_tensor.api
import
is_moe_tensor
from
colossalai.testing
import
rerun_if_address_is_in_use
,
spawn
from
tests.test_moe.moe_utils
import
MoeGradientHandler
,
MoeModel
from
colossalai.testing.random
import
seed_all
from
tests.test_moe.moe_utils
import
MoeModel
,
delete_moe_info
,
loose_close
,
run_fwd_bwd
,
sync_local_from_ep
def
split_ddp_grad
(
grad
,
world_size
):
with
torch
.
no_grad
():
grad
=
grad
.
clone
().
detach
().
flatten
()
padding_size
=
(
world_size
-
grad
.
numel
()
%
world_size
)
%
world_size
if
padding_size
>
0
:
grad
=
torch
.
nn
.
functional
.
pad
(
grad
,
[
0
,
padding_size
])
splited_grad
=
grad
.
split
(
grad
.
numel
()
//
world_size
)
return
splited_grad
def
run_fwd_bwd
(
model
,
data
,
label
,
criterion
,
optimizer
,
enable_autocast
=
False
):
model
.
train
()
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
enable_autocast
):
if
criterion
:
y
=
model
(
data
)
loss
=
criterion
(
y
,
label
)
else
:
loss
=
model
(
data
,
label
)
loss
=
loss
.
float
()
if
isinstance
(
model
,
LowLevelZeroModel
):
optimizer
.
backward
(
loss
)
else
:
loss
.
backward
()
return
y
def
run_zero_optim_test
(
local_rank
,
world_size
,
stage
=
1
):
def
run_zero_test
(
local_rank
,
stage
=
1
):
criterion
=
torch
.
nn
.
CrossEntropyLoss
()
zero_model
=
MoeModel
()
zero_optimizer
=
torch
.
optim
.
Adam
(
zero_model
.
parameters
())
plugin
=
LowLevelZeroPlugin
(
stage
=
stage
,
precision
=
"fp32"
)
booster
=
Booster
(
plugin
=
plugin
)
zero_model
,
zero_optimizer
,
_
,
_
,
_
=
booster
.
boost
(
zero_model
,
zero_optimizer
)
torch_model
=
MoeModel
()
for
zero_param
,
torch_param
in
zip
(
zero_model
.
parameters
(),
torch_model
.
parameters
()):
torch_param
.
data
.
copy_
(
zero_param
.
data
)
torch_optimizer
=
torch
.
optim
.
Adam
(
torch_model
.
parameters
())
torch_model
=
torch_model
.
cuda
()
grad_handler
=
MoeGradientHandler
(
torch_model
)
for
_
in
range
(
2
):
data
=
torch
.
randn
(
16
,
4
).
cuda
()
/
(
local_rank
+
1
)
label
=
torch
.
randint
(
0
,
4
,
(
16
,)).
cuda
()
run_fwd_bwd
(
torch_model
,
data
,
label
,
criterion
,
None
)
run_fwd_bwd
(
zero_model
,
data
,
label
,
criterion
,
zero_optimizer
)
grad_handler
.
handle_gradient
()
torch_optimizer
.
step
()
MOE_MANAGER
.
__init__
()
MOE_MANAGER
.
setup
(
parallel
=
"EP"
)
moe_model
=
MoeModel
().
bfloat16
()
moe_optimizer
=
torch
.
optim
.
Adam
(
moe_model
.
parameters
(),
lr
=
1.0
)
moe_plugin
=
LowLevelZeroPlugin
(
stage
=
stage
,
precision
=
"bf16"
)
moe_booster
=
Booster
(
plugin
=
moe_plugin
)
moe_model
,
moe_optimizer
,
_
,
_
,
_
=
moe_booster
.
boost
(
moe_model
,
moe_optimizer
)
MOE_MANAGER
.
__init__
()
MOE_MANAGER
.
setup
(
parallel
=
None
)
zero_model
=
MoeModel
().
bfloat16
()
delete_moe_info
(
zero_model
)
sync_local_from_ep
(
zero_model
,
moe_model
)
zero_optimizer
=
torch
.
optim
.
Adam
(
zero_model
.
parameters
(),
lr
=
1.0
)
zero_plugin
=
LowLevelZeroPlugin
(
stage
=
stage
,
precision
=
"bf16"
)
zero_booster
=
Booster
(
plugin
=
zero_plugin
)
zero_model
,
zero_optimizer
,
_
,
_
,
_
=
zero_booster
.
boost
(
zero_model
,
zero_optimizer
)
for
(
moe_name
,
moe_param
),
(
zero_name
,
zero_param
)
in
zip
(
moe_model
.
named_parameters
(),
zero_model
.
named_parameters
()
):
if
".experts."
in
moe_name
:
continue
assert
moe_name
==
zero_name
assert
torch
.
allclose
(
moe_param
.
data
,
zero_param
.
data
),
f
"
{
moe_name
}
\n
torch_param
{
moe_param
.
data
}
\n
zero_param
{
zero_param
.
data
}
"
for
_
in
range
(
1
):
data
=
torch
.
randn
(
2
,
4
).
bfloat16
().
cuda
()
label
=
torch
.
randint
(
0
,
4
,
(
2
,)).
cuda
()
moe_out
=
run_fwd_bwd
(
moe_model
,
data
,
label
,
criterion
,
moe_optimizer
)
zero_out
=
run_fwd_bwd
(
zero_model
,
data
,
label
,
criterion
,
zero_optimizer
)
assert
torch
.
allclose
(
zero_out
,
moe_out
)
moe_optimizer
.
step
()
zero_optimizer
.
step
()
for
(
torch
_name
,
torch
_param
),
(
zero_name
,
zero_param
)
in
zip
(
torch
_model
.
named_parameters
(),
zero_model
.
named_parameters
()
for
(
moe
_name
,
moe
_param
),
(
zero_name
,
zero_param
)
in
zip
(
moe
_model
.
named_parameters
(),
zero_model
.
named_parameters
()
):
assert
torch
.
allclose
(
torch_param
.
data
,
zero_param
.
data
),
f
"
{
torch_name
}
\n
torch_param
{
torch_param
.
data
}
\n
zero_param
{
zero_param
.
data
}
"
assert
moe_name
==
zero_name
if
is_moe_tensor
(
moe_param
):
param_size
=
moe_param
.
shape
[
0
]
zero_param
=
zero_param
[
local_rank
*
param_size
:
(
local_rank
+
1
)
*
param_size
]
loose_close
(
moe_param
.
data
,
zero_param
.
data
,
dtype
=
moe_param
.
dtype
)
torch
_optimizer
.
zero_grad
()
moe
_optimizer
.
zero_grad
()
zero_optimizer
.
zero_grad
()
def
run_dist
(
rank
,
world_size
,
port
):
def
run_dist
(
rank
,
world_size
,
port
,
stage
):
colossalai
.
launch
(
config
=
dict
(),
rank
=
rank
,
world_size
=
world_size
,
host
=
"localhost"
,
port
=
port
,
backend
=
"nccl"
)
MOE_MANAGER
.
setup
(
parallel
=
"EP"
)
run_zero_optim_test
(
rank
,
world_size
,
stage
=
1
)
run_zero_optim_test
(
rank
,
world_size
,
stage
=
2
)
seed_all
(
42
+
rank
)
run_zero_test
(
rank
,
stage
=
stage
)
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
2
])
@
pytest
.
mark
.
parametrize
(
"stage"
,
[
1
,
2
])
@
rerun_if_address_is_in_use
()
def
test_moe_zero_optim
(
world_size
):
spawn
(
run_dist
,
world_size
)
def
test_moe_zero_optim
(
world_size
,
stage
):
spawn
(
run_dist
,
world_size
,
stage
=
stage
)
if
__name__
==
"__main__"
:
test_moe_zero_optim
(
world_size
=
2
)
test_moe_zero_optim
(
world_size
=
2
,
stage
=
1
)
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment