Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
efef43b5
Unverified
Commit
efef43b5
authored
Feb 08, 2024
by
Frank Lee
Committed by
GitHub
Feb 08, 2024
Browse files
Merge pull request #5372 from hpcaitech/exp/mixtral
parents
4c03347f
06db94fb
Changes
33
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2091 additions
and
6 deletions
+2091
-6
applications/ColossalMoE/README.md
applications/ColossalMoE/README.md
+0
-0
applications/ColossalMoE/colossal_moe/__init__.py
applications/ColossalMoE/colossal_moe/__init__.py
+0
-0
applications/ColossalMoE/colossal_moe/models/__init__.py
applications/ColossalMoE/colossal_moe/models/__init__.py
+0
-0
applications/ColossalMoE/colossal_moe/models/mixtral_checkpoint.py
...ons/ColossalMoE/colossal_moe/models/mixtral_checkpoint.py
+629
-0
applications/ColossalMoE/colossal_moe/models/mixtral_layer.py
...ications/ColossalMoE/colossal_moe/models/mixtral_layer.py
+92
-0
applications/ColossalMoE/colossal_moe/models/mixtral_policy.py
...cations/ColossalMoE/colossal_moe/models/mixtral_policy.py
+557
-0
applications/ColossalMoE/colossal_moe/utils.py
applications/ColossalMoE/colossal_moe/utils.py
+84
-0
applications/ColossalMoE/infer.py
applications/ColossalMoE/infer.py
+111
-0
applications/ColossalMoE/infer.sh
applications/ColossalMoE/infer.sh
+7
-0
applications/ColossalMoE/requirements.txt
applications/ColossalMoE/requirements.txt
+5
-0
applications/ColossalMoE/setup.py
applications/ColossalMoE/setup.py
+43
-0
applications/ColossalMoE/tests/__init__.py
applications/ColossalMoE/tests/__init__.py
+0
-0
applications/ColossalMoE/tests/test_mixtral_layer.py
applications/ColossalMoE/tests/test_mixtral_layer.py
+63
-0
applications/ColossalMoE/tests/test_moe_checkpoint.py
applications/ColossalMoE/tests/test_moe_checkpoint.py
+146
-0
applications/ColossalMoE/train.py
applications/ColossalMoE/train.py
+295
-0
applications/ColossalMoE/train.sh
applications/ColossalMoE/train.sh
+19
-0
applications/ColossalMoE/version.txt
applications/ColossalMoE/version.txt
+1
-0
colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
+25
-3
colossalai/checkpoint_io/checkpoint_io_base.py
colossalai/checkpoint_io/checkpoint_io_base.py
+10
-2
colossalai/moe/__init__.py
colossalai/moe/__init__.py
+4
-1
No files found.
applications/ColossalMoE/README.md
0 → 100644
View file @
efef43b5
File added
applications/ColossalMoE/colossal_moe/__init__.py
0 → 100644
View file @
efef43b5
applications/ColossalMoE/colossal_moe/models/__init__.py
0 → 100644
View file @
efef43b5
applications/ColossalMoE/colossal_moe/models/mixtral_checkpoint.py
0 → 100644
View file @
efef43b5
import
copy
import
logging
import
os
from
pathlib
import
Path
from
shutil
import
rmtree
from
typing
import
Dict
,
Iterator
,
Optional
,
OrderedDict
,
Tuple
import
torch
import
torch.distributed
as
dist
import
torch.nn
as
nn
from
torch.distributed
import
ProcessGroup
from
colossalai.checkpoint_io
import
CheckpointIndexFile
from
colossalai.checkpoint_io.hybrid_parallel_checkpoint_io
import
HybridParallelCheckpointIO
from
colossalai.checkpoint_io.index_file
import
CheckpointIndexFile
from
colossalai.checkpoint_io.utils
import
(
StateDictSharder
,
gather_distributed_param
,
get_model_base_filenames
,
get_optimizer_base_filenames
,
load_shard_state_dict
,
load_states_into_optimizer
,
save_config_file
,
save_param_groups
,
save_state_dict_shards
,
search_tp_partition_dim
,
sharded_optimizer_loading_epilogue
,
)
from
colossalai.interface
import
ModelWrapper
,
OptimizerWrapper
from
colossalai.moe
import
MOE_MANAGER
from
colossalai.tensor.moe_tensor.api
import
is_moe_tensor
try
:
from
torch.nn.modules.module
import
_EXTRA_STATE_KEY_SUFFIX
except
ImportError
:
_EXTRA_STATE_KEY_SUFFIX
=
"_extra_state"
class
MixtralMoEHybridParallelCheckpointIO
(
HybridParallelCheckpointIO
):
def
__init__
(
self
,
dp_group
:
ProcessGroup
,
pp_group
:
ProcessGroup
,
tp_group
:
ProcessGroup
,
zero_stage
:
int
,
verbose
:
bool
=
True
,
)
->
None
:
super
().
__init__
(
dp_group
,
pp_group
,
tp_group
,
zero_stage
,
verbose
)
moe_info
=
MOE_MANAGER
.
parallel_info_dict
[
MOE_MANAGER
.
ep_size
]
self
.
ep_group
=
moe_info
.
ep_group
self
.
ep_size
=
moe_info
.
ep_size
self
.
ep_rank
=
moe_info
.
ep_rank
self
.
real_dp_rank
=
moe_info
.
dp_rank
@
staticmethod
def
_model_sharder
(
model
:
nn
.
Module
,
prefix
:
str
=
""
,
keep_vars
:
bool
=
False
,
size_per_shard
:
int
=
1024
,
param_name_pattern
:
Optional
[
str
]
=
None
,
)
->
Iterator
[
Tuple
[
OrderedDict
,
int
]]:
# An internel method that breaks state_dict of model into shards within limited size.
state_dict_sharder
=
StateDictSharder
(
size_per_shard
)
# Save parameters.
for
name
,
param
in
model
.
named_parameters
():
if
param
is
None
:
continue
if
param_name_pattern
is
not
None
and
param_name_pattern
not
in
name
:
continue
# Gather tensor pieces when using tensor parallel.
param_
=
gather_distributed_param
(
param
,
keep_vars
=
False
)
block
,
block_size
=
state_dict_sharder
.
append_param
(
prefix
+
name
,
param_
)
if
block
is
not
None
:
yield
block
,
block_size
# Save buffers.
for
name
,
buf
in
model
.
named_buffers
():
if
buf
is
not
None
and
name
not
in
model
.
_non_persistent_buffers_set
:
buffer
=
buf
if
keep_vars
else
buf
.
detach
()
block
,
block_size
=
state_dict_sharder
.
append_param
(
prefix
+
name
,
buffer
)
if
block
is
not
None
:
yield
block
,
block_size
# Save extra states.
extra_state_key
=
prefix
+
_EXTRA_STATE_KEY_SUFFIX
if
(
getattr
(
model
.
__class__
,
"get_extra_state"
,
torch
.
nn
.
Module
.
get_extra_state
)
is
not
torch
.
nn
.
Module
.
get_extra_state
):
extra_state
=
model
.
get_extra_state
()
block
,
block_size
=
state_dict_sharder
.
append_param
(
extra_state_key
,
extra_state
)
if
block
is
not
None
:
yield
block
,
block_size
# Return the last block in sharder.
yield
state_dict_sharder
.
current_block
,
state_dict_sharder
.
current_block_size
def
save_sharded_model
(
self
,
model
:
ModelWrapper
,
checkpoint
:
str
,
gather_dtensor
:
bool
=
True
,
prefix
:
Optional
[
str
]
=
None
,
size_per_shard
:
int
=
1024
,
use_safetensors
:
bool
=
False
,
)
->
None
:
"""
Save sharded model checkpoint under the given checkpointing path.
The following files will be created under the path:
- An index file (pytorch_model.bin.index.json) containing a map between model params/buffers and file names.
- Multiple files that store state tensors of models.
If pipeline parallelism is used, the filenames are in the form of "pytorch_model.<prefix>-stage-000XX-shard-000XX.bin".
If pipeline parallelism is not used, "pytorch_model.<prefix>-000XX.bin"
Args:
model (nn.Module): Model on local device to be saved.
checkpoint (str): Checkpointing path which should be a directory path.
gather_dtensor (bool, optional): Whether to gather_dtensor, currently not used. Defaults to True.
prefix (str, optional): Perfix of file to save. Defaults to None.
size_per_shard (int, optional): Size per shard in MB. Defaults to 1024.
use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False.
"""
assert
isinstance
(
model
,
ModelWrapper
),
"Please boost the model before saving!"
model
=
model
.
unwrap
()
if
os
.
path
.
isfile
(
checkpoint
):
logging
.
error
(
f
"Provided path (
{
checkpoint
}
) should be a directory, not a file"
)
return
Path
(
checkpoint
).
mkdir
(
parents
=
True
,
exist_ok
=
True
)
if
self
.
real_dp_rank
!=
0
:
dist
.
barrier
()
return
# ep_rank 0 saves all the parameters and buffers.
# other ep_ranks save only experts
ep_param_pattern
=
"experts."
if
self
.
ep_rank
!=
0
else
None
# Then collect the sharded parameters & buffers along tp_group.
# Only devices with tp_rank == 0 are responsible for model saving.
state_dict_shard
=
MixtralMoEHybridParallelCheckpointIO
.
_model_sharder
(
model
,
size_per_shard
=
size_per_shard
,
param_name_pattern
=
ep_param_pattern
)
weights_name
,
save_index_file
=
get_model_base_filenames
(
prefix
,
use_safetensors
)
index_file
=
CheckpointIndexFile
(
checkpoint
)
control_saving
=
self
.
tp_rank
==
0
if
self
.
pp_size
==
1
and
self
.
ep_size
==
1
:
# When pipeline is not used, save the model shards as in general checkpointIO
total_size
=
save_state_dict_shards
(
sharded_state_dict
=
state_dict_shard
,
checkpoint
=
checkpoint
,
index_file
=
index_file
,
base_filename
=
weights_name
,
is_master
=
control_saving
,
use_safetensors
=
use_safetensors
,
)
if
control_saving
:
index_file
.
append_meta_data
(
"total_size"
,
total_size
)
index_file
.
write_index_file
(
save_index_file
)
save_config_file
(
model
,
checkpoint
)
if
self
.
verbose
and
self
.
coordinator
.
is_master
():
logging
.
info
(
f
"The model is split into checkpoint shards. "
f
"You can find where each parameters has been saved in the "
f
"index located at
{
save_index_file
}
."
)
dist
.
barrier
()
else
:
# When pipeline is used, each stage produces its own shard files and index files.
# Index files belonging to each stage are saved under a temporary folder ./tmp_index_files/
# After all the state_dicts have been saved, the master rank integrates all the index files into one final index file and deletes the tmp folder.
final_index_file_path
=
copy
.
deepcopy
(
save_index_file
)
tmp_index_file_folder
=
os
.
path
.
join
(
checkpoint
,
"tmp_index_files"
)
Path
(
tmp_index_file_folder
).
mkdir
(
parents
=
True
,
exist_ok
=
True
)
# Manage filenames of sharded weights and index file for each pipeline stage.
weights_name
=
weights_name
.
replace
(
".bin"
,
f
"-stage-
{
self
.
pp_rank
+
1
:
05
d
}
-
{
self
.
ep_rank
+
1
:
05
d
}
-shard.bin"
)
weights_name
=
weights_name
.
replace
(
".safetensors"
,
f
"-stage-
{
self
.
pp_rank
+
1
:
05
d
}
-
{
self
.
ep_rank
+
1
:
05
d
}
-shard.safetensors"
)
save_index_file
=
save_index_file
.
replace
(
".json"
,
f
"-stage-
{
self
.
pp_rank
+
1
:
05
d
}
-
{
self
.
ep_rank
+
1
:
05
d
}
.json"
)
save_index_file
=
os
.
path
.
join
(
"tmp_index_files"
,
save_index_file
)
total_size
=
save_state_dict_shards
(
sharded_state_dict
=
state_dict_shard
,
checkpoint
=
checkpoint
,
index_file
=
index_file
,
base_filename
=
weights_name
,
is_master
=
control_saving
,
use_safetensors
=
use_safetensors
,
use_pp_format
=
True
,
)
if
control_saving
:
index_file
.
append_meta_data
(
"total_size"
,
total_size
)
index_file
.
write_index_file
(
save_index_file
)
else
:
dist
.
barrier
()
return
dist
.
barrier
()
# The global master rank integrates the index files and clean the folder.
if
self
.
coordinator
.
is_master
():
final_index_file
=
CheckpointIndexFile
(
checkpoint
)
final_index_file
.
append_meta_data
(
"total_size"
,
0
)
for
filename
in
os
.
listdir
(
tmp_index_file_folder
):
stage_index_file
=
CheckpointIndexFile
.
from_file
(
os
.
path
.
join
(
tmp_index_file_folder
,
filename
))
final_index_file
.
metadata
[
"total_size"
]
+=
stage_index_file
.
metadata
[
"total_size"
]
for
weight
,
weight_filename
in
stage_index_file
.
weight_map
.
items
():
final_index_file
.
append_weight_map
(
weight
,
weight_filename
)
final_index_file
.
write_index_file
(
final_index_file_path
)
save_config_file
(
model
,
checkpoint
)
rmtree
(
tmp_index_file_folder
)
if
self
.
verbose
and
self
.
coordinator
.
is_master
():
logging
.
info
(
f
"The model is split into checkpoint shards. "
f
"You can find where each parameters has been saved in the "
f
"index located at
{
final_index_file_path
}
."
)
@
staticmethod
def
gather_from_sharded_optimizer_state
(
state
:
OrderedDict
,
param
:
torch
.
Tensor
,
original_shape
:
torch
.
Size
,
dp_group
:
ProcessGroup
,
tp_group
:
ProcessGroup
,
use_zero
:
bool
,
inplace
:
bool
,
is_moe_param
:
bool
,
device
:
torch
.
device
=
torch
.
device
(
"cpu"
),
)
->
OrderedDict
:
"""
With given parameter and its optimizer states, gather the complete optimizer state for saving.
Args:
state (OrderedDict): Optimizer states of given parameter, might be distributed among tp/dp group if using TP/Zero.
param (torch.Tensor): The given parameter. It should be working_param when using Zero.
original_shape (torch.Size): The size of parameter before sharding.
dp_group (ProcessGroup): The process group of data parallel.
tp_group (ProcessGroup): The process group of tensor parallel.
use_zero (bool): Whether Zero is used.
inplace (bool): If set to True, will update the values of argument 'state' in place. Else will make a copy of state.
device (torch.device): The destination device of loaded optimizer states. Defaults to torch.device('cpu').
Returns:
OrderedDict: The complete optimizer state of given parameter.
"""
dp_size
=
dist
.
get_world_size
(
dp_group
)
tp_size
=
dist
.
get_world_size
(
tp_group
)
current_shape
=
param
.
shape
state_
=
state
if
inplace
else
copy
.
deepcopy
(
state
)
for
k
,
v
in
state_
.
items
():
if
isinstance
(
v
,
torch
.
Tensor
)
and
k
!=
"step"
:
# First gather Zero shards.
if
use_zero
and
not
is_moe_param
:
v
=
v
.
cuda
()
gather_tensor
=
[
torch
.
zeros_like
(
v
)
for
_
in
range
(
dp_size
)]
dist
.
all_gather
(
gather_tensor
,
v
,
group
=
dp_group
)
v
=
torch
.
stack
(
gather_tensor
).
view
(
-
1
)[:
param
.
numel
()].
reshape_as
(
param
)
# Then gather TP shards.
partition_dim
=
search_tp_partition_dim
(
current_shape
,
original_shape
,
tp_size
)
if
partition_dim
is
not
None
:
gather_tensor
=
[
torch
.
zeros_like
(
v
)
for
_
in
range
(
tp_size
)]
dist
.
all_gather
(
gather_tensor
,
v
,
group
=
tp_group
)
v
=
torch
.
cat
(
gather_tensor
,
dim
=
partition_dim
)
state_
[
k
]
=
v
.
detach
().
clone
().
to
(
device
)
return
state_
@
staticmethod
def
_optimizer_sharder
(
optimizer
:
OptimizerWrapper
,
use_zero
:
bool
,
dp_group
:
ProcessGroup
,
tp_group
:
ProcessGroup
,
size_per_shard
:
int
=
1024
,
only_moe_param
:
bool
=
False
,
):
# An internel method that breaks state_dict of optimizer into shards within limited size.
state_dict_sharder
=
StateDictSharder
(
size_per_shard
)
param_info
=
optimizer
.
param_info
master_to_working_map
=
optimizer
.
get_master_to_working_map
()
for
param
,
state
in
optimizer
.
optim
.
state
.
items
():
if
param
is
None
:
continue
if
master_to_working_map
is
not
None
:
working_param
=
master_to_working_map
[
id
(
param
)]
else
:
working_param
=
param
param_id
=
param_info
[
"param2id"
][
id
(
working_param
)]
original_shape
=
param_info
[
"param2shape"
][
id
(
working_param
)]
state_
=
MixtralMoEHybridParallelCheckpointIO
.
gather_from_sharded_optimizer_state
(
state
,
working_param
,
original_shape
=
original_shape
,
dp_group
=
dp_group
,
tp_group
=
tp_group
,
use_zero
=
use_zero
,
inplace
=
False
,
is_moe_param
=
is_moe_tensor
(
working_param
),
)
if
only_moe_param
and
not
is_moe_tensor
(
working_param
):
continue
block
,
block_size
=
state_dict_sharder
.
append_optim_state
(
param_id
,
state_
)
if
block
is
not
None
:
yield
block
,
block_size
# Return the last block in sharder.
yield
state_dict_sharder
.
current_block
,
state_dict_sharder
.
current_block_size
def
save_sharded_optimizer
(
self
,
optimizer
:
OptimizerWrapper
,
checkpoint
:
str
,
gather_dtensor
:
bool
=
True
,
prefix
:
Optional
[
str
]
=
None
,
size_per_shard
:
int
=
1024
,
):
"""
Save sharded optimizer checkpoint under the given checkpointing path.
The following files will be created under the path:
- An index file (pytorch_optim.bin.index.json) containing a map between optimizer states and file names
- A group file (pytorch_optim_group.bin) recording information of param_groups
- Multiple files that store state tensors of optimizers.
If pipeline parallelism is used, the filenames are in the form of "pytorch_optim.<prefix>-stage-000XX-shard-000XX.bin".
If pipeline parallelism is not used, "pytorch_optim.<prefix>-000XX.bin"
Args:
optimizer (OptimizerWrapper): Optimizer to save sharded state_dict
checkpoint (str): Path to save optimizer state_dict
gather_dtensor (bool): Whether to gather_dtensor, not used
prefix (str): Perfix of file to save
size_per_shard (int): Max file size of each file shard that store state tensors
"""
assert
isinstance
(
optimizer
,
OptimizerWrapper
),
"Please boost the optimizer before saving!"
if
os
.
path
.
isfile
(
checkpoint
):
logging
.
error
(
f
"Provided path (
{
checkpoint
}
) should be a directory, not a file"
)
return
Path
(
checkpoint
).
mkdir
(
parents
=
True
,
exist_ok
=
True
)
# Devices along the same dp_group share the same copies of states when zero is not used.
# In this case only let the device with dp_rank == 0 save the model.
if
not
self
.
use_zero
and
self
.
real_dp_rank
!=
0
:
dist
.
barrier
()
return
# Then collect the sharded states along dp_group(if using zero)/tp_group.
# Only devices with (dp_rank == 0 and tp_rank == 0) are responsible for states saving.
state_dict_shard
=
MixtralMoEHybridParallelCheckpointIO
.
_optimizer_sharder
(
optimizer
,
use_zero
=
self
.
use_zero
,
dp_group
=
self
.
dp_group
,
tp_group
=
self
.
tp_group
,
size_per_shard
=
size_per_shard
,
only_moe_param
=
self
.
ep_rank
!=
0
,
)
states_name
,
save_index_file
,
param_group_file
=
get_optimizer_base_filenames
(
prefix
)
index_file
=
CheckpointIndexFile
(
checkpoint
)
control_saving
=
self
.
real_dp_rank
==
0
and
self
.
tp_rank
==
0
if
self
.
pp_size
==
1
and
self
.
ep_size
==
1
:
# When pipeline is not used, save the optimizer shards as in general checkpointIO
total_size
=
save_state_dict_shards
(
sharded_state_dict
=
state_dict_shard
,
checkpoint
=
checkpoint
,
index_file
=
index_file
,
base_filename
=
states_name
,
is_master
=
control_saving
,
)
if
control_saving
:
# Store param groups.
index_file
.
append_meta_data
(
"param_groups"
,
param_group_file
)
group_file_path
=
os
.
path
.
join
(
checkpoint
,
param_group_file
)
param_groups
=
[
{
**
group
,
"params"
:
group_info
[
"params"
]}
for
group
,
group_info
in
zip
(
optimizer
.
param_groups
,
optimizer
.
param_info
[
"param_groups"
])
]
save_param_groups
({
"param_groups"
:
param_groups
},
group_file_path
)
# Store index file.
index_file
.
append_meta_data
(
"total_size"
,
total_size
)
index_file
.
write_index_file
(
save_index_file
)
if
self
.
verbose
and
self
.
coordinator
.
is_master
():
logging
.
info
(
f
"The optimizer is going to be split to checkpoint shards. "
f
"You can find where each parameters has been saved in the "
f
"index located at
{
save_index_file
}
."
)
dist
.
barrier
()
else
:
# When pipeline is used, each stage produces its own shard files and index files.
# Index files belonging to each stage are saved under a temporary folder ./tmp_index_files/
# After all the state_dicts have been saved, the master rank integrates all the index files into one final index file and deletes the tmp folder.
final_index_file_path
=
copy
.
deepcopy
(
save_index_file
)
tmp_index_file_folder
=
os
.
path
.
join
(
checkpoint
,
"tmp_index_files"
)
Path
(
tmp_index_file_folder
).
mkdir
(
parents
=
True
,
exist_ok
=
True
)
# Manage filenames of sharded weights and index file for each pipeline stage.
states_name
=
states_name
.
replace
(
".bin"
,
f
"-stage-
{
self
.
pp_rank
+
1
:
05
d
}
-
{
self
.
ep_rank
+
1
:
05
d
}
-shard.bin"
)
save_index_file
=
save_index_file
.
replace
(
".json"
,
f
"-stage-
{
self
.
pp_rank
+
1
:
05
d
}
-
{
self
.
ep_rank
+
1
:
05
d
}
.json"
)
save_index_file
=
os
.
path
.
join
(
"tmp_index_files"
,
save_index_file
)
total_size
=
save_state_dict_shards
(
sharded_state_dict
=
state_dict_shard
,
checkpoint
=
checkpoint
,
index_file
=
index_file
,
base_filename
=
states_name
,
is_master
=
control_saving
,
use_pp_format
=
True
,
)
if
control_saving
:
index_file
.
append_meta_data
(
"total_size"
,
total_size
)
index_file
.
write_index_file
(
save_index_file
)
else
:
dist
.
barrier
()
return
dist
.
barrier
()
# The global master rank integrates the index files and clean the folder.
if
self
.
coordinator
.
is_master
():
final_index_file
=
CheckpointIndexFile
(
checkpoint
)
final_index_file
.
append_meta_data
(
"total_size"
,
0
)
for
filename
in
os
.
listdir
(
tmp_index_file_folder
):
stage_index_file
=
CheckpointIndexFile
.
from_file
(
os
.
path
.
join
(
tmp_index_file_folder
,
filename
))
final_index_file
.
metadata
[
"total_size"
]
+=
stage_index_file
.
metadata
[
"total_size"
]
for
param_id
,
state_filename
in
stage_index_file
.
weight_map
.
items
():
final_index_file
.
append_weight_map
(
param_id
,
state_filename
)
# Store param groups.
final_index_file
.
append_meta_data
(
"param_groups"
,
param_group_file
)
group_file_path
=
os
.
path
.
join
(
checkpoint
,
param_group_file
)
param_groups
=
[
{
**
group
,
"params"
:
group_info
[
"params"
]}
for
group
,
group_info
in
zip
(
optimizer
.
param_groups
,
optimizer
.
param_info
[
"param_groups"
])
]
save_param_groups
({
"param_groups"
:
param_groups
},
group_file_path
)
final_index_file
.
write_index_file
(
final_index_file_path
)
rmtree
(
tmp_index_file_folder
)
if
self
.
verbose
and
self
.
coordinator
.
is_master
():
logging
.
info
(
f
"The model is split into checkpoint shards. "
f
"You can find where each parameters has been saved in the "
f
"index located at
{
final_index_file_path
}
."
)
def
load_sharded_optimizer
(
self
,
optimizer
:
OptimizerWrapper
,
checkpoint_index_file
:
str
,
prefix
:
str
=
""
):
"""
Load sharded optimizer with the given path to index file of checkpoint folder.
Args:
optimizer (OptimizerWrapper): The optimizer to be loaded.
checkpoint_index_file (str): Path to the index file of checkpointing folder.
prefix (str): Not used.
"""
assert
isinstance
(
optimizer
,
OptimizerWrapper
),
"Please boost the optimizer before loading!"
def
_get_param_id_from_optimizer_param
(
param
:
torch
.
Tensor
,
master_to_working_map
:
Optional
[
Dict
[
int
,
torch
.
Tensor
]]
=
None
):
if
master_to_working_map
is
not
None
:
working_param
=
master_to_working_map
[
id
(
param
)]
else
:
working_param
=
param
return
optimizer
.
param_info
[
"param2id"
][
id
(
working_param
)]
# id_map is a mapping from param ids kept by current pipeline, to their corresponding parameter objects.
# When Zero is used, the mapped parameter objects should be fp32 master parameters.
# IDs should be obtained through saved param2id mapping earlier saved in optimizer.param_info.
id_map
=
{}
master_to_working_map
=
optimizer
.
get_master_to_working_map
()
for
pg
in
optimizer
.
optim
.
param_groups
:
for
param
in
pg
[
"params"
]:
param_id
=
_get_param_id_from_optimizer_param
(
param
,
master_to_working_map
)
id_map
[
param_id
]
=
param
# Read checkpoint index file.
ckpt_index_file
=
CheckpointIndexFile
.
from_file
(
checkpoint_index_file
)
ckpt_root_path
=
ckpt_index_file
.
root_path
weight_map
=
ckpt_index_file
.
weight_map
weight_map
=
{
int
(
k
):
v
for
k
,
v
in
weight_map
.
items
()}
# convert saved id from str to int
# Load param_groups
param_group_path
=
ckpt_index_file
.
get_param_group_filename
()
if
param_group_path
is
None
:
raise
RuntimeError
(
f
"Invalid index file path
{
checkpoint_index_file
}
for an optimizer.
\
Lacking param group file under current directory."
)
saved_groups
=
torch
.
load
(
param_group_path
)
updated_groups
=
[]
for
old_pg
,
saved_pg
in
zip
(
optimizer
.
optim
.
param_groups
,
saved_groups
):
# obtain updated param group
new_pg
=
copy
.
deepcopy
(
saved_pg
)
new_pg
[
"params"
]
=
old_pg
[
"params"
]
# The parameters in the same group shouln't change.
updated_groups
.
append
(
new_pg
)
# ep param groups
if
len
(
optimizer
.
optim
.
param_groups
)
==
len
(
saved_groups
)
+
1
:
new_pg
=
copy
.
deepcopy
(
saved_pg
)
new_pg
[
"params"
]
=
optimizer
.
optim
.
param_groups
[
-
1
][
"params"
]
updated_groups
.
append
(
new_pg
)
optimizer
.
optim
.
__dict__
.
update
({
"param_groups"
:
updated_groups
})
# Load saved states to optimizer.
# Keep a record of loaded files so that file will not be repeatedly loaded.
loaded_file
=
set
()
for
pg
in
optimizer
.
optim
.
param_groups
:
for
param
in
pg
[
"params"
]:
if
param
is
None
:
continue
param_id
=
_get_param_id_from_optimizer_param
(
param
,
master_to_working_map
)
if
param_id
not
in
weight_map
:
continue
filename
=
weight_map
[
param_id
]
# If this param's states has been loaded before, directly return.
if
filename
in
loaded_file
:
continue
file_path
=
os
.
path
.
join
(
ckpt_root_path
,
filename
)
state_dict
=
load_shard_state_dict
(
Path
(
file_path
),
use_safetensors
=
False
)
load_states_into_optimizer
(
optimizer
.
optim
,
state_dict
,
id_map
,
strict
=
True
)
loaded_file
.
add
(
filename
)
# Then shard the loaded optimizer states if using tp/zero.
for
param
,
state
in
optimizer
.
optim
.
state
.
items
():
device
=
param
.
device
if
master_to_working_map
is
not
None
:
working_param
=
master_to_working_map
[
id
(
param
)]
else
:
working_param
=
param
original_shape
=
optimizer
.
param_info
[
"param2shape"
][
id
(
working_param
)]
sharded_state
=
self
.
shard_from_complete_optimizer_state
(
state
,
current_shape
=
working_param
.
shape
,
original_shape
=
original_shape
,
device
=
device
,
inplace
=
True
,
is_moe_param
=
is_moe_tensor
(
working_param
),
)
optimizer
.
optim
.
state
[
param
]
=
sharded_state
sharded_optimizer_loading_epilogue
(
optimizer
.
optim
)
if
self
.
verbose
and
self
.
coordinator
.
is_master
():
logging
.
info
(
f
"The optimizer has been successfully loaded from sharded checkpoint:
{
ckpt_root_path
}
."
)
def
shard_from_complete_optimizer_state
(
self
,
state
:
OrderedDict
,
current_shape
:
torch
.
Size
,
original_shape
:
torch
.
Size
,
device
:
torch
.
device
,
inplace
:
bool
,
is_moe_param
:
bool
,
)
->
OrderedDict
:
"""
With complete optimizer states of a specific parameter loaded from checkpoint,
slice out the sharded optimizer states kept by current device.
Args:
state (OrderedDict): Complete optimizer states of a given parameter, loaded from checkpoint.
current_shape (torch.Size): The size of parameter after sharding.
original_shape (torch.Size): The size of parameter before sharding.
device (torch.device): The destination device of loaded optimizer states.
inplace (bool): If set to True, will update the values of argument 'state' in place. Else will make a copy of state.
Returns:
OrderedDict: The sharded optimizer state of the given parameter.
"""
state_
=
state
if
inplace
else
copy
.
deepcopy
(
state
)
for
k
,
v
in
state_
.
items
():
if
isinstance
(
v
,
torch
.
Tensor
)
and
k
!=
"step"
:
# Shard state along tensor parallel group.
partition_dim
=
search_tp_partition_dim
(
current_shape
,
original_shape
,
self
.
tp_size
)
if
partition_dim
is
not
None
:
slice_size
=
current_shape
[
partition_dim
]
v
=
v
.
split
(
slice_size
,
dim
=
partition_dim
)[
self
.
tp_rank
]
# Shard state along data parallel group when using Zero.
if
self
.
use_zero
and
not
is_moe_param
:
padding_size
=
(
self
.
dp_size
-
v
.
numel
()
%
self
.
dp_size
)
%
self
.
dp_size
with
torch
.
no_grad
():
v
=
v
.
flatten
()
if
padding_size
>
0
:
v
=
torch
.
nn
.
functional
.
pad
(
v
,
[
0
,
padding_size
])
slice_size
=
v
.
numel
()
//
self
.
dp_size
v
=
v
.
split
(
slice_size
,
dim
=
0
)[
self
.
dp_rank
]
state_
[
k
]
=
v
.
detach
().
clone
().
to
(
device
)
return
state_
def
save_unsharded_model
(
self
,
model
:
ModelWrapper
,
checkpoint
:
str
,
gather_dtensor
:
bool
,
use_safetensors
:
bool
):
raise
NotImplementedError
def
save_unsharded_optimizer
(
self
,
optimizer
:
OptimizerWrapper
,
checkpoint
:
str
,
gather_dtensor
:
bool
):
raise
NotImplementedError
def
load_unsharded_optimizer
(
self
,
optimizer
:
OptimizerWrapper
,
checkpoint
:
str
,
strict
:
bool
=
False
):
raise
NotImplementedError
applications/ColossalMoE/colossal_moe/models/mixtral_layer.py
0 → 100644
View file @
efef43b5
import
torch
import
torch.distributed
as
dist
import
torch.nn.functional
as
F
from
transformers.models.mixtral.modeling_mixtral
import
MixtralSparseMoeBlock
from
colossalai.lazy
import
LazyInitContext
from
colossalai.moe
import
MOE_MANAGER
from
colossalai.moe._operation
import
MoeInGradScaler
,
MoeOutGradScaler
,
all_to_all_uneven
from
colossalai.shardformer.shard.utils
import
set_tensors_to_none
from
colossalai.tensor.moe_tensor.api
import
set_moe_tensor_info
class
EPMixtralSparseMoeBlock
(
MixtralSparseMoeBlock
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
self
.
setup_ep
()
def
setup_ep
(
self
):
_
,
moe_info
=
MOE_MANAGER
.
get_info
(
self
.
num_experts
)
ep_group
=
moe_info
.
ep_group
self
.
ep_size
=
dist
.
get_world_size
(
ep_group
)
if
ep_group
is
not
None
else
1
self
.
ep_rank
=
dist
.
get_rank
(
ep_group
)
if
ep_group
is
not
None
else
0
assert
self
.
num_experts
%
self
.
ep_size
==
0
self
.
ep_group
=
ep_group
self
.
num_experts_per_ep
=
self
.
num_experts
//
self
.
ep_size
self
.
expert_start_idx
=
self
.
ep_rank
*
self
.
num_experts_per_ep
held_experts
=
self
.
experts
[
self
.
expert_start_idx
:
self
.
expert_start_idx
+
self
.
num_experts_per_ep
]
set_tensors_to_none
(
self
.
experts
,
exclude
=
set
(
held_experts
))
for
p
in
self
.
experts
.
parameters
():
set_moe_tensor_info
(
p
,
moe_info
)
@
staticmethod
def
from_native_module
(
module
:
MixtralSparseMoeBlock
,
*
args
,
**
kwargs
)
->
"EPMixtralSparseMoeBlock"
:
LazyInitContext
.
materialize
(
module
)
module
.
__class__
=
EPMixtralSparseMoeBlock
module
.
setup_ep
()
return
module
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
batch_size
,
sequence_length
,
hidden_dim
=
hidden_states
.
shape
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_dim
)
# router_logits: (batch * sequence_length, n_experts)
router_logits
=
self
.
gate
(
hidden_states
)
routing_weights
=
F
.
softmax
(
router_logits
,
dim
=
1
,
dtype
=
torch
.
float
)
routing_weights
,
selected_experts
=
torch
.
topk
(
routing_weights
,
self
.
top_k
,
dim
=-
1
)
routing_weights
/=
routing_weights
.
sum
(
dim
=-
1
,
keepdim
=
True
)
# we cast back to the input dtype
routing_weights
=
routing_weights
.
to
(
hidden_states
.
dtype
)
selected_experts
=
selected_experts
.
t
().
reshape
(
-
1
)
selected_experts_idx
=
selected_experts
.
argsort
()
dispatch_states
=
hidden_states
.
repeat
(
self
.
top_k
,
1
)[
selected_experts_idx
]
input_split_sizes
=
selected_experts
.
bincount
(
minlength
=
self
.
num_experts
)
output_split_sizes
=
torch
.
zeros_like
(
input_split_sizes
)
dist
.
all_to_all_single
(
output_split_sizes
,
input_split_sizes
,
group
=
self
.
ep_group
)
input_split_list
=
input_split_sizes
.
view
(
self
.
ep_size
,
self
.
num_experts_per_ep
).
sum
(
dim
=-
1
).
tolist
()
output_split_list
=
output_split_sizes
.
view
(
self
.
ep_size
,
self
.
num_experts_per_ep
).
sum
(
dim
=-
1
).
tolist
()
output_states
,
_
=
all_to_all_uneven
(
dispatch_states
,
input_split_list
,
output_split_list
,
self
.
ep_group
)
# compute expert output
output_states
=
MoeInGradScaler
.
apply
(
output_states
,
self
.
ep_size
)
if
output_states
.
size
(
0
)
>
0
:
if
self
.
num_experts_per_ep
==
1
:
# no need to split
expert
=
self
.
experts
[
self
.
expert_start_idx
]
output_states
=
expert
.
act_fn
(
expert
.
w1
(
output_states
))
*
expert
.
w3
(
output_states
)
output_states
=
expert
.
w2
(
output_states
)
else
:
output_states_splits
=
output_states
.
split
(
output_split_sizes
.
tolist
())
output_states_list
=
[]
for
i
,
split_states
in
enumerate
(
output_states_splits
):
if
split_states
.
size
(
0
)
==
0
:
continue
expert
=
self
.
experts
[
self
.
expert_start_idx
+
i
%
self
.
num_experts_per_ep
]
split_states
=
expert
.
act_fn
(
expert
.
w1
(
split_states
))
*
expert
.
w3
(
split_states
)
split_states
=
expert
.
w2
(
split_states
)
output_states_list
.
append
(
split_states
)
output_states
=
torch
.
cat
(
output_states_list
)
output_states
=
MoeOutGradScaler
.
apply
(
output_states
,
self
.
ep_size
)
dispatch_states
,
_
=
all_to_all_uneven
(
output_states
,
output_split_list
,
input_split_list
,
self
.
ep_group
)
recover_experts_idx
=
torch
.
empty_like
(
selected_experts_idx
)
recover_experts_idx
[
selected_experts_idx
]
=
torch
.
arange
(
selected_experts_idx
.
size
(
0
),
device
=
selected_experts_idx
.
device
)
dispatch_states
=
dispatch_states
[
recover_experts_idx
]
k_hidden_states
=
dispatch_states
.
chunk
(
self
.
top_k
)
output_states
=
k_hidden_states
[
0
]
*
routing_weights
[:,
0
,
None
]
for
i
in
range
(
1
,
self
.
top_k
):
output_states
+=
k_hidden_states
[
i
]
*
routing_weights
[:,
i
,
None
]
output_states
=
output_states
.
reshape
(
batch_size
,
sequence_length
,
hidden_dim
)
return
output_states
,
router_logits
applications/ColossalMoE/colossal_moe/models/mixtral_policy.py
0 → 100644
View file @
efef43b5
from
functools
import
partial
from
typing
import
Callable
,
Dict
,
List
,
Optional
,
Union
import
torch
import
torch.nn
as
nn
from
torch
import
Tensor
from
torch.nn
import
CrossEntropyLoss
,
Module
from
transformers.models.mixtral.modeling_mixtral
import
(
MixtralDecoderLayer
,
MixtralForCausalLM
,
MixtralModel
,
MoeCausalLMOutputWithPast
,
_prepare_4d_causal_attention_mask
,
load_balancing_loss_func
,
)
from
transformers.utils
import
logging
from
colossalai.pipeline.stage_manager
import
PipelineStageManager
from
colossalai.shardformer.layer
import
FusedRMSNorm
,
Linear1D_Col
from
colossalai.shardformer.policies.base_policy
import
ModulePolicyDescription
,
Policy
,
SubModuleReplacementDescription
from
colossalai.shardformer.shard
import
ShardConfig
from
.mixtral_layer
import
EPMixtralSparseMoeBlock
__all__
=
[
"MixtralPolicy"
,
"MixtralForCausalLMPolicy"
]
class
MixtralPolicy
(
Policy
):
def
config_sanity_check
(
self
):
pass
def
preprocess
(
self
):
if
self
.
shard_config
.
enable_tensor_parallelism
:
# Resize embedding
vocab_size
=
self
.
model
.
config
.
vocab_size
world_size
=
self
.
shard_config
.
tensor_parallel_size
if
vocab_size
%
world_size
!=
0
:
new_vocab_size
=
vocab_size
+
world_size
-
vocab_size
%
world_size
self
.
model
.
resize_token_embeddings
(
new_vocab_size
)
return
self
.
model
def
module_policy
(
self
)
->
Dict
[
Union
[
str
,
nn
.
Module
],
ModulePolicyDescription
]:
policy
=
{}
if
self
.
shard_config
.
enable_sequence_parallelism
:
self
.
shard_config
.
enable_sequence_parallelism
=
False
raise
NotImplementedError
(
"Mixtral dosen't support sequence parallelism now, will ignore the sequence parallelism flag."
)
if
self
.
shard_config
.
enable_tensor_parallelism
:
raise
NotImplementedError
(
"Tensor parallelism is not supported for Mixtral model now."
)
# expert parallel
self
.
append_or_create_submodule_replacement
(
description
=
[
SubModuleReplacementDescription
(
suffix
=
"block_sparse_moe"
,
target_module
=
EPMixtralSparseMoeBlock
,
)
],
policy
=
policy
,
target_key
=
MixtralDecoderLayer
,
)
# optimization configuration
if
self
.
shard_config
.
enable_fused_normalization
:
self
.
append_or_create_submodule_replacement
(
description
=
[
SubModuleReplacementDescription
(
suffix
=
"input_layernorm"
,
target_module
=
FusedRMSNorm
,
),
SubModuleReplacementDescription
(
suffix
=
"post_attention_layernorm"
,
target_module
=
FusedRMSNorm
,
),
],
policy
=
policy
,
target_key
=
MixtralDecoderLayer
,
)
self
.
append_or_create_submodule_replacement
(
description
=
SubModuleReplacementDescription
(
suffix
=
"norm"
,
target_module
=
FusedRMSNorm
,
),
policy
=
policy
,
target_key
=
MixtralModel
,
)
if
self
.
shard_config
.
enable_flash_attention
:
raise
NotImplementedError
(
"Flash attention has already been replaced in mixtral."
)
return
policy
def
postprocess
(
self
):
return
self
.
model
def
set_pipeline_forward
(
self
,
model_cls
:
nn
.
Module
,
new_forward
:
Callable
,
policy
:
Dict
)
->
None
:
"""If under pipeline parallel setting, replacing the original forward method of huggingface
to customized forward method, and add this changing to policy."""
if
self
.
pipeline_stage_manager
:
stage_manager
=
self
.
pipeline_stage_manager
if
self
.
model
.
__class__
.
__name__
==
"MixtralModel"
:
module
=
self
.
model
else
:
module
=
self
.
model
.
model
layers_per_stage
=
self
.
distribute_layers
(
len
(
module
.
layers
),
stage_manager
.
num_stages
)
stage_index
=
Policy
.
get_stage_index
(
layers_per_stage
,
stage_manager
.
stage
)
method_replacement
=
{
"forward"
:
partial
(
new_forward
,
stage_manager
=
stage_manager
,
stage_index
=
stage_index
)}
self
.
append_or_create_method_replacement
(
description
=
method_replacement
,
policy
=
policy
,
target_key
=
model_cls
)
return
def
get_held_layers
(
self
)
->
List
[
Module
]:
"""Get pipeline layers for current stage."""
assert
self
.
pipeline_stage_manager
is
not
None
if
self
.
model
.
__class__
.
__name__
==
"MixtralModel"
:
module
=
self
.
model
else
:
module
=
self
.
model
.
model
stage_manager
=
self
.
pipeline_stage_manager
held_layers
=
[]
layers_per_stage
=
self
.
distribute_layers
(
len
(
module
.
layers
),
stage_manager
.
num_stages
)
if
stage_manager
.
is_first_stage
():
held_layers
.
append
(
module
.
embed_tokens
)
start_idx
,
end_idx
=
self
.
get_stage_index
(
layers_per_stage
,
stage_manager
.
stage
)
held_layers
.
extend
(
module
.
layers
[
start_idx
:
end_idx
])
if
stage_manager
.
is_last_stage
():
held_layers
.
append
(
module
.
norm
)
return
held_layers
class
MixtralModelPolicy
(
MixtralPolicy
):
def
__init__
(
self
)
->
None
:
super
().
__init__
()
def
module_policy
(
self
):
policy
=
super
().
module_policy
()
if
self
.
pipeline_stage_manager
:
# set None as default
self
.
set_pipeline_forward
(
model_cls
=
MixtralModel
,
new_forward
=
MixtralPipelineForwards
.
mixtral_model_forward
,
policy
=
policy
,
)
return
policy
def
get_held_layers
(
self
)
->
List
[
Module
]:
"""Get pipeline layers for current stage."""
held_layers
=
super
().
get_held_layers
()
return
held_layers
def
get_shared_params
(
self
)
->
List
[
Dict
[
int
,
Tensor
]]:
"""No shared params in llama model"""
return
[]
class
MixtralForCausalLMPolicy
(
MixtralPolicy
):
def
module_policy
(
self
):
policy
=
super
().
module_policy
()
if
self
.
shard_config
.
enable_tensor_parallelism
:
# add a new item for casual lm
new_item
=
{
MixtralForCausalLM
:
ModulePolicyDescription
(
sub_module_replacement
=
[
SubModuleReplacementDescription
(
suffix
=
"lm_head"
,
target_module
=
Linear1D_Col
,
kwargs
=
dict
(
gather_output
=
True
),
)
]
)
}
policy
.
update
(
new_item
)
if
self
.
pipeline_stage_manager
:
# set None as default
self
.
set_pipeline_forward
(
model_cls
=
MixtralForCausalLM
,
new_forward
=
MixtralPipelineForwards
.
mixtral_for_causal_lm_forward
,
policy
=
policy
,
)
return
policy
def
get_held_layers
(
self
)
->
List
[
Module
]:
"""Get pipeline layers for current stage."""
stage_manager
=
self
.
pipeline_stage_manager
held_layers
=
super
().
get_held_layers
()
if
stage_manager
.
is_last_stage
():
held_layers
.
append
(
self
.
model
.
lm_head
)
return
held_layers
def
get_shared_params
(
self
)
->
List
[
Dict
[
int
,
Tensor
]]:
llama_model
=
self
.
model
.
model
if
self
.
pipeline_stage_manager
and
self
.
pipeline_stage_manager
.
num_stages
>
1
:
if
(
id
(
llama_model
.
embed_tokens
.
weight
)
==
id
(
self
.
model
.
lm_head
.
weight
)
and
self
.
pipeline_stage_manager
.
num_stages
>
1
):
# tie weights
return
[
{
0
:
llama_model
.
embed_tokens
.
weight
,
self
.
pipeline_stage_manager
.
num_stages
-
1
:
self
.
model
.
lm_head
.
weight
,
}
]
return
[]
class
MixtralPipelineForwards
:
"""
This class serves as a micro library for forward function substitution of Llama models
under pipeline setting.
"""
@
staticmethod
def
mixtral_model_forward
(
self
,
input_ids
:
torch
.
LongTensor
=
None
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
past_key_values
:
Optional
[
List
[
torch
.
FloatTensor
]]
=
None
,
inputs_embeds
:
Optional
[
torch
.
FloatTensor
]
=
None
,
use_cache
:
Optional
[
bool
]
=
None
,
output_attentions
:
Optional
[
bool
]
=
None
,
output_hidden_states
:
Optional
[
bool
]
=
None
,
output_router_logits
:
Optional
[
bool
]
=
None
,
return_dict
:
Optional
[
bool
]
=
None
,
stage_manager
:
Optional
[
PipelineStageManager
]
=
None
,
hidden_states
:
Optional
[
torch
.
FloatTensor
]
=
None
,
past_router_logits
:
Optional
[
torch
.
FloatTensor
]
=
None
,
stage_index
:
Optional
[
List
[
int
]]
=
None
,
shard_config
:
ShardConfig
=
None
,
):
r
"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, MixtralForCausalLM
>>> model = MixtralForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
```"""
logger
=
logging
.
get_logger
(
__name__
)
output_attentions
=
output_attentions
if
output_attentions
is
not
None
else
self
.
config
.
output_attentions
output_router_logits
=
(
output_router_logits
if
output_router_logits
is
not
None
else
self
.
config
.
output_router_logits
)
output_hidden_states
=
(
output_hidden_states
if
output_hidden_states
is
not
None
else
self
.
config
.
output_hidden_states
)
return_dict
=
return_dict
if
return_dict
is
not
None
else
self
.
config
.
use_return_dict
# retrieve input_ids and inputs_embeds
if
stage_manager
.
is_first_stage
():
# retrieve input_ids and inputs_embeds
if
input_ids
is
not
None
and
inputs_embeds
is
not
None
:
raise
ValueError
(
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
)
elif
input_ids
is
not
None
:
batch_size
,
seq_length
=
input_ids
.
shape
elif
inputs_embeds
is
not
None
:
batch_size
,
seq_length
,
_
=
inputs_embeds
.
shape
else
:
raise
ValueError
(
"You have to specify either decoder_input_ids or decoder_inputs_embeds"
)
device
=
input_ids
.
device
if
input_ids
is
not
None
else
inputs_embeds
.
device
if
inputs_embeds
is
None
:
inputs_embeds
=
self
.
embed_tokens
(
input_ids
)
hidden_states
=
inputs_embeds
else
:
input_shape
=
hidden_states
.
shape
[:
-
1
]
batch_size
,
seq_length
=
input_shape
device
=
hidden_states
.
device
seq_length_with_past
=
seq_length
past_key_values_length
=
0
# TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future.
if
output_attentions
:
logger
.
warning_once
(
"output_attentions=True is not supported for pipeline models at the moment."
)
output_attentions
=
False
if
output_hidden_states
:
logger
.
warning_once
(
"output_hidden_states=True is not supported for pipeline models at the moment."
)
output_hidden_states
=
False
if
use_cache
:
logger
.
warning_once
(
"use_cache=True is not supported for pipeline models at the moment."
)
use_cache
=
False
if
past_key_values
is
not
None
:
past_key_values_length
=
past_key_values
[
0
][
0
].
shape
[
2
]
seq_length_with_past
=
seq_length_with_past
+
past_key_values_length
if
position_ids
is
None
:
position_ids
=
torch
.
arange
(
past_key_values_length
,
seq_length
+
past_key_values_length
,
dtype
=
torch
.
long
,
device
=
device
,
)
position_ids
=
position_ids
.
unsqueeze
(
0
).
view
(
-
1
,
seq_length
)
else
:
position_ids
=
position_ids
.
view
(
-
1
,
seq_length
).
long
()
# embed positions, for the first stage, hidden_states is the input embeddings,
# for the other stages, hidden_states is the output of the previous stage
if
self
.
_use_flash_attention_2
:
# 2d mask is passed through the layers
attention_mask
=
attention_mask
if
(
attention_mask
is
not
None
and
0
in
attention_mask
)
else
None
else
:
# 4d mask is passed through the layers
attention_mask
=
_prepare_4d_causal_attention_mask
(
attention_mask
,
(
batch_size
,
seq_length
),
hidden_states
,
past_key_values_length
,
sliding_window
=
self
.
config
.
sliding_window
,
)
if
self
.
gradient_checkpointing
and
self
.
training
:
if
use_cache
:
logger
.
warning_once
(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache
=
False
# decoder layers
all_hidden_states
=
()
if
output_hidden_states
else
None
all_self_attns
=
()
if
output_attentions
else
None
all_router_logits
=
()
if
output_router_logits
else
None
next_decoder_cache
=
None
start_idx
,
end_idx
=
stage_index
[
0
],
stage_index
[
1
]
for
idx
,
decoder_layer
in
enumerate
(
self
.
layers
[
start_idx
:
end_idx
],
start
=
start_idx
):
if
output_hidden_states
:
all_hidden_states
+=
(
hidden_states
,)
past_key_value
=
past_key_values
[
idx
]
if
past_key_values
is
not
None
else
None
if
self
.
gradient_checkpointing
and
self
.
training
:
def
create_custom_forward
(
module
):
def
custom_forward
(
*
inputs
):
# None for past_key_value
return
module
(
*
inputs
)
return
custom_forward
layer_outputs
=
torch
.
utils
.
checkpoint
.
checkpoint
(
create_custom_forward
(
decoder_layer
),
hidden_states
,
attention_mask
,
position_ids
,
None
,
output_attentions
,
output_router_logits
,
)
else
:
layer_outputs
=
decoder_layer
(
hidden_states
,
attention_mask
,
position_ids
,
past_key_value
,
output_attentions
,
output_router_logits
,
use_cache
,
)
hidden_states
=
layer_outputs
[
0
]
if
use_cache
:
next_decoder_cache
=
(
layer_outputs
[
2
if
output_attentions
else
1
],)
if
output_attentions
:
all_self_attns
+=
(
layer_outputs
[
1
],)
if
output_router_logits
:
all_router_logits
+=
(
layer_outputs
[
-
1
],)
if
stage_manager
.
is_last_stage
():
hidden_states
=
self
.
norm
(
hidden_states
)
# add hidden states from the last decoder layer
if
output_hidden_states
:
all_hidden_states
+=
(
hidden_states
,)
next_cache
=
next_decoder_cache
if
use_cache
else
None
if
output_router_logits
and
past_router_logits
is
not
None
:
all_router_logits
=
past_router_logits
+
all_router_logits
if
stage_manager
.
is_last_stage
():
return
tuple
(
v
for
v
in
[
hidden_states
,
next_cache
,
all_hidden_states
,
all_self_attns
,
all_router_logits
]
if
v
is
not
None
)
# always return dict for imediate stage
return
{
"hidden_states"
:
hidden_states
,
"past_router_logits"
:
all_router_logits
,
}
@
staticmethod
def
mixtral_for_causal_lm_forward
(
self
,
input_ids
:
torch
.
LongTensor
=
None
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
past_key_values
:
Optional
[
List
[
torch
.
FloatTensor
]]
=
None
,
inputs_embeds
:
Optional
[
torch
.
FloatTensor
]
=
None
,
labels
:
Optional
[
torch
.
LongTensor
]
=
None
,
use_cache
:
Optional
[
bool
]
=
None
,
output_attentions
:
Optional
[
bool
]
=
None
,
output_hidden_states
:
Optional
[
bool
]
=
None
,
output_router_logits
:
Optional
[
bool
]
=
None
,
return_dict
:
Optional
[
bool
]
=
None
,
stage_manager
:
Optional
[
PipelineStageManager
]
=
None
,
hidden_states
:
Optional
[
torch
.
FloatTensor
]
=
None
,
past_router_logits
:
Optional
[
torch
.
FloatTensor
]
=
None
,
stage_index
:
Optional
[
List
[
int
]]
=
None
,
shard_config
:
ShardConfig
=
None
,
):
r
"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, MixtralForCausalLM
>>> model = MixtralForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
```"""
logger
=
logging
.
get_logger
(
__name__
)
output_attentions
=
output_attentions
if
output_attentions
is
not
None
else
self
.
config
.
output_attentions
output_router_logits
=
(
output_router_logits
if
output_router_logits
is
not
None
else
self
.
config
.
output_router_logits
)
output_hidden_states
=
(
output_hidden_states
if
output_hidden_states
is
not
None
else
self
.
config
.
output_hidden_states
)
return_dict
=
return_dict
if
return_dict
is
not
None
else
self
.
config
.
use_return_dict
# TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future.
if
output_attentions
:
logger
.
warning_once
(
"output_attentions=True is not supported for pipeline models at the moment."
)
output_attentions
=
False
if
output_hidden_states
:
logger
.
warning_once
(
"output_hidden_states=True is not supported for pipeline models at the moment."
)
output_hidden_states
=
False
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs
=
MixtralPipelineForwards
.
mixtral_model_forward
(
self
.
model
,
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
position_ids
=
position_ids
,
past_key_values
=
past_key_values
,
inputs_embeds
=
inputs_embeds
,
use_cache
=
use_cache
,
output_attentions
=
output_attentions
,
output_hidden_states
=
output_hidden_states
,
output_router_logits
=
output_router_logits
,
return_dict
=
return_dict
,
stage_manager
=
stage_manager
,
hidden_states
=
hidden_states
,
stage_index
=
stage_index
,
past_router_logits
=
past_router_logits
,
)
past_key_values
=
None
if
stage_manager
.
is_last_stage
():
hidden_states
=
outputs
[
0
]
logits
=
self
.
lm_head
(
hidden_states
)
logits
=
logits
.
float
()
loss
=
None
if
labels
is
not
None
:
# Shift so that tokens < n predict n
shift_logits
=
logits
[...,
:
-
1
,
:].
contiguous
()
shift_labels
=
labels
[...,
1
:].
contiguous
()
# Flatten the tokens
loss_fct
=
CrossEntropyLoss
()
shift_logits
=
shift_logits
.
view
(
-
1
,
self
.
config
.
vocab_size
)
shift_labels
=
shift_labels
.
view
(
-
1
)
# Enable model parallelism
shift_labels
=
shift_labels
.
to
(
shift_logits
.
device
)
loss
=
loss_fct
(
shift_logits
,
shift_labels
)
aux_loss
=
None
if
output_router_logits
:
aux_loss
=
load_balancing_loss_func
(
outputs
[
-
1
],
self
.
num_experts
,
self
.
num_experts_per_tok
)
if
labels
is
not
None
:
loss
+=
self
.
router_aux_loss_coef
*
aux_loss
if
not
return_dict
:
output
=
(
logits
,)
+
outputs
[
1
:]
if
output_router_logits
:
output
=
(
aux_loss
,)
+
output
return
(
loss
,)
+
output
if
loss
is
not
None
else
output
return
MoeCausalLMOutputWithPast
(
loss
=
loss
,
aux_loss
=
aux_loss
,
logits
=
logits
,
past_key_values
=
None
,
hidden_states
=
outputs
[
0
],
attentions
=
None
,
router_logits
=
outputs
[
-
1
],
)
else
:
out
=
{}
hidden_states
=
outputs
.
get
(
"hidden_states"
)
out
[
"hidden_states"
]
=
hidden_states
if
output_router_logits
:
out
[
"past_router_logits"
]
=
outputs
[
"past_router_logits"
]
return
out
applications/ColossalMoE/colossal_moe/utils.py
0 → 100644
View file @
efef43b5
import
json
import
os
from
typing
import
Any
,
Dict
,
Tuple
,
Union
import
torch
from
torch.optim.lr_scheduler
import
_LRScheduler
from
torch.optim.optimizer
import
Optimizer
from
colossalai.booster
import
Booster
from
colossalai.cluster
import
DistCoordinator
def
move_to_cuda
(
batch
,
device
):
return
{
k
:
v
.
to
(
device
)
for
k
,
v
in
batch
.
items
()}
def
load_json
(
file_path
:
Union
[
str
,
os
.
PathLike
])
->
Dict
[
str
,
Any
]:
"""
Load file in JSON format
"""
with
open
(
file
=
file_path
,
mode
=
"r"
,
encoding
=
"utf-8"
)
as
fp
:
return
json
.
load
(
fp
)
def
save_json
(
data
:
Dict
[
str
,
Any
],
file_path
:
Union
[
str
,
os
.
PathLike
])
->
None
:
"""
Save as JSON format
"""
with
open
(
file
=
file_path
,
mode
=
"w"
,
encoding
=
"utf-8"
)
as
fp
:
json
.
dump
(
data
,
fp
=
fp
,
ensure_ascii
=
False
,
indent
=
4
)
def
save_checkpoint
(
save_dir
:
Union
[
str
,
os
.
PathLike
],
booster
:
Booster
,
model
:
torch
.
nn
.
Module
,
optimizer
:
Optimizer
,
lr_scheduler
:
_LRScheduler
,
epoch
:
int
,
step
:
int
,
batch_size
:
int
,
coordinator
:
DistCoordinator
,
)
->
None
:
"""
Save model checkpoint, optimizer, LR scheduler and intermedidate running states.
"""
save_dir
=
os
.
path
.
join
(
save_dir
,
f
"epoch-
{
epoch
}
_step-
{
step
}
"
)
os
.
makedirs
(
os
.
path
.
join
(
save_dir
,
"modeling"
),
exist_ok
=
True
)
booster
.
save_model
(
model
,
os
.
path
.
join
(
save_dir
,
"modeling"
),
shard
=
True
)
booster
.
save_optimizer
(
optimizer
,
os
.
path
.
join
(
save_dir
,
"optimizer"
),
shard
=
True
)
booster
.
save_lr_scheduler
(
lr_scheduler
,
os
.
path
.
join
(
save_dir
,
"lr_scheduler"
))
running_states
=
{
"epoch"
:
epoch
,
"step"
:
step
,
"sample_start_index"
:
step
*
batch_size
,
}
if
coordinator
.
is_master
():
save_json
(
running_states
,
os
.
path
.
join
(
save_dir
,
"running_states.json"
))
def
load_checkpoint
(
load_dir
:
Union
[
str
,
os
.
PathLike
],
booster
:
Booster
,
model
:
torch
.
nn
.
Module
,
optimizer
:
Optimizer
,
lr_scheduler
:
_LRScheduler
,
)
->
Tuple
[
int
,
int
,
int
]:
"""
Load model checkpoint, optimizer, LR scheduler and intermedidate running states.
"""
# Update booster params states.
booster
.
load_model
(
model
,
os
.
path
.
join
(
load_dir
,
"modeling"
))
booster
.
load_optimizer
(
optimizer
=
optimizer
,
checkpoint
=
os
.
path
.
join
(
load_dir
,
"optimizer"
))
booster
.
load_lr_scheduler
(
lr_scheduler
=
lr_scheduler
,
checkpoint
=
os
.
path
.
join
(
load_dir
,
"lr_scheduler"
))
running_states
=
load_json
(
file_path
=
os
.
path
.
join
(
load_dir
,
"running_states.json"
))
return
(
running_states
[
"epoch"
],
running_states
[
"step"
],
running_states
[
"sample_start_index"
],
)
applications/ColossalMoE/infer.py
0 → 100644
View file @
efef43b5
import
argparse
import
torch
import
torch.distributed
as
dist
from
colossal_moe.models.mixtral_checkpoint
import
MixtralMoEHybridParallelCheckpointIO
from
colossal_moe.models.mixtral_policy
import
MixtralForCausalLMPolicy
from
transformers
import
AutoTokenizer
from
transformers.models.mixtral
import
MixtralConfig
,
MixtralForCausalLM
import
colossalai
from
colossalai.booster
import
Booster
from
colossalai.booster.plugin.moe_hybrid_parallel_plugin
import
MoeHybridParallelPlugin
from
colossalai.cluster
import
DistCoordinator
def
parse_args
():
# basic settings
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--model_name"
,
type
=
str
,
default
=
"mistralai/Mixtral-8x7B-v0.1"
,
help
=
"Path to pretrained model or model identifier from huggingface.co/models."
,
)
parser
.
add_argument
(
"--plugin"
,
type
=
str
,
default
=
"ep"
,
choices
=
[
"ep"
],
help
=
"Parallel methos."
,
)
parser
.
add_argument
(
"--precision"
,
type
=
str
,
default
=
"bf16"
,
choices
=
[
"fp32"
,
"bf16"
,
"fp16"
],
help
=
"The mixed precision training."
,
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
42
,
help
=
"A seed for reproducible training."
)
# kernel
parser
.
add_argument
(
"--use_kernel"
,
action
=
"store_true"
,
help
=
"Use kernel optim. Need to install flash attention and triton to enable all kernel optimizations. Skip if not installed."
,
)
parser
.
add_argument
(
"--use_layernorm_kernel"
,
action
=
"store_true"
,
help
=
"Use layernorm kernel. Need to install apex. Raise error if not installed."
,
)
args
=
parser
.
parse_args
()
return
args
def
main
():
args
=
parse_args
()
# Launch ColossalAI
colossalai
.
launch_from_torch
(
config
=
{},
seed
=
args
.
seed
)
coordinator
=
DistCoordinator
()
config
=
MixtralConfig
.
from_pretrained
(
args
.
model_name
)
ep_size
=
min
(
dist
.
get_world_size
(),
config
.
num_local_experts
)
# Set plugin
if
args
.
plugin
==
"ep"
:
plugin
=
MoeHybridParallelPlugin
(
tp_size
=
1
,
pp_size
=
1
,
ep_size
=
ep_size
,
zero_stage
=
1
,
precision
=
args
.
precision
,
custom_policy
=
MixtralForCausalLMPolicy
(),
checkpoint_io
=
MixtralMoEHybridParallelCheckpointIO
,
enable_fused_normalization
=
args
.
use_layernorm_kernel
,
enable_jit_fused
=
args
.
use_kernel
,
)
else
:
raise
ValueError
(
f
"Invalid plugin
{
args
.
plugin
}
"
)
coordinator
.
print_on_master
(
f
"Set plugin as
{
plugin
.
__class__
.
__name__
}
"
)
# Build mixtral model
model
=
MixtralForCausalLM
.
from_pretrained
(
args
.
model_name
)
coordinator
.
print_on_master
(
f
"Finish load model"
)
# Prepare tokenizer and dataloader
tokenizer
=
AutoTokenizer
.
from_pretrained
(
args
.
model_name
)
# Set booster
booster
=
Booster
(
plugin
=
plugin
)
model
,
_
,
_
,
_
,
_
=
booster
.
boost
(
model
=
model
)
coordinator
.
print_on_master
(
f
"Finish init booster"
)
model
.
eval
()
if
coordinator
.
rank
==
0
:
text
=
[
"Hello my name is"
]
else
:
text
=
[
"What's the largest country in the world?"
,
"How many people live in China?"
,
"帮我续写这首诗:离离原上草"
]
tokenizer
.
pad_token
=
tokenizer
.
unk_token
inputs
=
tokenizer
(
text
,
return_tensors
=
"pt"
,
padding
=
True
).
to
(
torch
.
cuda
.
current_device
())
with
torch
.
no_grad
():
outputs
=
model
.
module
.
generate
(
**
inputs
,
max_new_tokens
=
20
)
outputs
=
tokenizer
.
batch_decode
(
outputs
,
skip_special_tokens
=
True
)
print
(
f
"[
{
coordinator
.
rank
}
]
{
outputs
}
"
)
if
__name__
==
"__main__"
:
main
()
applications/ColossalMoE/infer.sh
0 → 100644
View file @
efef43b5
NUM_GPU
=
2
MODEL
=
"mistralai/Mixtral-8x7B-v0.1"
# ep
torchrun
--standalone
--nproc_per_node
$NUM_GPU
infer.py
\
--model_name
$MODEL
\
--plugin
"ep"
\
applications/ColossalMoE/requirements.txt
0 → 100644
View file @
efef43b5
colossalai >= 0.3.3
torch >= 1.8.1
transformers == 4.36.0
sentencepiece
datasets
applications/ColossalMoE/setup.py
0 → 100644
View file @
efef43b5
from
setuptools
import
find_packages
,
setup
def
fetch_requirements
(
path
):
with
open
(
path
,
"r"
)
as
fd
:
return
[
r
.
strip
()
for
r
in
fd
.
readlines
()]
def
fetch_readme
():
with
open
(
"README.md"
,
encoding
=
"utf-8"
)
as
f
:
return
f
.
read
()
def
fetch_version
():
with
open
(
"version.txt"
,
"r"
)
as
f
:
return
f
.
read
().
strip
()
setup
(
name
=
"colossal_moe"
,
version
=
fetch_version
(),
packages
=
find_packages
(
exclude
=
(
"tests"
,
"benchmarks"
,
"*.egg-info"
,
)
),
description
=
"Colossal-AI MoE"
,
long_description
=
fetch_readme
(),
long_description_content_type
=
"text/markdown"
,
license
=
"Apache Software License 2.0"
,
url
=
"https://github.com/hpcaitech"
,
install_requires
=
fetch_requirements
(
"requirements.txt"
),
python_requires
=
">=3.6"
,
classifiers
=
[
"Programming Language :: Python :: 3"
,
"License :: OSI Approved :: Apache Software License"
,
"Environment :: GPU :: NVIDIA CUDA"
,
"Topic :: Scientific/Engineering :: Artificial Intelligence"
,
"Topic :: System :: Distributed Computing"
,
],
)
applications/ColossalMoE/tests/__init__.py
0 → 100644
View file @
efef43b5
applications/ColossalMoE/tests/test_mixtral_layer.py
0 → 100644
View file @
efef43b5
from
copy
import
deepcopy
import
pytest
import
torch
import
torch.distributed
as
dist
from
colossal_moe.models.mixtral_layer
import
EPMixtralSparseMoeBlock
from
torch.testing
import
assert_close
from
transformers.models.mixtral.configuration_mixtral
import
MixtralConfig
from
transformers.models.mixtral.modeling_mixtral
import
MixtralSparseMoeBlock
import
colossalai
from
colossalai.moe
import
MOE_MANAGER
from
colossalai.testing.utils
import
spawn
tokens
,
n_experts
=
7
,
4
hidden_size
=
8
top_k
=
2
def
check_mixtral_moe_layer
():
torch
.
cuda
.
set_device
(
dist
.
get_rank
())
MOE_MANAGER
.
setup
(
parallel
=
"EP"
,
mode
=
"fixed"
,
fixed_dp_size
=
1
,
fixed_ep_size
=
dist
.
get_world_size
(),
fixed_pp_size
=
1
)
config
=
MixtralConfig
(
hidden_size
=
hidden_size
,
intermediate_size
=
hidden_size
*
2
,
num_local_experts
=
n_experts
,
num_experts_per_tok
=
top_k
,
)
torch
.
manual_seed
(
0
)
orig_model
=
MixtralSparseMoeBlock
(
config
).
cuda
()
x
=
torch
.
rand
(
1
,
tokens
,
hidden_size
,
requires_grad
=
True
).
cuda
()
orig_output
,
orig_logits
=
orig_model
(
x
)
model
=
deepcopy
(
orig_model
)
model
=
EPMixtralSparseMoeBlock
.
from_native_module
(
model
)
ep_output
,
ep_logits
=
model
(
x
)
assert_close
(
orig_logits
,
ep_logits
)
assert_close
(
orig_output
,
ep_output
)
orig_loss
=
orig_output
.
mean
()
orig_loss
.
backward
()
ep_loss
=
ep_output
.
mean
()
ep_loss
.
backward
()
assert_close
(
orig_loss
,
ep_loss
)
name_to_p
=
{
n
:
p
for
n
,
p
in
orig_model
.
named_parameters
()}
for
n
,
ep_p
in
model
.
named_parameters
():
p
=
name_to_p
[
n
]
if
ep_p
.
grad
is
not
None
:
assert_close
(
p
.
grad
,
ep_p
.
grad
)
def
run_dist
(
rank
:
int
,
world_size
:
int
,
port
:
int
):
colossalai
.
launch
({},
rank
,
world_size
,
"localhost"
,
port
)
check_mixtral_moe_layer
()
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
2
,
4
])
def
test_mixtral_moe_layer
(
world_size
:
int
):
spawn
(
run_dist
,
world_size
)
if
__name__
==
"__main__"
:
test_mixtral_moe_layer
(
2
)
applications/ColossalMoE/tests/test_moe_checkpoint.py
0 → 100644
View file @
efef43b5
from
copy
import
deepcopy
import
pytest
import
torch
import
torch.distributed
as
dist
from
colossal_moe.models.mixtral_checkpoint
import
MixtralMoEHybridParallelCheckpointIO
from
colossal_moe.models.mixtral_policy
import
MixtralForCausalLMPolicy
from
torch.optim
import
Adam
from
transformers.models.mixtral.configuration_mixtral
import
MixtralConfig
from
transformers.models.mixtral.modeling_mixtral
import
MixtralForCausalLM
import
colossalai
from
colossalai.booster
import
Booster
from
colossalai.booster.plugin.moe_hybrid_parallel_plugin
import
MoeHybridParallelPlugin
from
colossalai.testing.utils
import
spawn
tokens
,
n_experts
=
7
,
4
hidden_size
=
8
top_k
=
2
def
check_model_equal
(
model1
,
model2
):
assert
set
(
model1
.
state_dict
().
keys
())
==
set
(
model2
.
state_dict
().
keys
())
for
p1
,
p2
in
zip
(
model1
.
parameters
(),
model2
.
parameters
()):
assert
torch
.
equal
(
p1
.
half
(),
p2
.
half
())
def
get_optimizer_snapshot
(
optim
):
state
=
{
id
(
k
):
deepcopy
(
v
)
for
k
,
v
in
optim
.
state
.
items
()}
param_groups
=
[]
for
group
in
optim
.
param_groups
:
params
=
[
id
(
p
)
for
p
in
group
[
"params"
]]
new_group
=
{
"params"
:
params
}
for
k
,
v
in
group
.
items
():
if
k
!=
"params"
:
new_group
[
k
]
=
v
param_groups
.
append
(
new_group
)
return
{
"state"
:
state
,
"param_groups"
:
param_groups
,
}
def
check_optimizer_snapshot_equal
(
snapshot1
,
snapshot2
):
# check param_groups
assert
len
(
snapshot1
[
"param_groups"
])
==
len
(
snapshot2
[
"param_groups"
])
for
group1
,
group2
in
zip
(
snapshot1
[
"param_groups"
],
snapshot2
[
"param_groups"
]):
assert
set
(
group1
.
keys
())
==
set
(
group2
.
keys
())
for
k
in
group1
.
keys
():
assert
group1
[
k
]
==
group2
[
k
]
# check state
assert
set
(
snapshot1
[
"state"
].
keys
())
==
set
(
snapshot2
[
"state"
].
keys
()
),
f
"
{
snapshot1
[
'state'
].
keys
()
}
,
{
snapshot2
[
'state'
].
keys
()
}
"
for
pid
in
snapshot1
[
"state"
].
keys
():
state1
,
state2
=
snapshot1
[
"state"
][
pid
],
snapshot2
[
"state"
][
pid
]
assert
set
(
state1
.
keys
())
==
set
(
state2
.
keys
())
for
k
in
state1
.
keys
():
if
isinstance
(
state1
[
k
],
torch
.
Tensor
):
assert
torch
.
equal
(
state1
[
k
],
state2
[
k
]),
f
"
{
k
}
,
{
state1
[
k
]
}
,
{
state2
[
k
]
}
"
else
:
assert
state1
[
k
]
==
state2
[
k
]
def
check_mixtral_moe_layer
():
torch
.
cuda
.
set_device
(
dist
.
get_rank
())
config
=
MixtralConfig
(
hidden_size
=
hidden_size
,
intermediate_size
=
hidden_size
*
2
,
num_local_experts
=
n_experts
,
num_experts_per_tok
=
top_k
,
num_attention_heads
=
2
,
num_key_value_heads
=
2
,
)
torch
.
manual_seed
(
0
)
input_ids
=
torch
.
randint
(
0
,
100
,
(
2
,
tokens
)).
cuda
()
orig_model
=
MixtralForCausalLM
(
config
).
cuda
()
model
=
deepcopy
(
orig_model
)
optimizer
=
Adam
(
model
.
parameters
(),
lr
=
1e-3
)
plugin
=
MoeHybridParallelPlugin
(
tp_size
=
1
,
pp_size
=
2
,
ep_size
=
2
,
custom_policy
=
MixtralForCausalLMPolicy
(),
checkpoint_io
=
MixtralMoEHybridParallelCheckpointIO
,
microbatch_size
=
1
,
zero_stage
=
1
,
)
booster
=
Booster
(
plugin
=
plugin
)
model
,
optimizer
,
*
_
=
booster
.
boost
(
model
=
model
,
optimizer
=
optimizer
)
# initialize grads
data_iter
=
iter
(
[{
"input_ids"
:
input_ids
,
"attention_mask"
:
torch
.
ones_like
(
input_ids
),
"labels"
:
input_ids
.
clone
()}]
)
booster
.
execute_pipeline
(
data_iter
,
model
,
lambda
outputs
,
inputs
:
outputs
.
loss
,
optimizer
,
)
# check save model
booster
.
save_model
(
model
,
"mixtral_model"
,
shard
=
True
)
dist
.
barrier
()
if
dist
.
get_rank
()
==
0
:
saved_model
=
MixtralForCausalLM
.
from_pretrained
(
"mixtral_model"
).
cuda
()
check_model_equal
(
orig_model
,
saved_model
)
saved_model
.
save_pretrained
(
"mixtral_hf_model"
)
dist
.
barrier
()
# check load model
new_model
=
MixtralForCausalLM
(
config
).
cuda
()
new_optimizer
=
Adam
(
new_model
.
parameters
(),
lr
=
1e-3
)
new_model
,
new_optimizer
,
*
_
=
booster
.
boost
(
model
=
new_model
,
optimizer
=
new_optimizer
)
booster
.
load_model
(
new_model
,
"mixtral_hf_model"
)
check_model_equal
(
model
,
new_model
)
# check save optimizer
optimizer
.
step
()
for
group
in
optimizer
.
param_groups
:
group
[
"lr"
]
=
0.1
snapshot
=
get_optimizer_snapshot
(
optimizer
.
unwrap
())
booster
.
save_optimizer
(
optimizer
,
"mixtral_optim"
,
shard
=
True
)
dist
.
barrier
()
# reset optimizer state
for
state
in
optimizer
.
unwrap
().
state
.
values
():
for
v
in
state
.
values
():
if
isinstance
(
v
,
torch
.
Tensor
):
v
.
zero_
()
booster
.
load_optimizer
(
optimizer
,
"mixtral_optim"
)
loaded_snapshot
=
get_optimizer_snapshot
(
optimizer
.
unwrap
())
check_optimizer_snapshot_equal
(
snapshot
,
loaded_snapshot
)
def
run_dist
(
rank
:
int
,
world_size
:
int
,
port
:
int
):
colossalai
.
launch
({},
rank
,
world_size
,
"localhost"
,
port
)
check_mixtral_moe_layer
()
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
4
])
def
test_mixtral_moe_layer
(
world_size
:
int
):
spawn
(
run_dist
,
world_size
)
if
__name__
==
"__main__"
:
test_mixtral_moe_layer
(
4
)
applications/ColossalMoE/train.py
0 → 100644
View file @
efef43b5
import
argparse
import
torch
import
torch.distributed
as
dist
from
colossal_moe.models.mixtral_checkpoint
import
MixtralMoEHybridParallelCheckpointIO
from
colossal_moe.models.mixtral_policy
import
MixtralForCausalLMPolicy
from
colossal_moe.utils
import
load_checkpoint
,
move_to_cuda
,
save_checkpoint
from
torch.utils.data
import
Dataset
from
tqdm
import
tqdm
from
transformers
import
AutoTokenizer
from
transformers.models.mixtral
import
MixtralForCausalLM
import
colossalai
from
colossalai.booster
import
Booster
from
colossalai.booster.plugin.moe_hybrid_parallel_plugin
import
MoeHybridParallelPlugin
from
colossalai.cluster
import
DistCoordinator
from
colossalai.nn.lr_scheduler
import
CosineAnnealingWarmupLR
from
colossalai.nn.optimizer
import
HybridAdam
from
colossalai.utils
import
get_current_device
@
torch
.
no_grad
()
def
get_global_loss
(
loss
,
booster
):
global_loss
=
loss
.
clone
().
detach
()
dist
.
all_reduce
(
tensor
=
global_loss
,
op
=
dist
.
ReduceOp
.
SUM
,
group
=
booster
.
plugin
.
dp_group
)
global_loss
.
div_
(
booster
.
plugin
.
dp_size
)
return
global_loss
class
RandomDataset
(
Dataset
):
def
__init__
(
self
,
num_samples
:
int
=
1000
,
max_length
:
int
=
2048
,
vocab_size
:
int
=
100
,
tokenizer
=
None
):
self
.
num_samples
=
num_samples
self
.
max_length
=
max_length
self
.
input_ids
=
torch
.
randint
(
0
,
vocab_size
,
(
num_samples
,
max_length
),
device
=
get_current_device
())
self
.
attention_mask
=
torch
.
ones_like
(
self
.
input_ids
)
def
__len__
(
self
):
return
self
.
num_samples
def
__getitem__
(
self
,
idx
):
return
{
"input_ids"
:
self
.
input_ids
[
idx
],
"attention_mask"
:
self
.
attention_mask
[
idx
],
"labels"
:
self
.
input_ids
[
idx
],
}
def
parse_args
():
# basic settings
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--model_name"
,
type
=
str
,
default
=
"mistralai/Mixtral-8x7B-v0.1"
,
help
=
"Path to pretrained model or model identifier from huggingface.co/models."
,
)
parser
.
add_argument
(
"--load_checkpoint"
,
type
=
str
,
default
=
None
,
help
=
"Load checkpoint"
)
parser
.
add_argument
(
"--plugin"
,
type
=
str
,
default
=
"hybrid"
,
choices
=
[
"hybrid"
],
help
=
"Parallel methods."
,
)
parser
.
add_argument
(
"--output_path"
,
type
=
str
,
default
=
"./outputs"
,
help
=
"The path of your saved model after finetuning."
,
)
parser
.
add_argument
(
"--num_epoch"
,
type
=
int
,
default
=
1
,
help
=
"Number of epochs."
)
parser
.
add_argument
(
"--batch_size"
,
type
=
int
,
default
=
1
,
help
=
"Batch size (per dp group) for the training dataloader."
,
)
parser
.
add_argument
(
"--save_interval"
,
type
=
int
,
default
=
1000
,
help
=
" The interval (steps) of saving checkpoints."
,
)
parser
.
add_argument
(
"--precision"
,
type
=
str
,
default
=
"bf16"
,
choices
=
[
"fp32"
,
"bf16"
,
"fp16"
],
help
=
"The mixed precision training."
,
)
parser
.
add_argument
(
"--max_length"
,
type
=
int
,
default
=
2048
,
help
=
"Max sequence length."
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
42
,
help
=
"A seed for reproducible training."
)
# optim
parser
.
add_argument
(
"--lr"
,
type
=
float
,
default
=
1e-5
,
help
=
"Learning rate."
)
parser
.
add_argument
(
"--weight_decay"
,
type
=
float
,
default
=
0.0
,
help
=
"Weight decay to use."
)
# lr scheduler
parser
.
add_argument
(
"--num_epochs"
,
type
=
int
,
default
=
1
,
help
=
"Number of training epochs"
)
parser
.
add_argument
(
"--warmup_steps"
,
type
=
int
,
default
=
None
,
help
=
"Warmup steps"
)
# zero stage for all plugins
parser
.
add_argument
(
"--zero_stage"
,
type
=
int
,
default
=
2
,
help
=
"zero stage."
)
# hybrid plugin
parser
.
add_argument
(
"--pp_size"
,
type
=
int
,
default
=
2
,
help
=
"pp size for hybrid plugin"
)
parser
.
add_argument
(
"--dp_size"
,
type
=
int
,
default
=
1
,
help
=
"dp size for hybrid plugin"
)
parser
.
add_argument
(
"--ep_size"
,
type
=
int
,
default
=
2
,
help
=
"ep size for hybrid plugin"
)
parser
.
add_argument
(
"--microbatch_size"
,
type
=
int
,
default
=
1
,
help
=
"Microbatch size in pipeline for hybrid plugin"
)
# kernel
parser
.
add_argument
(
"--use_kernel"
,
action
=
"store_true"
,
help
=
"Use kernel optim. Need to install flash attention and triton to enable all kernel optimizations. Skip if not installed."
,
)
parser
.
add_argument
(
"--use_layernorm_kernel"
,
action
=
"store_true"
,
help
=
"Use layernorm kernel. Need to install apex. Raise error if not installed."
,
)
# load balance
parser
.
add_argument
(
"--load_balance"
,
action
=
"store_true"
,
help
=
"Expert load balance. Defaults to False. Recommend to enable."
)
parser
.
add_argument
(
"--load_balance_interval"
,
type
=
int
,
default
=
1000
,
help
=
"Expert load balance interval."
)
# communicate overlap
parser
.
add_argument
(
"--comm_overlap"
,
action
=
"store_true"
,
help
=
"Use communication overlap for MoE. Recommended to enable for muiti-node training."
,
)
# hierarchical all-to-all
parser
.
add_argument
(
"--hierarchical_alltoall"
,
action
=
"store_true"
,
help
=
"Use hierarchical all-to-all for MoE. Recommended to enable for muiti-node training."
,
)
args
=
parser
.
parse_args
()
return
args
def
main
():
args
=
parse_args
()
# Launch ColossalAI
colossalai
.
launch_from_torch
(
config
=
{},
seed
=
args
.
seed
)
coordinator
=
DistCoordinator
()
# Set plugin
if
args
.
plugin
==
"hybrid"
:
plugin
=
MoeHybridParallelPlugin
(
tp_size
=
1
,
pp_size
=
args
.
pp_size
,
ep_size
=
args
.
ep_size
,
microbatch_size
=
args
.
microbatch_size
,
custom_policy
=
MixtralForCausalLMPolicy
(),
enable_fused_normalization
=
args
.
use_layernorm_kernel
,
enable_jit_fused
=
args
.
use_kernel
,
precision
=
args
.
precision
,
zero_stage
=
args
.
zero_stage
,
checkpoint_io
=
MixtralMoEHybridParallelCheckpointIO
,
)
else
:
raise
ValueError
(
f
"Invalid plugin
{
args
.
plugin
}
"
)
coordinator
.
print_on_master
(
f
"Set plugin as
{
plugin
.
__class__
.
__name__
}
"
)
# Build Mixtral model
model
=
MixtralForCausalLM
.
from_pretrained
(
args
.
model_name
)
coordinator
.
print_on_master
(
f
"Finish init model"
)
# Enable gradient checkpointing
model
.
gradient_checkpointing_enable
()
# Prepare tokenizer and dataloader
tokenizer
=
AutoTokenizer
.
from_pretrained
(
args
.
model_name
)
dataset
=
RandomDataset
(
num_samples
=
100
,
tokenizer
=
tokenizer
)
collate_fn
=
None
dataloader
=
plugin
.
prepare_dataloader
(
dataset
,
batch_size
=
args
.
batch_size
,
shuffle
=
True
,
drop_last
=
True
,
collate_fn
=
collate_fn
)
# Set optimizer
optimizer
=
HybridAdam
(
model_params
=
model
.
parameters
(),
lr
=
args
.
lr
,
betas
=
(
0.9
,
0.95
),
weight_decay
=
args
.
weight_decay
,
adamw_mode
=
True
,
)
# Set lr scheduler
lr_scheduler
=
CosineAnnealingWarmupLR
(
optimizer
=
optimizer
,
total_steps
=
args
.
num_epochs
*
len
(
dataloader
),
warmup_steps
=
args
.
warmup_steps
if
args
.
warmup_steps
is
not
None
else
int
(
args
.
num_epochs
*
len
(
dataloader
)
*
0.025
),
eta_min
=
0.1
*
args
.
lr
,
)
# Set booster
booster
=
Booster
(
plugin
=
plugin
)
model
,
optimizer
,
_
,
dataloader
,
lr_scheduler
=
booster
.
boost
(
model
=
model
,
optimizer
=
optimizer
,
lr_scheduler
=
lr_scheduler
,
dataloader
=
dataloader
,
)
use_pipeline
=
isinstance
(
booster
.
plugin
,
MoeHybridParallelPlugin
)
and
booster
.
plugin
.
pp_size
>
1
is_pp_last_stage
=
use_pipeline
and
booster
.
plugin
.
stage_manager
.
is_last_stage
()
coordinator
.
print_on_master
(
f
"Finish init booster"
)
# Load ckpt
if
args
.
load_checkpoint
is
not
None
:
load_checkpoint
(
args
.
load_checkpoint
,
booster
,
model
,
optimizer
,
lr_scheduler
)
coordinator
.
print_on_master
(
f
"Finish load optimizer"
)
# Start finetuning
coordinator
.
print_on_master
(
f
"Start finetuning"
)
for
epoch
in
range
(
args
.
num_epoch
):
model
.
train
()
train_dataloader_iter
=
iter
(
dataloader
)
total_len
=
len
(
train_dataloader_iter
)
with
tqdm
(
range
(
total_len
),
desc
=
f
"Epoch [
{
epoch
+
1
}
/
{
args
.
num_epoch
}
]"
,
disable
=
not
coordinator
.
is_master
()
if
use_pipeline
==
False
else
not
is_pp_last_stage
,
)
as
pbar
:
for
step
in
pbar
:
if
use_pipeline
:
# Forward pass
outputs
=
booster
.
execute_pipeline
(
train_dataloader_iter
,
model
,
lambda
x
,
y
:
x
.
loss
,
optimizer
,
return_loss
=
True
,
return_outputs
=
True
,
)
# Backward and optimize
if
is_pp_last_stage
:
loss
=
outputs
[
"loss"
]
global_loss
=
get_global_loss
(
loss
,
booster
)
if
coordinator
.
_local_rank
==
"0"
:
pbar
.
set_postfix
({
"Loss"
:
global_loss
.
item
()})
else
:
# Forward pass
data
=
next
(
train_dataloader_iter
)
data
=
move_to_cuda
(
data
,
torch
.
cuda
.
current_device
())
outputs
=
model
(
**
data
)
loss
=
outputs
[
"loss"
]
# Backward
booster
.
backward
(
loss
,
optimizer
)
pbar
.
set_postfix
({
"loss"
:
loss
.
item
()})
optimizer
.
step
()
lr_scheduler
.
step
()
optimizer
.
zero_grad
()
# Apply load balance
# if (
# args.load_balance
# and args.load_balance_interval > 0
# and (step + 1) % args.load_balance_interval == 0
# ):
# coordinator.print_on_master(f"Apply load balance")
# apply_load_balance(model, optimizer)
# save ckeckpoint
if
(
step
+
1
)
%
args
.
save_interval
==
0
:
coordinator
.
print_on_master
(
f
"Saving model checkpoint to
{
args
.
output_path
}
"
)
save_checkpoint
(
args
.
output_path
,
booster
,
model
,
optimizer
,
lr_scheduler
,
epoch
,
step
,
args
.
batch_size
,
coordinator
,
)
# save checkpoint at the end of each epochs
booster
.
save_model
(
model
,
args
.
output_path
,
shard
=
True
,
size_per_shard
=
5120
)
coordinator
.
print_on_master
(
f
"Saving model checkpoint to
{
args
.
output_path
}
"
)
# Finish training
coordinator
.
print_on_master
(
f
"Finish training"
)
if
__name__
==
"__main__"
:
main
()
applications/ColossalMoE/train.sh
0 → 100644
View file @
efef43b5
NUM_GPU
=
8
MODEL
=
"mistralai/Mixtral-8x7B-v0.1"
SEQ_LENGTH
=
2048
BATCH_SIZE
=
1
LR
=
0.00001
# hybrid
# torchrun --standalone --nproc_per_node $NUM_GPU \
colossalai run
--nproc_per_node
$NUM_GPU
--hostfile
"hostfile"
\
train.py
\
--num_epoch
1
\
--model_name
$MODEL
\
--plugin
"hybrid"
\
--batch_size
$BATCH_SIZE
\
--lr
$LR
\
--zero_stage
1
\
--pp_size
2
\
--dp_size
1
\
--ep_size
8
\
applications/ColossalMoE/version.txt
0 → 100644
View file @
efef43b5
1.0.0
colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
View file @
efef43b5
...
@@ -22,7 +22,7 @@ from colossalai.booster.plugin.hybrid_parallel_plugin import (
...
@@ -22,7 +22,7 @@ from colossalai.booster.plugin.hybrid_parallel_plugin import (
)
)
from
colossalai.cluster
import
ProcessGroupMesh
from
colossalai.cluster
import
ProcessGroupMesh
from
colossalai.interface
import
ModelWrapper
,
OptimizerWrapper
from
colossalai.interface
import
ModelWrapper
,
OptimizerWrapper
from
colossalai.moe
import
MoECheckpintIO
from
colossalai.moe
import
MOE_MANAGER
,
MoECheckpintIO
from
colossalai.pipeline.schedule
import
OneForwardOneBackwardSchedule
from
colossalai.pipeline.schedule
import
OneForwardOneBackwardSchedule
from
colossalai.pipeline.stage_manager
import
PipelineStageManager
from
colossalai.pipeline.stage_manager
import
PipelineStageManager
from
colossalai.shardformer
import
ShardConfig
from
colossalai.shardformer
import
ShardConfig
...
@@ -150,6 +150,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
...
@@ -150,6 +150,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
self
,
self
,
tp_size
:
int
,
tp_size
:
int
,
pp_size
:
int
,
pp_size
:
int
,
ep_size
:
int
,
extra_dp_size
:
int
=
1
,
extra_dp_size
:
int
=
1
,
precision
:
str
=
"fp16"
,
precision
:
str
=
"fp16"
,
zero_stage
:
int
=
0
,
zero_stage
:
int
=
0
,
...
@@ -181,6 +182,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
...
@@ -181,6 +182,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
overlap_communication
:
bool
=
True
,
overlap_communication
:
bool
=
True
,
use_ep_inside
:
bool
=
True
,
use_ep_inside
:
bool
=
True
,
custom_policy
:
Policy
=
None
,
custom_policy
:
Policy
=
None
,
checkpoint_io
:
Optional
[
MoECheckpintIO
]
=
None
,
)
->
None
:
)
->
None
:
assert
(
assert
(
dist
.
get_world_size
()
%
(
tp_size
*
pp_size
)
==
0
dist
.
get_world_size
()
%
(
tp_size
*
pp_size
)
==
0
...
@@ -188,10 +190,26 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
...
@@ -188,10 +190,26 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
if
enable_sequence_parallelism
:
if
enable_sequence_parallelism
:
assert
tp_size
>
1
,
"Sequence parallelism must be enabled when using tensor parallelism"
assert
tp_size
>
1
,
"Sequence parallelism must be enabled when using tensor parallelism"
assert
(
dist
.
get_world_size
()
%
(
tp_size
*
pp_size
)
==
0
),
f
"world size
{
dist
.
get_world_size
()
}
is not divisible by tp_size
{
tp_size
}
* pp_size
{
pp_size
}
"
assert
(
dist
.
get_world_size
()
%
(
tp_size
*
pp_size
*
ep_size
)
==
0
),
f
"world size
{
dist
.
get_world_size
()
}
is not divisible by tp_size
{
tp_size
}
* pp_size
{
pp_size
}
* ep_size
{
ep_size
}
"
self
.
real_dp_size
=
dist
.
get_world_size
()
//
(
tp_size
*
pp_size
*
ep_size
)
MOE_MANAGER
.
setup
(
parallel
=
"EP"
,
mode
=
"fixed"
,
fixed_dp_size
=
self
.
real_dp_size
,
fixed_ep_size
=
ep_size
,
fixed_pp_size
=
pp_size
,
use_ep_inside
=
use_ep_inside
,
)
self
.
tp_size
=
tp_size
self
.
tp_size
=
tp_size
self
.
pp_size
=
pp_size
self
.
pp_size
=
pp_size
self
.
dp_size
=
dist
.
get_world_size
()
//
(
tp_size
*
pp_size
)
self
.
dp_size
=
dist
.
get_world_size
()
//
(
tp_size
*
pp_size
)
self
.
ep_size
=
ep_size
self
.
moe_info
=
MOE_MANAGER
.
get_info
(
0
)[
1
]
self
.
precision
=
precision
self
.
precision
=
precision
self
.
zero_stage
=
zero_stage
self
.
zero_stage
=
zero_stage
self
.
cpu_offload
=
cpu_offload
self
.
cpu_offload
=
cpu_offload
...
@@ -200,6 +218,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
...
@@ -200,6 +218,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
self
.
enable_flash_attention
=
enable_flash_attention
self
.
enable_flash_attention
=
enable_flash_attention
self
.
enable_jit_fused
=
enable_jit_fused
self
.
enable_jit_fused
=
enable_jit_fused
self
.
enable_sequence_parallelism
=
enable_sequence_parallelism
self
.
enable_sequence_parallelism
=
enable_sequence_parallelism
self
.
checkpoint_io
=
checkpoint_io
# we change pg mesh to (pp, dp, tp) for better moe performance
# we change pg mesh to (pp, dp, tp) for better moe performance
self
.
pg_mesh
=
ProcessGroupMesh
(
self
.
pp_size
,
self
.
dp_size
,
self
.
tp_size
)
self
.
pg_mesh
=
ProcessGroupMesh
(
self
.
pp_size
,
self
.
dp_size
,
self
.
tp_size
)
...
@@ -323,7 +342,10 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
...
@@ -323,7 +342,10 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
)
)
def
get_checkpoint_io
(
self
)
->
MoECheckpintIO
:
def
get_checkpoint_io
(
self
)
->
MoECheckpintIO
:
self
.
checkpoint_io
=
MoECheckpintIO
(
self
.
dp_group
,
self
.
pp_group
,
self
.
tp_group
,
self
.
zero_stage
)
if
self
.
checkpoint_io
is
None
:
self
.
checkpoint_io
=
MoECheckpintIO
(
self
.
dp_group
,
self
.
pp_group
,
self
.
tp_group
,
self
.
zero_stage
)
else
:
self
.
checkpoint_io
=
self
.
checkpoint_io
(
self
.
dp_group
,
self
.
pp_group
,
self
.
tp_group
,
self
.
zero_stage
)
return
self
.
checkpoint_io
return
self
.
checkpoint_io
def
configure
(
def
configure
(
...
...
colossalai/checkpoint_io/checkpoint_io_base.py
View file @
efef43b5
...
@@ -9,7 +9,7 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
...
@@ -9,7 +9,7 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from
colossalai.interface
import
ModelWrapper
from
colossalai.interface
import
ModelWrapper
from
.utils
import
has_index_file
from
.utils
import
SAFE_WEIGHTS_NAME
,
WEIGHTS_NAME
,
has_index_file
__all__
=
[
"CheckpointIO"
]
__all__
=
[
"CheckpointIO"
]
...
@@ -90,7 +90,15 @@ class CheckpointIO(ABC):
...
@@ -90,7 +90,15 @@ class CheckpointIO(ABC):
if
index_file_exists
:
if
index_file_exists
:
self
.
load_sharded_model
(
model
,
index_file_path
,
strict
)
self
.
load_sharded_model
(
model
,
index_file_path
,
strict
)
else
:
else
:
self
.
load_unsharded_model
(
model
,
checkpoint
,
strict
)
path
=
Path
(
checkpoint
,
SAFE_WEIGHTS_NAME
)
if
path
.
is_file
():
self
.
load_unsharded_model
(
model
,
str
(
path
),
strict
)
else
:
path
=
Path
(
checkpoint
,
WEIGHTS_NAME
)
if
path
.
is_file
():
self
.
load_unsharded_model
(
model
,
str
(
path
),
strict
)
else
:
self
.
load_unsharded_model
(
model
,
checkpoint
,
strict
)
return
origin_model
return
origin_model
...
...
colossalai/moe/__init__.py
View file @
efef43b5
from
.checkpoint
import
MoECheckpintIO
from
.checkpoint
import
MoECheckpintIO
from
.experts
import
MLPExperts
from
.experts
import
MLPExperts
from
.layers
import
SparseMLP
from
.layers
import
SparseMLP
,
apply_load_balance
from
.manager
import
MOE_MANAGER
from
.routers
import
MoeRouter
,
Top1Router
,
Top2Router
,
TopKRouter
from
.routers
import
MoeRouter
,
Top1Router
,
Top2Router
,
TopKRouter
from
.utils
import
NormalNoiseGenerator
,
UniformNoiseGenerator
from
.utils
import
NormalNoiseGenerator
,
UniformNoiseGenerator
...
@@ -14,4 +15,6 @@ __all__ = [
...
@@ -14,4 +15,6 @@ __all__ = [
"UniformNoiseGenerator"
,
"UniformNoiseGenerator"
,
"SparseMLP"
,
"SparseMLP"
,
"MoECheckpintIO"
,
"MoECheckpintIO"
,
"MOE_MANAGER"
,
"apply_load_balance"
,
]
]
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment