Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
efef43b5
Unverified
Commit
efef43b5
authored
Feb 08, 2024
by
Frank Lee
Committed by
GitHub
Feb 08, 2024
Browse files
Merge pull request #5372 from hpcaitech/exp/mixtral
parents
4c03347f
06db94fb
Changes
33
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
444 additions
and
266 deletions
+444
-266
colossalai/moe/_operation.py
colossalai/moe/_operation.py
+66
-1
colossalai/moe/checkpoint.py
colossalai/moe/checkpoint.py
+39
-28
colossalai/moe/experts.py
colossalai/moe/experts.py
+5
-1
colossalai/moe/layers.py
colossalai/moe/layers.py
+40
-32
colossalai/moe/routers.py
colossalai/moe/routers.py
+36
-11
colossalai/moe/utils.py
colossalai/moe/utils.py
+2
-0
colossalai/tensor/moe_tensor/moe_info.py
colossalai/tensor/moe_tensor/moe_info.py
+2
-0
colossalai/zero/low_level/low_level_optim.py
colossalai/zero/low_level/low_level_optim.py
+47
-16
tests/test_moe/moe_utils.py
tests/test_moe/moe_utils.py
+80
-3
tests/test_moe/test_moe_checkpoint.py
tests/test_moe/test_moe_checkpoint.py
+5
-22
tests/test_moe/test_moe_router.py
tests/test_moe/test_moe_router.py
+19
-13
tests/test_moe/test_moe_zero_fwd_bwd.py
tests/test_moe/test_moe_zero_fwd_bwd.py
+48
-75
tests/test_moe/test_moe_zero_optim.py
tests/test_moe/test_moe_zero_optim.py
+55
-64
No files found.
colossalai/moe/_operation.py
View file @
efef43b5
from
typing
import
Any
,
Optional
,
Tuple
from
typing
import
Any
,
List
,
Optional
,
Tuple
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
...
@@ -329,3 +329,68 @@ class MoeOutGradScaler(torch.autograd.Function):
...
@@ -329,3 +329,68 @@ class MoeOutGradScaler(torch.autograd.Function):
if
ctx
.
ep_size
!=
1
:
if
ctx
.
ep_size
!=
1
:
grad
=
grad
/
ctx
.
ep_size
grad
=
grad
/
ctx
.
ep_size
return
grad
,
None
return
grad
,
None
def
_all_to_all
(
inputs
:
torch
.
Tensor
,
input_split_sizes
:
Optional
[
List
[
int
]]
=
None
,
output_split_sizes
:
Optional
[
List
[
int
]]
=
None
,
group
=
None
,
async_op
:
bool
=
False
,
):
"""
Returns:
outputs: Tensor
handle: Optional[Work], if overlap is True
"""
outputs_shape
=
list
(
inputs
.
shape
)
if
output_split_sizes
is
not
None
:
outputs_shape
[
0
]
=
sum
(
output_split_sizes
)
outputs
=
torch
.
empty
(
outputs_shape
,
dtype
=
inputs
.
dtype
,
device
=
inputs
.
device
)
inputs
=
inputs
.
contiguous
()
outputs
=
outputs
.
contiguous
()
handle
=
dist
.
all_to_all_single
(
outputs
,
inputs
,
output_split_sizes
,
input_split_sizes
,
group
=
group
,
async_op
=
async_op
)
return
outputs
,
handle
class
AllToAllUneven
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
inputs
,
input_split_sizes
=
None
,
output_split_sizes
=
None
,
group
=
None
,
overlap
:
bool
=
False
,
):
"""
Returns:
outputs: Tensor
handle: Optional[Work], if overlap is True
"""
ctx
.
input_split_sizes
=
input_split_sizes
ctx
.
output_split_sizes
=
output_split_sizes
ctx
.
group
=
group
return
_all_to_all
(
inputs
,
input_split_sizes
,
output_split_sizes
,
group
,
overlap
)
@
staticmethod
def
backward
(
ctx
:
Any
,
*
grad_outputs
):
return
(
_all_to_all
(
grad_outputs
[
0
],
ctx
.
output_split_sizes
,
ctx
.
input_split_sizes
,
ctx
.
group
,
False
)[
0
],
None
,
None
,
None
,
None
,
)
def
all_to_all_uneven
(
inputs
:
torch
.
Tensor
,
input_split_sizes
:
Optional
[
List
[
int
]]
=
None
,
output_split_sizes
:
Optional
[
List
[
int
]]
=
None
,
group
=
None
,
overlap
:
bool
=
False
,
):
return
AllToAllUneven
.
apply
(
inputs
,
input_split_sizes
,
output_split_sizes
,
group
,
overlap
)
colossalai/moe/checkpoint.py
View file @
efef43b5
...
@@ -224,6 +224,7 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
...
@@ -224,6 +224,7 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
size_per_shard (int, optional): Size per shard in MB. Defaults to 1024.
size_per_shard (int, optional): Size per shard in MB. Defaults to 1024.
use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False.
use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False.
"""
"""
torch
.
cuda
.
empty_cache
()
if
os
.
path
.
isfile
(
checkpoint
):
if
os
.
path
.
isfile
(
checkpoint
):
logging
.
error
(
f
"Provided path (
{
checkpoint
}
) should be a directory, not a file"
)
logging
.
error
(
f
"Provided path (
{
checkpoint
}
) should be a directory, not a file"
)
return
return
...
@@ -265,6 +266,7 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
...
@@ -265,6 +266,7 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
f
"index located at
{
save_index_file
}
."
f
"index located at
{
save_index_file
}
."
)
)
dist
.
barrier
()
dist
.
barrier
()
torch
.
cuda
.
empty_cache
()
# ========================================================
# ========================================================
# Abstract methods for optimizer loading/saving implementation
# Abstract methods for optimizer loading/saving implementation
...
@@ -332,10 +334,12 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
...
@@ -332,10 +334,12 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
assert
isinstance
(
optimizer
,
OptimizerWrapper
),
"Please boost the optimizer before loading!"
assert
isinstance
(
optimizer
,
OptimizerWrapper
),
"Please boost the optimizer before loading!"
def
_get_param_id_from_optimizer_param
(
def
_get_param_id_from_optimizer_param
(
param
:
torch
.
Tensor
,
master_to_working_map
:
Optional
[
Dict
[
int
,
torch
.
Tensor
]]
=
None
param
:
torch
.
Tensor
,
master_to_working_map
:
Optional
[
Dict
[
int
,
torch
.
Tensor
]]
=
None
,
optimizer
=
None
):
):
if
master_to_working_map
is
not
None
and
id
(
param
)
in
master_to_working_map
:
if
master_to_working_map
is
not
None
and
id
(
param
)
in
master_to_working_map
:
working_param
=
master_to_working_map
[
id
(
param
)]
working_param
=
master_to_working_map
[
id
(
param
)]
elif
hasattr
(
optimizer
,
"moe_master_to_working_map"
)
and
id
(
param
)
in
optimizer
.
moe_master_to_working_map
:
working_param
=
optimizer
.
moe_master_to_working_map
[
id
(
param
)]
else
:
else
:
working_param
=
param
working_param
=
param
return
optimizer
.
param_info
[
"param2id"
][
id
(
working_param
)]
return
optimizer
.
param_info
[
"param2id"
][
id
(
working_param
)]
...
@@ -347,7 +351,7 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
...
@@ -347,7 +351,7 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
master_to_working_map
=
optimizer
.
get_master_to_working_map
()
master_to_working_map
=
optimizer
.
get_master_to_working_map
()
for
pg
in
optimizer
.
optim
.
param_groups
:
for
pg
in
optimizer
.
optim
.
param_groups
:
for
param
in
pg
[
"params"
]:
for
param
in
pg
[
"params"
]:
param_id
=
_get_param_id_from_optimizer_param
(
param
,
master_to_working_map
)
param_id
=
_get_param_id_from_optimizer_param
(
param
,
master_to_working_map
,
optimizer
)
id_map
[
param_id
]
=
param
id_map
[
param_id
]
=
param
# Read checkpoint index file.
# Read checkpoint index file.
...
@@ -371,14 +375,10 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
...
@@ -371,14 +375,10 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
new_pg
=
copy
.
deepcopy
(
saved_pg
)
new_pg
=
copy
.
deepcopy
(
saved_pg
)
new_pg
[
"params"
]
=
old_pg
[
"params"
]
# The parameters in the same group shouln't change.
new_pg
[
"params"
]
=
old_pg
[
"params"
]
# The parameters in the same group shouln't change.
updated_groups
.
append
(
new_pg
)
updated_groups
.
append
(
new_pg
)
# ep
ext
ra group
# ep
pa
ra
m
group
if
MOE_MANAGER
.
parallel
==
"EP"
:
if
len
(
optimizer
.
optim
.
param_groups
)
>
len
(
saved_groups
)
:
new_pg
=
copy
.
deepcopy
(
saved_pg
)
new_pg
=
copy
.
deepcopy
(
saved_pg
)
new_pg
[
"params"
]
=
optimizer
.
optim
.
param_groups
[
-
1
][
new_pg
[
"params"
]
=
optimizer
.
optim
.
param_groups
[
-
1
][
"params"
]
"params"
]
# Only keep the parameters kept by current pipeline stage.
for
param
in
new_pg
[
"params"
]:
param
.
data
=
param
.
data
.
to
(
torch
.
float32
)
updated_groups
.
append
(
new_pg
)
updated_groups
.
append
(
new_pg
)
optimizer
.
optim
.
__dict__
.
update
({
"param_groups"
:
updated_groups
})
optimizer
.
optim
.
__dict__
.
update
({
"param_groups"
:
updated_groups
})
...
@@ -389,7 +389,7 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
...
@@ -389,7 +389,7 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
for
param
in
pg
[
"params"
]:
for
param
in
pg
[
"params"
]:
if
param
is
None
:
if
param
is
None
:
continue
continue
param_id
=
_get_param_id_from_optimizer_param
(
param
,
master_to_working_map
)
param_id
=
_get_param_id_from_optimizer_param
(
param
,
master_to_working_map
,
optimizer
)
if
param_id
not
in
weight_map
:
if
param_id
not
in
weight_map
:
continue
continue
filename
=
weight_map
[
param_id
]
filename
=
weight_map
[
param_id
]
...
@@ -400,27 +400,34 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
...
@@ -400,27 +400,34 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
file_path
=
os
.
path
.
join
(
ckpt_root_path
,
filename
)
file_path
=
os
.
path
.
join
(
ckpt_root_path
,
filename
)
state_dict
=
load_shard_state_dict
(
Path
(
file_path
),
use_safetensors
=
False
)
state_dict
=
load_shard_state_dict
(
Path
(
file_path
),
use_safetensors
=
False
)
# Then shard the loaded optimizer states if using tp/zero.
for
pid
,
state
in
list
(
state_dict
.
items
()):
if
pid
in
id_map
:
param
=
id_map
[
pid
]
if
master_to_working_map
is
not
None
and
id
(
param
)
in
master_to_working_map
:
working_param
=
master_to_working_map
[
id
(
param
)]
elif
(
hasattr
(
optimizer
,
"moe_master_to_working_map"
)
and
id
(
param
)
in
optimizer
.
moe_master_to_working_map
):
working_param
=
optimizer
.
moe_master_to_working_map
[
id
(
param
)]
else
:
working_param
=
param
original_shape
=
optimizer
.
param_info
[
"param2shape"
][
id
(
working_param
)]
sharded_state
=
self
.
pre_load_optim
(
state
,
working_param
,
current_shape
=
working_param
.
shape
,
original_shape
=
original_shape
,
device
=
"cpu"
,
inplace
=
True
,
)
state_dict
[
pid
]
=
sharded_state
load_states_into_optimizer
(
optimizer
.
optim
,
state_dict
,
id_map
,
strict
=
True
)
load_states_into_optimizer
(
optimizer
.
optim
,
state_dict
,
id_map
,
strict
=
True
)
loaded_file
.
add
(
filename
)
loaded_file
.
add
(
filename
)
# Then shard the loaded optimizer states if using tp/zero.
for
param
,
state
in
optimizer
.
optim
.
state
.
items
():
device
=
param
.
device
if
master_to_working_map
is
not
None
and
id
(
param
)
in
master_to_working_map
:
working_param
=
master_to_working_map
[
id
(
param
)]
else
:
working_param
=
param
original_shape
=
optimizer
.
param_info
[
"param2shape"
][
id
(
working_param
)]
sharded_state
=
self
.
pre_load_optim
(
state
,
param
,
current_shape
=
working_param
.
shape
,
original_shape
=
original_shape
,
device
=
device
,
inplace
=
True
,
)
optimizer
.
optim
.
state
[
param
]
=
sharded_state
sharded_optimizer_loading_epilogue
(
optimizer
.
optim
)
sharded_optimizer_loading_epilogue
(
optimizer
.
optim
)
if
self
.
verbose
and
self
.
coordinator
.
is_master
():
if
self
.
verbose
and
self
.
coordinator
.
is_master
():
logging
.
info
(
f
"The optimizer has been successfully loaded from sharded checkpoint:
{
ckpt_root_path
}
."
)
logging
.
info
(
f
"The optimizer has been successfully loaded from sharded checkpoint:
{
ckpt_root_path
}
."
)
...
@@ -576,6 +583,8 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
...
@@ -576,6 +583,8 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
if
master_to_working_map
is
not
None
and
id
(
param
)
in
master_to_working_map
:
if
master_to_working_map
is
not
None
and
id
(
param
)
in
master_to_working_map
:
working_param
=
master_to_working_map
[
id
(
param
)]
working_param
=
master_to_working_map
[
id
(
param
)]
elif
hasattr
(
optimizer
,
"moe_master_to_working_map"
)
and
id
(
param
)
in
optimizer
.
moe_master_to_working_map
:
working_param
=
optimizer
.
moe_master_to_working_map
[
id
(
param
)]
else
:
else
:
working_param
=
param
working_param
=
param
...
@@ -618,6 +627,7 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
...
@@ -618,6 +627,7 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
prefix (str): Perfix of file to save
prefix (str): Perfix of file to save
size_per_shard (int): Max file size of each file shard that store state tensors
size_per_shard (int): Max file size of each file shard that store state tensors
"""
"""
torch
.
cuda
.
empty_cache
()
assert
isinstance
(
optimizer
,
OptimizerWrapper
),
"Please boost the optimizer before saving!"
assert
isinstance
(
optimizer
,
OptimizerWrapper
),
"Please boost the optimizer before saving!"
if
os
.
path
.
isfile
(
checkpoint
):
if
os
.
path
.
isfile
(
checkpoint
):
logging
.
error
(
f
"Provided path (
{
checkpoint
}
) should be a directory, not a file"
)
logging
.
error
(
f
"Provided path (
{
checkpoint
}
) should be a directory, not a file"
)
...
@@ -723,6 +733,7 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
...
@@ -723,6 +733,7 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
f
"You can find where each parameters has been saved in the "
f
"You can find where each parameters has been saved in the "
f
"index located at
{
final_index_file_path
}
."
f
"index located at
{
final_index_file_path
}
."
)
)
torch
.
cuda
.
empty_cache
()
def
save_unsharded_optimizer
(
self
,
optimizer
:
OptimizerWrapper
,
checkpoint
:
str
,
gather_dtensor
:
bool
):
def
save_unsharded_optimizer
(
self
,
optimizer
:
OptimizerWrapper
,
checkpoint
:
str
,
gather_dtensor
:
bool
):
"""
"""
...
...
colossalai/moe/experts.py
View file @
efef43b5
...
@@ -67,7 +67,11 @@ class MLPExperts(nn.Module):
...
@@ -67,7 +67,11 @@ class MLPExperts(nn.Module):
self
.
ep_size
=
1
self
.
ep_size
=
1
if
gated
:
if
gated
:
self
.
wi_gate
=
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
hidden_size
,
intermediate_size
*
2
))
self
.
wi_gate
=
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
hidden_size
,
intermediate_size
*
2
if
activation
==
"swiglu"
else
intermediate_size
)
)
self
.
wi_up
=
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
hidden_size
,
intermediate_size
))
self
.
wi_up
=
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
hidden_size
,
intermediate_size
))
else
:
else
:
self
.
wi
=
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
hidden_size
,
intermediate_size
))
self
.
wi
=
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
hidden_size
,
intermediate_size
))
...
...
colossalai/moe/layers.py
View file @
efef43b5
...
@@ -51,6 +51,8 @@ class SparseMLP(nn.Module):
...
@@ -51,6 +51,8 @@ class SparseMLP(nn.Module):
hidden_size
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
intermediate_size
:
int
,
router_top_k
:
int
=
1
,
router_top_k
:
int
=
1
,
router_loss
:
bool
=
True
,
router_norm
:
bool
=
False
,
router_capacity_factor_train
:
float
=
1.25
,
router_capacity_factor_train
:
float
=
1.25
,
router_capacity_factor_eval
:
float
=
2.0
,
router_capacity_factor_eval
:
float
=
2.0
,
router_min_capacity
:
int
=
4
,
router_min_capacity
:
int
=
4
,
...
@@ -65,15 +67,19 @@ class SparseMLP(nn.Module):
...
@@ -65,15 +67,19 @@ class SparseMLP(nn.Module):
enable_kernel
:
bool
=
False
,
enable_kernel
:
bool
=
False
,
enable_comm_overlap
:
bool
=
False
,
enable_comm_overlap
:
bool
=
False
,
enable_hierarchical_comm
:
bool
=
False
,
enable_hierarchical_comm
:
bool
=
False
,
return_gate_logits
:
bool
=
False
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
hidden_size
=
hidden_size
self
.
hidden_size
=
hidden_size
self
.
intermediate_size
=
intermediate_size
self
.
intermediate_size
=
intermediate_size
self
.
num_experts
=
num_experts
self
.
num_experts
=
num_experts
self
.
gated
=
mlp_gated
self
.
gated
=
mlp_gated
self
.
return_gate_logits
=
return_gate_logits
self
.
enable_kernel
=
enable_kernel
self
.
enable_kernel
=
enable_kernel
self
.
enable_comm_overlap
=
enable_comm_overlap
self
.
enable_comm_overlap
=
enable_comm_overlap
self
.
expert_parallel
=
MOE_MANAGER
.
get_parallel
()
self
.
expert_parallel
=
MOE_MANAGER
.
get_parallel
()
self
.
router_loss
=
router_loss
self
.
router_norm
=
router_norm
# moe router
# moe router
noisy_func
=
get_noise_generator
(
router_noisy_policy
,
num_experts
)
noisy_func
=
get_noise_generator
(
router_noisy_policy
,
num_experts
)
...
@@ -150,9 +156,8 @@ class SparseMLP(nn.Module):
...
@@ -150,9 +156,8 @@ class SparseMLP(nn.Module):
tokens
=
inputs
.
reshape
(
-
1
,
self
.
hidden_size
)
tokens
=
inputs
.
reshape
(
-
1
,
self
.
hidden_size
)
# the data type of the inputs in the gating should be fp32
# the data type of the inputs in the gating should be fp32
fp32_input
=
tokens
.
to
(
torch
.
float
)
gate_logits
=
F
.
linear
(
tokens
,
self
.
gate_weight
)
fp32_weight
=
self
.
gate_weight
.
to
(
torch
.
float
)
gate_output
=
gate_logits
.
to
(
torch
.
float
)
gate_output
=
F
.
linear
(
fp32_input
,
fp32_weight
)
# update expert load
# update expert load
if
self
.
enable_load_balance
==
True
:
if
self
.
enable_load_balance
==
True
:
...
@@ -165,7 +170,12 @@ class SparseMLP(nn.Module):
...
@@ -165,7 +170,12 @@ class SparseMLP(nn.Module):
# the result from the router
# the result from the router
used_capacity
,
*
route_result_list
=
self
.
router
(
used_capacity
,
*
route_result_list
=
self
.
router
(
inputs
=
gate_output
,
use_kernel
=
self
.
enable_kernel
,
ep_group
=
self
.
ep_group
)
inputs
=
gate_output
,
use_kernel
=
self
.
enable_kernel
,
ep_group
=
self
.
ep_group
,
use_loss
=
self
.
router_loss
,
use_norm
=
self
.
router_norm
,
)
# dispatch_data: (num_experts, capacity, hidden_size)
# dispatch_data: (num_experts, capacity, hidden_size)
if
self
.
enable_kernel
:
if
self
.
enable_kernel
:
...
@@ -177,22 +187,15 @@ class SparseMLP(nn.Module):
...
@@ -177,22 +187,15 @@ class SparseMLP(nn.Module):
# expert_output: (num_groups, num_experts, capacity, hidden_size)
# expert_output: (num_groups, num_experts, capacity, hidden_size)
if
self
.
expert_parallel
==
"EP"
:
if
self
.
expert_parallel
==
"EP"
:
expert_output
=
self
.
_ep_process
(
expert_output
=
self
.
_ep_process
(
dispatch_data
,
used_capacity
,
overlap
=
self
.
enable_comm_overlap
)
dispatch_data
,
used_capacity
,
overlap
=
self
.
enable_comm_overlap
)
elif
self
.
expert_parallel
==
"TP"
:
elif
self
.
expert_parallel
==
"TP"
:
expert_output
=
self
.
_tp_process
(
expert_output
=
self
.
_tp_process
(
dispatch_data
,
used_capacity
,
overlap
=
self
.
enable_comm_overlap
)
dispatch_data
,
used_capacity
,
overlap
=
self
.
enable_comm_overlap
)
elif
self
.
expert_parallel
is
None
:
elif
self
.
expert_parallel
is
None
:
expert_output
=
self
.
_local_process
(
dispatch_data
)
expert_output
=
self
.
_local_process
(
dispatch_data
)
else
:
else
:
raise
NotImplementedError
(
"This kind of communication has not been implemented yet.
\n
"
raise
NotImplementedError
(
"Please use Experts build function."
)
"This kind of communication has not been implemented yet.
\n
"
"Please use Experts build function."
)
if
self
.
enable_kernel
:
if
self
.
enable_kernel
:
expert_output
=
expert_output
.
reshape
(
-
1
,
self
.
hidden_size
)
expert_output
=
expert_output
.
reshape
(
-
1
,
self
.
hidden_size
)
...
@@ -204,7 +207,11 @@ class SparseMLP(nn.Module):
...
@@ -204,7 +207,11 @@ class SparseMLP(nn.Module):
ans
=
torch
.
matmul
(
combine_weights
,
expert_output
)
ans
=
torch
.
matmul
(
combine_weights
,
expert_output
)
ans
=
ans
.
reshape
(
inputs
.
shape
)
ans
=
ans
.
reshape
(
inputs
.
shape
)
return
ans
if
self
.
return_gate_logits
:
return
ans
,
gate_logits
else
:
return
ans
def
_local_process
(
self
,
expert_in
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
_local_process
(
self
,
expert_in
:
torch
.
Tensor
)
->
torch
.
Tensor
:
expert_in
=
expert_in
.
unsqueeze
(
0
)
expert_in
=
expert_in
.
unsqueeze
(
0
)
...
@@ -212,10 +219,7 @@ class SparseMLP(nn.Module):
...
@@ -212,10 +219,7 @@ class SparseMLP(nn.Module):
return
expert_out
return
expert_out
def
_ep_process
(
def
_ep_process
(
self
,
self
,
dispatch_data
:
torch
.
Tensor
,
used_capacity
:
torch
.
Tensor
,
overlap
:
bool
=
False
dispatch_data
:
torch
.
Tensor
,
used_capacity
:
torch
.
Tensor
,
overlap
:
bool
=
False
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""
"""
Expert Parallel
Expert Parallel
...
@@ -228,10 +232,14 @@ class SparseMLP(nn.Module):
...
@@ -228,10 +232,14 @@ class SparseMLP(nn.Module):
"""
"""
if
not
overlap
or
dist
.
get_world_size
(
self
.
ep_group
)
==
1
:
if
not
overlap
or
dist
.
get_world_size
(
self
.
ep_group
)
==
1
:
if
self
.
ep_hierarchical_group
is
not
None
:
if
self
.
ep_hierarchical_group
is
not
None
:
expert_input
=
HierarchicalAllToAll
.
apply
(
dispatch_data
,
self
.
ep_hierarchical_group
,
self
.
ep_intra_src_rank
)
expert_input
=
HierarchicalAllToAll
.
apply
(
dispatch_data
,
self
.
ep_hierarchical_group
,
self
.
ep_intra_src_rank
)
expert_input
=
expert_input
.
reshape
(
self
.
ep_size
,
self
.
num_local_experts
,
-
1
,
self
.
hidden_size
)
expert_input
=
expert_input
.
reshape
(
self
.
ep_size
,
self
.
num_local_experts
,
-
1
,
self
.
hidden_size
)
expert_output
=
self
.
experts
(
expert_input
)
expert_output
=
self
.
experts
(
expert_input
)
expert_output
=
HierarchicalAllToAll
.
apply
(
expert_output
,
self
.
ep_hierarchical_group
,
self
.
ep_intra_src_rank
)
expert_output
=
HierarchicalAllToAll
.
apply
(
expert_output
,
self
.
ep_hierarchical_group
,
self
.
ep_intra_src_rank
)
return
expert_output
return
expert_output
else
:
else
:
expert_input
=
AllToAll
.
apply
(
dispatch_data
,
self
.
ep_group
,
False
)[
0
]
expert_input
=
AllToAll
.
apply
(
dispatch_data
,
self
.
ep_group
,
False
)[
0
]
...
@@ -249,7 +257,7 @@ class SparseMLP(nn.Module):
...
@@ -249,7 +257,7 @@ class SparseMLP(nn.Module):
NUM_CHUNK
=
4
NUM_CHUNK
=
4
NUM_STAGES
=
4
NUM_STAGES
=
4
assert
(
dispatch_data
.
shape
[
1
]
%
NUM_CHUNK
==
0
)
,
"arbitrary chunk num is not supported yet"
assert
dispatch_data
.
shape
[
1
]
%
NUM_CHUNK
==
0
,
"arbitrary chunk num is not supported yet"
chunk_size
=
dispatch_data
.
shape
[
1
]
//
NUM_CHUNK
chunk_size
=
dispatch_data
.
shape
[
1
]
//
NUM_CHUNK
input_shape
=
(
self
.
ep_size
,
self
.
num_local_experts
,
-
1
,
self
.
hidden_size
)
input_shape
=
(
self
.
ep_size
,
self
.
num_local_experts
,
-
1
,
self
.
hidden_size
)
dispatch_data
=
dispatch_data
.
reshape
(
*
input_shape
)
dispatch_data
=
dispatch_data
.
reshape
(
*
input_shape
)
...
@@ -262,13 +270,15 @@ class SparseMLP(nn.Module):
...
@@ -262,13 +270,15 @@ class SparseMLP(nn.Module):
for
i
in
range
(
NUM_CHUNK
+
NUM_STAGES
-
1
):
for
i
in
range
(
NUM_CHUNK
+
NUM_STAGES
-
1
):
if
expert_out
is
not
None
:
if
expert_out
is
not
None
:
expert_out
.
handle
.
wait
()
expert_out
.
handle
.
wait
()
output
[:,
:,
offset
:
offset
+
chunk_size
,
:]
=
expert_out
.
data
output
[:,
:,
offset
:
offset
+
chunk_size
,
:]
=
expert_out
.
data
offset
+=
chunk_size
offset
+=
chunk_size
expert_out
=
None
expert_out
=
None
# all2all last output
# all2all last output
if
_expert_out
is
not
None
:
if
_expert_out
is
not
None
:
expert_out
=
Capsule
(
*
AllToAll
.
apply
(
_expert_out
.
data
,
self
.
ep_group
,
True
),)
expert_out
=
Capsule
(
*
AllToAll
.
apply
(
_expert_out
.
data
,
self
.
ep_group
,
True
),
)
_expert_out
=
None
_expert_out
=
None
# all2all next input
# all2all next input
...
@@ -288,10 +298,7 @@ class SparseMLP(nn.Module):
...
@@ -288,10 +298,7 @@ class SparseMLP(nn.Module):
return
output
return
output
def
_tp_process
(
def
_tp_process
(
self
,
self
,
dispatch_data
:
torch
.
Tensor
,
used_capacity
:
torch
.
Tensor
,
overlap
:
bool
=
False
dispatch_data
:
torch
.
Tensor
,
used_capacity
:
torch
.
Tensor
,
overlap
:
bool
=
False
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""
"""
without overlap:
without overlap:
...
@@ -326,8 +333,9 @@ class SparseMLP(nn.Module):
...
@@ -326,8 +333,9 @@ class SparseMLP(nn.Module):
NUM_CHUNK
=
4
NUM_CHUNK
=
4
NUM_STAGES
=
4
NUM_STAGES
=
4
assert
dispatch_data
.
shape
[
0
]
%
NUM_CHUNK
==
0
,
\
assert
(
"arbitrary chunk num is not supported yet, please use chunk num that can divide num_experts"
dispatch_data
.
shape
[
0
]
%
NUM_CHUNK
==
0
),
"arbitrary chunk num is not supported yet, please use chunk num that can divide num_experts"
chunk_size
=
dispatch_data
.
shape
[
0
]
//
NUM_CHUNK
chunk_size
=
dispatch_data
.
shape
[
0
]
//
NUM_CHUNK
chunk_data
=
torch
.
split
(
dispatch_data
,
chunk_size
,
dim
=
0
)
chunk_data
=
torch
.
split
(
dispatch_data
,
chunk_size
,
dim
=
0
)
output
=
torch
.
empty_like
(
dispatch_data
)
output
=
torch
.
empty_like
(
dispatch_data
)
...
...
colossalai/moe/routers.py
View file @
efef43b5
...
@@ -45,9 +45,13 @@ class MoeRouter(nn.Module, ABC):
...
@@ -45,9 +45,13 @@ class MoeRouter(nn.Module, ABC):
self
.
_z_loss
=
None
self
.
_z_loss
=
None
self
.
use_kernel
=
use_kernel
self
.
use_kernel
=
use_kernel
def
get_capacity
(
self
,
logits_shape
):
def
get_capacity
(
self
,
num_tokens
,
num_experts
,
ep_group
=
None
):
if
ep_group
is
not
None
:
num_tokens_tensor
=
torch
.
tensor
(
num_tokens
,
device
=
get_accelerator
().
get_current_device
())
dist
.
all_reduce
(
num_tokens_tensor
,
group
=
ep_group
)
num_tokens
=
num_tokens_tensor
.
item
()
//
dist
.
get_world_size
(
ep_group
)
capacity_factor
=
self
.
capacity_factor_train
if
self
.
training
else
self
.
capacity_factor_eval
capacity_factor
=
self
.
capacity_factor_train
if
self
.
training
else
self
.
capacity_factor_eval
capacity
=
math
.
floor
(
self
.
k_value
*
capacity_factor
*
logits_shape
[
-
2
]
/
logits_shape
[
-
1
]
)
capacity
=
math
.
floor
(
self
.
k_value
*
capacity_factor
*
num_tokens
/
num_experts
)
capacity
+=
capacity
%
2
capacity
+=
capacity
%
2
capacity
=
max
(
capacity
,
self
.
min_capacity
)
capacity
=
max
(
capacity
,
self
.
min_capacity
)
assert
capacity
>
0
assert
capacity
>
0
...
@@ -150,7 +154,14 @@ class Top1Router(MoeRouter):
...
@@ -150,7 +154,14 @@ class Top1Router(MoeRouter):
high
=
torch
.
tensor
(
1.0
,
device
=
get_accelerator
().
get_current_device
()),
high
=
torch
.
tensor
(
1.0
,
device
=
get_accelerator
().
get_current_device
()),
).
rsample
).
rsample
def
forward
(
self
,
inputs
:
torch
.
Tensor
,
use_kernel
:
bool
=
False
,
ep_group
:
Optional
[
ProcessGroup
]
=
None
)
->
Tuple
:
def
forward
(
self
,
inputs
:
torch
.
Tensor
,
use_kernel
:
bool
=
False
,
ep_group
:
Optional
[
ProcessGroup
]
=
None
,
use_loss
:
bool
=
False
,
use_norm
:
bool
=
False
,
)
->
Tuple
:
"""
"""
Args:
Args:
inputs (torch.Tensor): The input tensor of shape (batch_size * seq_len, num_experts).
inputs (torch.Tensor): The input tensor of shape (batch_size * seq_len, num_experts).
...
@@ -168,7 +179,8 @@ class Top1Router(MoeRouter):
...
@@ -168,7 +179,8 @@ class Top1Router(MoeRouter):
assert
inputs
.
dtype
==
torch
.
float
assert
inputs
.
dtype
==
torch
.
float
probs
=
F
.
softmax
(
inputs
,
dim
=-
1
)
probs
=
F
.
softmax
(
inputs
,
dim
=-
1
)
num_experts
=
probs
.
size
(
-
1
)
num_experts
=
probs
.
size
(
-
1
)
capacity
=
self
.
get_capacity
(
inputs
.
shape
)
num_tokens
=
inputs
.
size
(
0
)
capacity
=
self
.
get_capacity
(
num_tokens
,
num_experts
,
ep_group
)
top1_idx
=
torch
.
argmax
(
inputs
,
dim
=-
1
)
top1_idx
=
torch
.
argmax
(
inputs
,
dim
=-
1
)
mask
=
F
.
one_hot
(
top1_idx
,
num_classes
=
num_experts
).
to
(
torch
.
int32
)
mask
=
F
.
one_hot
(
top1_idx
,
num_classes
=
num_experts
).
to
(
torch
.
int32
)
...
@@ -207,7 +219,7 @@ class Top1Router(MoeRouter):
...
@@ -207,7 +219,7 @@ class Top1Router(MoeRouter):
weight
=
mask
*
probs
.
type_as
(
inputs
)
weight
=
mask
*
probs
.
type_as
(
inputs
)
combine_weights
=
weight
.
unsqueeze
(
2
)
*
ranks
.
unsqueeze
(
1
)
combine_weights
=
weight
.
unsqueeze
(
2
)
*
ranks
.
unsqueeze
(
1
)
sec_mask
=
combine_weights
.
bool
()
sec_mask
=
combine_weights
.
bool
()
return
used_capacity
,
combine_weights
,
sec_mask
return
used_capacity
,
combine_weights
,
sec_mask
,
probs
class
Top2Router
(
MoeRouter
):
class
Top2Router
(
MoeRouter
):
...
@@ -240,7 +252,14 @@ class Top2Router(MoeRouter):
...
@@ -240,7 +252,14 @@ class Top2Router(MoeRouter):
drop_tks
=
drop_tks
,
drop_tks
=
drop_tks
,
)
)
def
forward
(
self
,
inputs
:
torch
.
Tensor
,
use_kernel
:
bool
=
False
,
ep_group
:
Optional
[
ProcessGroup
]
=
None
)
->
Tuple
:
def
forward
(
self
,
inputs
:
torch
.
Tensor
,
use_kernel
:
bool
=
False
,
ep_group
:
Optional
[
ProcessGroup
]
=
None
,
use_norm
:
bool
=
False
,
use_loss
:
bool
=
True
,
)
->
Tuple
:
"""
"""
Args:
Args:
inputs (torch.Tensor): The input tensor of shape (batch_size * seq_len, num_experts).
inputs (torch.Tensor): The input tensor of shape (batch_size * seq_len, num_experts).
...
@@ -257,8 +276,13 @@ class Top2Router(MoeRouter):
...
@@ -257,8 +276,13 @@ class Top2Router(MoeRouter):
assert
inputs
.
dtype
==
torch
.
float
assert
inputs
.
dtype
==
torch
.
float
probs
=
F
.
softmax
(
inputs
,
dim
=-
1
)
probs
=
F
.
softmax
(
inputs
,
dim
=-
1
)
if
use_norm
:
routing_weights
,
_
=
torch
.
topk
(
probs
,
2
,
dim
=-
1
)
probs
=
probs
/
routing_weights
.
sum
(
dim
=-
1
,
keepdim
=
True
)
num_experts
=
probs
.
size
(
-
1
)
num_experts
=
probs
.
size
(
-
1
)
capacity
=
self
.
get_capacity
(
inputs
.
shape
)
num_tokens
=
inputs
.
size
(
0
)
capacity
=
self
.
get_capacity
(
num_tokens
,
num_experts
,
ep_group
)
top1_idx
=
torch
.
argmax
(
probs
,
dim
=-
1
)
top1_idx
=
torch
.
argmax
(
probs
,
dim
=-
1
)
mask1
=
F
.
one_hot
(
top1_idx
,
num_classes
=
num_experts
).
to
(
torch
.
int32
)
mask1
=
F
.
one_hot
(
top1_idx
,
num_classes
=
num_experts
).
to
(
torch
.
int32
)
...
@@ -270,10 +294,11 @@ class Top2Router(MoeRouter):
...
@@ -270,10 +294,11 @@ class Top2Router(MoeRouter):
cmask
=
cmask
.
float
()
/
2.0
# div 2 to normalize it to 1
cmask
=
cmask
.
float
()
/
2.0
# div 2 to normalize it to 1
# calculate loss
# calculate loss
expert_indices
=
torch
.
stack
([
top1_idx
,
top2_idx
],
dim
=-
1
)
if
use_loss
:
self
.
set_aux_loss
(
probs
,
expert_indices
,
num_experts
)
expert_indices
=
torch
.
stack
([
top1_idx
,
top2_idx
],
dim
=-
1
)
self
.
set_z_loss
(
inputs
)
self
.
set_aux_loss
(
probs
,
expert_indices
,
num_experts
)
self
.
pop_router_loss
()
self
.
set_z_loss
(
inputs
)
self
.
pop_router_loss
()
if
not
self
.
training
and
not
self
.
drop_tks
and
ep_group
is
not
None
:
if
not
self
.
training
and
not
self
.
drop_tks
and
ep_group
is
not
None
:
max_num
=
torch
.
max
(
torch
.
sum
(
cmask
,
dim
=
0
))
max_num
=
torch
.
max
(
torch
.
sum
(
cmask
,
dim
=
0
))
...
...
colossalai/moe/utils.py
View file @
efef43b5
...
@@ -83,6 +83,8 @@ def get_activation(act: str) -> Callable:
...
@@ -83,6 +83,8 @@ def get_activation(act: str) -> Callable:
return
torch
.
nn
.
GELU
()
return
torch
.
nn
.
GELU
()
elif
act
==
"swiglu"
:
elif
act
==
"swiglu"
:
return
SwiGLU
return
SwiGLU
elif
act
==
"silu"
:
return
torch
.
nn
.
SiLU
()
else
:
else
:
raise
NotImplementedError
(
"Unsupported activation function"
)
raise
NotImplementedError
(
"Unsupported activation function"
)
...
...
colossalai/tensor/moe_tensor/moe_info.py
View file @
efef43b5
...
@@ -26,3 +26,5 @@ class MoeParallelInfo:
...
@@ -26,3 +26,5 @@ class MoeParallelInfo:
self
.
ep_group_ranks
=
self
.
pg
.
get_ranks_in_group
(
self
.
ep_group
)
self
.
ep_group_ranks
=
self
.
pg
.
get_ranks_in_group
(
self
.
ep_group
)
self
.
dp_group
=
self
.
pg
.
get_group_along_axis
(
self
.
dp_axis
)
self
.
dp_group
=
self
.
pg
.
get_group_along_axis
(
self
.
dp_axis
)
self
.
dp_group_ranks
=
self
.
pg
.
get_ranks_in_group
(
self
.
dp_group
)
self
.
dp_group_ranks
=
self
.
pg
.
get_ranks_in_group
(
self
.
dp_group
)
self
.
ep_rank
=
self
.
pg
.
coordinate
(
self
.
ep_axis
)
self
.
dp_rank
=
self
.
pg
.
coordinate
(
self
.
dp_axis
)
colossalai/zero/low_level/low_level_optim.py
View file @
efef43b5
...
@@ -141,7 +141,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
...
@@ -141,7 +141,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
# because they have different parallel strategy
# because they have different parallel strategy
# so we need to store them separately in param_groups
# so we need to store them separately in param_groups
# instead of working_groups
# instead of working_groups
moe_params
=
list
()
self
.
working_
moe_params
=
list
()
# iterate over the param group in the optimizer
# iterate over the param group in the optimizer
# partition these param groups for data parallel training
# partition these param groups for data parallel training
...
@@ -153,7 +153,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
...
@@ -153,7 +153,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
if
self
.
moe_extra_dp_pg
is
None
:
if
self
.
moe_extra_dp_pg
is
None
:
# skip moe param
# skip moe param
if
is_moe_tensor
(
param
):
if
is_moe_tensor
(
param
):
moe_params
.
append
(
param
)
self
.
working_
moe_params
.
append
(
param
)
continue
continue
group_params
.
append
(
param
)
group_params
.
append
(
param
)
...
@@ -168,13 +168,23 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
...
@@ -168,13 +168,23 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
# managed by this data parallel rank
# managed by this data parallel rank
param_group
[
"params"
]
=
master_param_current_rank
param_group
[
"params"
]
=
master_param_current_rank
# if there are moe params, store in additional group in optim
# if there are moe params, store in addtional group in optim
if
len
(
moe_params
)
>
0
:
if
len
(
self
.
working_moe_params
)
>
0
:
self
.
_sync_master_param
=
False
param_group
=
dict
()
param_group
=
dict
()
# create fp32 master param
for
key
,
value
in
self
.
optim
.
param_groups
[
0
].
items
():
for
key
,
value
in
self
.
optim
.
param_groups
[
0
].
items
():
if
key
!=
"params"
:
if
key
!=
"params"
:
param_group
[
key
]
=
value
param_group
[
key
]
=
value
param_group
[
"params"
]
=
moe_params
self
.
master_moe_params
=
[]
for
param
in
self
.
working_moe_params
:
self
.
master_moe_params
.
append
(
param
.
clone
().
to
(
torch
.
float32
).
detach
())
# create mapping from master to working for optimizer io
self
.
moe_master_to_working_map
=
{}
for
master_moe_param
,
working_moe_param
in
zip
(
self
.
master_moe_params
,
self
.
working_moe_params
):
self
.
moe_master_to_working_map
[
id
(
master_moe_param
)]
=
working_moe_param
# add to optim
param_group
[
"params"
]
=
self
.
master_moe_params
self
.
optim
.
param_groups
.
append
(
param_group
)
self
.
optim
.
param_groups
.
append
(
param_group
)
# initialize communication stream for
# initialize communication stream for
...
@@ -593,24 +603,40 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
...
@@ -593,24 +603,40 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
# update the params in the optimizer
# update the params in the optimizer
self
.
optim
.
param_groups
[
group_id
][
"params"
]
=
real_master_params
[
group_id
]
self
.
optim
.
param_groups
[
group_id
][
"params"
]
=
real_master_params
[
group_id
]
# update param for moe ep
# move grad to master param and compute norm
if
len
(
self
.
working_moe_params
)
>
0
:
moe_grads
=
[]
for
master_moe_param
,
working_moe_param
in
zip
(
self
.
master_moe_params
,
self
.
working_moe_params
):
if
master_moe_param
.
grad
is
not
None
:
raise
RuntimeError
(
"Moe param should not have grad here"
)
grad
=
working_moe_param
.
grad
# no need to copy fp32 grad if master_weights is False
if
self
.
_master_weights
:
grad
=
grad
.
to
(
master_moe_param
.
dtype
).
to
(
master_moe_param
.
device
)
master_moe_param
.
grad
=
grad
working_moe_param
.
grad
=
None
moe_grads
.
append
(
grad
)
grad_partition_groups
.
append
(
grad
)
norm_group
=
self
.
_compute_grad_norm
(
gradients
=
moe_grads
)
norm_groups
.
append
(
norm_group
)
self
.
optim
.
param_groups
[
-
1
][
"params"
]
=
self
.
master_moe_params
del
moe_grads
# unscale and clip grads
# unscale and clip grads
global_norm
=
calculate_global_norm_from_list
(
norm_list
=
norm_groups
)
global_norm
=
calculate_global_norm_from_list
(
norm_list
=
norm_groups
)
self
.
_unscale_and_clip_grads
(
grad_partition_groups
,
global_norm
)
self
.
_unscale_and_clip_grads
(
grad_partition_groups
,
global_norm
)
# TODO: we should store master param for ep
if
len
(
self
.
param_groups
)
>
len
(
self
.
_working_param_groups
):
for
param
in
self
.
param_groups
[
-
1
][
"params"
]:
param
.
data
=
param
.
data
.
to
(
torch
.
float32
)
param
.
grad
=
param
.
grad
.
to
(
torch
.
float32
)
# update the parameters
# update the parameters
self
.
optim
.
step
()
self
.
optim
.
step
()
# release the moe gradm
# release moe grad
if
len
(
self
.
param_groups
)
>
len
(
self
.
_working_param_groups
):
if
len
(
self
.
working_moe_params
)
>
0
:
for
param
in
self
.
param_groups
[
-
1
][
"params"
]:
for
master_moe_param
,
working_moe_param
in
zip
(
self
.
master_moe_params
,
self
.
working_moe_params
):
param
.
grad
=
None
master_moe_param
.
grad
=
None
param
.
data
=
param
.
data
.
to
(
self
.
_dtype
)
working_moe_param
.
data
=
(
master_moe_param
.
data
.
to
(
working_moe_param
.
device
).
to
(
working_moe_param
.
dtype
).
detach
()
)
# release the grad
# release the grad
grad_partition_groups
=
[]
grad_partition_groups
=
[]
...
@@ -885,9 +911,14 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
...
@@ -885,9 +911,14 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
master_param
.
copy_
(
working_param
.
chunk
(
self
.
extra_dp_pg_size
)[
self
.
extra_dp_pg_rank
])
master_param
.
copy_
(
working_param
.
chunk
(
self
.
extra_dp_pg_size
)[
self
.
extra_dp_pg_rank
])
else
:
else
:
master_param
.
copy_
(
working_param
.
chunk
(
self
.
_world_size
)[
self
.
_local_rank
])
master_param
.
copy_
(
working_param
.
chunk
(
self
.
_world_size
)[
self
.
_local_rank
])
if
hasattr
(
self
,
"master_moe_params"
):
for
master_moe_param
,
working_moe_param
in
zip
(
self
.
master_moe_params
,
self
.
working_moe_params
):
master_moe_param
.
copy_
(
working_moe_param
)
def
get_working_to_master_map
(
self
)
->
Dict
[
int
,
torch
.
Tensor
]:
def
get_working_to_master_map
(
self
)
->
Dict
[
int
,
torch
.
Tensor
]:
return
self
.
_param_store
.
working_to_master_param
return
self
.
_param_store
.
working_to_master_param
def
get_master_to_working_map
(
self
)
->
Dict
[
int
,
torch
.
Tensor
]:
def
get_master_to_working_map
(
self
)
->
Dict
[
int
,
torch
.
Tensor
]:
if
hasattr
(
self
,
"moe_master_to_working_map"
):
return
{
**
self
.
_param_store
.
master_to_working_param
,
**
self
.
moe_master_to_working_map
}
return
self
.
_param_store
.
master_to_working_param
return
self
.
_param_store
.
master_to_working_param
tests/test_moe/moe_utils.py
View file @
efef43b5
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
import
torch.nn
as
nn
import
torch.nn
as
nn
from
torch.testing
import
assert_close
from
colossalai.booster.plugin.low_level_zero_plugin
import
LowLevelZeroModel
from
colossalai.legacy.engine.gradient_handler._base_gradient_handler
import
BaseGradientHandler
from
colossalai.legacy.engine.gradient_handler._base_gradient_handler
import
BaseGradientHandler
from
colossalai.legacy.engine.gradient_handler.utils
import
bucket_allreduce
from
colossalai.legacy.engine.gradient_handler.utils
import
bucket_allreduce
from
colossalai.legacy.registry
import
GRADIENT_HANDLER
from
colossalai.legacy.registry
import
GRADIENT_HANDLER
from
colossalai.moe
import
SparseMLP
from
colossalai.moe
import
SparseMLP
from
colossalai.moe.manager
import
MOE_MANAGER
from
colossalai.moe.manager
import
MOE_MANAGER
from
colossalai.moe.utils
import
get_moe_epsize_param_dict
from
colossalai.moe.utils
import
get_moe_epsize_param_dict
from
colossalai.tensor.moe_tensor.api
import
get_ep_group
,
get_ep_size
def
delete_moe_info
(
model
):
for
_
,
param
in
model
.
named_parameters
():
if
hasattr
(
param
,
"moe_info"
):
delattr
(
param
,
"moe_info"
)
class
MoeModel
(
nn
.
Module
):
class
MoeModel
(
nn
.
Module
):
...
@@ -85,6 +94,74 @@ def assert_not_equal_in_group(tensor, process_group=None):
...
@@ -85,6 +94,74 @@ def assert_not_equal_in_group(tensor, process_group=None):
for
i
in
range
(
world_size
-
1
):
for
i
in
range
(
world_size
-
1
):
a
=
tensor_list
[
i
]
a
=
tensor_list
[
i
]
b
=
tensor_list
[
i
+
1
]
b
=
tensor_list
[
i
+
1
]
assert
not
torch
.
allclose
(
a
,
b
),
\
assert
not
torch
.
allclose
(
a
,
b
),
(
(
f
"expected tensors on rank
{
i
}
and
{
i
+
1
}
not to be equal "
f
"expected tensors on rank
{
i
}
and
{
i
+
1
}
not to be equal "
f
"but they are,
{
a
}
vs
{
b
}
"
f
"but they are,
{
a
}
vs
{
b
}
"
)
)
def
run_fwd_bwd
(
model
,
data
,
label
,
criterion
,
optimizer
,
enable_autocast
=
False
):
model
.
train
()
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
enable_autocast
):
if
criterion
:
y
=
model
(
data
)
loss
=
criterion
(
y
,
label
)
else
:
loss
=
model
(
data
,
label
)
loss
=
loss
.
float
()
if
isinstance
(
model
,
LowLevelZeroModel
):
optimizer
.
backward
(
loss
)
else
:
loss
.
backward
()
return
y
def
sync_local_from_ep
(
local_model
:
SparseMLP
,
ep_model
:
SparseMLP
,
assert_grad_flag
:
bool
=
False
)
->
None
:
"""Sync the parameters of tp model from ep model
Args:
local_model (MoeModule)
ep_model (MoeModule)
"""
for
(
local_name
,
local_param
),
(
ep_name
,
ep_param
)
in
zip
(
local_model
.
named_parameters
(),
ep_model
.
named_parameters
()
):
assert
local_name
in
ep_name
,
print
(
f
"
{
local_name
}
!=
{
ep_name
}
"
)
if
"experts"
not
in
local_name
:
if
assert_grad_flag
:
assert
torch
.
allclose
(
local_param
,
ep_param
),
f
"local_param:
{
local_param
}
, ep_param:
{
ep_param
}
"
assert
torch
.
allclose
(
local_param
.
grad
,
ep_param
.
grad
)
else
:
local_param
.
data
.
copy_
(
ep_param
.
data
)
continue
# gather param from ep model
param_list
=
[
torch
.
zeros_like
(
ep_param
)
for
_
in
range
(
get_ep_size
(
ep_param
))]
dist
.
all_gather
(
param_list
,
ep_param
,
group
=
get_ep_group
(
ep_param
))
all_param
=
torch
.
cat
(
param_list
,
dim
=
0
)
if
assert_grad_flag
:
grad_list
=
[
torch
.
zeros_like
(
ep_param
)
for
_
in
range
(
get_ep_size
(
ep_param
))]
dist
.
all_gather
(
grad_list
,
ep_param
.
grad
,
group
=
get_ep_group
(
ep_param
))
all_grad
=
torch
.
cat
(
grad_list
,
dim
=
0
)
if
assert_grad_flag
:
assert
torch
.
allclose
(
local_param
,
all_param
)
assert
torch
.
allclose
(
local_param
.
grad
,
all_grad
)
else
:
local_param
.
data
.
copy_
(
all_param
.
data
)
def
loose_close
(
a
,
b
,
dtype
:
torch
.
dtype
=
torch
.
float32
):
rtol
=
None
atol
=
None
if
dtype
is
torch
.
float16
:
rtol
=
5e-2
atol
=
5e-4
elif
dtype
is
torch
.
bfloat16
:
rtol
=
4e-3
atol
=
4e-3
a
=
a
.
detach
().
to
(
dtype
)
b
=
b
.
detach
().
to
(
dtype
).
to
(
a
.
device
)
assert_close
(
a
,
b
,
rtol
=
rtol
,
atol
=
atol
)
tests/test_moe/test_moe_checkpoint.py
View file @
efef43b5
...
@@ -12,7 +12,6 @@ import colossalai
...
@@ -12,7 +12,6 @@ import colossalai
from
colossalai.accelerator
import
get_accelerator
from
colossalai.accelerator
import
get_accelerator
from
colossalai.booster
import
Booster
from
colossalai.booster
import
Booster
from
colossalai.booster.plugin.moe_hybrid_parallel_plugin
import
MoeHybridParallelPlugin
from
colossalai.booster.plugin.moe_hybrid_parallel_plugin
import
MoeHybridParallelPlugin
from
colossalai.moe.manager
import
MOE_MANAGER
from
colossalai.testing
import
DummyDataloader
,
check_state_dict_equal
,
rerun_if_address_is_in_use
,
spawn
from
colossalai.testing
import
DummyDataloader
,
check_state_dict_equal
,
rerun_if_address_is_in_use
,
spawn
sys
.
path
.
append
(
sys
.
path
.
append
(
...
@@ -95,6 +94,7 @@ def get_model(parallel):
...
@@ -95,6 +94,7 @@ def get_model(parallel):
precision
=
"bf16"
,
precision
=
"bf16"
,
tp_size
=
1
,
tp_size
=
1
,
pp_size
=
1
,
pp_size
=
1
,
ep_size
=
1
,
zero_stage
=
2
,
zero_stage
=
2
,
custom_policy
=
OpenMoeForCausalLMPolicy
(),
custom_policy
=
OpenMoeForCausalLMPolicy
(),
)
)
...
@@ -103,6 +103,7 @@ def get_model(parallel):
...
@@ -103,6 +103,7 @@ def get_model(parallel):
precision
=
"bf16"
,
precision
=
"bf16"
,
tp_size
=
1
,
tp_size
=
1
,
pp_size
=
1
,
pp_size
=
1
,
ep_size
=
dist
.
get_world_size
(),
zero_stage
=
2
,
zero_stage
=
2
,
custom_policy
=
OpenMoeForCausalLMPolicy
(),
custom_policy
=
OpenMoeForCausalLMPolicy
(),
)
)
...
@@ -111,6 +112,7 @@ def get_model(parallel):
...
@@ -111,6 +112,7 @@ def get_model(parallel):
precision
=
"bf16"
,
precision
=
"bf16"
,
tp_size
=
1
,
tp_size
=
1
,
pp_size
=
1
,
pp_size
=
1
,
ep_size
=
2
,
zero_stage
=
2
,
zero_stage
=
2
,
extra_dp_size
=
2
,
extra_dp_size
=
2
,
custom_policy
=
OpenMoeForCausalLMPolicy
(),
custom_policy
=
OpenMoeForCausalLMPolicy
(),
...
@@ -120,6 +122,7 @@ def get_model(parallel):
...
@@ -120,6 +122,7 @@ def get_model(parallel):
precision
=
"bf16"
,
precision
=
"bf16"
,
tp_size
=
1
,
tp_size
=
1
,
pp_size
=
2
,
pp_size
=
2
,
ep_size
=
2
,
zero_stage
=
1
,
zero_stage
=
1
,
microbatch_size
=
1
,
microbatch_size
=
1
,
custom_policy
=
OpenMoeForCausalLMPolicy
(),
custom_policy
=
OpenMoeForCausalLMPolicy
(),
...
@@ -130,27 +133,6 @@ def get_model(parallel):
...
@@ -130,27 +133,6 @@ def get_model(parallel):
def
_test_moe_checkpoint
(
rank
,
parallel
):
def
_test_moe_checkpoint
(
rank
,
parallel
):
if
parallel
==
None
:
MOE_MANAGER
.
setup
(
parallel
=
None
,
)
elif
parallel
==
"ep"
:
MOE_MANAGER
.
setup
(
parallel
=
"EP"
,
)
elif
parallel
==
"ep_zero"
:
MOE_MANAGER
.
setup
(
parallel
=
"EP"
,
max_ep_size
=
2
,
)
elif
parallel
==
"hybrid"
:
MOE_MANAGER
.
setup
(
parallel
=
"EP"
,
mode
=
"fixed"
,
fixed_dp_size
=
1
,
fixed_ep_size
=
2
,
fixed_pp_size
=
2
,
)
model1
,
booster1
,
optim1
=
get_model
(
parallel
)
model1
,
booster1
,
optim1
=
get_model
(
parallel
)
model2
,
booster2
,
optim2
=
get_model
(
parallel
)
model2
,
booster2
,
optim2
=
get_model
(
parallel
)
model3
,
booster3
,
optim3
=
get_model
(
parallel
)
model3
,
booster3
,
optim3
=
get_model
(
parallel
)
...
@@ -207,6 +189,7 @@ def _run_dist(rank, world_size, port, parallel):
...
@@ -207,6 +189,7 @@ def _run_dist(rank, world_size, port, parallel):
_test_moe_checkpoint
(
rank
,
parallel
)
_test_moe_checkpoint
(
rank
,
parallel
)
@
pytest
.
mark
.
skip
(
reason
=
"This is tested in ColossalMOE"
)
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
4
])
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
4
])
@
pytest
.
mark
.
parametrize
(
"parallel"
,
[
None
,
"ep"
,
"ep_zero"
,
"hybrid"
])
@
pytest
.
mark
.
parametrize
(
"parallel"
,
[
None
,
"ep"
,
"ep_zero"
,
"hybrid"
])
...
...
tests/test_moe/test_moe_router.py
View file @
efef43b5
...
@@ -4,15 +4,21 @@ import torch
...
@@ -4,15 +4,21 @@ import torch
from
colossalai.moe.routers
import
MoeRouter
,
Top1Router
,
Top2Router
,
TopKRouter
from
colossalai.moe.routers
import
MoeRouter
,
Top1Router
,
Top2Router
,
TopKRouter
@
pytest
.
mark
.
parametrize
([
"router"
,
"num_groups"
],
[
@
pytest
.
mark
.
parametrize
(
(
Top1Router
(),
1
),
[
"router"
,
"num_groups"
],
(
Top2Router
(),
1
),
[
# (TopKRouter(num_selected_experts=3), 4),
(
Top1Router
(),
1
),
])
(
Top2Router
(),
1
),
@
pytest
.
mark
.
parametrize
([
"batch_size"
,
"seq_len"
,
"num_experts"
],
[
# (TopKRouter(num_selected_experts=3), 4),
(
4
,
5
,
8
),
],
(
3
,
4
,
4
),
)
])
@
pytest
.
mark
.
parametrize
(
[
"batch_size"
,
"seq_len"
,
"num_experts"
],
[
(
4
,
5
,
8
),
(
3
,
4
,
4
),
],
)
def
test_router_forward
(
router
:
MoeRouter
,
batch_size
:
int
,
seq_len
:
int
,
num_experts
:
int
,
num_groups
:
int
):
def
test_router_forward
(
router
:
MoeRouter
,
batch_size
:
int
,
seq_len
:
int
,
num_experts
:
int
,
num_groups
:
int
):
x
=
torch
.
randn
((
batch_size
*
seq_len
,
num_experts
)).
cuda
()
x
=
torch
.
randn
((
batch_size
*
seq_len
,
num_experts
)).
cuda
()
if
num_groups
>
1
:
if
num_groups
>
1
:
...
@@ -20,18 +26,18 @@ def test_router_forward(router: MoeRouter, batch_size: int, seq_len: int, num_ex
...
@@ -20,18 +26,18 @@ def test_router_forward(router: MoeRouter, batch_size: int, seq_len: int, num_ex
router
.
train
()
router
.
train
()
if
isinstance
(
router
,
TopKRouter
):
if
isinstance
(
router
,
TopKRouter
):
_
,
combine_array
,
dispatch_mask
=
router
(
x
,
expert_capacity
=
2
)
combine_array
,
dispatch_mask
=
router
(
x
,
expert_capacity
=
2
)
else
:
else
:
_
,
combine_array
,
dispatch_mask
=
router
(
x
)
combine_array
,
dispatch_mask
=
router
(
x
)
[
1
:
3
]
assert
combine_array
.
shape
[:
-
1
]
==
x
.
shape
assert
combine_array
.
shape
[:
-
1
]
==
x
.
shape
assert
dispatch_mask
.
shape
[:
-
1
]
==
x
.
shape
assert
dispatch_mask
.
shape
[:
-
1
]
==
x
.
shape
assert
torch
.
all
(
dispatch_mask
.
sum
(
-
1
).
sum
(
-
1
)
<=
router
.
k_value
)
assert
torch
.
all
(
dispatch_mask
.
sum
(
-
1
).
sum
(
-
1
)
<=
router
.
k_value
)
router
.
eval
()
router
.
eval
()
if
isinstance
(
router
,
TopKRouter
):
if
isinstance
(
router
,
TopKRouter
):
_
,
combine_array
,
dispatch_mask
=
router
(
x
,
expert_capacity
=
2
)
combine_array
,
dispatch_mask
=
router
(
x
,
expert_capacity
=
2
)
else
:
else
:
_
,
combine_array
,
dispatch_mask
=
router
(
x
)
combine_array
,
dispatch_mask
=
router
(
x
)
[
1
:
3
]
assert
combine_array
.
shape
[:
-
1
]
==
x
.
shape
assert
combine_array
.
shape
[:
-
1
]
==
x
.
shape
assert
dispatch_mask
.
shape
[:
-
1
]
==
x
.
shape
assert
dispatch_mask
.
shape
[:
-
1
]
==
x
.
shape
assert
torch
.
all
(
dispatch_mask
.
sum
(
-
1
).
sum
(
-
1
)
<=
router
.
k_value
)
assert
torch
.
all
(
dispatch_mask
.
sum
(
-
1
).
sum
(
-
1
)
<=
router
.
k_value
)
...
...
tests/test_moe/test_moe_zero_fwd_bwd.py
View file @
efef43b5
...
@@ -4,102 +4,75 @@ import torch
...
@@ -4,102 +4,75 @@ import torch
import
colossalai
import
colossalai
from
colossalai.booster
import
Booster
from
colossalai.booster
import
Booster
from
colossalai.booster.plugin
import
LowLevelZeroPlugin
from
colossalai.booster.plugin
import
LowLevelZeroPlugin
from
colossalai.booster.plugin.low_level_zero_plugin
import
LowLevelZeroModel
from
colossalai.moe.manager
import
MOE_MANAGER
from
colossalai.moe.manager
import
MOE_MANAGER
from
colossalai.testing
import
rerun_if_address_is_in_use
,
spawn
from
colossalai.testing
import
rerun_if_address_is_in_use
,
spawn
from
colossalai.testing.random
import
seed_all
from
colossalai.testing.random
import
seed_all
from
tests.test_moe.moe_utils
import
Moe
GradientHandler
,
MoeModel
from
tests.test_moe.moe_utils
import
Moe
Model
,
delete_moe_info
,
run_fwd_bwd
,
sync_local_from_ep
def
split_ddp_grad
(
grad
,
world_size
):
def
run_zero_test
(
local_rank
,
stage
=
1
):
with
torch
.
no_grad
():
grad
=
grad
.
clone
().
detach
().
flatten
()
padding_size
=
(
world_size
-
grad
.
numel
()
%
world_size
)
%
world_size
if
padding_size
>
0
:
grad
=
torch
.
nn
.
functional
.
pad
(
grad
,
[
0
,
padding_size
])
splited_grad
=
grad
.
split
(
grad
.
numel
()
//
world_size
)
return
splited_grad
def
run_fwd_bwd
(
model
,
data
,
label
,
criterion
,
optimizer
,
enable_autocast
=
False
):
model
.
train
()
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
enable_autocast
):
if
criterion
:
y
=
model
(
data
)
loss
=
criterion
(
y
,
label
)
else
:
loss
=
model
(
data
,
label
)
loss
=
loss
.
float
()
if
isinstance
(
model
,
LowLevelZeroModel
):
optimizer
.
backward
(
loss
)
else
:
loss
.
backward
()
return
y
def
run_zero_test
(
local_rank
,
world_size
,
stage
=
1
):
criterion
=
torch
.
nn
.
CrossEntropyLoss
()
criterion
=
torch
.
nn
.
CrossEntropyLoss
()
zero_model
=
MoeModel
()
MOE_MANAGER
.
__init__
()
optimizer
=
torch
.
optim
.
Adam
(
zero_model
.
parameters
())
MOE_MANAGER
.
setup
(
parallel
=
"EP"
)
plugin
=
LowLevelZeroPlugin
(
stage
=
stage
,
precision
=
"fp32"
)
moe_model
=
MoeModel
().
bfloat16
()
booster
=
Booster
(
plugin
=
plugin
)
moe_optimizer
=
torch
.
optim
.
Adam
(
moe_model
.
parameters
())
zero_model
,
optimizer
,
_
,
_
,
_
=
booster
.
boost
(
zero_model
,
optimizer
)
moe_plugin
=
LowLevelZeroPlugin
(
stage
=
stage
,
precision
=
"bf16"
)
moe_booster
=
Booster
(
plugin
=
moe_plugin
)
torch_model
=
MoeModel
()
moe_model
,
moe_optimizer
,
_
,
_
,
_
=
moe_booster
.
boost
(
moe_model
,
moe_optimizer
)
for
zero_param
,
torch_param
in
zip
(
zero_model
.
parameters
(),
torch_model
.
parameters
()):
torch_param
.
data
.
copy_
(
zero_param
.
data
)
MOE_MANAGER
.
__init__
()
torch_model
=
torch_model
.
cuda
()
MOE_MANAGER
.
setup
(
parallel
=
None
)
grad_handler
=
MoeGradientHandler
(
torch_model
)
zero_model
=
MoeModel
().
bfloat16
()
delete_moe_info
(
zero_model
)
# assert zero model
zero_optimizer
=
torch
.
optim
.
Adam
(
zero_model
.
parameters
())
for
(
torch_name
,
torch_param
),
(
zero_name
,
zero_param
)
in
zip
(
zero_plugin
=
LowLevelZeroPlugin
(
stage
=
stage
,
precision
=
"bf16"
)
torch_model
.
named_parameters
(),
zero_model
.
module
.
named_parameters
()
zero_booster
=
Booster
(
plugin
=
zero_plugin
)
):
zero_model
,
zero_optimizer
,
_
,
_
,
_
=
zero_booster
.
boost
(
zero_model
,
zero_optimizer
)
assert
zero_name
==
torch_name
sync_local_from_ep
(
zero_model
,
moe_model
)
assert
torch
.
allclose
(
zero_param
.
data
,
torch_param
.
data
)
data
=
torch
.
randn
(
16
,
4
).
bfloat16
().
cuda
()
data
=
torch
.
randn
(
16
,
4
).
cuda
()
label
=
torch
.
randint
(
0
,
4
,
(
16
,)).
cuda
()
label
=
torch
.
randint
(
0
,
4
,
(
16
,)).
cuda
()
torch_out
=
run_fwd_bwd
(
torch_model
,
data
,
label
,
criterion
,
None
)
zero_out
=
run_fwd_bwd
(
zero_model
,
data
,
label
,
criterion
,
zero_optimizer
)
zero_out
=
run_fwd_bwd
(
zero_model
,
data
,
label
,
criterion
,
optimizer
)
moe_out
=
run_fwd_bwd
(
moe_model
,
data
,
label
,
criterion
,
moe_optimizer
)
assert
torch
.
allclose
(
torch_out
,
zero_out
)
assert
torch
.
allclose
(
zero_out
,
moe_out
)
grad_handler
.
handle_gradient
()
for
(
zero
_name
,
zero
_param
),
(
torch
_name
,
torch
_param
)
in
zip
(
for
(
moe
_name
,
moe
_param
),
(
zero
_name
,
zero
_param
)
in
zip
(
zero
_model
.
module
.
named_parameters
(),
torch
_model
.
named_parameters
()
moe
_model
.
module
.
named_parameters
(),
zero
_model
.
module
.
named_parameters
()
):
):
assert
zero_name
==
torch_name
assert
moe_name
==
zero_name
zero_grad_list
=
optimizer
.
_grad_store
.
get_partitioned_gradients_by_param_id
(
0
,
id
(
zero_param
))
moe_grad_list
=
moe_optimizer
.
_grad_store
.
get_partitioned_gradients_by_param_id
(
0
,
id
(
moe_param
))
if
hasattr
(
zero_param
,
"moe_info"
):
zero_grad_list
=
zero_optimizer
.
_grad_store
.
get_partitioned_gradients_by_param_id
(
0
,
id
(
zero_param
))
assert
len
(
zero_grad_list
)
==
0
if
hasattr
(
moe_param
,
"moe_info"
):
assert
torch
.
allclose
(
zero_param
.
grad
,
torch_param
.
grad
)
assert
len
(
moe_grad_list
)
==
0
if
stage
==
1
:
zero_grad
=
zero_grad_list
[
local_rank
].
view
(
moe_param
.
grad
.
shape
)
else
:
zero_grad
=
zero_grad_list
[
0
].
view
(
moe_param
.
grad
.
shape
)
assert
torch
.
allclose
(
moe_param
.
grad
,
zero_grad
,
atol
=
1e-5
),
f
"zero grad:
\n
{
moe_param
.
grad
}
\n
torch grad:
\n
{
zero_grad
}
\n
max diff:
{
(
moe_param
.
grad
-
zero_grad
).
abs
().
max
()
}
, mean diff:
{
(
moe_param
.
grad
-
zero_grad
).
abs
().
mean
()
}
"
else
:
else
:
assert
len
(
zero_grad_list
)
>
0
assert
len
(
moe_grad_list
)
>
0
torch_grad_list
=
split_ddp_grad
(
torch_param
.
grad
,
world_size
)
assert
len
(
moe_grad_list
)
==
len
(
zero_grad_list
)
if
stage
==
2
:
for
moe_grad
,
zero_grad
in
zip
(
moe_grad_list
,
zero_grad_list
):
torch_grad_list
=
torch_grad_list
[
local_rank
:
local_rank
+
1
]
assert
torch
.
allclose
(
moe_grad
,
zero_grad
)
assert
len
(
zero_grad_list
)
==
len
(
torch_grad_list
)
for
zero_grad
,
torch_grad
in
zip
(
zero_grad_list
,
torch_grad_list
):
assert
torch
.
allclose
(
zero_grad
,
torch_grad
)
def
run_dist
(
rank
,
world_size
,
port
):
def
run_dist
(
rank
,
world_size
,
port
,
stage
):
colossalai
.
launch
(
config
=
dict
(),
rank
=
rank
,
world_size
=
world_size
,
host
=
"localhost"
,
port
=
port
,
backend
=
"nccl"
)
colossalai
.
launch
(
config
=
dict
(),
rank
=
rank
,
world_size
=
world_size
,
host
=
"localhost"
,
port
=
port
,
backend
=
"nccl"
)
MOE_MANAGER
.
setup
(
parallel
=
"EP"
)
seed_all
(
42
+
rank
)
seed_all
(
42
+
rank
)
run_zero_test
(
rank
,
world_size
,
stage
=
1
)
run_zero_test
(
rank
,
stage
=
stage
)
run_zero_test
(
rank
,
world_size
,
stage
=
2
)
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
2
])
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
2
])
@
pytest
.
mark
.
parametrize
(
"stage"
,
[
1
,
2
])
@
rerun_if_address_is_in_use
()
@
rerun_if_address_is_in_use
()
def
test_moe_zero_model
(
world_size
):
def
test_moe_zero_model
(
world_size
,
stage
):
spawn
(
run_dist
,
world_size
)
spawn
(
run_dist
,
world_size
,
stage
=
stage
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_moe_zero_model
(
world_size
=
2
)
test_moe_zero_model
(
world_size
=
2
,
stage
=
1
)
tests/test_moe/test_moe_zero_optim.py
View file @
efef43b5
...
@@ -4,89 +4,80 @@ import torch
...
@@ -4,89 +4,80 @@ import torch
import
colossalai
import
colossalai
from
colossalai.booster
import
Booster
from
colossalai.booster
import
Booster
from
colossalai.booster.plugin
import
LowLevelZeroPlugin
from
colossalai.booster.plugin
import
LowLevelZeroPlugin
from
colossalai.booster.plugin.low_level_zero_plugin
import
LowLevelZeroModel
from
colossalai.moe.manager
import
MOE_MANAGER
from
colossalai.moe.manager
import
MOE_MANAGER
from
colossalai.tensor.moe_tensor.api
import
is_moe_tensor
from
colossalai.testing
import
rerun_if_address_is_in_use
,
spawn
from
colossalai.testing
import
rerun_if_address_is_in_use
,
spawn
from
tests.test_moe.moe_utils
import
MoeGradientHandler
,
MoeModel
from
colossalai.testing.random
import
seed_all
from
tests.test_moe.moe_utils
import
MoeModel
,
delete_moe_info
,
loose_close
,
run_fwd_bwd
,
sync_local_from_ep
def
split_ddp_grad
(
grad
,
world_size
):
def
run_zero_test
(
local_rank
,
stage
=
1
):
with
torch
.
no_grad
():
grad
=
grad
.
clone
().
detach
().
flatten
()
padding_size
=
(
world_size
-
grad
.
numel
()
%
world_size
)
%
world_size
if
padding_size
>
0
:
grad
=
torch
.
nn
.
functional
.
pad
(
grad
,
[
0
,
padding_size
])
splited_grad
=
grad
.
split
(
grad
.
numel
()
//
world_size
)
return
splited_grad
def
run_fwd_bwd
(
model
,
data
,
label
,
criterion
,
optimizer
,
enable_autocast
=
False
):
model
.
train
()
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
enable_autocast
):
if
criterion
:
y
=
model
(
data
)
loss
=
criterion
(
y
,
label
)
else
:
loss
=
model
(
data
,
label
)
loss
=
loss
.
float
()
if
isinstance
(
model
,
LowLevelZeroModel
):
optimizer
.
backward
(
loss
)
else
:
loss
.
backward
()
return
y
def
run_zero_optim_test
(
local_rank
,
world_size
,
stage
=
1
):
criterion
=
torch
.
nn
.
CrossEntropyLoss
()
criterion
=
torch
.
nn
.
CrossEntropyLoss
()
zero_model
=
MoeModel
()
MOE_MANAGER
.
__init__
()
zero_optimizer
=
torch
.
optim
.
Adam
(
zero_model
.
parameters
())
MOE_MANAGER
.
setup
(
parallel
=
"EP"
)
plugin
=
LowLevelZeroPlugin
(
stage
=
stage
,
precision
=
"fp32"
)
moe_model
=
MoeModel
().
bfloat16
()
booster
=
Booster
(
plugin
=
plugin
)
moe_optimizer
=
torch
.
optim
.
Adam
(
moe_model
.
parameters
(),
lr
=
1.0
)
zero_model
,
zero_optimizer
,
_
,
_
,
_
=
booster
.
boost
(
zero_model
,
zero_optimizer
)
moe_plugin
=
LowLevelZeroPlugin
(
stage
=
stage
,
precision
=
"bf16"
)
moe_booster
=
Booster
(
plugin
=
moe_plugin
)
torch_model
=
MoeModel
()
moe_model
,
moe_optimizer
,
_
,
_
,
_
=
moe_booster
.
boost
(
moe_model
,
moe_optimizer
)
for
zero_param
,
torch_param
in
zip
(
zero_model
.
parameters
(),
torch_model
.
parameters
()):
torch_param
.
data
.
copy_
(
zero_param
.
data
)
MOE_MANAGER
.
__init__
()
torch_optimizer
=
torch
.
optim
.
Adam
(
torch_model
.
parameters
())
MOE_MANAGER
.
setup
(
parallel
=
None
)
torch_model
=
torch_model
.
cuda
()
zero_model
=
MoeModel
().
bfloat16
()
grad_handler
=
MoeGradientHandler
(
torch_model
)
delete_moe_info
(
zero_model
)
sync_local_from_ep
(
zero_model
,
moe_model
)
for
_
in
range
(
2
):
zero_optimizer
=
torch
.
optim
.
Adam
(
zero_model
.
parameters
(),
lr
=
1.0
)
data
=
torch
.
randn
(
16
,
4
).
cuda
()
/
(
local_rank
+
1
)
zero_plugin
=
LowLevelZeroPlugin
(
stage
=
stage
,
precision
=
"bf16"
)
label
=
torch
.
randint
(
0
,
4
,
(
16
,)).
cuda
()
zero_booster
=
Booster
(
plugin
=
zero_plugin
)
run_fwd_bwd
(
torch_model
,
data
,
label
,
criterion
,
None
)
zero_model
,
zero_optimizer
,
_
,
_
,
_
=
zero_booster
.
boost
(
zero_model
,
zero_optimizer
)
run_fwd_bwd
(
zero_model
,
data
,
label
,
criterion
,
zero_optimizer
)
grad_handler
.
handle_gradient
()
for
(
moe_name
,
moe_param
),
(
zero_name
,
zero_param
)
in
zip
(
moe_model
.
named_parameters
(),
zero_model
.
named_parameters
()
torch_optimizer
.
step
()
):
if
".experts."
in
moe_name
:
continue
assert
moe_name
==
zero_name
assert
torch
.
allclose
(
moe_param
.
data
,
zero_param
.
data
),
f
"
{
moe_name
}
\n
torch_param
{
moe_param
.
data
}
\n
zero_param
{
zero_param
.
data
}
"
for
_
in
range
(
1
):
data
=
torch
.
randn
(
2
,
4
).
bfloat16
().
cuda
()
label
=
torch
.
randint
(
0
,
4
,
(
2
,)).
cuda
()
moe_out
=
run_fwd_bwd
(
moe_model
,
data
,
label
,
criterion
,
moe_optimizer
)
zero_out
=
run_fwd_bwd
(
zero_model
,
data
,
label
,
criterion
,
zero_optimizer
)
assert
torch
.
allclose
(
zero_out
,
moe_out
)
moe_optimizer
.
step
()
zero_optimizer
.
step
()
zero_optimizer
.
step
()
for
(
torch
_name
,
torch
_param
),
(
zero_name
,
zero_param
)
in
zip
(
for
(
moe
_name
,
moe
_param
),
(
zero_name
,
zero_param
)
in
zip
(
torch
_model
.
named_parameters
(),
zero_model
.
named_parameters
()
moe
_model
.
named_parameters
(),
zero_model
.
named_parameters
()
):
):
assert
torch
.
allclose
(
assert
moe_name
==
zero_name
torch_param
.
data
,
zero_param
.
data
if
is_moe_tensor
(
moe_param
):
),
f
"
{
torch_name
}
\n
torch_param
{
torch_param
.
data
}
\n
zero_param
{
zero_param
.
data
}
"
param_size
=
moe_param
.
shape
[
0
]
zero_param
=
zero_param
[
local_rank
*
param_size
:
(
local_rank
+
1
)
*
param_size
]
loose_close
(
moe_param
.
data
,
zero_param
.
data
,
dtype
=
moe_param
.
dtype
)
torch
_optimizer
.
zero_grad
()
moe
_optimizer
.
zero_grad
()
zero_optimizer
.
zero_grad
()
zero_optimizer
.
zero_grad
()
def
run_dist
(
rank
,
world_size
,
port
):
def
run_dist
(
rank
,
world_size
,
port
,
stage
):
colossalai
.
launch
(
config
=
dict
(),
rank
=
rank
,
world_size
=
world_size
,
host
=
"localhost"
,
port
=
port
,
backend
=
"nccl"
)
colossalai
.
launch
(
config
=
dict
(),
rank
=
rank
,
world_size
=
world_size
,
host
=
"localhost"
,
port
=
port
,
backend
=
"nccl"
)
MOE_MANAGER
.
setup
(
parallel
=
"EP"
)
seed_all
(
42
+
rank
)
run_zero_optim_test
(
rank
,
world_size
,
stage
=
1
)
run_zero_test
(
rank
,
stage
=
stage
)
run_zero_optim_test
(
rank
,
world_size
,
stage
=
2
)
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
2
])
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
2
])
@
pytest
.
mark
.
parametrize
(
"stage"
,
[
1
,
2
])
@
rerun_if_address_is_in_use
()
@
rerun_if_address_is_in_use
()
def
test_moe_zero_optim
(
world_size
):
def
test_moe_zero_optim
(
world_size
,
stage
):
spawn
(
run_dist
,
world_size
)
spawn
(
run_dist
,
world_size
,
stage
=
stage
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_moe_zero_optim
(
world_size
=
2
)
test_moe_zero_optim
(
world_size
=
2
,
stage
=
1
)
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment