Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
c9cff7e7
Unverified
Commit
c9cff7e7
authored
Jun 15, 2023
by
Baizhou Zhang
Committed by
GitHub
Jun 15, 2023
Browse files
[checkpointio] General Checkpointing of Sharded Optimizers (#3984)
parent
8bcad736
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
399 additions
and
38 deletions
+399
-38
colossalai/booster/plugin/gemini_plugin.py
colossalai/booster/plugin/gemini_plugin.py
+3
-3
colossalai/booster/plugin/torch_ddp_plugin.py
colossalai/booster/plugin/torch_ddp_plugin.py
+13
-3
colossalai/booster/plugin/torch_fsdp_plugin.py
colossalai/booster/plugin/torch_fsdp_plugin.py
+5
-4
colossalai/checkpoint_io/checkpoint_io_base.py
colossalai/checkpoint_io/checkpoint_io_base.py
+6
-6
colossalai/checkpoint_io/general_checkpoint_io.py
colossalai/checkpoint_io/general_checkpoint_io.py
+83
-12
colossalai/checkpoint_io/index_file.py
colossalai/checkpoint_io/index_file.py
+13
-1
colossalai/checkpoint_io/utils.py
colossalai/checkpoint_io/utils.py
+177
-8
tests/test_checkpoint_io/test_general_checkpoint_io.py
tests/test_checkpoint_io/test_general_checkpoint_io.py
+99
-1
No files found.
colossalai/booster/plugin/gemini_plugin.py
View file @
c9cff7e7
...
@@ -12,7 +12,7 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
...
@@ -12,7 +12,7 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from
torch.utils.data
import
DataLoader
from
torch.utils.data
import
DataLoader
from
colossalai.checkpoint_io
import
CheckpointIndexFile
,
CheckpointIO
,
GeneralCheckpointIO
from
colossalai.checkpoint_io
import
CheckpointIndexFile
,
CheckpointIO
,
GeneralCheckpointIO
from
colossalai.checkpoint_io.utils
import
get_base_filenames
,
get_shard_filename
,
save_state_dict
from
colossalai.checkpoint_io.utils
import
get_
model_
base_filenames
,
get_shard_filename
,
save_state_dict
from
colossalai.cluster
import
DistCoordinator
from
colossalai.cluster
import
DistCoordinator
from
colossalai.interface
import
ModelWrapper
,
OptimizerWrapper
from
colossalai.interface
import
ModelWrapper
,
OptimizerWrapper
from
colossalai.utils
import
get_current_device
from
colossalai.utils
import
get_current_device
...
@@ -76,14 +76,14 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
...
@@ -76,14 +76,14 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
model
:
GeminiDDP
,
model
:
GeminiDDP
,
checkpoint_path
:
str
,
checkpoint_path
:
str
,
gather_dtensor
:
bool
=
False
,
gather_dtensor
:
bool
=
False
,
variant
:
Optional
[
str
]
=
None
,
prefix
:
Optional
[
str
]
=
None
,
max_shard_size
:
int
=
1024
,
max_shard_size
:
int
=
1024
,
use_safetensors
:
bool
=
False
):
use_safetensors
:
bool
=
False
):
"""
"""
Save sharded model
Save sharded model
"""
"""
state_dict_shard
=
model
.
state_dict_shard
(
max_shard_size
=
max_shard_size
,
only_rank_0
=
True
,
dtype
=
torch
.
float32
)
state_dict_shard
=
model
.
state_dict_shard
(
max_shard_size
=
max_shard_size
,
only_rank_0
=
True
,
dtype
=
torch
.
float32
)
weights_name
,
save_index_file
=
get_base_filenames
(
variant
,
use_safetensors
)
weights_name
,
save_index_file
=
get_
model_
base_filenames
(
prefix
,
use_safetensors
)
total_size
=
0
total_size
=
0
index_file
=
CheckpointIndexFile
(
checkpoint_path
)
index_file
=
CheckpointIndexFile
(
checkpoint_path
)
for
idx
,
shard_pair
in
enumerate
(
state_dict_shard
):
for
idx
,
shard_pair
in
enumerate
(
state_dict_shard
):
...
...
colossalai/booster/plugin/torch_ddp_plugin.py
View file @
c9cff7e7
...
@@ -32,7 +32,6 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
...
@@ -32,7 +32,6 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
"""
"""
Save model to checkpoint but only on master process.
Save model to checkpoint but only on master process.
"""
"""
# the model should be unwrapped in self.load_model via ModelWrapper.unwrap
if
self
.
coordinator
.
is_master
():
if
self
.
coordinator
.
is_master
():
super
().
save_unsharded_model
(
model
,
checkpoint
,
gather_dtensor
,
use_safetensors
)
super
().
save_unsharded_model
(
model
,
checkpoint
,
gather_dtensor
,
use_safetensors
)
...
@@ -54,11 +53,22 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
...
@@ -54,11 +53,22 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
model
:
nn
.
Module
,
model
:
nn
.
Module
,
checkpoint_path
:
str
,
checkpoint_path
:
str
,
gather_dtensor
:
bool
=
False
,
gather_dtensor
:
bool
=
False
,
variant
:
Optional
[
str
]
=
None
,
prefix
:
Optional
[
str
]
=
None
,
max_shard_size
:
int
=
1024
,
max_shard_size
:
int
=
1024
,
use_safetensors
:
bool
=
False
):
use_safetensors
:
bool
=
False
):
"""
Save model to checkpoint but only on master process.
"""
if
self
.
coordinator
.
is_master
():
super
().
save_sharded_model
(
model
,
checkpoint_path
,
gather_dtensor
,
prefix
,
max_shard_size
,
use_safetensors
)
def
save_sharded_optimier
(
self
,
optimizer
:
Optimizer
,
checkpoint
:
str
,
gather_dtensor
:
bool
,
prefix
:
str
,
size_per_shard
:
int
):
"""
Save optimizer to checkpoint but only on master process.
"""
if
self
.
coordinator
.
is_master
():
if
self
.
coordinator
.
is_master
():
super
().
save_sharded_
model
(
model
,
checkpoint
_path
,
gather_dtensor
,
variant
,
max_shard_size
,
use_safetensors
)
super
().
save_sharded_
optimizer
(
optimizer
,
checkpoint
,
gather_dtensor
,
prefix
,
size_per_shard
)
class
TorchDDPModel
(
ModelWrapper
):
class
TorchDDPModel
(
ModelWrapper
):
...
...
colossalai/booster/plugin/torch_fsdp_plugin.py
View file @
c9cff7e7
import
warnings
from
pathlib
import
Path
from
pathlib
import
Path
from
typing
import
Callable
,
Iterable
,
Iterator
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Callable
,
Iterable
,
Iterator
,
List
,
Optional
,
Tuple
,
Union
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
warnings
from
packaging
import
version
from
packaging
import
version
from
torch.distributed
import
ProcessGroup
from
torch.distributed
import
ProcessGroup
...
@@ -69,7 +69,7 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
...
@@ -69,7 +69,7 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
full_optimizer_state
=
FSDP
.
full_optim_state_dict
(
fsdp_model
,
optim
=
optimizer
,
rank0_only
=
True
)
full_optimizer_state
=
FSDP
.
full_optim_state_dict
(
fsdp_model
,
optim
=
optimizer
,
rank0_only
=
True
)
utils
.
save_state_dict
(
full_optimizer_state
,
checkpoint_file_path
=
checkpoint
,
use_safetensors
=
False
)
utils
.
save_state_dict
(
full_optimizer_state
,
checkpoint_file_path
=
checkpoint
,
use_safetensors
=
False
)
def
save_sharded_model
(
self
,
model
:
nn
.
Module
,
checkpoint
:
str
,
gather_dtensor
:
bool
,
variant
:
Optional
[
str
],
def
save_sharded_model
(
self
,
model
:
nn
.
Module
,
checkpoint
:
str
,
gather_dtensor
:
bool
,
prefix
:
Optional
[
str
],
size_per_shard
:
int
,
use_safetensors
:
bool
):
size_per_shard
:
int
,
use_safetensors
:
bool
):
"""
"""
Save model to checkpoint but only on master process.
Save model to checkpoint but only on master process.
...
@@ -87,13 +87,14 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
...
@@ -87,13 +87,14 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
"""
"""
raise
NotImplementedError
(
"Sharded model checkpoint is not supported yet."
)
raise
NotImplementedError
(
"Sharded model checkpoint is not supported yet."
)
def
save_sharded_optimizer
(
self
,
optimizer
:
Optimizer
,
checkpoint
:
str
,
gather_dtensor
:
bool
):
def
save_sharded_optimizer
(
self
,
optimizer
:
Optimizer
,
checkpoint
:
str
,
gather_dtensor
:
bool
,
prefix
:
str
,
size_per_shard
:
int
):
"""
"""
Save optimizer to checkpoint but only on master process.
Save optimizer to checkpoint but only on master process.
"""
"""
raise
NotImplementedError
(
"Sharded optimizer checkpoint is not supported yet."
)
raise
NotImplementedError
(
"Sharded optimizer checkpoint is not supported yet."
)
def
load_sharded_optimizer
(
self
,
optimizer
:
Optimizer
,
index_file_path
:
str
,
prefix
:
str
,
size_per_shard
:
int
):
def
load_sharded_optimizer
(
self
,
optimizer
:
Optimizer
,
index_file_path
:
str
,
size_per_shard
:
int
):
"""
"""
Load optimizer to checkpoint but only on master process.
Load optimizer to checkpoint but only on master process.
"""
"""
...
...
colossalai/checkpoint_io/checkpoint_io_base.py
View file @
c9cff7e7
...
@@ -103,7 +103,7 @@ class CheckpointIO(ABC):
...
@@ -103,7 +103,7 @@ class CheckpointIO(ABC):
checkpoint
:
str
,
checkpoint
:
str
,
shard
:
bool
=
False
,
shard
:
bool
=
False
,
gather_dtensor
:
bool
=
True
,
gather_dtensor
:
bool
=
True
,
variant
:
str
=
None
,
prefix
:
str
=
None
,
size_per_shard
:
int
=
1024
,
size_per_shard
:
int
=
1024
,
use_safetensors
:
bool
=
False
):
use_safetensors
:
bool
=
False
):
"""
"""
...
@@ -128,7 +128,7 @@ class CheckpointIO(ABC):
...
@@ -128,7 +128,7 @@ class CheckpointIO(ABC):
multiple files. The model shards will be specified by a `model.index.json` file. When shard = True, please ensure
multiple files. The model shards will be specified by a `model.index.json` file. When shard = True, please ensure
that the checkpoint path is a directory path instead of a file path.
that the checkpoint path is a directory path instead of a file path.
gather_dtensor (bool): whether to gather the distributed tensor to the first device. Default: True.
gather_dtensor (bool): whether to gather the distributed tensor to the first device. Default: True.
variant
(str): If specified, weights are saved in the format pytorch_model.<
variant
>.bin. Default: None.
prefix
(str): If specified, weights are saved in the format pytorch_model.<
prefix
>.bin. Default: None.
size_per_shard (int): size per shard in MB. Default: 1024. This value is only used when shard = True.
size_per_shard (int): size per shard in MB. Default: 1024. This value is only used when shard = True.
use_safetensors (bool): whether to use safe tensors. Default: False. If set to True, the checkpoint will be saved
use_safetensors (bool): whether to use safe tensors. Default: False. If set to True, the checkpoint will be saved
"""
"""
...
@@ -137,11 +137,11 @@ class CheckpointIO(ABC):
...
@@ -137,11 +137,11 @@ class CheckpointIO(ABC):
model
=
model
.
unwrap
()
model
=
model
.
unwrap
()
if
shard
:
if
shard
:
self
.
save_sharded_model
(
model
,
checkpoint
,
gather_dtensor
,
variant
,
size_per_shard
,
use_safetensors
)
self
.
save_sharded_model
(
model
,
checkpoint
,
gather_dtensor
,
prefix
,
size_per_shard
,
use_safetensors
)
else
:
else
:
self
.
save_unsharded_model
(
model
,
checkpoint
,
gather_dtensor
,
use_safetensors
)
self
.
save_unsharded_model
(
model
,
checkpoint
,
gather_dtensor
,
use_safetensors
)
def
load_optimizer
(
self
,
optimizer
:
Optimizer
,
checkpoint
:
str
):
def
load_optimizer
(
self
,
optimizer
:
Optimizer
,
checkpoint
:
str
,
prefix
:
str
=
None
,
size_per_shard
:
int
=
1024
):
"""
"""
Load optimizer from checkpoint.
Load optimizer from checkpoint.
...
@@ -157,7 +157,7 @@ class CheckpointIO(ABC):
...
@@ -157,7 +157,7 @@ class CheckpointIO(ABC):
if
index_file_exists
:
if
index_file_exists
:
# the existence of index file means it is a sharded checkpoint
# the existence of index file means it is a sharded checkpoint
self
.
load_sharded_optimizer
(
optimizer
,
index_file_path
)
self
.
load_sharded_optimizer
(
optimizer
,
index_file_path
,
prefix
,
size_per_shard
)
else
:
else
:
self
.
load_unsharded_optimizer
(
optimizer
,
checkpoint
)
self
.
load_unsharded_optimizer
(
optimizer
,
checkpoint
)
...
@@ -218,7 +218,7 @@ class CheckpointIO(ABC):
...
@@ -218,7 +218,7 @@ class CheckpointIO(ABC):
pass
pass
@
abstractmethod
@
abstractmethod
def
save_sharded_model
(
self
,
model
:
nn
.
Module
,
checkpoint
:
str
,
gather_dtensor
:
bool
,
variant
:
Optional
[
str
],
def
save_sharded_model
(
self
,
model
:
nn
.
Module
,
checkpoint
:
str
,
gather_dtensor
:
bool
,
prefix
:
Optional
[
str
],
size_per_shard
:
int
,
use_safetensors
:
bool
):
size_per_shard
:
int
,
use_safetensors
:
bool
):
"""
"""
Save model to sharded checkpoint.
Save model to sharded checkpoint.
...
...
colossalai/checkpoint_io/general_checkpoint_io.py
View file @
c9cff7e7
...
@@ -11,15 +11,21 @@ from torch.optim import Optimizer
...
@@ -11,15 +11,21 @@ from torch.optim import Optimizer
from
.checkpoint_io_base
import
CheckpointIO
from
.checkpoint_io_base
import
CheckpointIO
from
.index_file
import
CheckpointIndexFile
from
.index_file
import
CheckpointIndexFile
from
.utils
import
(
from
.utils
import
(
get_base_filenames
,
get_model_base_filenames
,
get_optimizer_base_filenames
,
get_shard_filename
,
get_shard_filename
,
has_index_file
,
has_index_file
,
is_safetensors_available
,
is_safetensors_available
,
load_param_groups_into_optimizer
,
load_shard_state_dict
,
load_shard_state_dict
,
load_state_dict
,
load_state_dict
,
load_state_dict_into_model
,
load_state_dict_into_model
,
load_states_into_optimizer
,
save_param_groups
,
save_state_dict
,
save_state_dict
,
shard_checkpoint
,
shard_model_checkpoint
,
shard_optimizer_checkpoint
,
sharded_optimizer_loading_epilogue
,
)
)
__all__
=
[
'GeneralCheckpointIO'
]
__all__
=
[
'GeneralCheckpointIO'
]
...
@@ -44,12 +50,30 @@ class GeneralCheckpointIO(CheckpointIO):
...
@@ -44,12 +50,30 @@ class GeneralCheckpointIO(CheckpointIO):
# save the checkpoint
# save the checkpoint
save_state_dict
(
state_dict
,
checkpoint
,
use_safetensors
)
save_state_dict
(
state_dict
,
checkpoint
,
use_safetensors
)
def
load_sharded_optimizer
(
self
,
optimizer
:
Optimizer
,
checkpoint
:
Path
,
prefix
:
str
,
size_per_shard
:
int
):
def
load_sharded_optimizer
(
self
,
optimizer
:
Optimizer
,
index_file_path
:
str
,
prefix
:
str
,
size_per_shard
:
int
):
raise
NotImplementedError
(
"Sharded optimizer checkpoint is not supported yet."
)
"""
Load sharded optimizer with the given path to index file.
"""
optimizer
.
load_state_dict
# Read checkpoint index file.
ckpt_index_file
=
CheckpointIndexFile
.
from_file
(
index_file_path
)
def
load_unsharded_optimizer
(
self
,
optimizer
:
Optimizer
,
checkpoint
:
Path
):
# Load param_groups
checkpoint
=
load_state_dict
(
checkpoint
)
param_group_path
=
ckpt_index_file
.
get_param_group_filename
()
optimizer
.
load_state_dict
(
checkpoint
)
if
param_group_path
is
None
:
raise
RuntimeError
(
f
'Invalid index file path
{
index_file_path
}
for an optimizer.
\
Lacking param group file under current directory.'
)
id_map
=
load_param_groups_into_optimizer
(
optimizer
,
param_group_path
)
checkpoint_files
,
_
=
ckpt_index_file
.
get_checkpoint_filenames
()
for
shard_file
in
checkpoint_files
:
state_dict
=
load_shard_state_dict
(
Path
(
shard_file
),
use_safetensors
=
False
)
load_states_into_optimizer
(
optimizer
,
state_dict
,
id_map
)
del
state_dict
gc
.
collect
()
sharded_optimizer_loading_epilogue
(
optimizer
)
def
save_sharded_optimizer
(
def
save_sharded_optimizer
(
self
,
self
,
...
@@ -59,7 +83,54 @@ class GeneralCheckpointIO(CheckpointIO):
...
@@ -59,7 +83,54 @@ class GeneralCheckpointIO(CheckpointIO):
prefix
:
str
,
prefix
:
str
,
size_per_shard
:
int
,
size_per_shard
:
int
,
):
):
raise
NotImplementedError
(
"Sharded optimizer checkpoint is not supported yet."
)
"""
Save sharded optimizer checkpoint under the given checkpointing path.
The following files will be created under the path:
- An index file (pytorch_optim.bin.index.json) containing a map between optimizer states and file names
- A group file (pytorch_optim_group.bin) recording information of param_groups
- Multiple files (pytorch_optim-000XX.bin) that store state tensors of optimizer in a sharding way
"""
if
os
.
path
.
isfile
(
checkpoint
):
logging
.
error
(
f
"Provided path (
{
checkpoint
}
) should be a directory, not a file"
)
return
Path
(
checkpoint
).
mkdir
(
parents
=
True
,
exist_ok
=
True
)
# Offload optimizer states. States are broken into shards within max_shard_size.
state_dict
=
optimizer
.
state_dict
()
sharded_state
=
shard_optimizer_checkpoint
(
state_dict
,
max_shard_size
=
size_per_shard
)
# Preparing file paths and index file.
states_name
,
save_index_file
,
param_group_file
=
get_optimizer_base_filenames
(
prefix
)
index_file
=
CheckpointIndexFile
(
checkpoint
)
# Store the information of param groups to param_group_file.
index_file
.
append_meta_data
(
"param_groups"
,
param_group_file
)
group_file_path
=
os
.
path
.
join
(
checkpoint
,
param_group_file
)
save_param_groups
(
state_dict
,
group_file_path
)
# Save shards of optimizer states.
total_size
=
0
for
idx
,
shard_pair
in
enumerate
(
sharded_state
):
shard
,
current_size
=
shard_pair
shard_file
=
get_shard_filename
(
states_name
,
idx
)
total_size
=
total_size
+
current_size
for
param_id
in
shard
.
keys
():
index_file
.
append_weight_map
(
str
(
param_id
),
shard_file
)
checkpoint_file_path
=
os
.
path
.
join
(
checkpoint
,
shard_file
)
save_state_dict
(
shard
,
checkpoint_file_path
,
use_safetensors
=
False
)
# Wrap up index file.
index_file
.
append_meta_data
(
"total_size"
,
total_size
)
index_file
.
write_index_file
(
save_index_file
)
logging
.
info
(
f
"The optimizer is going to be split to checkpoint shards. "
f
"You can find where each parameters has been saved in the "
f
"index located at
{
save_index_file
}
."
)
def
load_unsharded_optimizer
(
self
,
optimizer
:
Optimizer
,
checkpoint
:
Path
):
checkpoint
=
load_state_dict
(
checkpoint
)
optimizer
.
load_state_dict
(
checkpoint
)
def
save_unsharded_optimizer
(
def
save_unsharded_optimizer
(
self
,
self
,
...
@@ -74,7 +145,7 @@ class GeneralCheckpointIO(CheckpointIO):
...
@@ -74,7 +145,7 @@ class GeneralCheckpointIO(CheckpointIO):
model
:
nn
.
Module
,
model
:
nn
.
Module
,
checkpoint_path
:
str
,
checkpoint_path
:
str
,
gather_dtensor
:
bool
=
False
,
gather_dtensor
:
bool
=
False
,
variant
:
Optional
[
str
]
=
None
,
prefix
:
Optional
[
str
]
=
None
,
max_shard_size
:
int
=
1024
,
max_shard_size
:
int
=
1024
,
use_safetensors
:
bool
=
False
):
use_safetensors
:
bool
=
False
):
"""
"""
...
@@ -89,9 +160,9 @@ class GeneralCheckpointIO(CheckpointIO):
...
@@ -89,9 +160,9 @@ class GeneralCheckpointIO(CheckpointIO):
# shard checkpoint
# shard checkpoint
state_dict
=
model
.
state_dict
()
state_dict
=
model
.
state_dict
()
state_dict_shard
=
shard_checkpoint
(
state_dict
,
max_shard_size
=
max_shard_size
)
state_dict_shard
=
shard_
model_
checkpoint
(
state_dict
,
max_shard_size
=
max_shard_size
)
weights_name
,
save_index_file
=
get_base_filenames
(
variant
,
use_safetensors
)
weights_name
,
save_index_file
=
get_
model_
base_filenames
(
prefix
,
use_safetensors
)
total_size
=
0
total_size
=
0
index_file
=
CheckpointIndexFile
(
checkpoint_path
)
index_file
=
CheckpointIndexFile
(
checkpoint_path
)
for
idx
,
shard_pair
in
enumerate
(
state_dict_shard
):
for
idx
,
shard_pair
in
enumerate
(
state_dict_shard
):
...
@@ -128,7 +199,7 @@ class GeneralCheckpointIO(CheckpointIO):
...
@@ -128,7 +199,7 @@ class GeneralCheckpointIO(CheckpointIO):
# read checkpoint index file
# read checkpoint index file
ckpt_index_file
=
CheckpointIndexFile
.
from_file
(
checkpoint_index_file
)
ckpt_index_file
=
CheckpointIndexFile
.
from_file
(
checkpoint_index_file
)
checkpoint_files
,
_
=
ckpt_index_file
.
get_checkpoint_file
a
names
()
checkpoint_files
,
_
=
ckpt_index_file
.
get_checkpoint_filenames
()
missing_keys
=
[]
missing_keys
=
[]
for
shard_file
in
checkpoint_files
:
for
shard_file
in
checkpoint_files
:
...
...
colossalai/checkpoint_io/index_file.py
View file @
c9cff7e7
...
@@ -111,7 +111,7 @@ class CheckpointIndexFile:
...
@@ -111,7 +111,7 @@ class CheckpointIndexFile:
return
True
return
True
return
False
return
False
def
get_checkpoint_file
a
names
(
self
)
->
List
[
str
]:
def
get_checkpoint_filenames
(
self
)
->
List
[
str
]:
"""
"""
Get the set of checkpoint filenames in the weight map.
Get the set of checkpoint filenames in the weight map.
...
@@ -159,6 +159,18 @@ class CheckpointIndexFile:
...
@@ -159,6 +159,18 @@ class CheckpointIndexFile:
"""
"""
return
list
(
self
.
weight_map
.
keys
())
return
list
(
self
.
weight_map
.
keys
())
def
get_param_group_filename
(
self
)
->
Union
[
str
,
None
]:
"""
Get the file name of param_group file if this is a checkpoint for optimizer.
Returns:
str: param_group file name
"""
filename
=
self
.
metadata
.
get
(
"param_groups"
,
None
)
if
filename
:
return
str
(
self
.
root_path
.
joinpath
(
filename
))
else
:
return
None
def
write_index_file
(
self
,
save_index_file
):
def
write_index_file
(
self
,
save_index_file
):
"""
"""
Write index file.
Write index file.
...
...
colossalai/checkpoint_io/utils.py
View file @
c9cff7e7
# coding=utf-8
# coding=utf-8
import
re
import
re
from
collections
import
abc
as
container_abcs
from
collections
import
defaultdict
from
itertools
import
chain
from
pathlib
import
Path
from
pathlib
import
Path
from
typing
import
Iterator
,
List
,
Mapping
,
Optional
,
OrderedDict
,
Tuple
from
typing
import
Iterator
,
List
,
Mapping
,
Optional
,
OrderedDict
,
Tuple
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
torch.optim
import
Optimizer
from
colossalai.tensor.d_tensor.d_tensor
import
DTensor
from
colossalai.tensor.d_tensor.d_tensor
import
DTensor
SAFE_WEIGHTS_NAME
=
"model.safetensors"
SAFE_WEIGHTS_NAME
=
"model.safetensors"
WEIGHTS_NAME
=
"pytorch_model.bin"
WEIGHTS_NAME
=
"pytorch_model.bin"
STATES_NAME
=
"pytorch_optim.bin"
SAFE_WEIGHTS_INDEX_NAME
=
"model.safetensors.index.json"
SAFE_WEIGHTS_INDEX_NAME
=
"model.safetensors.index.json"
WEIGHTS_INDEX_NAME
=
"pytorch_model.bin.index.json"
WEIGHTS_INDEX_NAME
=
"pytorch_model.bin.index.json"
STATES_INDEX_NAME
=
"pytorch_optim.bin.index.json"
GROUP_FILE_NAME
=
"pytorch_optim_group.bin"
# ======================================
# ======================================
# General helper functions
# General helper functions
...
@@ -81,7 +88,7 @@ def is_safetensor_checkpoint(checkpoint_file_path: str) -> bool:
...
@@ -81,7 +88,7 @@ def is_safetensor_checkpoint(checkpoint_file_path: str) -> bool:
# ======================================
# ======================================
# Helper functions for saving shard file
# Helper functions for saving shard file
# ======================================
# ======================================
def
shard_checkpoint
(
state_dict
:
torch
.
Tensor
,
max_shard_size
:
int
=
1024
)
->
Iterator
[
Tuple
[
OrderedDict
,
int
]]:
def
shard_
model_
checkpoint
(
state_dict
:
torch
.
Tensor
,
max_shard_size
:
int
=
1024
)
->
Iterator
[
Tuple
[
OrderedDict
,
int
]]:
"""
"""
Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
given size.
given size.
...
@@ -110,6 +117,50 @@ def shard_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024) -> It
...
@@ -110,6 +117,50 @@ def shard_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024) -> It
yield
current_block
,
current_block_size
yield
current_block
,
current_block_size
def
shard_optimizer_checkpoint
(
state_dict
:
dict
,
max_shard_size
:
int
=
1024
)
->
Iterator
[
Tuple
[
OrderedDict
,
int
]]:
"""
Splits an optimizer state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
given size.
"""
# Only split state_dict['state']; state_dict['param_group'] is not considered in this function.
states
=
state_dict
[
'state'
]
current_block
=
{}
current_block_size
=
0
for
param_id
,
state
in
states
.
items
():
ret_block
=
None
ret_block_size
=
0
# A state might contain more than one tensors.
# e.g. each Adam state includes: 'step', 'exp_avg', 'exp_avg_sq'
state_size
=
0
isDTensor
=
False
for
state_tensor
in
state
.
values
():
# If the states are stored as DTensors, mark isDTensor as true.
if
type
(
state_tensor
)
==
DTensor
:
isDTensor
=
True
state_size
+=
calculate_tensor_size
(
state_tensor
)
if
not
isDTensor
:
if
current_block_size
+
state_size
>
max_shard_size
:
ret_block
=
current_block
ret_block_size
=
current_block_size
current_block
=
{}
current_block_size
=
0
current_block
[
param_id
]
=
state
current_block_size
+=
state_size
if
ret_block
!=
None
:
yield
ret_block
,
ret_block_size
yield
current_block
,
current_block_size
def
load_shard_state_dict
(
checkpoint_file
:
Path
,
use_safetensors
:
bool
=
False
):
def
load_shard_state_dict
(
checkpoint_file
:
Path
,
use_safetensors
:
bool
=
False
):
"""
"""
load shard state dict into model
load shard state dict into model
...
@@ -179,6 +230,96 @@ def load_state_dict_into_model(model: nn.Module,
...
@@ -179,6 +230,96 @@ def load_state_dict_into_model(model: nn.Module,
model
.
__class__
.
__name__
,
"
\n\t
"
.
join
(
error_msgs
)))
model
.
__class__
.
__name__
,
"
\n\t
"
.
join
(
error_msgs
)))
def
load_param_groups_into_optimizer
(
optimizer
:
Optimizer
,
param_group_path
:
str
)
->
dict
:
"""
Load information of param_groups into an initialized optimizer.
"""
# Load list of param_groups from given file path.
# The params in saved_groups are in the form of integer indices.
saved_groups
=
torch
.
load
(
param_group_path
)
if
not
isinstance
(
saved_groups
,
List
):
raise
ValueError
(
f
'The param_groups saved at
{
param_group_path
}
is not of List type'
)
# The params in param_groups are in the form of pytorch tensors.
# For more details, please view source code of Optimizer class in pytorch.
param_groups
=
optimizer
.
param_groups
# Check the compatibility of saved_groups and param_groups.
if
len
(
param_groups
)
!=
len
(
saved_groups
):
raise
ValueError
(
"loaded state dict has a different number of original parameter groups"
)
param_lens
=
(
len
(
g
[
'params'
])
for
g
in
param_groups
)
saved_lens
=
(
len
(
g
[
'params'
])
for
g
in
saved_groups
)
if
any
(
p_len
!=
s_len
for
p_len
,
s_len
in
zip
(
param_lens
,
saved_lens
)):
raise
ValueError
(
"loaded state dict contains a parameter group "
"that doesn't match the size of optimizer's group"
)
# Creating mapping from id to parameters.
id_map
=
{
old_id
:
p
for
old_id
,
p
in
zip
(
chain
.
from_iterable
((
g
[
'params'
]
for
g
in
saved_groups
)),
chain
.
from_iterable
((
g
[
'params'
]
for
g
in
param_groups
)))
}
# Update parameter groups, setting their 'params' value.
def
update_group
(
group
,
new_group
):
new_group
[
'params'
]
=
group
[
'params'
]
return
new_group
updated_groups
=
[
update_group
(
g
,
ng
)
for
g
,
ng
in
zip
(
param_groups
,
saved_groups
)]
optimizer
.
__dict__
.
update
({
'param_groups'
:
updated_groups
})
return
id_map
def
load_states_into_optimizer
(
optimzier
:
Optimizer
,
state_dict
:
dict
,
id_map
:
dict
):
r
"""Copies states from `state_dict` into an Optimizer object.
Args:
optimizer(Optimizer): An initialized Optimizer object to be loaded
state_dict(dict): a mapping from tensor index (an integer)
to its states to be loaded (a mapping from state name to a tensor).
id_map(dict): a mapping from tensor index (an integer)
to its corresponding parameter (a tensor) whose states will be updated.
"""
def
cast
(
param
,
value
,
key
=
None
):
r
"""Make a deep copy of value, casting all tensors to device of param."""
if
isinstance
(
value
,
torch
.
Tensor
):
# Floating-point types are a bit special here. They are the only ones
# that are assumed to always match the type of params.
# Make sure state['step'] is not casted https://github.com/pytorch/pytorch/issues/74424
if
(
key
!=
"step"
):
if
param
.
is_floating_point
():
value
=
value
.
to
(
param
.
dtype
)
value
=
value
.
to
(
param
.
device
)
return
value
elif
isinstance
(
value
,
dict
):
return
{
k
:
cast
(
param
,
v
,
key
=
k
)
for
k
,
v
in
value
.
items
()}
elif
isinstance
(
value
,
container_abcs
.
Iterable
):
return
type
(
value
)(
cast
(
param
,
v
)
for
v
in
value
)
else
:
return
value
# Copy state assigned to params (and cast tensors to appropriate types).
# State that is not assigned to params is copied as is (needed for
# backward compatibility).
new_states
=
defaultdict
(
dict
)
for
k
,
v
in
state_dict
.
items
():
if
k
in
id_map
:
param
=
id_map
[
k
]
new_states
[
param
]
=
cast
(
param
,
v
)
else
:
new_states
[
k
]
=
v
optimzier
.
state
.
update
(
new_states
)
def
sharded_optimizer_loading_epilogue
(
optimizer
:
Optimizer
):
# Do the cleaning up as in src code of Pytorch.
optimizer
.
_hook_for_profile
()
# To support multiprocessing pickle/unpickle.
optimizer
.
defaults
.
setdefault
(
'differentiable'
,
False
)
# ======================================
# ======================================
# Helper functions for saving state dict
# Helper functions for saving state dict
# ======================================
# ======================================
...
@@ -203,6 +344,18 @@ def save_state_dict(state_dict: dict, checkpoint_file_path: str, use_safetensors
...
@@ -203,6 +344,18 @@ def save_state_dict(state_dict: dict, checkpoint_file_path: str, use_safetensors
torch
.
save
(
state_dict
,
checkpoint_file_path
)
torch
.
save
(
state_dict
,
checkpoint_file_path
)
def
save_param_groups
(
state_dict
:
dict
,
group_file_path
:
str
)
->
None
:
"""
Save information of param_groups to given file path.
Args:
state_dict (dict): state dict.
group_file_path (str): path to the group file.
"""
param_groups
=
state_dict
[
"param_groups"
]
torch
.
save
(
param_groups
,
group_file_path
)
def
save_dtensor
(
name
:
str
,
tensor
:
torch
.
Tensor
,
index_file
:
"CheckpointIndexFile"
,
use_safetensors
:
bool
)
->
None
:
def
save_dtensor
(
name
:
str
,
tensor
:
torch
.
Tensor
,
index_file
:
"CheckpointIndexFile"
,
use_safetensors
:
bool
)
->
None
:
"""
"""
Save distributed tensor to checkpoint. This checkpoint will be a dictionary which contains
Save distributed tensor to checkpoint. This checkpoint will be a dictionary which contains
...
@@ -392,28 +545,44 @@ def load_state_dict(checkpoint_file_path: Path):
...
@@ -392,28 +545,44 @@ def load_state_dict(checkpoint_file_path: Path):
return
torch
.
load
(
checkpoint_file_path
)
return
torch
.
load
(
checkpoint_file_path
)
def
add_
variant
(
weights_name
:
str
,
variant
:
Optional
[
str
]
=
None
)
->
str
:
def
add_
prefix
(
weights_name
:
str
,
prefix
:
Optional
[
str
]
=
None
)
->
str
:
if
variant
is
not
None
and
len
(
variant
)
>
0
:
if
prefix
is
not
None
and
len
(
prefix
)
>
0
:
splits
=
weights_name
.
split
(
"."
)
splits
=
weights_name
.
split
(
"."
)
splits
=
splits
[:
-
1
]
+
[
variant
]
+
splits
[
-
1
:]
splits
=
splits
[:
-
1
]
+
[
prefix
]
+
splits
[
-
1
:]
weights_name
=
"."
.
join
(
splits
)
weights_name
=
"."
.
join
(
splits
)
return
weights_name
return
weights_name
def
get_base_filenames
(
variant
:
str
=
None
,
use_safetensors
:
bool
=
False
):
def
get_
model_
base_filenames
(
prefix
:
str
=
None
,
use_safetensors
:
bool
=
False
):
"""
"""
generate base weight filenames
generate base
model
weight filenames
"""
"""
weights_name
=
SAFE_WEIGHTS_NAME
if
use_safetensors
else
WEIGHTS_NAME
weights_name
=
SAFE_WEIGHTS_NAME
if
use_safetensors
else
WEIGHTS_NAME
weights_name
=
add_
variant
(
weights_name
,
variant
)
weights_name
=
add_
prefix
(
weights_name
,
prefix
)
save_index_file
=
SAFE_WEIGHTS_INDEX_NAME
if
use_safetensors
else
WEIGHTS_INDEX_NAME
save_index_file
=
SAFE_WEIGHTS_INDEX_NAME
if
use_safetensors
else
WEIGHTS_INDEX_NAME
save_index_file
=
add_
variant
(
save_index_file
,
variant
)
save_index_file
=
add_
prefix
(
save_index_file
,
prefix
)
return
weights_name
,
save_index_file
return
weights_name
,
save_index_file
def
get_optimizer_base_filenames
(
prefix
:
str
=
None
):
"""
generate base optimizer state filenames
"""
states_name
=
STATES_NAME
states_name
=
add_prefix
(
states_name
,
prefix
)
save_index_file
=
STATES_INDEX_NAME
save_index_file
=
add_prefix
(
save_index_file
,
prefix
)
param_group_file
=
GROUP_FILE_NAME
param_group_file
=
add_prefix
(
param_group_file
,
prefix
)
return
states_name
,
save_index_file
,
param_group_file
def
get_shard_filename
(
weights_name
:
str
,
idx
:
int
):
def
get_shard_filename
(
weights_name
:
str
,
idx
:
int
):
"""
"""
get shard file name
get shard file name
...
...
tests/test_checkpoint_io/test_general_checkpoint_io.py
View file @
c9cff7e7
...
@@ -60,7 +60,7 @@ def test_unsharded_checkpoint(use_safetensors: bool):
...
@@ -60,7 +60,7 @@ def test_unsharded_checkpoint(use_safetensors: bool):
@
pytest
.
mark
.
parametrize
(
'use_safetensors'
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
'use_safetensors'
,
[
True
,
False
])
def
test_sharded_checkpoint
(
use_safetensors
:
bool
):
def
test_sharded_
model_
checkpoint
(
use_safetensors
:
bool
):
# create a model and optimizer
# create a model and optimizer
model
=
resnet18
()
model
=
resnet18
()
optimizer
=
Adam
(
model
.
parameters
(),
lr
=
0.001
)
optimizer
=
Adam
(
model
.
parameters
(),
lr
=
0.001
)
...
@@ -100,3 +100,101 @@ def test_sharded_checkpoint(use_safetensors: bool):
...
@@ -100,3 +100,101 @@ def test_sharded_checkpoint(use_safetensors: bool):
# check for model and optimizer state dict recursively
# check for model and optimizer state dict recursively
check_state_dict_equal
(
model
.
state_dict
(),
new_model
.
state_dict
())
check_state_dict_equal
(
model
.
state_dict
(),
new_model
.
state_dict
())
check_state_dict_equal
(
optimizer
.
state_dict
(),
new_optimizer
.
state_dict
())
check_state_dict_equal
(
optimizer
.
state_dict
(),
new_optimizer
.
state_dict
())
def
test_sharded_optimizer_checkpoint
():
# create a model and optimizer
model
=
resnet18
()
optimizer
=
Adam
(
model
.
parameters
(),
lr
=
0.001
)
# create test data sample
x
=
torch
.
randn
(
1
,
3
,
224
,
224
)
# run fwd and bwd
y
=
model
(
x
)
loss
=
y
.
sum
()
loss
.
backward
()
optimizer
.
step
()
# create temp directories for checkpoint
model_ckpt_dir
=
tempfile
.
TemporaryDirectory
()
optimizer_ckpt_dir
=
tempfile
.
TemporaryDirectory
()
# save the model and optimizer
ckpt_io
=
GeneralCheckpointIO
()
ckpt_io
.
save_model
(
model
,
model_ckpt_dir
.
name
,
True
,
True
,
""
,
10
,
use_safetensors
=
False
)
ckpt_io
.
save_optimizer
(
optimizer
,
optimizer_ckpt_dir
.
name
,
shard
=
True
,
size_per_shard
=
10
)
# create new model
new_model
=
resnet18
()
new_optimizer
=
Adam
(
new_model
.
parameters
(),
lr
=
0.001
)
ckpt_io
.
load_model
(
new_model
,
str
(
model_ckpt_dir
.
name
),
strict
=
True
)
ckpt_io
.
load_optimizer
(
new_optimizer
,
str
(
optimizer_ckpt_dir
.
name
))
# check for model and optimizer state dict recursively
check_state_dict_equal
(
model
.
state_dict
(),
new_model
.
state_dict
())
check_state_dict_equal
(
optimizer
.
state_dict
(),
new_optimizer
.
state_dict
())
# continue running fwd and bwd
for
_
in
range
(
5
):
y
=
new_model
(
x
)
loss
=
y
.
sum
()
loss
.
backward
()
new_optimizer
.
step
()
# save the newly got optimizer
ckpt_io
.
save_model
(
new_model
,
model_ckpt_dir
.
name
,
True
,
True
,
""
,
10
,
use_safetensors
=
False
)
ckpt_io
.
save_optimizer
(
new_optimizer
,
optimizer_ckpt_dir
.
name
,
shard
=
True
,
size_per_shard
=
10
)
# create another new model
new_new_model
=
resnet18
()
new_new_optimizer
=
Adam
(
new_new_model
.
parameters
(),
lr
=
0.001
)
ckpt_io
.
load_model
(
new_new_model
,
str
(
model_ckpt_dir
.
name
),
strict
=
True
)
ckpt_io
.
load_optimizer
(
new_new_optimizer
,
str
(
optimizer_ckpt_dir
.
name
))
# check for model and optimizer state dict recursively
check_state_dict_equal
(
new_model
.
state_dict
(),
new_new_model
.
state_dict
())
check_state_dict_equal
(
new_optimizer
.
state_dict
(),
new_new_optimizer
.
state_dict
())
def
test_sharded_optimizer_multiple_param_groups
():
# create a model and optimizer
model
=
resnet18
()
optimizer
=
Adam
([{
'params'
:
model
.
layer1
.
parameters
()},
\
{
'params'
:
model
.
layer2
.
parameters
(),
'lr'
:
0.002
}],
lr
=
0.001
)
# create test data sample
x
=
torch
.
randn
(
1
,
3
,
224
,
224
)
# run fwd and bwd
y
=
model
(
x
)
loss
=
y
.
sum
()
loss
.
backward
()
optimizer
.
step
()
# create temp directories for checkpoint
model_ckpt_dir
=
tempfile
.
TemporaryDirectory
()
optimizer_ckpt_dir
=
tempfile
.
TemporaryDirectory
()
# save the model and optimizer
ckpt_io
=
GeneralCheckpointIO
()
ckpt_io
.
save_model
(
model
,
model_ckpt_dir
.
name
,
True
,
True
,
""
,
10
,
use_safetensors
=
False
)
ckpt_io
.
save_optimizer
(
optimizer
,
optimizer_ckpt_dir
.
name
,
shard
=
True
,
size_per_shard
=
10
)
# create new model
new_model
=
resnet18
()
new_optimizer
=
Adam
([{
'params'
:
new_model
.
layer1
.
parameters
()},
\
{
'params'
:
new_model
.
layer2
.
parameters
(),
'lr'
:
0.002
}],
lr
=
0.001
)
ckpt_io
.
load_model
(
new_model
,
str
(
model_ckpt_dir
.
name
),
strict
=
True
)
ckpt_io
.
load_optimizer
(
new_optimizer
,
str
(
optimizer_ckpt_dir
.
name
))
# check for model and optimizer state dict recursively
check_state_dict_equal
(
model
.
state_dict
(),
new_model
.
state_dict
())
check_state_dict_equal
(
optimizer
.
state_dict
(),
new_optimizer
.
state_dict
())
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment