Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
c040d70a
Unverified
Commit
c040d70a
authored
Oct 31, 2023
by
Baizhou Zhang
Committed by
GitHub
Oct 31, 2023
Browse files
[hotfix] fix the bug of repeatedly storing param group (#4951)
parent
be82b5d4
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
10 additions
and
9 deletions
+10
-9
colossalai/booster/plugin/gemini_plugin.py
colossalai/booster/plugin/gemini_plugin.py
+6
-6
colossalai/booster/plugin/low_level_zero_plugin.py
colossalai/booster/plugin/low_level_zero_plugin.py
+4
-3
No files found.
colossalai/booster/plugin/gemini_plugin.py
View file @
c040d70a
...
...
@@ -150,24 +150,24 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
# Preparing file paths and index file.
states_name
,
save_index_file
,
param_group_file
=
get_optimizer_base_filenames
(
prefix
)
index_file
=
CheckpointIndexFile
(
checkpoint
)
index_file
.
append_meta_data
(
"param_groups"
,
param_group_file
)
# Store the information of param groups to param_group_file.
i
ndex_file
.
append_meta_data
(
"param_groups"
,
param_group_file
)
group_file_path
=
os
.
path
.
join
(
checkpoint
,
param_group_file
)
param_groups
=
optimizer
.
get_param_groups_for_saving
()
torch
.
save
(
param_groups
,
group_file_path
)
i
f
self
.
coordinator
.
is_master
():
group_file_path
=
os
.
path
.
join
(
checkpoint
,
param_group_file
)
param_groups
=
optimizer
.
get_param_groups_for_saving
()
torch
.
save
(
param_groups
,
group_file_path
)
# States are broken into shards within max_shard_size.
state_dict_shard
=
optimizer
.
state_shard
(
prefix
=
prefix
,
max_shard_size
=
size_per_shard
,
only_rank_0
=
True
)
# Save shards of optimizer states.
is_master
=
self
.
coordinator
.
is_master
()
total_size
=
save_state_dict_shards
(
sharded_state_dict
=
state_dict_shard
,
checkpoint
=
checkpoint
,
index_file
=
index_file
,
base_filename
=
states_name
,
is_master
=
is_master
,
is_master
=
self
.
coordinator
.
is_master
()
,
use_safetensors
=
False
,
)
...
...
colossalai/booster/plugin/low_level_zero_plugin.py
View file @
c040d70a
...
...
@@ -119,11 +119,12 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
# Preparing file paths and index file.
states_name
,
save_index_file
,
param_group_file
=
get_optimizer_base_filenames
(
prefix
)
index_file
=
CheckpointIndexFile
(
checkpoint
)
index_file
.
append_meta_data
(
"param_groups"
,
param_group_file
)
# Store the information of param groups to param_group_file.
i
ndex_file
.
append_meta_data
(
"param_groups"
,
param_group_file
)
group_file_path
=
os
.
path
.
join
(
checkpoint
,
param_group_file
)
save_param_groups
(
state_dict
,
group_file_path
)
i
f
self
.
coordinator
.
is_master
():
group_file_path
=
os
.
path
.
join
(
checkpoint
,
param_group_file
)
save_param_groups
(
state_dict
,
group_file_path
)
# Save shards of optimizer states.
total_size
=
0
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment