Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
71fe5276
Commit
71fe5276
authored
Jun 09, 2023
by
Frank Lee
Committed by
ver217
Jun 12, 2023
Browse files
[gemini] fixed the gemini checkpoint io (#3934)
parent
b3ab7fba
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
19 additions
and
11 deletions
+19
-11
colossalai/booster/plugin/gemini_plugin.py
colossalai/booster/plugin/gemini_plugin.py
+5
-2
colossalai/checkpoint_io/index_file.py
colossalai/checkpoint_io/index_file.py
+10
-8
colossalai/zero/gemini/gemini_ddp.py
colossalai/zero/gemini/gemini_ddp.py
+4
-1
No files found.
colossalai/booster/plugin/gemini_plugin.py
View file @
71fe5276
...
...
@@ -99,8 +99,11 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
save_state_dict
(
shard
,
checkpoint_file_path
,
use_safetensors
)
index_file
.
append_meta_data
(
"total_size"
,
total_size
)
index_file
.
write_index_file
(
save_index_file
)
logging
.
info
(
f
"The model is going to be split to checkpoint shards. "
# only save the index file on the master rank
if
self
.
coordinator
.
is_master
():
index_file
.
write_index_file
(
save_index_file
)
logging
.
info
(
f
"The model is split into checkpoint shards. "
f
"You can find where each parameters has been saved in the "
f
"index located at
{
save_index_file
}
."
)
...
...
colossalai/checkpoint_io/index_file.py
View file @
71fe5276
import
json
from
pathlib
import
Path
from
typing
import
Any
,
List
,
Union
import
os
import
json
from
collections
import
OrderedDict
from
pathlib
import
Path
from
typing
import
Any
,
Dict
,
List
,
Union
from
.utils
import
is_dtensor_checkpoint
...
...
@@ -22,8 +22,10 @@ class CheckpointIndexFile:
def
__init__
(
self
,
root_path
=
None
)
->
None
:
self
.
root_path
=
root_path
self
.
metadata
:
dict
=
dict
()
self
.
weight_map
:
dict
=
dict
()
# use ordered dict to preserve the tensor checkpoint order
self
.
metadata
:
Dict
=
OrderedDict
()
self
.
weight_map
:
Dict
=
OrderedDict
()
@
staticmethod
def
from_file
(
index_path
:
Union
[
str
,
Path
]):
...
...
@@ -150,13 +152,13 @@ class CheckpointIndexFile:
"""
ckpt_path
=
self
.
weight_map
[
param_name
]
return
ckpt_path
def
get_all_param_names
(
self
):
"""
Get all the weight keys.
"""
return
list
(
self
.
weight_map
.
keys
())
def
write_index_file
(
self
,
save_index_file
):
"""
Write index file.
...
...
@@ -164,5 +166,5 @@ class CheckpointIndexFile:
save_index_file
=
os
.
path
.
join
(
self
.
root_path
,
save_index_file
)
index
=
{
"metadata"
:
self
.
metadata
,
"weight_map"
:
self
.
weight_map
}
with
open
(
save_index_file
,
"w"
,
encoding
=
"utf-8"
)
as
f
:
content
=
json
.
dumps
(
index
,
indent
=
2
,
sort_keys
=
True
)
+
"
\n
"
content
=
json
.
dumps
(
index
,
indent
=
2
)
+
"
\n
"
f
.
write
(
content
)
colossalai/zero/gemini/gemini_ddp.py
View file @
71fe5276
...
...
@@ -716,7 +716,10 @@ class _StateDictSharder:
tensor_size
=
calculate_tensor_size
(
tensor
)
ret_block
=
None
ret_block_size
=
0
if
self
.
current_block_size
+
tensor_size
>
self
.
max_shard_size
:
# before we return the current block and create a new block,
# we need to ensure that the current block is not empty
if
self
.
current_block_size
+
tensor_size
>
self
.
max_shard_size
and
self
.
current_block_size
>
0
:
ret_block
=
self
.
current_block
ret_block_size
=
self
.
current_block_size
self
.
current_block
=
OrderedDict
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment