Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
b60be18d
Commit
b60be18d
authored
Jan 27, 2024
by
Hongxin Liu
Committed by
ver217
Feb 07, 2024
Browse files
[moe] fix mixtral checkpoint io (#5314)
parent
da39d21b
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
8 additions
and
4 deletions
+8
-4
applications/ColossalMoE/colossal_moe/models/mixtral_checkpoint.py
...ons/ColossalMoE/colossal_moe/models/mixtral_checkpoint.py
+8
-4
No files found.
applications/ColossalMoE/colossal_moe/models/mixtral_checkpoint.py
View file @
b60be18d
...
@@ -135,6 +135,7 @@ class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO):
...
@@ -135,6 +135,7 @@ class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO):
Path
(
checkpoint
).
mkdir
(
parents
=
True
,
exist_ok
=
True
)
Path
(
checkpoint
).
mkdir
(
parents
=
True
,
exist_ok
=
True
)
if
self
.
real_dp_rank
!=
0
:
if
self
.
real_dp_rank
!=
0
:
dist
.
barrier
()
return
return
# ep_rank 0 saves all the parameters and buffers.
# ep_rank 0 saves all the parameters and buffers.
...
@@ -171,6 +172,7 @@ class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO):
...
@@ -171,6 +172,7 @@ class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO):
f
"index located at
{
save_index_file
}
."
f
"index located at
{
save_index_file
}
."
)
)
dist
.
barrier
()
else
:
else
:
# When pipeline is used, each stage produces its own shard files and index files.
# When pipeline is used, each stage produces its own shard files and index files.
# Index files belonging to each stage are saved under a temporary folder ./tmp_index_files/
# Index files belonging to each stage are saved under a temporary folder ./tmp_index_files/
...
@@ -201,10 +203,10 @@ class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO):
...
@@ -201,10 +203,10 @@ class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO):
index_file
.
append_meta_data
(
"total_size"
,
total_size
)
index_file
.
append_meta_data
(
"total_size"
,
total_size
)
index_file
.
write_index_file
(
save_index_file
)
index_file
.
write_index_file
(
save_index_file
)
else
:
else
:
dist
.
barrier
()
return
return
dist
.
barrier
(
self
.
pp_group
)
dist
.
barrier
()
dist
.
barrier
(
self
.
ep_group
)
# The global master rank integrates the index files and clean the folder.
# The global master rank integrates the index files and clean the folder.
if
self
.
coordinator
.
is_master
():
if
self
.
coordinator
.
is_master
():
...
@@ -360,6 +362,7 @@ class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO):
...
@@ -360,6 +362,7 @@ class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO):
# Devices along the same dp_group share the same copies of states when zero is not used.
# Devices along the same dp_group share the same copies of states when zero is not used.
# In this case only let the device with dp_rank == 0 save the model.
# In this case only let the device with dp_rank == 0 save the model.
if
not
self
.
use_zero
and
self
.
real_dp_rank
!=
0
:
if
not
self
.
use_zero
and
self
.
real_dp_rank
!=
0
:
dist
.
barrier
()
return
return
# Then collect the sharded states along dp_group(if using zero)/tp_group.
# Then collect the sharded states along dp_group(if using zero)/tp_group.
...
@@ -401,6 +404,7 @@ class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO):
...
@@ -401,6 +404,7 @@ class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO):
f
"index located at
{
save_index_file
}
."
f
"index located at
{
save_index_file
}
."
)
)
dist
.
barrier
()
else
:
else
:
# When pipeline is used, each stage produces its own shard files and index files.
# When pipeline is used, each stage produces its own shard files and index files.
# Index files belonging to each stage are saved under a temporary folder ./tmp_index_files/
# Index files belonging to each stage are saved under a temporary folder ./tmp_index_files/
...
@@ -428,10 +432,10 @@ class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO):
...
@@ -428,10 +432,10 @@ class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO):
index_file
.
append_meta_data
(
"total_size"
,
total_size
)
index_file
.
append_meta_data
(
"total_size"
,
total_size
)
index_file
.
write_index_file
(
save_index_file
)
index_file
.
write_index_file
(
save_index_file
)
else
:
else
:
dist
.
barrier
()
return
return
dist
.
barrier
(
self
.
pp_group
)
dist
.
barrier
()
dist
.
barrier
(
self
.
ep_group
)
# The global master rank integrates the index files and clean the folder.
# The global master rank integrates the index files and clean the folder.
if
self
.
coordinator
.
is_master
():
if
self
.
coordinator
.
is_master
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment