Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
c9cff7e7
Unverified
Commit
c9cff7e7
authored
Jun 15, 2023
by
Baizhou Zhang
Committed by
GitHub
Jun 15, 2023
Browse files
[checkpointio] General Checkpointing of Sharded Optimizers (#3984)
parent
8bcad736
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
399 additions
and
38 deletions
+399
-38
colossalai/booster/plugin/gemini_plugin.py
colossalai/booster/plugin/gemini_plugin.py
+3
-3
colossalai/booster/plugin/torch_ddp_plugin.py
colossalai/booster/plugin/torch_ddp_plugin.py
+13
-3
colossalai/booster/plugin/torch_fsdp_plugin.py
colossalai/booster/plugin/torch_fsdp_plugin.py
+5
-4
colossalai/checkpoint_io/checkpoint_io_base.py
colossalai/checkpoint_io/checkpoint_io_base.py
+6
-6
colossalai/checkpoint_io/general_checkpoint_io.py
colossalai/checkpoint_io/general_checkpoint_io.py
+83
-12
colossalai/checkpoint_io/index_file.py
colossalai/checkpoint_io/index_file.py
+13
-1
colossalai/checkpoint_io/utils.py
colossalai/checkpoint_io/utils.py
+177
-8
tests/test_checkpoint_io/test_general_checkpoint_io.py
tests/test_checkpoint_io/test_general_checkpoint_io.py
+99
-1
No files found.
colossalai/booster/plugin/gemini_plugin.py
View file @
c9cff7e7
...
...
@@ -12,7 +12,7 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from
torch.utils.data
import
DataLoader
from
colossalai.checkpoint_io
import
CheckpointIndexFile
,
CheckpointIO
,
GeneralCheckpointIO
from
colossalai.checkpoint_io.utils
import
get_base_filenames
,
get_shard_filename
,
save_state_dict
from
colossalai.checkpoint_io.utils
import
get_
model_
base_filenames
,
get_shard_filename
,
save_state_dict
from
colossalai.cluster
import
DistCoordinator
from
colossalai.interface
import
ModelWrapper
,
OptimizerWrapper
from
colossalai.utils
import
get_current_device
...
...
@@ -76,14 +76,14 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
model
:
GeminiDDP
,
checkpoint_path
:
str
,
gather_dtensor
:
bool
=
False
,
variant
:
Optional
[
str
]
=
None
,
prefix
:
Optional
[
str
]
=
None
,
max_shard_size
:
int
=
1024
,
use_safetensors
:
bool
=
False
):
"""
Save sharded model
"""
state_dict_shard
=
model
.
state_dict_shard
(
max_shard_size
=
max_shard_size
,
only_rank_0
=
True
,
dtype
=
torch
.
float32
)
weights_name
,
save_index_file
=
get_base_filenames
(
variant
,
use_safetensors
)
weights_name
,
save_index_file
=
get_
model_
base_filenames
(
prefix
,
use_safetensors
)
total_size
=
0
index_file
=
CheckpointIndexFile
(
checkpoint_path
)
for
idx
,
shard_pair
in
enumerate
(
state_dict_shard
):
...
...
colossalai/booster/plugin/torch_ddp_plugin.py
View file @
c9cff7e7
...
...
@@ -32,7 +32,6 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
"""
Save model to checkpoint but only on master process.
"""
# the model should be unwrapped in self.load_model via ModelWrapper.unwrap
if
self
.
coordinator
.
is_master
():
super
().
save_unsharded_model
(
model
,
checkpoint
,
gather_dtensor
,
use_safetensors
)
...
...
@@ -54,11 +53,22 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
model
:
nn
.
Module
,
checkpoint_path
:
str
,
gather_dtensor
:
bool
=
False
,
variant
:
Optional
[
str
]
=
None
,
prefix
:
Optional
[
str
]
=
None
,
max_shard_size
:
int
=
1024
,
use_safetensors
:
bool
=
False
):
"""
Save model to checkpoint but only on master process.
"""
if
self
.
coordinator
.
is_master
():
super
().
save_sharded_model
(
model
,
checkpoint_path
,
gather_dtensor
,
prefix
,
max_shard_size
,
use_safetensors
)
def
save_sharded_optimier
(
self
,
optimizer
:
Optimizer
,
checkpoint
:
str
,
gather_dtensor
:
bool
,
prefix
:
str
,
size_per_shard
:
int
):
"""
Save optimizer to checkpoint but only on master process.
"""
if
self
.
coordinator
.
is_master
():
super
().
save_sharded_
model
(
model
,
checkpoint
_path
,
gather_dtensor
,
variant
,
max_shard_size
,
use_safetensors
)
super
().
save_sharded_
optimizer
(
optimizer
,
checkpoint
,
gather_dtensor
,
prefix
,
size_per_shard
)
class
TorchDDPModel
(
ModelWrapper
):
...
...
colossalai/booster/plugin/torch_fsdp_plugin.py
View file @
c9cff7e7
import
warnings
from
pathlib
import
Path
from
typing
import
Callable
,
Iterable
,
Iterator
,
List
,
Optional
,
Tuple
,
Union
import
torch
import
torch.nn
as
nn
import
warnings
from
packaging
import
version
from
torch.distributed
import
ProcessGroup
...
...
@@ -69,7 +69,7 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
full_optimizer_state
=
FSDP
.
full_optim_state_dict
(
fsdp_model
,
optim
=
optimizer
,
rank0_only
=
True
)
utils
.
save_state_dict
(
full_optimizer_state
,
checkpoint_file_path
=
checkpoint
,
use_safetensors
=
False
)
def
save_sharded_model
(
self
,
model
:
nn
.
Module
,
checkpoint
:
str
,
gather_dtensor
:
bool
,
variant
:
Optional
[
str
],
def
save_sharded_model
(
self
,
model
:
nn
.
Module
,
checkpoint
:
str
,
gather_dtensor
:
bool
,
prefix
:
Optional
[
str
],
size_per_shard
:
int
,
use_safetensors
:
bool
):
"""
Save model to checkpoint but only on master process.
...
...
@@ -87,13 +87,14 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
"""
raise
NotImplementedError
(
"Sharded model checkpoint is not supported yet."
)
def
save_sharded_optimizer
(
self
,
optimizer
:
Optimizer
,
checkpoint
:
str
,
gather_dtensor
:
bool
):
def
save_sharded_optimizer
(
self
,
optimizer
:
Optimizer
,
checkpoint
:
str
,
gather_dtensor
:
bool
,
prefix
:
str
,
size_per_shard
:
int
):
"""
Save optimizer to checkpoint but only on master process.
"""
raise
NotImplementedError
(
"Sharded optimizer checkpoint is not supported yet."
)
def
load_sharded_optimizer
(
self
,
optimizer
:
Optimizer
,
index_file_path
:
str
,
prefix
:
str
,
size_per_shard
:
int
):
def
load_sharded_optimizer
(
self
,
optimizer
:
Optimizer
,
index_file_path
:
str
,
size_per_shard
:
int
):
"""
Load optimizer to checkpoint but only on master process.
"""
...
...
colossalai/checkpoint_io/checkpoint_io_base.py
View file @
c9cff7e7
...
...
@@ -103,7 +103,7 @@ class CheckpointIO(ABC):
checkpoint
:
str
,
shard
:
bool
=
False
,
gather_dtensor
:
bool
=
True
,
variant
:
str
=
None
,
prefix
:
str
=
None
,
size_per_shard
:
int
=
1024
,
use_safetensors
:
bool
=
False
):
"""
...
...
@@ -128,7 +128,7 @@ class CheckpointIO(ABC):
multiple files. The model shards will be specified by a `model.index.json` file. When shard = True, please ensure
that the checkpoint path is a directory path instead of a file path.
gather_dtensor (bool): whether to gather the distributed tensor to the first device. Default: True.
variant
(str): If specified, weights are saved in the format pytorch_model.<
variant
>.bin. Default: None.
prefix
(str): If specified, weights are saved in the format pytorch_model.<
prefix
>.bin. Default: None.
size_per_shard (int): size per shard in MB. Default: 1024. This value is only used when shard = True.
use_safetensors (bool): whether to use safe tensors. Default: False. If set to True, the checkpoint will be saved
"""
...
...
@@ -137,11 +137,11 @@ class CheckpointIO(ABC):
model
=
model
.
unwrap
()
if
shard
:
self
.
save_sharded_model
(
model
,
checkpoint
,
gather_dtensor
,
variant
,
size_per_shard
,
use_safetensors
)
self
.
save_sharded_model
(
model
,
checkpoint
,
gather_dtensor
,
prefix
,
size_per_shard
,
use_safetensors
)
else
:
self
.
save_unsharded_model
(
model
,
checkpoint
,
gather_dtensor
,
use_safetensors
)
def
load_optimizer
(
self
,
optimizer
:
Optimizer
,
checkpoint
:
str
):
def
load_optimizer
(
self
,
optimizer
:
Optimizer
,
checkpoint
:
str
,
prefix
:
str
=
None
,
size_per_shard
:
int
=
1024
):
"""
Load optimizer from checkpoint.
...
...
@@ -157,7 +157,7 @@ class CheckpointIO(ABC):
if
index_file_exists
:
# the existence of index file means it is a sharded checkpoint
self
.
load_sharded_optimizer
(
optimizer
,
index_file_path
)
self
.
load_sharded_optimizer
(
optimizer
,
index_file_path
,
prefix
,
size_per_shard
)
else
:
self
.
load_unsharded_optimizer
(
optimizer
,
checkpoint
)
...
...
@@ -218,7 +218,7 @@ class CheckpointIO(ABC):
pass
@
abstractmethod
def
save_sharded_model
(
self
,
model
:
nn
.
Module
,
checkpoint
:
str
,
gather_dtensor
:
bool
,
variant
:
Optional
[
str
],
def
save_sharded_model
(
self
,
model
:
nn
.
Module
,
checkpoint
:
str
,
gather_dtensor
:
bool
,
prefix
:
Optional
[
str
],
size_per_shard
:
int
,
use_safetensors
:
bool
):
"""
Save model to sharded checkpoint.
...
...
colossalai/checkpoint_io/general_checkpoint_io.py
View file @
c9cff7e7
...
...
@@ -11,15 +11,21 @@ from torch.optim import Optimizer
from
.checkpoint_io_base
import
CheckpointIO
from
.index_file
import
CheckpointIndexFile
from
.utils
import
(
get_base_filenames
,
get_model_base_filenames
,
get_optimizer_base_filenames
,
get_shard_filename
,
has_index_file
,
is_safetensors_available
,
load_param_groups_into_optimizer
,
load_shard_state_dict
,
load_state_dict
,
load_state_dict_into_model
,
load_states_into_optimizer
,
save_param_groups
,
save_state_dict
,
shard_checkpoint
,
shard_model_checkpoint
,
shard_optimizer_checkpoint
,
sharded_optimizer_loading_epilogue
,
)
__all__
=
[
'GeneralCheckpointIO'
]
...
...
@@ -44,12 +50,30 @@ class GeneralCheckpointIO(CheckpointIO):
# save the checkpoint
save_state_dict
(
state_dict
,
checkpoint
,
use_safetensors
)
def
load_sharded_optimizer
(
self
,
optimizer
:
Optimizer
,
checkpoint
:
Path
,
prefix
:
str
,
size_per_shard
:
int
):
raise
NotImplementedError
(
"Sharded optimizer checkpoint is not supported yet."
)
def
load_sharded_optimizer
(
self
,
optimizer
:
Optimizer
,
index_file_path
:
str
,
prefix
:
str
,
size_per_shard
:
int
):
"""
Load sharded optimizer with the given path to index file.
"""
optimizer
.
load_state_dict
# Read checkpoint index file.
ckpt_index_file
=
CheckpointIndexFile
.
from_file
(
index_file_path
)
def
load_unsharded_optimizer
(
self
,
optimizer
:
Optimizer
,
checkpoint
:
Path
):
checkpoint
=
load_state_dict
(
checkpoint
)
optimizer
.
load_state_dict
(
checkpoint
)
# Load param_groups
param_group_path
=
ckpt_index_file
.
get_param_group_filename
()
if
param_group_path
is
None
:
raise
RuntimeError
(
f
'Invalid index file path
{
index_file_path
}
for an optimizer.
\
Lacking param group file under current directory.'
)
id_map
=
load_param_groups_into_optimizer
(
optimizer
,
param_group_path
)
checkpoint_files
,
_
=
ckpt_index_file
.
get_checkpoint_filenames
()
for
shard_file
in
checkpoint_files
:
state_dict
=
load_shard_state_dict
(
Path
(
shard_file
),
use_safetensors
=
False
)
load_states_into_optimizer
(
optimizer
,
state_dict
,
id_map
)
del
state_dict
gc
.
collect
()
sharded_optimizer_loading_epilogue
(
optimizer
)
def
save_sharded_optimizer
(
self
,
...
...
@@ -59,7 +83,54 @@ class GeneralCheckpointIO(CheckpointIO):
prefix
:
str
,
size_per_shard
:
int
,
):
raise
NotImplementedError
(
"Sharded optimizer checkpoint is not supported yet."
)
"""
Save sharded optimizer checkpoint under the given checkpointing path.
The following files will be created under the path:
- An index file (pytorch_optim.bin.index.json) containing a map between optimizer states and file names
- A group file (pytorch_optim_group.bin) recording information of param_groups
- Multiple files (pytorch_optim-000XX.bin) that store state tensors of optimizer in a sharding way
"""
if
os
.
path
.
isfile
(
checkpoint
):
logging
.
error
(
f
"Provided path (
{
checkpoint
}
) should be a directory, not a file"
)
return
Path
(
checkpoint
).
mkdir
(
parents
=
True
,
exist_ok
=
True
)
# Offload optimizer states. States are broken into shards within max_shard_size.
state_dict
=
optimizer
.
state_dict
()
sharded_state
=
shard_optimizer_checkpoint
(
state_dict
,
max_shard_size
=
size_per_shard
)
# Preparing file paths and index file.
states_name
,
save_index_file
,
param_group_file
=
get_optimizer_base_filenames
(
prefix
)
index_file
=
CheckpointIndexFile
(
checkpoint
)
# Store the information of param groups to param_group_file.
index_file
.
append_meta_data
(
"param_groups"
,
param_group_file
)
group_file_path
=
os
.
path
.
join
(
checkpoint
,
param_group_file
)
save_param_groups
(
state_dict
,
group_file_path
)
# Save shards of optimizer states.
total_size
=
0
for
idx
,
shard_pair
in
enumerate
(
sharded_state
):
shard
,
current_size
=
shard_pair
shard_file
=
get_shard_filename
(
states_name
,
idx
)
total_size
=
total_size
+
current_size
for
param_id
in
shard
.
keys
():
index_file
.
append_weight_map
(
str
(
param_id
),
shard_file
)
checkpoint_file_path
=
os
.
path
.
join
(
checkpoint
,
shard_file
)
save_state_dict
(
shard
,
checkpoint_file_path
,
use_safetensors
=
False
)
# Wrap up index file.
index_file
.
append_meta_data
(
"total_size"
,
total_size
)
index_file
.
write_index_file
(
save_index_file
)
logging
.
info
(
f
"The optimizer is going to be split to checkpoint shards. "
f
"You can find where each parameters has been saved in the "
f
"index located at
{
save_index_file
}
."
)
def
load_unsharded_optimizer
(
self
,
optimizer
:
Optimizer
,
checkpoint
:
Path
):
checkpoint
=
load_state_dict
(
checkpoint
)
optimizer
.
load_state_dict
(
checkpoint
)
def
save_unsharded_optimizer
(
self
,
...
...
@@ -74,7 +145,7 @@ class GeneralCheckpointIO(CheckpointIO):
model
:
nn
.
Module
,
checkpoint_path
:
str
,
gather_dtensor
:
bool
=
False
,
variant
:
Optional
[
str
]
=
None
,
prefix
:
Optional
[
str
]
=
None
,
max_shard_size
:
int
=
1024
,
use_safetensors
:
bool
=
False
):
"""
...
...
@@ -89,9 +160,9 @@ class GeneralCheckpointIO(CheckpointIO):
# shard checkpoint
state_dict
=
model
.
state_dict
()
state_dict_shard
=
shard_checkpoint
(
state_dict
,
max_shard_size
=
max_shard_size
)
state_dict_shard
=
shard_
model_
checkpoint
(
state_dict
,
max_shard_size
=
max_shard_size
)
weights_name
,
save_index_file
=
get_base_filenames
(
variant
,
use_safetensors
)
weights_name
,
save_index_file
=
get_
model_
base_filenames
(
prefix
,
use_safetensors
)
total_size
=
0
index_file
=
CheckpointIndexFile
(
checkpoint_path
)
for
idx
,
shard_pair
in
enumerate
(
state_dict_shard
):
...
...
@@ -128,7 +199,7 @@ class GeneralCheckpointIO(CheckpointIO):
# read checkpoint index file
ckpt_index_file
=
CheckpointIndexFile
.
from_file
(
checkpoint_index_file
)
checkpoint_files
,
_
=
ckpt_index_file
.
get_checkpoint_file
a
names
()
checkpoint_files
,
_
=
ckpt_index_file
.
get_checkpoint_filenames
()
missing_keys
=
[]
for
shard_file
in
checkpoint_files
:
...
...
colossalai/checkpoint_io/index_file.py
View file @
c9cff7e7
...
...
@@ -111,7 +111,7 @@ class CheckpointIndexFile:
return
True
return
False
def
get_checkpoint_file
a
names
(
self
)
->
List
[
str
]:
def
get_checkpoint_filenames
(
self
)
->
List
[
str
]:
"""
Get the set of checkpoint filenames in the weight map.
...
...
@@ -159,6 +159,18 @@ class CheckpointIndexFile:
"""
return
list
(
self
.
weight_map
.
keys
())
def
get_param_group_filename
(
self
)
->
Union
[
str
,
None
]:
"""
Get the file name of param_group file if this is a checkpoint for optimizer.
Returns:
str: param_group file name
"""
filename
=
self
.
metadata
.
get
(
"param_groups"
,
None
)
if
filename
:
return
str
(
self
.
root_path
.
joinpath
(
filename
))
else
:
return
None
def
write_index_file
(
self
,
save_index_file
):
"""
Write index file.
...
...
colossalai/checkpoint_io/utils.py
View file @
c9cff7e7
# coding=utf-8
import
re
from
collections
import
abc
as
container_abcs
from
collections
import
defaultdict
from
itertools
import
chain
from
pathlib
import
Path
from
typing
import
Iterator
,
List
,
Mapping
,
Optional
,
OrderedDict
,
Tuple
import
torch
import
torch.nn
as
nn
from
torch.optim
import
Optimizer
from
colossalai.tensor.d_tensor.d_tensor
import
DTensor
SAFE_WEIGHTS_NAME
=
"model.safetensors"
WEIGHTS_NAME
=
"pytorch_model.bin"
STATES_NAME
=
"pytorch_optim.bin"
SAFE_WEIGHTS_INDEX_NAME
=
"model.safetensors.index.json"
WEIGHTS_INDEX_NAME
=
"pytorch_model.bin.index.json"
STATES_INDEX_NAME
=
"pytorch_optim.bin.index.json"
GROUP_FILE_NAME
=
"pytorch_optim_group.bin"
# ======================================
# General helper functions
...
...
@@ -81,7 +88,7 @@ def is_safetensor_checkpoint(checkpoint_file_path: str) -> bool:
# ======================================
# Helper functions for saving shard file
# ======================================
def
shard_checkpoint
(
state_dict
:
torch
.
Tensor
,
max_shard_size
:
int
=
1024
)
->
Iterator
[
Tuple
[
OrderedDict
,
int
]]:
def
shard_
model_
checkpoint
(
state_dict
:
torch
.
Tensor
,
max_shard_size
:
int
=
1024
)
->
Iterator
[
Tuple
[
OrderedDict
,
int
]]:
"""
Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
given size.
...
...
@@ -110,6 +117,50 @@ def shard_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024) -> It
yield
current_block
,
current_block_size
def
shard_optimizer_checkpoint
(
state_dict
:
dict
,
max_shard_size
:
int
=
1024
)
->
Iterator
[
Tuple
[
OrderedDict
,
int
]]:
"""
Splits an optimizer state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
given size.
"""
# Only split state_dict['state']; state_dict['param_group'] is not considered in this function.
states
=
state_dict
[
'state'
]
current_block
=
{}
current_block_size
=
0
for
param_id
,
state
in
states
.
items
():
ret_block
=
None
ret_block_size
=
0
# A state might contain more than one tensors.
# e.g. each Adam state includes: 'step', 'exp_avg', 'exp_avg_sq'
state_size
=
0
isDTensor
=
False
for
state_tensor
in
state
.
values
():
# If the states are stored as DTensors, mark isDTensor as true.
if
type
(
state_tensor
)
==
DTensor
:
isDTensor
=
True
state_size
+=
calculate_tensor_size
(
state_tensor
)
if
not
isDTensor
:
if
current_block_size
+
state_size
>
max_shard_size
:
ret_block
=
current_block
ret_block_size
=
current_block_size
current_block
=
{}
current_block_size
=
0
current_block
[
param_id
]
=
state
current_block_size
+=
state_size
if
ret_block
!=
None
:
yield
ret_block
,
ret_block_size
yield
current_block
,
current_block_size
def
load_shard_state_dict
(
checkpoint_file
:
Path
,
use_safetensors
:
bool
=
False
):
"""
load shard state dict into model
...
...
@@ -179,6 +230,96 @@ def load_state_dict_into_model(model: nn.Module,
model
.
__class__
.
__name__
,
"
\n\t
"
.
join
(
error_msgs
)))
def
load_param_groups_into_optimizer
(
optimizer
:
Optimizer
,
param_group_path
:
str
)
->
dict
:
"""
Load information of param_groups into an initialized optimizer.
"""
# Load list of param_groups from given file path.
# The params in saved_groups are in the form of integer indices.
saved_groups
=
torch
.
load
(
param_group_path
)
if
not
isinstance
(
saved_groups
,
List
):
raise
ValueError
(
f
'The param_groups saved at
{
param_group_path
}
is not of List type'
)
# The params in param_groups are in the form of pytorch tensors.
# For more details, please view source code of Optimizer class in pytorch.
param_groups
=
optimizer
.
param_groups
# Check the compatibility of saved_groups and param_groups.
if
len
(
param_groups
)
!=
len
(
saved_groups
):
raise
ValueError
(
"loaded state dict has a different number of original parameter groups"
)
param_lens
=
(
len
(
g
[
'params'
])
for
g
in
param_groups
)
saved_lens
=
(
len
(
g
[
'params'
])
for
g
in
saved_groups
)
if
any
(
p_len
!=
s_len
for
p_len
,
s_len
in
zip
(
param_lens
,
saved_lens
)):
raise
ValueError
(
"loaded state dict contains a parameter group "
"that doesn't match the size of optimizer's group"
)
# Creating mapping from id to parameters.
id_map
=
{
old_id
:
p
for
old_id
,
p
in
zip
(
chain
.
from_iterable
((
g
[
'params'
]
for
g
in
saved_groups
)),
chain
.
from_iterable
((
g
[
'params'
]
for
g
in
param_groups
)))
}
# Update parameter groups, setting their 'params' value.
def
update_group
(
group
,
new_group
):
new_group
[
'params'
]
=
group
[
'params'
]
return
new_group
updated_groups
=
[
update_group
(
g
,
ng
)
for
g
,
ng
in
zip
(
param_groups
,
saved_groups
)]
optimizer
.
__dict__
.
update
({
'param_groups'
:
updated_groups
})
return
id_map
def
load_states_into_optimizer
(
optimzier
:
Optimizer
,
state_dict
:
dict
,
id_map
:
dict
):
r
"""Copies states from `state_dict` into an Optimizer object.
Args:
optimizer(Optimizer): An initialized Optimizer object to be loaded
state_dict(dict): a mapping from tensor index (an integer)
to its states to be loaded (a mapping from state name to a tensor).
id_map(dict): a mapping from tensor index (an integer)
to its corresponding parameter (a tensor) whose states will be updated.
"""
def
cast
(
param
,
value
,
key
=
None
):
r
"""Make a deep copy of value, casting all tensors to device of param."""
if
isinstance
(
value
,
torch
.
Tensor
):
# Floating-point types are a bit special here. They are the only ones
# that are assumed to always match the type of params.
# Make sure state['step'] is not casted https://github.com/pytorch/pytorch/issues/74424
if
(
key
!=
"step"
):
if
param
.
is_floating_point
():
value
=
value
.
to
(
param
.
dtype
)
value
=
value
.
to
(
param
.
device
)
return
value
elif
isinstance
(
value
,
dict
):
return
{
k
:
cast
(
param
,
v
,
key
=
k
)
for
k
,
v
in
value
.
items
()}
elif
isinstance
(
value
,
container_abcs
.
Iterable
):
return
type
(
value
)(
cast
(
param
,
v
)
for
v
in
value
)
else
:
return
value
# Copy state assigned to params (and cast tensors to appropriate types).
# State that is not assigned to params is copied as is (needed for
# backward compatibility).
new_states
=
defaultdict
(
dict
)
for
k
,
v
in
state_dict
.
items
():
if
k
in
id_map
:
param
=
id_map
[
k
]
new_states
[
param
]
=
cast
(
param
,
v
)
else
:
new_states
[
k
]
=
v
optimzier
.
state
.
update
(
new_states
)
def
sharded_optimizer_loading_epilogue
(
optimizer
:
Optimizer
):
# Do the cleaning up as in src code of Pytorch.
optimizer
.
_hook_for_profile
()
# To support multiprocessing pickle/unpickle.
optimizer
.
defaults
.
setdefault
(
'differentiable'
,
False
)
# ======================================
# Helper functions for saving state dict
# ======================================
...
...
@@ -203,6 +344,18 @@ def save_state_dict(state_dict: dict, checkpoint_file_path: str, use_safetensors
torch
.
save
(
state_dict
,
checkpoint_file_path
)
def
save_param_groups
(
state_dict
:
dict
,
group_file_path
:
str
)
->
None
:
"""
Save information of param_groups to given file path.
Args:
state_dict (dict): state dict.
group_file_path (str): path to the group file.
"""
param_groups
=
state_dict
[
"param_groups"
]
torch
.
save
(
param_groups
,
group_file_path
)
def
save_dtensor
(
name
:
str
,
tensor
:
torch
.
Tensor
,
index_file
:
"CheckpointIndexFile"
,
use_safetensors
:
bool
)
->
None
:
"""
Save distributed tensor to checkpoint. This checkpoint will be a dictionary which contains
...
...
@@ -392,28 +545,44 @@ def load_state_dict(checkpoint_file_path: Path):
return
torch
.
load
(
checkpoint_file_path
)
def
add_
variant
(
weights_name
:
str
,
variant
:
Optional
[
str
]
=
None
)
->
str
:
if
variant
is
not
None
and
len
(
variant
)
>
0
:
def
add_
prefix
(
weights_name
:
str
,
prefix
:
Optional
[
str
]
=
None
)
->
str
:
if
prefix
is
not
None
and
len
(
prefix
)
>
0
:
splits
=
weights_name
.
split
(
"."
)
splits
=
splits
[:
-
1
]
+
[
variant
]
+
splits
[
-
1
:]
splits
=
splits
[:
-
1
]
+
[
prefix
]
+
splits
[
-
1
:]
weights_name
=
"."
.
join
(
splits
)
return
weights_name
def
get_base_filenames
(
variant
:
str
=
None
,
use_safetensors
:
bool
=
False
):
def
get_
model_
base_filenames
(
prefix
:
str
=
None
,
use_safetensors
:
bool
=
False
):
"""
generate base weight filenames
generate base
model
weight filenames
"""
weights_name
=
SAFE_WEIGHTS_NAME
if
use_safetensors
else
WEIGHTS_NAME
weights_name
=
add_
variant
(
weights_name
,
variant
)
weights_name
=
add_
prefix
(
weights_name
,
prefix
)
save_index_file
=
SAFE_WEIGHTS_INDEX_NAME
if
use_safetensors
else
WEIGHTS_INDEX_NAME
save_index_file
=
add_
variant
(
save_index_file
,
variant
)
save_index_file
=
add_
prefix
(
save_index_file
,
prefix
)
return
weights_name
,
save_index_file
def
get_optimizer_base_filenames
(
prefix
:
str
=
None
):
"""
generate base optimizer state filenames
"""
states_name
=
STATES_NAME
states_name
=
add_prefix
(
states_name
,
prefix
)
save_index_file
=
STATES_INDEX_NAME
save_index_file
=
add_prefix
(
save_index_file
,
prefix
)
param_group_file
=
GROUP_FILE_NAME
param_group_file
=
add_prefix
(
param_group_file
,
prefix
)
return
states_name
,
save_index_file
,
param_group_file
def
get_shard_filename
(
weights_name
:
str
,
idx
:
int
):
"""
get shard file name
...
...
tests/test_checkpoint_io/test_general_checkpoint_io.py
View file @
c9cff7e7
...
...
@@ -60,7 +60,7 @@ def test_unsharded_checkpoint(use_safetensors: bool):
@
pytest
.
mark
.
parametrize
(
'use_safetensors'
,
[
True
,
False
])
def
test_sharded_checkpoint
(
use_safetensors
:
bool
):
def
test_sharded_
model_
checkpoint
(
use_safetensors
:
bool
):
# create a model and optimizer
model
=
resnet18
()
optimizer
=
Adam
(
model
.
parameters
(),
lr
=
0.001
)
...
...
@@ -100,3 +100,101 @@ def test_sharded_checkpoint(use_safetensors: bool):
# check for model and optimizer state dict recursively
check_state_dict_equal
(
model
.
state_dict
(),
new_model
.
state_dict
())
check_state_dict_equal
(
optimizer
.
state_dict
(),
new_optimizer
.
state_dict
())
def
test_sharded_optimizer_checkpoint
():
# create a model and optimizer
model
=
resnet18
()
optimizer
=
Adam
(
model
.
parameters
(),
lr
=
0.001
)
# create test data sample
x
=
torch
.
randn
(
1
,
3
,
224
,
224
)
# run fwd and bwd
y
=
model
(
x
)
loss
=
y
.
sum
()
loss
.
backward
()
optimizer
.
step
()
# create temp directories for checkpoint
model_ckpt_dir
=
tempfile
.
TemporaryDirectory
()
optimizer_ckpt_dir
=
tempfile
.
TemporaryDirectory
()
# save the model and optimizer
ckpt_io
=
GeneralCheckpointIO
()
ckpt_io
.
save_model
(
model
,
model_ckpt_dir
.
name
,
True
,
True
,
""
,
10
,
use_safetensors
=
False
)
ckpt_io
.
save_optimizer
(
optimizer
,
optimizer_ckpt_dir
.
name
,
shard
=
True
,
size_per_shard
=
10
)
# create new model
new_model
=
resnet18
()
new_optimizer
=
Adam
(
new_model
.
parameters
(),
lr
=
0.001
)
ckpt_io
.
load_model
(
new_model
,
str
(
model_ckpt_dir
.
name
),
strict
=
True
)
ckpt_io
.
load_optimizer
(
new_optimizer
,
str
(
optimizer_ckpt_dir
.
name
))
# check for model and optimizer state dict recursively
check_state_dict_equal
(
model
.
state_dict
(),
new_model
.
state_dict
())
check_state_dict_equal
(
optimizer
.
state_dict
(),
new_optimizer
.
state_dict
())
# continue running fwd and bwd
for
_
in
range
(
5
):
y
=
new_model
(
x
)
loss
=
y
.
sum
()
loss
.
backward
()
new_optimizer
.
step
()
# save the newly got optimizer
ckpt_io
.
save_model
(
new_model
,
model_ckpt_dir
.
name
,
True
,
True
,
""
,
10
,
use_safetensors
=
False
)
ckpt_io
.
save_optimizer
(
new_optimizer
,
optimizer_ckpt_dir
.
name
,
shard
=
True
,
size_per_shard
=
10
)
# create another new model
new_new_model
=
resnet18
()
new_new_optimizer
=
Adam
(
new_new_model
.
parameters
(),
lr
=
0.001
)
ckpt_io
.
load_model
(
new_new_model
,
str
(
model_ckpt_dir
.
name
),
strict
=
True
)
ckpt_io
.
load_optimizer
(
new_new_optimizer
,
str
(
optimizer_ckpt_dir
.
name
))
# check for model and optimizer state dict recursively
check_state_dict_equal
(
new_model
.
state_dict
(),
new_new_model
.
state_dict
())
check_state_dict_equal
(
new_optimizer
.
state_dict
(),
new_new_optimizer
.
state_dict
())
def
test_sharded_optimizer_multiple_param_groups
():
# create a model and optimizer
model
=
resnet18
()
optimizer
=
Adam
([{
'params'
:
model
.
layer1
.
parameters
()},
\
{
'params'
:
model
.
layer2
.
parameters
(),
'lr'
:
0.002
}],
lr
=
0.001
)
# create test data sample
x
=
torch
.
randn
(
1
,
3
,
224
,
224
)
# run fwd and bwd
y
=
model
(
x
)
loss
=
y
.
sum
()
loss
.
backward
()
optimizer
.
step
()
# create temp directories for checkpoint
model_ckpt_dir
=
tempfile
.
TemporaryDirectory
()
optimizer_ckpt_dir
=
tempfile
.
TemporaryDirectory
()
# save the model and optimizer
ckpt_io
=
GeneralCheckpointIO
()
ckpt_io
.
save_model
(
model
,
model_ckpt_dir
.
name
,
True
,
True
,
""
,
10
,
use_safetensors
=
False
)
ckpt_io
.
save_optimizer
(
optimizer
,
optimizer_ckpt_dir
.
name
,
shard
=
True
,
size_per_shard
=
10
)
# create new model
new_model
=
resnet18
()
new_optimizer
=
Adam
([{
'params'
:
new_model
.
layer1
.
parameters
()},
\
{
'params'
:
new_model
.
layer2
.
parameters
(),
'lr'
:
0.002
}],
lr
=
0.001
)
ckpt_io
.
load_model
(
new_model
,
str
(
model_ckpt_dir
.
name
),
strict
=
True
)
ckpt_io
.
load_optimizer
(
new_optimizer
,
str
(
optimizer_ckpt_dir
.
name
))
# check for model and optimizer state dict recursively
check_state_dict_equal
(
model
.
state_dict
(),
new_model
.
state_dict
())
check_state_dict_equal
(
optimizer
.
state_dict
(),
new_optimizer
.
state_dict
())
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment