Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
65e5d6ba
Commit
65e5d6ba
authored
Feb 01, 2024
by
Hongxin Liu
Committed by
ver217
Feb 07, 2024
Browse files
[moe] fix mixtral optim checkpoint (#5344)
parent
956b561b
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
12 additions
and
2 deletions
+12
-2
applications/ColossalMoE/colossal_moe/models/mixtral_checkpoint.py
...ons/ColossalMoE/colossal_moe/models/mixtral_checkpoint.py
+10
-2
applications/ColossalMoE/tests/test_moe_checkpoint.py
applications/ColossalMoE/tests/test_moe_checkpoint.py
+2
-0
No files found.
applications/ColossalMoE/colossal_moe/models/mixtral_checkpoint.py
View file @
65e5d6ba
...
@@ -393,7 +393,11 @@ class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO):
...
@@ -393,7 +393,11 @@ class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO):
# Store param groups.
# Store param groups.
index_file
.
append_meta_data
(
"param_groups"
,
param_group_file
)
index_file
.
append_meta_data
(
"param_groups"
,
param_group_file
)
group_file_path
=
os
.
path
.
join
(
checkpoint
,
param_group_file
)
group_file_path
=
os
.
path
.
join
(
checkpoint
,
param_group_file
)
save_param_groups
(
optimizer
.
param_info
,
group_file_path
)
param_groups
=
[
{
**
group
,
"params"
:
group_info
[
"params"
]}
for
group
,
group_info
in
zip
(
optimizer
.
param_groups
,
optimizer
.
param_info
[
"param_groups"
])
]
save_param_groups
({
"param_groups"
:
param_groups
},
group_file_path
)
# Store index file.
# Store index file.
index_file
.
append_meta_data
(
"total_size"
,
total_size
)
index_file
.
append_meta_data
(
"total_size"
,
total_size
)
index_file
.
write_index_file
(
save_index_file
)
index_file
.
write_index_file
(
save_index_file
)
...
@@ -451,7 +455,11 @@ class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO):
...
@@ -451,7 +455,11 @@ class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO):
# Store param groups.
# Store param groups.
final_index_file
.
append_meta_data
(
"param_groups"
,
param_group_file
)
final_index_file
.
append_meta_data
(
"param_groups"
,
param_group_file
)
group_file_path
=
os
.
path
.
join
(
checkpoint
,
param_group_file
)
group_file_path
=
os
.
path
.
join
(
checkpoint
,
param_group_file
)
save_param_groups
(
optimizer
.
param_info
,
group_file_path
)
param_groups
=
[
{
**
group
,
"params"
:
group_info
[
"params"
]}
for
group
,
group_info
in
zip
(
optimizer
.
param_groups
,
optimizer
.
param_info
[
"param_groups"
])
]
save_param_groups
({
"param_groups"
:
param_groups
},
group_file_path
)
final_index_file
.
write_index_file
(
final_index_file_path
)
final_index_file
.
write_index_file
(
final_index_file_path
)
rmtree
(
tmp_index_file_folder
)
rmtree
(
tmp_index_file_folder
)
...
...
applications/ColossalMoE/tests/test_moe_checkpoint.py
View file @
65e5d6ba
...
@@ -117,6 +117,8 @@ def check_mixtral_moe_layer():
...
@@ -117,6 +117,8 @@ def check_mixtral_moe_layer():
# check save optimizer
# check save optimizer
optimizer
.
step
()
optimizer
.
step
()
for
group
in
optimizer
.
param_groups
:
group
[
"lr"
]
=
0.1
snapshot
=
get_optimizer_snapshot
(
optimizer
.
unwrap
())
snapshot
=
get_optimizer_snapshot
(
optimizer
.
unwrap
())
booster
.
save_optimizer
(
optimizer
,
"mixtral_optim"
,
shard
=
True
)
booster
.
save_optimizer
(
optimizer
,
"mixtral_optim"
,
shard
=
True
)
dist
.
barrier
()
dist
.
barrier
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment