Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
acae68eb
Unverified
Commit
acae68eb
authored
Apr 01, 2022
by
アマデウス
Committed by
GitHub
Apr 01, 2022
Browse files
[model checkpoint] updated checkpoint save/load utils (#592)
parent
1c40ee87
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
218 additions
and
176 deletions
+218
-176
colossalai/utils/__init__.py
colossalai/utils/__init__.py
+5
-4
colossalai/utils/checkpointing.py
colossalai/utils/checkpointing.py
+213
-172
No files found.
colossalai/utils/__init__.py
View file @
acae68eb
from
.cuda
import
empty_cache
,
get_current_device
,
set_to_cuda
,
synchronize
from
.activation_checkpoint
import
checkpoint
from
.checkpointing
import
load_checkpoint
,
save_checkpoint
from
.common
import
(
clip_grad_norm_fp32
,
conditional_context
,
copy_tensor_parallel_attributes
,
count_zeros_fp32
,
free_port
,
is_dp_rank_0
,
is_model_parallel_parameter
,
is_no_pp_or_last_stage
,
is_tp_rank_0
,
is_using_ddp
,
is_using_pp
,
is_using_sequence
,
multi_tensor_applier
,
ensure_path_exists
,
free_port
,
is_dp_rank_0
,
is_model_parallel_parameter
,
is_no_pp_or_last_stage
,
is_tp_rank_0
,
is_using_ddp
,
is_using_pp
,
is_using_sequence
,
multi_tensor_applier
,
param_is_not_tensor_parallel_duplicate
,
print_rank_0
,
switch_virtual_pipeline_parallel_rank
,
sync_model_param
)
from
.data_sampler
import
DataParallelSampler
,
get_dataloader
...
...
@@ -18,5 +18,6 @@ __all__ = [
'is_model_parallel_parameter'
,
'clip_grad_norm_fp32'
,
'count_zeros_fp32'
,
'copy_tensor_parallel_attributes'
,
'param_is_not_tensor_parallel_duplicate'
,
'get_current_device'
,
'synchronize'
,
'empty_cache'
,
'set_to_cuda'
,
'report_memory_usage'
,
'Timer'
,
'MultiTimer'
,
'multi_tensor_applier'
,
'accumulate_gradient'
,
'DataParallelSampler'
,
'get_dataloader'
,
'switch_virtual_pipeline_parallel_rank'
,
'TensorDetector'
'get_dataloader'
,
'switch_virtual_pipeline_parallel_rank'
,
'TensorDetector'
,
'load_checkpoint'
,
'save_checkpoint'
,
'ensure_path_exists'
]
colossalai/utils/checkpointing.py
View file @
acae68eb
import
os
import
os.path
as
osp
import
re
from
typing
import
Tuple
from
pathlib
import
Path
from
collections
import
OrderedDict
from
itertools
import
chain
import
torch
from
colossalai.co
ntext
import
Config
import
torch.distributed
as
dist
from
colossalai.co
mmunication.collective
import
scatter_object_list
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
torch.nn.modules.module
import
_EXTRA_STATE_KEY_SUFFIX
__all__
=
[
'get_checkpoint_path'
,
'get_latest_checkpoint_path'
,
'get_latest_checkpoint_pattern'
,
'save_checkpoint'
,
'load_checkpoint'
]
def
unwrap_config
(
config
:
Config
):
"""Unwrap Config objects to normal dicts
"""
config_dict
=
dict
()
for
k
,
v
in
config
.
items
():
if
isinstance
(
v
,
dict
):
config_dict
[
k
]
=
unwrap_config
(
v
)
else
:
config_dict
[
k
]
=
v
return
config_dict
def
_get_ranks_name
():
# tensor parallel
tp_local_rank
=
0
if
gpc
.
is_initialized
(
ParallelMode
.
TENSOR
):
tp_local_rank
=
gpc
.
get_local_rank
(
ParallelMode
.
TENSOR
)
# pipeline parallel
pp_local_rank
=
0
if
gpc
.
is_initialized
(
ParallelMode
.
PIPELINE
):
pp_local_rank
=
gpc
.
get_local_rank
(
ParallelMode
.
PIPELINE
)
ranks_name
=
f
'tp
{
tp_local_rank
}
-pp
{
pp_local_rank
}
'
return
ranks_name
def
_get_standard_checkpoint_filename
(
epoch
:
int
,
suffix
:
str
=
''
):
ranks_name
=
_get_ranks_name
()
return
f
'epoch
{
epoch
}
-
{
ranks_name
}{
suffix
}
.pt'
def
get_checkpoint_path
(
checkpoint_dir
:
str
,
epoch
:
int
,
suffix
:
str
=
''
):
"""This is a function to generate the checkpoint path from the tuple
(checkpoint_dir, epoch, suffix, gpu_parallel_rank).
This is useful during generation and recuperation of the checkpoint.
Args:
checkpoint_dir (str): Set up a directory for saving checkpoints.
epoch (int): Epoch number (indicate how many epochs have you trained this model).
suffix (str, optional): Additional notation to specify the model or checkpoint, defaults to ''
Returns:
str: The checkpoint path to be generated.
"""
ckpt_filename
=
_get_standard_checkpoint_filename
(
epoch
,
suffix
)
return
os
.
path
.
join
(
checkpoint_dir
,
ckpt_filename
)
def
_ensure_directory_exists
(
filename
:
str
):
# ensure the directory exists
dirpath
=
os
.
path
.
dirname
(
filename
)
if
not
os
.
path
.
exists
(
dirpath
):
Path
(
dirpath
).
mkdir
(
parents
=
True
,
exist_ok
=
True
)
from
.common
import
is_using_pp
__all__
=
[
"save_checkpoint"
,
"load_checkpoint"
]
def
get_latest_checkpoint_pattern
(
suffix
:
str
=
''
):
"""Generate Regular expression of the latest checkpoint's pattern.
Args:
suffix (str, optional): Additional notation to specify the model or checkpoint, defaults to ''.
def
broadcast_state_dict
(
state_dict
,
parallel_mode
):
state_dict
=
[
state_dict
.
copy
()
if
isinstance
(
state_dict
,
dict
)
else
state_dict
]
src_rank
=
gpc
.
get_ranks_in_group
(
parallel_mode
)[
0
]
dist
.
broadcast_object_list
(
state_dict
,
src
=
src_rank
,
group
=
gpc
.
get_cpu_group
(
parallel_mode
))
return
state_dict
[
0
]
Returns:
str: The regular expression of checkpoint pattern.
"""
ranks_name
=
_get_ranks_name
()
pattern
=
r
'epoch(\d+)-{}{}\.pt'
.
format
(
ranks_name
,
suffix
)
ckpt_pattern
=
re
.
compile
(
pattern
)
return
ckpt_pattern
def
partition_tensor_parallel_state_dict
(
state_dict
:
OrderedDict
,
parallel_mode
:
ParallelMode
,
dims
:
dict
=
dict
(),
partition_states
:
dict
=
dict
()
):
src_rank
=
gpc
.
get_ranks_in_group
(
parallel_mode
)[
0
]
depth
=
gpc
.
get_world_size
(
parallel_mode
)
def
get_latest_checkpoint_path
(
checkpoint_dir
:
str
,
suffix
:
str
=
''
):
"""This is a function to retrieve the latest checkpoint path from the tuple
(checkpoint_dir, suffix, gpu_parallel_rank).
This is useful during recuperation of the checkpoint, especially when you do not know the epoch number.
if
gpc
.
get_local_rank
(
parallel_mode
)
==
0
:
Args:
checkpoint_dir (str): Directory for saving checkpoints
suffix (str, optional): Additional notation to specify the model or checkpoint, defaults to ''
partitioned_state_list
=
[
dict
()
for
_
in
range
(
depth
)]
Returns:
str: The latest retrieved checkpoint path.
for
key
in
list
(
state_dict
.
keys
()):
param
=
state_dict
.
pop
(
key
)
dim
=
dims
.
get
(
key
,
0
)
do_partition
=
partition_states
.
get
(
key
,
True
)
if
do_partition
:
param
=
torch
.
chunk
(
param
,
depth
,
dim
=
dim
)
for
i
,
p
in
enumerate
(
partitioned_state_list
):
p
[
key
]
=
param
[
i
]
if
do_partition
else
param
Raises:
FileNotFoundError: Raise error when we cannot find the latest checkpoint file with inputs given.
"""
CKPT_NAME_PAT
=
get_latest_checkpoint_pattern
(
suffix
=
suffix
)
last_epoch
=
-
1
assert
osp
.
isdir
(
checkpoint_dir
),
f
'
{
checkpoint_dir
}
is not a directory'
for
filename
in
os
.
listdir
(
checkpoint_dir
):
ret
=
CKPT_NAME_PAT
.
match
(
filename
)
if
ret
:
epoch
=
int
(
ret
[
0
].
split
(
'-'
)[
0
].
lstrip
(
'epoch'
))
if
epoch
>
last_epoch
:
last_epoch
=
epoch
if
last_epoch
==
-
1
:
ranks_name
=
_get_ranks_name
()
raise
FileNotFoundError
(
f
"Cannot find the latest checkpoint file for
{
ranks_name
}
in
{
checkpoint_dir
}
"
)
else
:
target_file
=
_get_standard_checkpoint_filename
(
last_epoch
,
suffix
=
suffix
)
path
=
osp
.
join
(
checkpoint_dir
,
target_file
)
return
path
def
save_checkpoint
(
checkpoint_path
:
str
,
epoch
:
int
,
model
:
torch
.
nn
.
Module
,
optimizer
:
torch
.
optim
.
Optimizer
,
lr_scheduler
:
torch
.
optim
.
lr_scheduler
.
_LRScheduler
=
None
,
**
kwargs
):
"""Given a directory to store the checkpoints, saves all the training components' parameters or buffers, such as
model, optimizer, lr_scheduler etc. into a checkpoint dictionary.
This method can be used for both :class:`colossalai.nn.BaseModel` and normal :class:`torch.nn.Module`.
partitioned_state_list
=
[
None
for
_
in
range
(
depth
)]
partitioned_state
=
[
None
]
scatter_object_list
(
partitioned_state
,
partitioned_state_list
,
src
=
src_rank
,
group
=
gpc
.
get_cpu_group
(
parallel_mode
))
return
partitioned_state
[
0
]
def
gather_tensor_parallel_state_dict
(
state_dict
:
OrderedDict
,
parallel_mode
:
ParallelMode
,
dims
:
dict
=
dict
(),
partition_states
:
dict
=
dict
(),
keep_vars
:
bool
=
False
,
):
dst_rank
=
gpc
.
get_ranks_in_group
(
parallel_mode
)[
0
]
depth
=
gpc
.
get_world_size
(
parallel_mode
)
for
key
in
list
(
state_dict
.
keys
()):
param
=
state_dict
.
pop
(
key
)
param
=
param
if
keep_vars
else
param
.
detach
()
dim
=
dims
.
get
(
key
,
0
)
do_partition
=
partition_states
.
get
(
key
,
True
)
if
do_partition
:
temp
=
param
.
transpose
(
0
,
dim
).
contiguous
()
gather_list
=
None
if
gpc
.
get_local_rank
(
parallel_mode
)
==
0
:
shape
=
list
(
param
.
shape
)
shape
[
0
],
shape
[
dim
]
=
shape
[
dim
],
shape
[
0
]
shape
[
0
]
*=
depth
param
=
torch
.
empty
(
shape
,
dtype
=
param
.
dtype
,
device
=
param
.
device
)
gather_list
=
list
(
torch
.
chunk
(
param
,
depth
,
dim
=
0
))
dist
.
gather
(
temp
,
gather_list
,
dst
=
dst_rank
,
group
=
gpc
.
get_cpu_group
(
parallel_mode
))
param
=
torch
.
transpose
(
param
,
0
,
dim
)
# update params in state_dict only on local rank 0
if
gpc
.
get_local_rank
(
parallel_mode
)
==
0
:
state_dict
[
key
]
=
param
return
state_dict
def
_send_state_dict
(
state_dict
,
dst
,
parallel_mode
):
state_tensor
,
state_size
=
dist
.
distributed_c10d
.
_object_to_tensor
(
state_dict
)
dist
.
send
(
state_size
,
dst
,
group
=
gpc
.
get_cpu_group
(
parallel_mode
))
dist
.
send
(
state_tensor
,
dst
,
group
=
gpc
.
get_cpu_group
(
parallel_mode
))
def
_recv_state_dict
(
src
,
parallel_mode
):
state_size
=
torch
.
tensor
([
0
],
dtype
=
torch
.
long
)
dist
.
recv
(
state_size
,
src
,
group
=
gpc
.
get_cpu_group
(
parallel_mode
))
state_tensor
=
torch
.
empty
(
state_size
.
item
(),
dtype
=
torch
.
uint8
)
dist
.
recv
(
state_tensor
,
src
,
group
=
gpc
.
get_cpu_group
(
parallel_mode
))
state_dict
=
dist
.
distributed_c10d
.
_tensor_to_object
(
state_tensor
,
state_size
)
return
state_dict
def
partition_pipeline_parallel_state_dict
(
model
,
state_dict
):
pipeline_state
=
OrderedDict
()
if
gpc
.
get_local_rank
(
ParallelMode
.
TENSOR
)
==
0
:
# receive all states from prev stage
if
not
gpc
.
is_first_rank
(
ParallelMode
.
PIPELINE
):
state_dict
=
_recv_state_dict
(
gpc
.
get_prev_global_rank
(
ParallelMode
.
PIPELINE
),
ParallelMode
.
PIPELINE
)
# move states to output
for
name
,
_
in
model
.
named_parameters
(
recurse
=
True
):
if
name
in
state_dict
:
pipeline_state
[
name
]
=
state_dict
.
pop
(
name
)
for
name
,
_
in
model
.
named_buffers
(
recurse
=
True
):
if
name
in
state_dict
:
pipeline_state
[
name
]
=
state_dict
.
pop
(
name
)
for
name
,
_
in
model
.
named_modules
():
extra_state_key
=
name
+
"."
+
_EXTRA_STATE_KEY_SUFFIX
if
extra_state_key
in
state_dict
:
pipeline_state
[
extra_state_key
]
=
state_dict
.
pop
(
extra_state_key
)
# send rest states to next stage
if
not
gpc
.
is_last_rank
(
ParallelMode
.
PIPELINE
):
_send_state_dict
(
state_dict
,
gpc
.
get_next_global_rank
(
ParallelMode
.
PIPELINE
),
ParallelMode
.
PIPELINE
)
return
pipeline_state
def
gather_pipeline_parallel_state_dict
(
state_dict
):
gathered_states
=
(
[
None
for
_
in
range
(
gpc
.
get_world_size
(
ParallelMode
.
PIPELINE
))]
if
gpc
.
get_local_rank
(
ParallelMode
.
PIPELINE
)
==
0
else
None
)
dist
.
gather_object
(
state_dict
,
gathered_states
,
dst
=
gpc
.
get_ranks_in_group
(
ParallelMode
.
PIPELINE
)[
0
],
group
=
gpc
.
get_cpu_group
(
ParallelMode
.
PIPELINE
),
)
state_dict
=
(
OrderedDict
(
chain
.
from_iterable
(
state
.
items
()
for
state
in
gathered_states
))
if
gpc
.
get_local_rank
(
ParallelMode
.
PIPELINE
)
==
0
else
OrderedDict
()
)
return
state_dict
def
save_checkpoint
(
file
,
epoch
:
int
,
model
:
torch
.
nn
.
Module
,
optimizer
:
torch
.
optim
.
Optimizer
=
None
,
lr_scheduler
:
torch
.
optim
.
lr_scheduler
.
_LRScheduler
=
None
,
**
kwargs
):
"""Stores the checkpoint to disk. Saves all the training components' parameters or buffers, such as model, optimizer,
lr_scheduler etc. into a checkpoint dictionary.
Args:
checkpoint_path (str): Set up a directory for saving checkpoints.
epoch (int): Epoch number (indicate how many epochs have you trained this model).
model (:class:`torch.nn.Module`): Model to be registered.
optimizer (Union[:class:`torch.optim.Optimizer`, :class:`colossalai.nn.optimizer`]): Optimizer to be registered.
file: a file-like object (has to implement write and flush) or a string or os.PathLike object containing a
file name.
epoch (int): Epoch number (indicates how many epochs have you trained this model).
model (:class:`torch.nn.Module`): Model to be saved.
optimizer (Union[:class:`torch.optim.Optimizer`, :class:`colossalai.nn.optimizer`]): Optimizer to be saved.
lr_scheduler (Union[:class:`torch.optim.lr_scheduler`,
:class:`colossalai.nn.lr_scheduler`], optional): lr_scheduler to be registered, defaults to None.
kwargs (dict): additional parameters to be saved.
:class:`colossalai.nn.lr_scheduler`], optional): lr_scheduler to be saved, defaults to None.
pickle_module: module used for pickling metadata and objects
pickle_protocol: can be specified to override the default protocol
"""
# for compatibility with normal pytorch nn.Module
if
hasattr
(
model
,
'state_dict_for_save_checkpoint'
):
model_sd
=
model
.
state_dict_for_save_checkpoint
()
else
:
model_sd
=
model
.
state_dict
()
# ckpt container
checkpoint
=
{
'epoch'
:
epoch
,
'model'
:
model_sd
,
'optimizer'
:
optimizer
.
state_dict
(),
**
kwargs
}
if
lr_scheduler
is
not
None
:
checkpoint
[
'lr_scheduler'
]
=
lr_scheduler
.
state_dict
()
checkpoint
=
{
"epoch"
:
epoch
}
_ensure_directory_exists
(
checkpoint_path
)
torch
.
save
(
checkpoint
,
checkpoint_path
)
model_state
=
model
.
state_dict
()
if
is_using_pp
()
and
gpc
.
get_local_rank
(
ParallelMode
.
TENSOR
)
==
0
:
model_state
=
gather_pipeline_parallel_state_dict
(
model_state
)
if
gpc
.
get_global_rank
()
==
0
:
checkpoint
[
"model"
]
=
model_state
def
load_checkpoint
(
checkpoint_path
:
str
,
model
:
torch
.
nn
.
Module
,
optimizer
:
torch
.
optim
.
Optimizer
,
lr_scheduler
:
torch
.
optim
.
lr_scheduler
.
_LRScheduler
=
None
,
finetune
:
bool
=
False
,
strict
:
bool
=
True
)
->
Tuple
:
"""Loads the checkpoint file.
# if optimizer is not None:
# checkpoint['optimizer'] = optimizer.state_dict()
If finetune is False, then we intend to continue/resume the training process from the checkpoint given.
So we copy parameters and buffers from state_dict into these modules(model, optimizer,lr_scheduler)
and its descendants.
# if lr_scheduler is not None:
# checkpoint['lr_scheduler'] = lr_scheduler.state_dict()
If finetune is True, then only the weights and buffers of model should be reloaded.
If strict is True, then the keys of state_dict must exactly match the keys returned
by this module’s state_dict() function.
torch
.
save
(
checkpoint
,
file
,
**
kwargs
)
Args:
checkpoint_path (str): The exact and matched checkpoint_path directory to retrieve appropriate state_dict.
model (:class:`torch.nn.Module`): Model to reload parameters and buffers.
def
load_checkpoint
(
file
,
model
:
torch
.
nn
.
Module
,
optimizer
:
torch
.
optim
.
Optimizer
=
None
,
lr_scheduler
:
torch
.
optim
.
lr_scheduler
.
_LRScheduler
=
None
,
strict
:
bool
=
True
,
):
"""Loads training states from a checkpoint file.
Args:
file: a file-like object (has to implement read(), readline(), tell(), and seek()), or a string or os.PathLike
object containing a file name.
model (:class:`torch.nn.Module`): Model to load saved weights and buffers.
optimizer (Union[:class:`torch.optim.Optimizer`, :class:`colossalai.nn.optimizer`]): Optimizer to recuperate.
lr_scheduler (:class:`torch.optim.lr_scheduler._LRScheduler`, optional):
lr_scheduler to recuperate, defaults to None.
finetune (bool, optional): Whether to finetune the model with new dataset or
continue the pre-training, defaults to False.
strict (bool, optional): Whether to strictly enforce that the keys in :attr:`state_dict`
of the checkpoint match the names of parameters and buffers in model, defaults to True.
Returns:
Tuple(int, ``checkpoint``): The tuple (the epoch number of the checkpoint retrieved, the checkpoint retrieved)
.
int: The saved epoch number
.
Raises:
Valu
eError: Raise error if the model/optimizer cannot successfully be recuperated
Runtim
eError: Raise error if the model/optimizer cannot successfully be recuperated
"""
# Load the checkpoint.
checkpoint
=
torch
.
load
(
checkpoint_path
,
map_location
=
'cpu'
)
state_dict
=
(
torch
.
load
(
file
,
map_location
=
torch
.
device
(
"cpu"
))
if
gpc
.
get_local_rank
(
ParallelMode
.
MODEL
)
==
0
else
None
)
# model states
model_state
=
state_dict
.
pop
(
"model"
)
if
state_dict
is
not
None
else
dict
()
# pipeline
if
is_using_pp
():
model_state
=
partition_pipeline_parallel_state_dict
(
model
,
model_state
)
try
:
last_epoch
=
checkpoint
.
pop
(
'epoch'
)
if
not
finetune
else
0
model
.
load_state_dict
(
checkpoint
.
pop
(
'model'
),
strict
=
strict
)
except
KeyError
:
raise
ValueError
(
'Checkpoint is corrupted'
)
model
.
load_state_dict
(
model_state
,
strict
=
strict
)
except
RuntimeError
as
e
:
error_msgs
=
str
(
e
)
if
error_msgs
.
startswith
(
"Error(s) in loading state_dict for "
):
error_msgs
=
error_msgs
.
split
(
"
\n\t
"
)[
1
:]
dst_rank
=
gpc
.
get_ranks_in_group
(
ParallelMode
.
MODEL
)[
0
]
all_error_msgs
=
[
None
for
_
in
range
(
gpc
.
get_world_size
(
ParallelMode
.
MODEL
))]
dist
.
gather_object
(
error_msgs
,
all_error_msgs
,
dst
=
dst_rank
,
group
=
gpc
.
get_cpu_group
(
ParallelMode
.
MODEL
))
if
gpc
.
get_global_rank
()
==
0
:
all_error_msgs
=
list
(
chain
.
from_iterable
(
all_error_msgs
))
raise
RuntimeError
(
"Error(s) in loading state_dict for {}:
\n\t
{}"
.
format
(
model
.
__class__
.
__name__
,
"
\n\t
"
.
join
(
all_error_msgs
)
)
)
else
:
raise
e
# broadcast the rest states
state_dict
=
broadcast_state_dict
(
state_dict
,
ParallelMode
.
MODEL
)
# # optimizer states
# if optimizer is not None and 'optimizer' in state_dict:
# optimizer.load_state_dict(state_dict['optimizer'])
if
not
finetune
:
try
:
optimizer
.
load_state_dict
(
checkpoint
.
pop
(
'optimizer'
))
except
KeyError
:
raise
ValueError
(
'Checkpoint is corrupted'
)
# # lr scheduler states
# if lr_scheduler is not None and 'lr_scheduler' in state_dict:
# lr_scheduler.load_state_dict(state_dict['lr_scheduler'])
if
lr_scheduler
is
not
None
and
'lr_scheduler'
in
checkpoint
:
lr_scheduler
.
load_state_dict
(
checkpoint
.
pop
(
'lr_scheduler'
)
)
# last epoch
last_epoch
=
state_dict
.
pop
(
"epoch"
,
-
1
)
return
last_epoch
,
checkpoint
return
last_epoch
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment