Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
efef43b5
Unverified
Commit
efef43b5
authored
Feb 08, 2024
by
Frank Lee
Committed by
GitHub
Feb 08, 2024
Browse files
Merge pull request #5372 from hpcaitech/exp/mixtral
parents
4c03347f
06db94fb
Changes
33
Show whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
444 additions
and
266 deletions
+444
-266
colossalai/moe/_operation.py
colossalai/moe/_operation.py
+66
-1
colossalai/moe/checkpoint.py
colossalai/moe/checkpoint.py
+39
-28
colossalai/moe/experts.py
colossalai/moe/experts.py
+5
-1
colossalai/moe/layers.py
colossalai/moe/layers.py
+40
-32
colossalai/moe/routers.py
colossalai/moe/routers.py
+36
-11
colossalai/moe/utils.py
colossalai/moe/utils.py
+2
-0
colossalai/tensor/moe_tensor/moe_info.py
colossalai/tensor/moe_tensor/moe_info.py
+2
-0
colossalai/zero/low_level/low_level_optim.py
colossalai/zero/low_level/low_level_optim.py
+47
-16
tests/test_moe/moe_utils.py
tests/test_moe/moe_utils.py
+80
-3
tests/test_moe/test_moe_checkpoint.py
tests/test_moe/test_moe_checkpoint.py
+5
-22
tests/test_moe/test_moe_router.py
tests/test_moe/test_moe_router.py
+19
-13
tests/test_moe/test_moe_zero_fwd_bwd.py
tests/test_moe/test_moe_zero_fwd_bwd.py
+48
-75
tests/test_moe/test_moe_zero_optim.py
tests/test_moe/test_moe_zero_optim.py
+55
-64
No files found.
colossalai/moe/_operation.py
View file @
efef43b5
from
typing
import
Any
,
Optional
,
Tuple
from
typing
import
Any
,
List
,
Optional
,
Tuple
import
torch
import
torch.distributed
as
dist
...
...
@@ -329,3 +329,68 @@ class MoeOutGradScaler(torch.autograd.Function):
if
ctx
.
ep_size
!=
1
:
grad
=
grad
/
ctx
.
ep_size
return
grad
,
None
def
_all_to_all
(
inputs
:
torch
.
Tensor
,
input_split_sizes
:
Optional
[
List
[
int
]]
=
None
,
output_split_sizes
:
Optional
[
List
[
int
]]
=
None
,
group
=
None
,
async_op
:
bool
=
False
,
):
"""
Returns:
outputs: Tensor
handle: Optional[Work], if overlap is True
"""
outputs_shape
=
list
(
inputs
.
shape
)
if
output_split_sizes
is
not
None
:
outputs_shape
[
0
]
=
sum
(
output_split_sizes
)
outputs
=
torch
.
empty
(
outputs_shape
,
dtype
=
inputs
.
dtype
,
device
=
inputs
.
device
)
inputs
=
inputs
.
contiguous
()
outputs
=
outputs
.
contiguous
()
handle
=
dist
.
all_to_all_single
(
outputs
,
inputs
,
output_split_sizes
,
input_split_sizes
,
group
=
group
,
async_op
=
async_op
)
return
outputs
,
handle
class
AllToAllUneven
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
inputs
,
input_split_sizes
=
None
,
output_split_sizes
=
None
,
group
=
None
,
overlap
:
bool
=
False
,
):
"""
Returns:
outputs: Tensor
handle: Optional[Work], if overlap is True
"""
ctx
.
input_split_sizes
=
input_split_sizes
ctx
.
output_split_sizes
=
output_split_sizes
ctx
.
group
=
group
return
_all_to_all
(
inputs
,
input_split_sizes
,
output_split_sizes
,
group
,
overlap
)
@
staticmethod
def
backward
(
ctx
:
Any
,
*
grad_outputs
):
return
(
_all_to_all
(
grad_outputs
[
0
],
ctx
.
output_split_sizes
,
ctx
.
input_split_sizes
,
ctx
.
group
,
False
)[
0
],
None
,
None
,
None
,
None
,
)
def
all_to_all_uneven
(
inputs
:
torch
.
Tensor
,
input_split_sizes
:
Optional
[
List
[
int
]]
=
None
,
output_split_sizes
:
Optional
[
List
[
int
]]
=
None
,
group
=
None
,
overlap
:
bool
=
False
,
):
return
AllToAllUneven
.
apply
(
inputs
,
input_split_sizes
,
output_split_sizes
,
group
,
overlap
)
colossalai/moe/checkpoint.py
View file @
efef43b5
...
...
@@ -224,6 +224,7 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
size_per_shard (int, optional): Size per shard in MB. Defaults to 1024.
use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False.
"""
torch
.
cuda
.
empty_cache
()
if
os
.
path
.
isfile
(
checkpoint
):
logging
.
error
(
f
"Provided path (
{
checkpoint
}
) should be a directory, not a file"
)
return
...
...
@@ -265,6 +266,7 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
f
"index located at
{
save_index_file
}
."
)
dist
.
barrier
()
torch
.
cuda
.
empty_cache
()
# ========================================================
# Abstract methods for optimizer loading/saving implementation
...
...
@@ -332,10 +334,12 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
assert
isinstance
(
optimizer
,
OptimizerWrapper
),
"Please boost the optimizer before loading!"
def
_get_param_id_from_optimizer_param
(
param
:
torch
.
Tensor
,
master_to_working_map
:
Optional
[
Dict
[
int
,
torch
.
Tensor
]]
=
None
param
:
torch
.
Tensor
,
master_to_working_map
:
Optional
[
Dict
[
int
,
torch
.
Tensor
]]
=
None
,
optimizer
=
None
):
if
master_to_working_map
is
not
None
and
id
(
param
)
in
master_to_working_map
:
working_param
=
master_to_working_map
[
id
(
param
)]
elif
hasattr
(
optimizer
,
"moe_master_to_working_map"
)
and
id
(
param
)
in
optimizer
.
moe_master_to_working_map
:
working_param
=
optimizer
.
moe_master_to_working_map
[
id
(
param
)]
else
:
working_param
=
param
return
optimizer
.
param_info
[
"param2id"
][
id
(
working_param
)]
...
...
@@ -347,7 +351,7 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
master_to_working_map
=
optimizer
.
get_master_to_working_map
()
for
pg
in
optimizer
.
optim
.
param_groups
:
for
param
in
pg
[
"params"
]:
param_id
=
_get_param_id_from_optimizer_param
(
param
,
master_to_working_map
)
param_id
=
_get_param_id_from_optimizer_param
(
param
,
master_to_working_map
,
optimizer
)
id_map
[
param_id
]
=
param
# Read checkpoint index file.
...
...
@@ -371,14 +375,10 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
new_pg
=
copy
.
deepcopy
(
saved_pg
)
new_pg
[
"params"
]
=
old_pg
[
"params"
]
# The parameters in the same group shouln't change.
updated_groups
.
append
(
new_pg
)
# ep
ext
ra group
if
MOE_MANAGER
.
parallel
==
"EP"
:
# ep
pa
ra
m
group
if
len
(
optimizer
.
optim
.
param_groups
)
>
len
(
saved_groups
)
:
new_pg
=
copy
.
deepcopy
(
saved_pg
)
new_pg
[
"params"
]
=
optimizer
.
optim
.
param_groups
[
-
1
][
"params"
]
# Only keep the parameters kept by current pipeline stage.
for
param
in
new_pg
[
"params"
]:
param
.
data
=
param
.
data
.
to
(
torch
.
float32
)
new_pg
[
"params"
]
=
optimizer
.
optim
.
param_groups
[
-
1
][
"params"
]
updated_groups
.
append
(
new_pg
)
optimizer
.
optim
.
__dict__
.
update
({
"param_groups"
:
updated_groups
})
...
...
@@ -389,7 +389,7 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
for
param
in
pg
[
"params"
]:
if
param
is
None
:
continue
param_id
=
_get_param_id_from_optimizer_param
(
param
,
master_to_working_map
)
param_id
=
_get_param_id_from_optimizer_param
(
param
,
master_to_working_map
,
optimizer
)
if
param_id
not
in
weight_map
:
continue
filename
=
weight_map
[
param_id
]
...
...
@@ -400,26 +400,33 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
file_path
=
os
.
path
.
join
(
ckpt_root_path
,
filename
)
state_dict
=
load_shard_state_dict
(
Path
(
file_path
),
use_safetensors
=
False
)
load_states_into_optimizer
(
optimizer
.
optim
,
state_dict
,
id_map
,
strict
=
True
)
loaded_file
.
add
(
filename
)
# Then shard the loaded optimizer states if using tp/zero.
for
param
,
state
in
optimizer
.
optim
.
state
.
items
():
device
=
param
.
device
for
pid
,
state
in
list
(
state_dict
.
items
()):
if
pid
in
id_map
:
param
=
id_map
[
pid
]
if
master_to_working_map
is
not
None
and
id
(
param
)
in
master_to_working_map
:
working_param
=
master_to_working_map
[
id
(
param
)]
elif
(
hasattr
(
optimizer
,
"moe_master_to_working_map"
)
and
id
(
param
)
in
optimizer
.
moe_master_to_working_map
):
working_param
=
optimizer
.
moe_master_to_working_map
[
id
(
param
)]
else
:
working_param
=
param
original_shape
=
optimizer
.
param_info
[
"param2shape"
][
id
(
working_param
)]
sharded_state
=
self
.
pre_load_optim
(
state
,
param
,
working_
param
,
current_shape
=
working_param
.
shape
,
original_shape
=
original_shape
,
device
=
device
,
device
=
"cpu"
,
inplace
=
True
,
)
optimizer
.
optim
.
state
[
param
]
=
sharded_state
state_dict
[
pid
]
=
sharded_state
load_states_into_optimizer
(
optimizer
.
optim
,
state_dict
,
id_map
,
strict
=
True
)
loaded_file
.
add
(
filename
)
sharded_optimizer_loading_epilogue
(
optimizer
.
optim
)
if
self
.
verbose
and
self
.
coordinator
.
is_master
():
...
...
@@ -576,6 +583,8 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
if
master_to_working_map
is
not
None
and
id
(
param
)
in
master_to_working_map
:
working_param
=
master_to_working_map
[
id
(
param
)]
elif
hasattr
(
optimizer
,
"moe_master_to_working_map"
)
and
id
(
param
)
in
optimizer
.
moe_master_to_working_map
:
working_param
=
optimizer
.
moe_master_to_working_map
[
id
(
param
)]
else
:
working_param
=
param
...
...
@@ -618,6 +627,7 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
prefix (str): Perfix of file to save
size_per_shard (int): Max file size of each file shard that store state tensors
"""
torch
.
cuda
.
empty_cache
()
assert
isinstance
(
optimizer
,
OptimizerWrapper
),
"Please boost the optimizer before saving!"
if
os
.
path
.
isfile
(
checkpoint
):
logging
.
error
(
f
"Provided path (
{
checkpoint
}
) should be a directory, not a file"
)
...
...
@@ -723,6 +733,7 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
f
"You can find where each parameters has been saved in the "
f
"index located at
{
final_index_file_path
}
."
)
torch
.
cuda
.
empty_cache
()
def
save_unsharded_optimizer
(
self
,
optimizer
:
OptimizerWrapper
,
checkpoint
:
str
,
gather_dtensor
:
bool
):
"""
...
...
colossalai/moe/experts.py
View file @
efef43b5
...
...
@@ -67,7 +67,11 @@ class MLPExperts(nn.Module):
self
.
ep_size
=
1
if
gated
:
self
.
wi_gate
=
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
hidden_size
,
intermediate_size
*
2
))
self
.
wi_gate
=
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
hidden_size
,
intermediate_size
*
2
if
activation
==
"swiglu"
else
intermediate_size
)
)
self
.
wi_up
=
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
hidden_size
,
intermediate_size
))
else
:
self
.
wi
=
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
hidden_size
,
intermediate_size
))
...
...
colossalai/moe/layers.py
View file @
efef43b5
...
...
@@ -51,6 +51,8 @@ class SparseMLP(nn.Module):
hidden_size
:
int
,
intermediate_size
:
int
,
router_top_k
:
int
=
1
,
router_loss
:
bool
=
True
,
router_norm
:
bool
=
False
,
router_capacity_factor_train
:
float
=
1.25
,
router_capacity_factor_eval
:
float
=
2.0
,
router_min_capacity
:
int
=
4
,
...
...
@@ -65,15 +67,19 @@ class SparseMLP(nn.Module):
enable_kernel
:
bool
=
False
,
enable_comm_overlap
:
bool
=
False
,
enable_hierarchical_comm
:
bool
=
False
,
return_gate_logits
:
bool
=
False
,
):
super
().
__init__
()
self
.
hidden_size
=
hidden_size
self
.
intermediate_size
=
intermediate_size
self
.
num_experts
=
num_experts
self
.
gated
=
mlp_gated
self
.
return_gate_logits
=
return_gate_logits
self
.
enable_kernel
=
enable_kernel
self
.
enable_comm_overlap
=
enable_comm_overlap
self
.
expert_parallel
=
MOE_MANAGER
.
get_parallel
()
self
.
router_loss
=
router_loss
self
.
router_norm
=
router_norm
# moe router
noisy_func
=
get_noise_generator
(
router_noisy_policy
,
num_experts
)
...
...
@@ -150,9 +156,8 @@ class SparseMLP(nn.Module):
tokens
=
inputs
.
reshape
(
-
1
,
self
.
hidden_size
)
# the data type of the inputs in the gating should be fp32
fp32_input
=
tokens
.
to
(
torch
.
float
)
fp32_weight
=
self
.
gate_weight
.
to
(
torch
.
float
)
gate_output
=
F
.
linear
(
fp32_input
,
fp32_weight
)
gate_logits
=
F
.
linear
(
tokens
,
self
.
gate_weight
)
gate_output
=
gate_logits
.
to
(
torch
.
float
)
# update expert load
if
self
.
enable_load_balance
==
True
:
...
...
@@ -165,7 +170,12 @@ class SparseMLP(nn.Module):
# the result from the router
used_capacity
,
*
route_result_list
=
self
.
router
(
inputs
=
gate_output
,
use_kernel
=
self
.
enable_kernel
,
ep_group
=
self
.
ep_group
)
inputs
=
gate_output
,
use_kernel
=
self
.
enable_kernel
,
ep_group
=
self
.
ep_group
,
use_loss
=
self
.
router_loss
,
use_norm
=
self
.
router_norm
,
)
# dispatch_data: (num_experts, capacity, hidden_size)
if
self
.
enable_kernel
:
...
...
@@ -177,22 +187,15 @@ class SparseMLP(nn.Module):
# expert_output: (num_groups, num_experts, capacity, hidden_size)
if
self
.
expert_parallel
==
"EP"
:
expert_output
=
self
.
_ep_process
(
dispatch_data
,
used_capacity
,
overlap
=
self
.
enable_comm_overlap
)
expert_output
=
self
.
_ep_process
(
dispatch_data
,
used_capacity
,
overlap
=
self
.
enable_comm_overlap
)
elif
self
.
expert_parallel
==
"TP"
:
expert_output
=
self
.
_tp_process
(
dispatch_data
,
used_capacity
,
overlap
=
self
.
enable_comm_overlap
)
expert_output
=
self
.
_tp_process
(
dispatch_data
,
used_capacity
,
overlap
=
self
.
enable_comm_overlap
)
elif
self
.
expert_parallel
is
None
:
expert_output
=
self
.
_local_process
(
dispatch_data
)
else
:
raise
NotImplementedError
(
"This kind of communication has not been implemented yet.
\n
"
"Please use Experts build function."
)
raise
NotImplementedError
(
"This kind of communication has not been implemented yet.
\n
"
"Please use Experts build function."
)
if
self
.
enable_kernel
:
expert_output
=
expert_output
.
reshape
(
-
1
,
self
.
hidden_size
)
...
...
@@ -204,6 +207,10 @@ class SparseMLP(nn.Module):
ans
=
torch
.
matmul
(
combine_weights
,
expert_output
)
ans
=
ans
.
reshape
(
inputs
.
shape
)
if
self
.
return_gate_logits
:
return
ans
,
gate_logits
else
:
return
ans
def
_local_process
(
self
,
expert_in
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
...
@@ -212,10 +219,7 @@ class SparseMLP(nn.Module):
return
expert_out
def
_ep_process
(
self
,
dispatch_data
:
torch
.
Tensor
,
used_capacity
:
torch
.
Tensor
,
overlap
:
bool
=
False
self
,
dispatch_data
:
torch
.
Tensor
,
used_capacity
:
torch
.
Tensor
,
overlap
:
bool
=
False
)
->
torch
.
Tensor
:
"""
Expert Parallel
...
...
@@ -228,10 +232,14 @@ class SparseMLP(nn.Module):
"""
if
not
overlap
or
dist
.
get_world_size
(
self
.
ep_group
)
==
1
:
if
self
.
ep_hierarchical_group
is
not
None
:
expert_input
=
HierarchicalAllToAll
.
apply
(
dispatch_data
,
self
.
ep_hierarchical_group
,
self
.
ep_intra_src_rank
)
expert_input
=
HierarchicalAllToAll
.
apply
(
dispatch_data
,
self
.
ep_hierarchical_group
,
self
.
ep_intra_src_rank
)
expert_input
=
expert_input
.
reshape
(
self
.
ep_size
,
self
.
num_local_experts
,
-
1
,
self
.
hidden_size
)
expert_output
=
self
.
experts
(
expert_input
)
expert_output
=
HierarchicalAllToAll
.
apply
(
expert_output
,
self
.
ep_hierarchical_group
,
self
.
ep_intra_src_rank
)
expert_output
=
HierarchicalAllToAll
.
apply
(
expert_output
,
self
.
ep_hierarchical_group
,
self
.
ep_intra_src_rank
)
return
expert_output
else
:
expert_input
=
AllToAll
.
apply
(
dispatch_data
,
self
.
ep_group
,
False
)[
0
]
...
...
@@ -249,7 +257,7 @@ class SparseMLP(nn.Module):
NUM_CHUNK
=
4
NUM_STAGES
=
4
assert
(
dispatch_data
.
shape
[
1
]
%
NUM_CHUNK
==
0
)
,
"arbitrary chunk num is not supported yet"
assert
dispatch_data
.
shape
[
1
]
%
NUM_CHUNK
==
0
,
"arbitrary chunk num is not supported yet"
chunk_size
=
dispatch_data
.
shape
[
1
]
//
NUM_CHUNK
input_shape
=
(
self
.
ep_size
,
self
.
num_local_experts
,
-
1
,
self
.
hidden_size
)
dispatch_data
=
dispatch_data
.
reshape
(
*
input_shape
)
...
...
@@ -262,13 +270,15 @@ class SparseMLP(nn.Module):
for
i
in
range
(
NUM_CHUNK
+
NUM_STAGES
-
1
):
if
expert_out
is
not
None
:
expert_out
.
handle
.
wait
()
output
[:,
:,
offset
:
offset
+
chunk_size
,
:]
=
expert_out
.
data
output
[:,
:,
offset
:
offset
+
chunk_size
,
:]
=
expert_out
.
data
offset
+=
chunk_size
expert_out
=
None
# all2all last output
if
_expert_out
is
not
None
:
expert_out
=
Capsule
(
*
AllToAll
.
apply
(
_expert_out
.
data
,
self
.
ep_group
,
True
),)
expert_out
=
Capsule
(
*
AllToAll
.
apply
(
_expert_out
.
data
,
self
.
ep_group
,
True
),
)
_expert_out
=
None
# all2all next input
...
...
@@ -288,10 +298,7 @@ class SparseMLP(nn.Module):
return
output
def
_tp_process
(
self
,
dispatch_data
:
torch
.
Tensor
,
used_capacity
:
torch
.
Tensor
,
overlap
:
bool
=
False
self
,
dispatch_data
:
torch
.
Tensor
,
used_capacity
:
torch
.
Tensor
,
overlap
:
bool
=
False
)
->
torch
.
Tensor
:
"""
without overlap:
...
...
@@ -326,8 +333,9 @@ class SparseMLP(nn.Module):
NUM_CHUNK
=
4
NUM_STAGES
=
4
assert
dispatch_data
.
shape
[
0
]
%
NUM_CHUNK
==
0
,
\
"arbitrary chunk num is not supported yet, please use chunk num that can divide num_experts"
assert
(
dispatch_data
.
shape
[
0
]
%
NUM_CHUNK
==
0
),
"arbitrary chunk num is not supported yet, please use chunk num that can divide num_experts"
chunk_size
=
dispatch_data
.
shape
[
0
]
//
NUM_CHUNK
chunk_data
=
torch
.
split
(
dispatch_data
,
chunk_size
,
dim
=
0
)
output
=
torch
.
empty_like
(
dispatch_data
)
...
...
colossalai/moe/routers.py
View file @
efef43b5
...
...
@@ -45,9 +45,13 @@ class MoeRouter(nn.Module, ABC):
self
.
_z_loss
=
None
self
.
use_kernel
=
use_kernel
def
get_capacity
(
self
,
logits_shape
):
def
get_capacity
(
self
,
num_tokens
,
num_experts
,
ep_group
=
None
):
if
ep_group
is
not
None
:
num_tokens_tensor
=
torch
.
tensor
(
num_tokens
,
device
=
get_accelerator
().
get_current_device
())
dist
.
all_reduce
(
num_tokens_tensor
,
group
=
ep_group
)
num_tokens
=
num_tokens_tensor
.
item
()
//
dist
.
get_world_size
(
ep_group
)
capacity_factor
=
self
.
capacity_factor_train
if
self
.
training
else
self
.
capacity_factor_eval
capacity
=
math
.
floor
(
self
.
k_value
*
capacity_factor
*
logits_shape
[
-
2
]
/
logits_shape
[
-
1
]
)
capacity
=
math
.
floor
(
self
.
k_value
*
capacity_factor
*
num_tokens
/
num_experts
)
capacity
+=
capacity
%
2
capacity
=
max
(
capacity
,
self
.
min_capacity
)
assert
capacity
>
0
...
...
@@ -150,7 +154,14 @@ class Top1Router(MoeRouter):
high
=
torch
.
tensor
(
1.0
,
device
=
get_accelerator
().
get_current_device
()),
).
rsample
def
forward
(
self
,
inputs
:
torch
.
Tensor
,
use_kernel
:
bool
=
False
,
ep_group
:
Optional
[
ProcessGroup
]
=
None
)
->
Tuple
:
def
forward
(
self
,
inputs
:
torch
.
Tensor
,
use_kernel
:
bool
=
False
,
ep_group
:
Optional
[
ProcessGroup
]
=
None
,
use_loss
:
bool
=
False
,
use_norm
:
bool
=
False
,
)
->
Tuple
:
"""
Args:
inputs (torch.Tensor): The input tensor of shape (batch_size * seq_len, num_experts).
...
...
@@ -168,7 +179,8 @@ class Top1Router(MoeRouter):
assert
inputs
.
dtype
==
torch
.
float
probs
=
F
.
softmax
(
inputs
,
dim
=-
1
)
num_experts
=
probs
.
size
(
-
1
)
capacity
=
self
.
get_capacity
(
inputs
.
shape
)
num_tokens
=
inputs
.
size
(
0
)
capacity
=
self
.
get_capacity
(
num_tokens
,
num_experts
,
ep_group
)
top1_idx
=
torch
.
argmax
(
inputs
,
dim
=-
1
)
mask
=
F
.
one_hot
(
top1_idx
,
num_classes
=
num_experts
).
to
(
torch
.
int32
)
...
...
@@ -207,7 +219,7 @@ class Top1Router(MoeRouter):
weight
=
mask
*
probs
.
type_as
(
inputs
)
combine_weights
=
weight
.
unsqueeze
(
2
)
*
ranks
.
unsqueeze
(
1
)
sec_mask
=
combine_weights
.
bool
()
return
used_capacity
,
combine_weights
,
sec_mask
return
used_capacity
,
combine_weights
,
sec_mask
,
probs
class
Top2Router
(
MoeRouter
):
...
...
@@ -240,7 +252,14 @@ class Top2Router(MoeRouter):
drop_tks
=
drop_tks
,
)
def
forward
(
self
,
inputs
:
torch
.
Tensor
,
use_kernel
:
bool
=
False
,
ep_group
:
Optional
[
ProcessGroup
]
=
None
)
->
Tuple
:
def
forward
(
self
,
inputs
:
torch
.
Tensor
,
use_kernel
:
bool
=
False
,
ep_group
:
Optional
[
ProcessGroup
]
=
None
,
use_norm
:
bool
=
False
,
use_loss
:
bool
=
True
,
)
->
Tuple
:
"""
Args:
inputs (torch.Tensor): The input tensor of shape (batch_size * seq_len, num_experts).
...
...
@@ -257,8 +276,13 @@ class Top2Router(MoeRouter):
assert
inputs
.
dtype
==
torch
.
float
probs
=
F
.
softmax
(
inputs
,
dim
=-
1
)
if
use_norm
:
routing_weights
,
_
=
torch
.
topk
(
probs
,
2
,
dim
=-
1
)
probs
=
probs
/
routing_weights
.
sum
(
dim
=-
1
,
keepdim
=
True
)
num_experts
=
probs
.
size
(
-
1
)
capacity
=
self
.
get_capacity
(
inputs
.
shape
)
num_tokens
=
inputs
.
size
(
0
)
capacity
=
self
.
get_capacity
(
num_tokens
,
num_experts
,
ep_group
)
top1_idx
=
torch
.
argmax
(
probs
,
dim
=-
1
)
mask1
=
F
.
one_hot
(
top1_idx
,
num_classes
=
num_experts
).
to
(
torch
.
int32
)
...
...
@@ -270,6 +294,7 @@ class Top2Router(MoeRouter):
cmask
=
cmask
.
float
()
/
2.0
# div 2 to normalize it to 1
# calculate loss
if
use_loss
:
expert_indices
=
torch
.
stack
([
top1_idx
,
top2_idx
],
dim
=-
1
)
self
.
set_aux_loss
(
probs
,
expert_indices
,
num_experts
)
self
.
set_z_loss
(
inputs
)
...
...
colossalai/moe/utils.py
View file @
efef43b5
...
...
@@ -83,6 +83,8 @@ def get_activation(act: str) -> Callable:
return
torch
.
nn
.
GELU
()
elif
act
==
"swiglu"
:
return
SwiGLU
elif
act
==
"silu"
:
return
torch
.
nn
.
SiLU
()
else
:
raise
NotImplementedError
(
"Unsupported activation function"
)
...
...
colossalai/tensor/moe_tensor/moe_info.py
View file @
efef43b5
...
...
@@ -26,3 +26,5 @@ class MoeParallelInfo:
self
.
ep_group_ranks
=
self
.
pg
.
get_ranks_in_group
(
self
.
ep_group
)
self
.
dp_group
=
self
.
pg
.
get_group_along_axis
(
self
.
dp_axis
)
self
.
dp_group_ranks
=
self
.
pg
.
get_ranks_in_group
(
self
.
dp_group
)
self
.
ep_rank
=
self
.
pg
.
coordinate
(
self
.
ep_axis
)
self
.
dp_rank
=
self
.
pg
.
coordinate
(
self
.
dp_axis
)
colossalai/zero/low_level/low_level_optim.py
View file @
efef43b5
...
...
@@ -141,7 +141,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
# because they have different parallel strategy
# so we need to store them separately in param_groups
# instead of working_groups
moe_params
=
list
()
self
.
working_
moe_params
=
list
()
# iterate over the param group in the optimizer
# partition these param groups for data parallel training
...
...
@@ -153,7 +153,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
if
self
.
moe_extra_dp_pg
is
None
:
# skip moe param
if
is_moe_tensor
(
param
):
moe_params
.
append
(
param
)
self
.
working_
moe_params
.
append
(
param
)
continue
group_params
.
append
(
param
)
...
...
@@ -168,13 +168,23 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
# managed by this data parallel rank
param_group
[
"params"
]
=
master_param_current_rank
# if there are moe params, store in additional group in optim
if
len
(
moe_params
)
>
0
:
# if there are moe params, store in addtional group in optim
if
len
(
self
.
working_moe_params
)
>
0
:
self
.
_sync_master_param
=
False
param_group
=
dict
()
# create fp32 master param
for
key
,
value
in
self
.
optim
.
param_groups
[
0
].
items
():
if
key
!=
"params"
:
param_group
[
key
]
=
value
param_group
[
"params"
]
=
moe_params
self
.
master_moe_params
=
[]
for
param
in
self
.
working_moe_params
:
self
.
master_moe_params
.
append
(
param
.
clone
().
to
(
torch
.
float32
).
detach
())
# create mapping from master to working for optimizer io
self
.
moe_master_to_working_map
=
{}
for
master_moe_param
,
working_moe_param
in
zip
(
self
.
master_moe_params
,
self
.
working_moe_params
):
self
.
moe_master_to_working_map
[
id
(
master_moe_param
)]
=
working_moe_param
# add to optim
param_group
[
"params"
]
=
self
.
master_moe_params
self
.
optim
.
param_groups
.
append
(
param_group
)
# initialize communication stream for
...
...
@@ -593,24 +603,40 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
# update the params in the optimizer
self
.
optim
.
param_groups
[
group_id
][
"params"
]
=
real_master_params
[
group_id
]
# update param for moe ep
# move grad to master param and compute norm
if
len
(
self
.
working_moe_params
)
>
0
:
moe_grads
=
[]
for
master_moe_param
,
working_moe_param
in
zip
(
self
.
master_moe_params
,
self
.
working_moe_params
):
if
master_moe_param
.
grad
is
not
None
:
raise
RuntimeError
(
"Moe param should not have grad here"
)
grad
=
working_moe_param
.
grad
# no need to copy fp32 grad if master_weights is False
if
self
.
_master_weights
:
grad
=
grad
.
to
(
master_moe_param
.
dtype
).
to
(
master_moe_param
.
device
)
master_moe_param
.
grad
=
grad
working_moe_param
.
grad
=
None
moe_grads
.
append
(
grad
)
grad_partition_groups
.
append
(
grad
)
norm_group
=
self
.
_compute_grad_norm
(
gradients
=
moe_grads
)
norm_groups
.
append
(
norm_group
)
self
.
optim
.
param_groups
[
-
1
][
"params"
]
=
self
.
master_moe_params
del
moe_grads
# unscale and clip grads
global_norm
=
calculate_global_norm_from_list
(
norm_list
=
norm_groups
)
self
.
_unscale_and_clip_grads
(
grad_partition_groups
,
global_norm
)
# TODO: we should store master param for ep
if
len
(
self
.
param_groups
)
>
len
(
self
.
_working_param_groups
):
for
param
in
self
.
param_groups
[
-
1
][
"params"
]:
param
.
data
=
param
.
data
.
to
(
torch
.
float32
)
param
.
grad
=
param
.
grad
.
to
(
torch
.
float32
)
# update the parameters
self
.
optim
.
step
()
# release the moe gradm
if
len
(
self
.
param_groups
)
>
len
(
self
.
_working_param_groups
):
for
param
in
self
.
param_groups
[
-
1
][
"params"
]:
param
.
grad
=
None
param
.
data
=
param
.
data
.
to
(
self
.
_dtype
)
# release moe grad
if
len
(
self
.
working_moe_params
)
>
0
:
for
master_moe_param
,
working_moe_param
in
zip
(
self
.
master_moe_params
,
self
.
working_moe_params
):
master_moe_param
.
grad
=
None
working_moe_param
.
data
=
(
master_moe_param
.
data
.
to
(
working_moe_param
.
device
).
to
(
working_moe_param
.
dtype
).
detach
()
)
# release the grad
grad_partition_groups
=
[]
...
...
@@ -885,9 +911,14 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
master_param
.
copy_
(
working_param
.
chunk
(
self
.
extra_dp_pg_size
)[
self
.
extra_dp_pg_rank
])
else
:
master_param
.
copy_
(
working_param
.
chunk
(
self
.
_world_size
)[
self
.
_local_rank
])
if
hasattr
(
self
,
"master_moe_params"
):
for
master_moe_param
,
working_moe_param
in
zip
(
self
.
master_moe_params
,
self
.
working_moe_params
):
master_moe_param
.
copy_
(
working_moe_param
)
def
get_working_to_master_map
(
self
)
->
Dict
[
int
,
torch
.
Tensor
]:
return
self
.
_param_store
.
working_to_master_param
def
get_master_to_working_map
(
self
)
->
Dict
[
int
,
torch
.
Tensor
]:
if
hasattr
(
self
,
"moe_master_to_working_map"
):
return
{
**
self
.
_param_store
.
master_to_working_param
,
**
self
.
moe_master_to_working_map
}
return
self
.
_param_store
.
master_to_working_param
tests/test_moe/moe_utils.py
View file @
efef43b5
import
torch
import
torch.distributed
as
dist
import
torch.nn
as
nn
from
torch.testing
import
assert_close
from
colossalai.booster.plugin.low_level_zero_plugin
import
LowLevelZeroModel
from
colossalai.legacy.engine.gradient_handler._base_gradient_handler
import
BaseGradientHandler
from
colossalai.legacy.engine.gradient_handler.utils
import
bucket_allreduce
from
colossalai.legacy.registry
import
GRADIENT_HANDLER
from
colossalai.moe
import
SparseMLP
from
colossalai.moe.manager
import
MOE_MANAGER
from
colossalai.moe.utils
import
get_moe_epsize_param_dict
from
colossalai.tensor.moe_tensor.api
import
get_ep_group
,
get_ep_size
def
delete_moe_info
(
model
):
for
_
,
param
in
model
.
named_parameters
():
if
hasattr
(
param
,
"moe_info"
):
delattr
(
param
,
"moe_info"
)
class
MoeModel
(
nn
.
Module
):
...
...
@@ -85,6 +94,74 @@ def assert_not_equal_in_group(tensor, process_group=None):
for
i
in
range
(
world_size
-
1
):
a
=
tensor_list
[
i
]
b
=
tensor_list
[
i
+
1
]
assert
not
torch
.
allclose
(
a
,
b
),
\
(
f
"expected tensors on rank
{
i
}
and
{
i
+
1
}
not to be equal "
f
"but they are,
{
a
}
vs
{
b
}
"
)
assert
not
torch
.
allclose
(
a
,
b
),
(
f
"expected tensors on rank
{
i
}
and
{
i
+
1
}
not to be equal "
f
"but they are,
{
a
}
vs
{
b
}
"
)
def
run_fwd_bwd
(
model
,
data
,
label
,
criterion
,
optimizer
,
enable_autocast
=
False
):
model
.
train
()
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
enable_autocast
):
if
criterion
:
y
=
model
(
data
)
loss
=
criterion
(
y
,
label
)
else
:
loss
=
model
(
data
,
label
)
loss
=
loss
.
float
()
if
isinstance
(
model
,
LowLevelZeroModel
):
optimizer
.
backward
(
loss
)
else
:
loss
.
backward
()
return
y
def
sync_local_from_ep
(
local_model
:
SparseMLP
,
ep_model
:
SparseMLP
,
assert_grad_flag
:
bool
=
False
)
->
None
:
"""Sync the parameters of tp model from ep model
Args:
local_model (MoeModule)
ep_model (MoeModule)
"""
for
(
local_name
,
local_param
),
(
ep_name
,
ep_param
)
in
zip
(
local_model
.
named_parameters
(),
ep_model
.
named_parameters
()
):
assert
local_name
in
ep_name
,
print
(
f
"
{
local_name
}
!=
{
ep_name
}
"
)
if
"experts"
not
in
local_name
:
if
assert_grad_flag
:
assert
torch
.
allclose
(
local_param
,
ep_param
),
f
"local_param:
{
local_param
}
, ep_param:
{
ep_param
}
"
assert
torch
.
allclose
(
local_param
.
grad
,
ep_param
.
grad
)
else
:
local_param
.
data
.
copy_
(
ep_param
.
data
)
continue
# gather param from ep model
param_list
=
[
torch
.
zeros_like
(
ep_param
)
for
_
in
range
(
get_ep_size
(
ep_param
))]
dist
.
all_gather
(
param_list
,
ep_param
,
group
=
get_ep_group
(
ep_param
))
all_param
=
torch
.
cat
(
param_list
,
dim
=
0
)
if
assert_grad_flag
:
grad_list
=
[
torch
.
zeros_like
(
ep_param
)
for
_
in
range
(
get_ep_size
(
ep_param
))]
dist
.
all_gather
(
grad_list
,
ep_param
.
grad
,
group
=
get_ep_group
(
ep_param
))
all_grad
=
torch
.
cat
(
grad_list
,
dim
=
0
)
if
assert_grad_flag
:
assert
torch
.
allclose
(
local_param
,
all_param
)
assert
torch
.
allclose
(
local_param
.
grad
,
all_grad
)
else
:
local_param
.
data
.
copy_
(
all_param
.
data
)
def
loose_close
(
a
,
b
,
dtype
:
torch
.
dtype
=
torch
.
float32
):
rtol
=
None
atol
=
None
if
dtype
is
torch
.
float16
:
rtol
=
5e-2
atol
=
5e-4
elif
dtype
is
torch
.
bfloat16
:
rtol
=
4e-3
atol
=
4e-3
a
=
a
.
detach
().
to
(
dtype
)
b
=
b
.
detach
().
to
(
dtype
).
to
(
a
.
device
)
assert_close
(
a
,
b
,
rtol
=
rtol
,
atol
=
atol
)
tests/test_moe/test_moe_checkpoint.py
View file @
efef43b5
...
...
@@ -12,7 +12,6 @@ import colossalai
from
colossalai.accelerator
import
get_accelerator
from
colossalai.booster
import
Booster
from
colossalai.booster.plugin.moe_hybrid_parallel_plugin
import
MoeHybridParallelPlugin
from
colossalai.moe.manager
import
MOE_MANAGER
from
colossalai.testing
import
DummyDataloader
,
check_state_dict_equal
,
rerun_if_address_is_in_use
,
spawn
sys
.
path
.
append
(
...
...
@@ -95,6 +94,7 @@ def get_model(parallel):
precision
=
"bf16"
,
tp_size
=
1
,
pp_size
=
1
,
ep_size
=
1
,
zero_stage
=
2
,
custom_policy
=
OpenMoeForCausalLMPolicy
(),
)
...
...
@@ -103,6 +103,7 @@ def get_model(parallel):
precision
=
"bf16"
,
tp_size
=
1
,
pp_size
=
1
,
ep_size
=
dist
.
get_world_size
(),
zero_stage
=
2
,
custom_policy
=
OpenMoeForCausalLMPolicy
(),
)
...
...
@@ -111,6 +112,7 @@ def get_model(parallel):
precision
=
"bf16"
,
tp_size
=
1
,
pp_size
=
1
,
ep_size
=
2
,
zero_stage
=
2
,
extra_dp_size
=
2
,
custom_policy
=
OpenMoeForCausalLMPolicy
(),
...
...
@@ -120,6 +122,7 @@ def get_model(parallel):
precision
=
"bf16"
,
tp_size
=
1
,
pp_size
=
2
,
ep_size
=
2
,
zero_stage
=
1
,
microbatch_size
=
1
,
custom_policy
=
OpenMoeForCausalLMPolicy
(),
...
...
@@ -130,27 +133,6 @@ def get_model(parallel):
def
_test_moe_checkpoint
(
rank
,
parallel
):
if
parallel
==
None
:
MOE_MANAGER
.
setup
(
parallel
=
None
,
)
elif
parallel
==
"ep"
:
MOE_MANAGER
.
setup
(
parallel
=
"EP"
,
)
elif
parallel
==
"ep_zero"
:
MOE_MANAGER
.
setup
(
parallel
=
"EP"
,
max_ep_size
=
2
,
)
elif
parallel
==
"hybrid"
:
MOE_MANAGER
.
setup
(
parallel
=
"EP"
,
mode
=
"fixed"
,
fixed_dp_size
=
1
,
fixed_ep_size
=
2
,
fixed_pp_size
=
2
,
)
model1
,
booster1
,
optim1
=
get_model
(
parallel
)
model2
,
booster2
,
optim2
=
get_model
(
parallel
)
model3
,
booster3
,
optim3
=
get_model
(
parallel
)
...
...
@@ -207,6 +189,7 @@ def _run_dist(rank, world_size, port, parallel):
_test_moe_checkpoint
(
rank
,
parallel
)
@
pytest
.
mark
.
skip
(
reason
=
"This is tested in ColossalMOE"
)
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
4
])
@
pytest
.
mark
.
parametrize
(
"parallel"
,
[
None
,
"ep"
,
"ep_zero"
,
"hybrid"
])
...
...
tests/test_moe/test_moe_router.py
View file @
efef43b5
...
...
@@ -4,15 +4,21 @@ import torch
from
colossalai.moe.routers
import
MoeRouter
,
Top1Router
,
Top2Router
,
TopKRouter
@
pytest
.
mark
.
parametrize
([
"router"
,
"num_groups"
],
[
@
pytest
.
mark
.
parametrize
(
[
"router"
,
"num_groups"
],
[
(
Top1Router
(),
1
),
(
Top2Router
(),
1
),
# (TopKRouter(num_selected_experts=3), 4),
])
@
pytest
.
mark
.
parametrize
([
"batch_size"
,
"seq_len"
,
"num_experts"
],
[
],
)
@
pytest
.
mark
.
parametrize
(
[
"batch_size"
,
"seq_len"
,
"num_experts"
],
[
(
4
,
5
,
8
),
(
3
,
4
,
4
),
])
],
)
def
test_router_forward
(
router
:
MoeRouter
,
batch_size
:
int
,
seq_len
:
int
,
num_experts
:
int
,
num_groups
:
int
):
x
=
torch
.
randn
((
batch_size
*
seq_len
,
num_experts
)).
cuda
()
if
num_groups
>
1
:
...
...
@@ -20,18 +26,18 @@ def test_router_forward(router: MoeRouter, batch_size: int, seq_len: int, num_ex
router
.
train
()
if
isinstance
(
router
,
TopKRouter
):
_
,
combine_array
,
dispatch_mask
=
router
(
x
,
expert_capacity
=
2
)
combine_array
,
dispatch_mask
=
router
(
x
,
expert_capacity
=
2
)
else
:
_
,
combine_array
,
dispatch_mask
=
router
(
x
)
combine_array
,
dispatch_mask
=
router
(
x
)
[
1
:
3
]
assert
combine_array
.
shape
[:
-
1
]
==
x
.
shape
assert
dispatch_mask
.
shape
[:
-
1
]
==
x
.
shape
assert
torch
.
all
(
dispatch_mask
.
sum
(
-
1
).
sum
(
-
1
)
<=
router
.
k_value
)
router
.
eval
()
if
isinstance
(
router
,
TopKRouter
):
_
,
combine_array
,
dispatch_mask
=
router
(
x
,
expert_capacity
=
2
)
combine_array
,
dispatch_mask
=
router
(
x
,
expert_capacity
=
2
)
else
:
_
,
combine_array
,
dispatch_mask
=
router
(
x
)
combine_array
,
dispatch_mask
=
router
(
x
)
[
1
:
3
]
assert
combine_array
.
shape
[:
-
1
]
==
x
.
shape
assert
dispatch_mask
.
shape
[:
-
1
]
==
x
.
shape
assert
torch
.
all
(
dispatch_mask
.
sum
(
-
1
).
sum
(
-
1
)
<=
router
.
k_value
)
...
...
tests/test_moe/test_moe_zero_fwd_bwd.py
View file @
efef43b5
...
...
@@ -4,102 +4,75 @@ import torch
import
colossalai
from
colossalai.booster
import
Booster
from
colossalai.booster.plugin
import
LowLevelZeroPlugin
from
colossalai.booster.plugin.low_level_zero_plugin
import
LowLevelZeroModel
from
colossalai.moe.manager
import
MOE_MANAGER
from
colossalai.testing
import
rerun_if_address_is_in_use
,
spawn
from
colossalai.testing.random
import
seed_all
from
tests.test_moe.moe_utils
import
Moe
GradientHandler
,
MoeModel
from
tests.test_moe.moe_utils
import
Moe
Model
,
delete_moe_info
,
run_fwd_bwd
,
sync_local_from_ep
def
split_ddp_grad
(
grad
,
world_size
):
with
torch
.
no_grad
():
grad
=
grad
.
clone
().
detach
().
flatten
()
padding_size
=
(
world_size
-
grad
.
numel
()
%
world_size
)
%
world_size
if
padding_size
>
0
:
grad
=
torch
.
nn
.
functional
.
pad
(
grad
,
[
0
,
padding_size
])
splited_grad
=
grad
.
split
(
grad
.
numel
()
//
world_size
)
return
splited_grad
def
run_fwd_bwd
(
model
,
data
,
label
,
criterion
,
optimizer
,
enable_autocast
=
False
):
model
.
train
()
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
enable_autocast
):
if
criterion
:
y
=
model
(
data
)
loss
=
criterion
(
y
,
label
)
else
:
loss
=
model
(
data
,
label
)
loss
=
loss
.
float
()
if
isinstance
(
model
,
LowLevelZeroModel
):
optimizer
.
backward
(
loss
)
else
:
loss
.
backward
()
return
y
def
run_zero_test
(
local_rank
,
world_size
,
stage
=
1
):
def
run_zero_test
(
local_rank
,
stage
=
1
):
criterion
=
torch
.
nn
.
CrossEntropyLoss
()
zero_model
=
MoeModel
()
optimizer
=
torch
.
optim
.
Adam
(
zero_model
.
parameters
())
plugin
=
LowLevelZeroPlugin
(
stage
=
stage
,
precision
=
"fp32"
)
booster
=
Booster
(
plugin
=
plugin
)
zero_model
,
optimizer
,
_
,
_
,
_
=
booster
.
boost
(
zero_model
,
optimizer
)
torch_model
=
MoeModel
()
for
zero_param
,
torch_param
in
zip
(
zero_model
.
parameters
(),
torch_model
.
parameters
()):
torch_param
.
data
.
copy_
(
zero_param
.
data
)
torch_model
=
torch_model
.
cuda
()
grad_handler
=
MoeGradientHandler
(
torch_model
)
# assert zero model
for
(
torch_name
,
torch_param
),
(
zero_name
,
zero_param
)
in
zip
(
torch_model
.
named_parameters
(),
zero_model
.
module
.
named_parameters
()
):
assert
zero_name
==
torch_name
assert
torch
.
allclose
(
zero_param
.
data
,
torch_param
.
data
)
data
=
torch
.
randn
(
16
,
4
).
cuda
()
MOE_MANAGER
.
__init__
()
MOE_MANAGER
.
setup
(
parallel
=
"EP"
)
moe_model
=
MoeModel
().
bfloat16
()
moe_optimizer
=
torch
.
optim
.
Adam
(
moe_model
.
parameters
())
moe_plugin
=
LowLevelZeroPlugin
(
stage
=
stage
,
precision
=
"bf16"
)
moe_booster
=
Booster
(
plugin
=
moe_plugin
)
moe_model
,
moe_optimizer
,
_
,
_
,
_
=
moe_booster
.
boost
(
moe_model
,
moe_optimizer
)
MOE_MANAGER
.
__init__
()
MOE_MANAGER
.
setup
(
parallel
=
None
)
zero_model
=
MoeModel
().
bfloat16
()
delete_moe_info
(
zero_model
)
zero_optimizer
=
torch
.
optim
.
Adam
(
zero_model
.
parameters
())
zero_plugin
=
LowLevelZeroPlugin
(
stage
=
stage
,
precision
=
"bf16"
)
zero_booster
=
Booster
(
plugin
=
zero_plugin
)
zero_model
,
zero_optimizer
,
_
,
_
,
_
=
zero_booster
.
boost
(
zero_model
,
zero_optimizer
)
sync_local_from_ep
(
zero_model
,
moe_model
)
data
=
torch
.
randn
(
16
,
4
).
bfloat16
().
cuda
()
label
=
torch
.
randint
(
0
,
4
,
(
16
,)).
cuda
()
torch_out
=
run_fwd_bwd
(
torch_model
,
data
,
label
,
criterion
,
None
)
zero_out
=
run_fwd_bwd
(
zero_model
,
data
,
label
,
criterion
,
optimizer
)
assert
torch
.
allclose
(
torch_out
,
zero_out
)
grad_handler
.
handle_gradient
()
zero_out
=
run_fwd_bwd
(
zero_model
,
data
,
label
,
criterion
,
zero_optimizer
)
moe_out
=
run_fwd_bwd
(
moe_model
,
data
,
label
,
criterion
,
moe_optimizer
)
assert
torch
.
allclose
(
zero_out
,
moe_out
)
for
(
zero
_name
,
zero
_param
),
(
torch
_name
,
torch
_param
)
in
zip
(
zero
_model
.
module
.
named_parameters
(),
torch
_model
.
named_parameters
()
for
(
moe
_name
,
moe
_param
),
(
zero
_name
,
zero
_param
)
in
zip
(
moe
_model
.
module
.
named_parameters
(),
zero
_model
.
module
.
named_parameters
()
):
assert
zero_name
==
torch_name
zero_grad_list
=
optimizer
.
_grad_store
.
get_partitioned_gradients_by_param_id
(
0
,
id
(
zero_param
))
if
hasattr
(
zero_param
,
"moe_info"
):
assert
len
(
zero_grad_list
)
==
0
assert
torch
.
allclose
(
zero_param
.
grad
,
torch_param
.
grad
)
assert
moe_name
==
zero_name
moe_grad_list
=
moe_optimizer
.
_grad_store
.
get_partitioned_gradients_by_param_id
(
0
,
id
(
moe_param
))
zero_grad_list
=
zero_optimizer
.
_grad_store
.
get_partitioned_gradients_by_param_id
(
0
,
id
(
zero_param
))
if
hasattr
(
moe_param
,
"moe_info"
):
assert
len
(
moe_grad_list
)
==
0
if
stage
==
1
:
zero_grad
=
zero_grad_list
[
local_rank
].
view
(
moe_param
.
grad
.
shape
)
else
:
zero_grad
=
zero_grad_list
[
0
].
view
(
moe_param
.
grad
.
shape
)
assert
torch
.
allclose
(
moe_param
.
grad
,
zero_grad
,
atol
=
1e-5
),
f
"zero grad:
\n
{
moe_param
.
grad
}
\n
torch grad:
\n
{
zero_grad
}
\n
max diff:
{
(
moe_param
.
grad
-
zero_grad
).
abs
().
max
()
}
, mean diff:
{
(
moe_param
.
grad
-
zero_grad
).
abs
().
mean
()
}
"
else
:
assert
len
(
zero_grad_list
)
>
0
torch_grad_list
=
split_ddp_grad
(
torch_param
.
grad
,
world_size
)
if
stage
==
2
:
torch_grad_list
=
torch_grad_list
[
local_rank
:
local_rank
+
1
]
assert
len
(
zero_grad_list
)
==
len
(
torch_grad_list
)
for
zero_grad
,
torch_grad
in
zip
(
zero_grad_list
,
torch_grad_list
):
assert
torch
.
allclose
(
zero_grad
,
torch_grad
)
assert
len
(
moe_grad_list
)
>
0
assert
len
(
moe_grad_list
)
==
len
(
zero_grad_list
)
for
moe_grad
,
zero_grad
in
zip
(
moe_grad_list
,
zero_grad_list
):
assert
torch
.
allclose
(
moe_grad
,
zero_grad
)
def
run_dist
(
rank
,
world_size
,
port
):
def
run_dist
(
rank
,
world_size
,
port
,
stage
):
colossalai
.
launch
(
config
=
dict
(),
rank
=
rank
,
world_size
=
world_size
,
host
=
"localhost"
,
port
=
port
,
backend
=
"nccl"
)
MOE_MANAGER
.
setup
(
parallel
=
"EP"
)
seed_all
(
42
+
rank
)
run_zero_test
(
rank
,
world_size
,
stage
=
1
)
run_zero_test
(
rank
,
world_size
,
stage
=
2
)
run_zero_test
(
rank
,
stage
=
stage
)
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
2
])
@
pytest
.
mark
.
parametrize
(
"stage"
,
[
1
,
2
])
@
rerun_if_address_is_in_use
()
def
test_moe_zero_model
(
world_size
):
spawn
(
run_dist
,
world_size
)
def
test_moe_zero_model
(
world_size
,
stage
):
spawn
(
run_dist
,
world_size
,
stage
=
stage
)
if
__name__
==
"__main__"
:
test_moe_zero_model
(
world_size
=
2
)
test_moe_zero_model
(
world_size
=
2
,
stage
=
1
)
tests/test_moe/test_moe_zero_optim.py
View file @
efef43b5
...
...
@@ -4,89 +4,80 @@ import torch
import
colossalai
from
colossalai.booster
import
Booster
from
colossalai.booster.plugin
import
LowLevelZeroPlugin
from
colossalai.booster.plugin.low_level_zero_plugin
import
LowLevelZeroModel
from
colossalai.moe.manager
import
MOE_MANAGER
from
colossalai.tensor.moe_tensor.api
import
is_moe_tensor
from
colossalai.testing
import
rerun_if_address_is_in_use
,
spawn
from
tests.test_moe.moe_utils
import
MoeGradientHandler
,
MoeModel
from
colossalai.testing.random
import
seed_all
from
tests.test_moe.moe_utils
import
MoeModel
,
delete_moe_info
,
loose_close
,
run_fwd_bwd
,
sync_local_from_ep
def
split_ddp_grad
(
grad
,
world_size
):
with
torch
.
no_grad
():
grad
=
grad
.
clone
().
detach
().
flatten
()
padding_size
=
(
world_size
-
grad
.
numel
()
%
world_size
)
%
world_size
if
padding_size
>
0
:
grad
=
torch
.
nn
.
functional
.
pad
(
grad
,
[
0
,
padding_size
])
splited_grad
=
grad
.
split
(
grad
.
numel
()
//
world_size
)
return
splited_grad
def
run_fwd_bwd
(
model
,
data
,
label
,
criterion
,
optimizer
,
enable_autocast
=
False
):
model
.
train
()
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
enable_autocast
):
if
criterion
:
y
=
model
(
data
)
loss
=
criterion
(
y
,
label
)
else
:
loss
=
model
(
data
,
label
)
loss
=
loss
.
float
()
if
isinstance
(
model
,
LowLevelZeroModel
):
optimizer
.
backward
(
loss
)
else
:
loss
.
backward
()
return
y
def
run_zero_test
(
local_rank
,
stage
=
1
):
criterion
=
torch
.
nn
.
CrossEntropyLoss
()
MOE_MANAGER
.
__init__
()
MOE_MANAGER
.
setup
(
parallel
=
"EP"
)
moe_model
=
MoeModel
().
bfloat16
()
moe_optimizer
=
torch
.
optim
.
Adam
(
moe_model
.
parameters
(),
lr
=
1.0
)
moe_plugin
=
LowLevelZeroPlugin
(
stage
=
stage
,
precision
=
"bf16"
)
moe_booster
=
Booster
(
plugin
=
moe_plugin
)
moe_model
,
moe_optimizer
,
_
,
_
,
_
=
moe_booster
.
boost
(
moe_model
,
moe_optimizer
)
MOE_MANAGER
.
__init__
()
MOE_MANAGER
.
setup
(
parallel
=
None
)
zero_model
=
MoeModel
().
bfloat16
()
delete_moe_info
(
zero_model
)
sync_local_from_ep
(
zero_model
,
moe_model
)
zero_optimizer
=
torch
.
optim
.
Adam
(
zero_model
.
parameters
(),
lr
=
1.0
)
zero_plugin
=
LowLevelZeroPlugin
(
stage
=
stage
,
precision
=
"bf16"
)
zero_booster
=
Booster
(
plugin
=
zero_plugin
)
zero_model
,
zero_optimizer
,
_
,
_
,
_
=
zero_booster
.
boost
(
zero_model
,
zero_optimizer
)
for
(
moe_name
,
moe_param
),
(
zero_name
,
zero_param
)
in
zip
(
moe_model
.
named_parameters
(),
zero_model
.
named_parameters
()
):
if
".experts."
in
moe_name
:
continue
assert
moe_name
==
zero_name
assert
torch
.
allclose
(
moe_param
.
data
,
zero_param
.
data
),
f
"
{
moe_name
}
\n
torch_param
{
moe_param
.
data
}
\n
zero_param
{
zero_param
.
data
}
"
def
run_zero_optim_test
(
local_rank
,
world_size
,
stage
=
1
):
criterion
=
torch
.
nn
.
CrossEntropyLoss
()
for
_
in
range
(
1
):
data
=
torch
.
randn
(
2
,
4
).
bfloat16
().
cuda
()
label
=
torch
.
randint
(
0
,
4
,
(
2
,)).
cuda
()
zero_model
=
MoeModel
()
zero_optimizer
=
torch
.
optim
.
Adam
(
zero_model
.
parameters
())
plugin
=
LowLevelZeroPlugin
(
stage
=
stage
,
precision
=
"fp32"
)
booster
=
Booster
(
plugin
=
plugin
)
zero_model
,
zero_optimizer
,
_
,
_
,
_
=
booster
.
boost
(
zero_model
,
zero_optimizer
)
torch_model
=
MoeModel
()
for
zero_param
,
torch_param
in
zip
(
zero_model
.
parameters
(),
torch_model
.
parameters
()):
torch_param
.
data
.
copy_
(
zero_param
.
data
)
torch_optimizer
=
torch
.
optim
.
Adam
(
torch_model
.
parameters
())
torch_model
=
torch_model
.
cuda
()
grad_handler
=
MoeGradientHandler
(
torch_model
)
for
_
in
range
(
2
):
data
=
torch
.
randn
(
16
,
4
).
cuda
()
/
(
local_rank
+
1
)
label
=
torch
.
randint
(
0
,
4
,
(
16
,)).
cuda
()
run_fwd_bwd
(
torch_model
,
data
,
label
,
criterion
,
None
)
run_fwd_bwd
(
zero_model
,
data
,
label
,
criterion
,
zero_optimizer
)
grad_handler
.
handle_gradient
()
torch_optimizer
.
step
()
moe_out
=
run_fwd_bwd
(
moe_model
,
data
,
label
,
criterion
,
moe_optimizer
)
zero_out
=
run_fwd_bwd
(
zero_model
,
data
,
label
,
criterion
,
zero_optimizer
)
assert
torch
.
allclose
(
zero_out
,
moe_out
)
moe_optimizer
.
step
()
zero_optimizer
.
step
()
for
(
torch
_name
,
torch
_param
),
(
zero_name
,
zero_param
)
in
zip
(
torch
_model
.
named_parameters
(),
zero_model
.
named_parameters
()
for
(
moe
_name
,
moe
_param
),
(
zero_name
,
zero_param
)
in
zip
(
moe
_model
.
named_parameters
(),
zero_model
.
named_parameters
()
):
assert
torch
.
allclose
(
torch_param
.
data
,
zero_param
.
data
),
f
"
{
torch_name
}
\n
torch_param
{
torch_param
.
data
}
\n
zero_param
{
zero_param
.
data
}
"
assert
moe_name
==
zero_name
if
is_moe_tensor
(
moe_param
):
param_size
=
moe_param
.
shape
[
0
]
zero_param
=
zero_param
[
local_rank
*
param_size
:
(
local_rank
+
1
)
*
param_size
]
loose_close
(
moe_param
.
data
,
zero_param
.
data
,
dtype
=
moe_param
.
dtype
)
torch
_optimizer
.
zero_grad
()
moe
_optimizer
.
zero_grad
()
zero_optimizer
.
zero_grad
()
def
run_dist
(
rank
,
world_size
,
port
):
def
run_dist
(
rank
,
world_size
,
port
,
stage
):
colossalai
.
launch
(
config
=
dict
(),
rank
=
rank
,
world_size
=
world_size
,
host
=
"localhost"
,
port
=
port
,
backend
=
"nccl"
)
MOE_MANAGER
.
setup
(
parallel
=
"EP"
)
run_zero_optim_test
(
rank
,
world_size
,
stage
=
1
)
run_zero_optim_test
(
rank
,
world_size
,
stage
=
2
)
seed_all
(
42
+
rank
)
run_zero_test
(
rank
,
stage
=
stage
)
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
2
])
@
pytest
.
mark
.
parametrize
(
"stage"
,
[
1
,
2
])
@
rerun_if_address_is_in_use
()
def
test_moe_zero_optim
(
world_size
):
spawn
(
run_dist
,
world_size
)
def
test_moe_zero_optim
(
world_size
,
stage
):
spawn
(
run_dist
,
world_size
,
stage
=
stage
)
if
__name__
==
"__main__"
:
test_moe_zero_optim
(
world_size
=
2
)
test_moe_zero_optim
(
world_size
=
2
,
stage
=
1
)
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment