Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
f5d3a9c2
"vscode:/vscode.git/clone" did not exist on "416a50dbd713edab5ccb39aaf6dd1aecb0520e09"
Unverified
Commit
f5d3a9c2
authored
Apr 02, 2022
by
ver217
Committed by
GitHub
Apr 02, 2022
Browse files
polish checkpoint docstring (#637)
parent
055fbf5b
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
26 additions
and
37 deletions
+26
-37
colossalai/utils/checkpointing.py
colossalai/utils/checkpointing.py
+26
-37
No files found.
colossalai/utils/checkpointing.py
View file @
f5d3a9c2
...
...
@@ -23,9 +23,10 @@ def broadcast_state_dict(state_dict, parallel_mode):
return
state_dict
[
0
]
def
partition_tensor_parallel_state_dict
(
state_dict
:
OrderedDict
,
parallel_mode
:
ParallelMode
,
dims
:
dict
=
dict
(),
partition_states
:
dict
=
dict
()
):
def
partition_tensor_parallel_state_dict
(
state_dict
:
OrderedDict
,
parallel_mode
:
ParallelMode
,
dims
:
dict
=
dict
(),
partition_states
:
dict
=
dict
()):
src_rank
=
gpc
.
get_ranks_in_group
(
parallel_mode
)[
0
]
depth
=
gpc
.
get_world_size
(
parallel_mode
)
...
...
@@ -51,11 +52,11 @@ def partition_tensor_parallel_state_dict(
def
gather_tensor_parallel_state_dict
(
state_dict
:
OrderedDict
,
parallel_mode
:
ParallelMode
,
dims
:
dict
=
dict
(),
partition_states
:
dict
=
dict
(),
keep_vars
:
bool
=
False
,
state_dict
:
OrderedDict
,
parallel_mode
:
ParallelMode
,
dims
:
dict
=
dict
(),
partition_states
:
dict
=
dict
(),
keep_vars
:
bool
=
False
,
):
dst_rank
=
gpc
.
get_ranks_in_group
(
parallel_mode
)[
0
]
depth
=
gpc
.
get_world_size
(
parallel_mode
)
...
...
@@ -124,11 +125,8 @@ def partition_pipeline_parallel_state_dict(model, state_dict):
def
gather_pipeline_parallel_state_dict
(
state_dict
):
gathered_states
=
(
[
None
for
_
in
range
(
gpc
.
get_world_size
(
ParallelMode
.
PIPELINE
))]
if
gpc
.
get_local_rank
(
ParallelMode
.
PIPELINE
)
==
0
else
None
)
gathered_states
=
([
None
for
_
in
range
(
gpc
.
get_world_size
(
ParallelMode
.
PIPELINE
))]
if
gpc
.
get_local_rank
(
ParallelMode
.
PIPELINE
)
==
0
else
None
)
dist
.
gather_object
(
state_dict
,
gathered_states
,
...
...
@@ -136,23 +134,18 @@ def gather_pipeline_parallel_state_dict(state_dict):
group
=
gpc
.
get_cpu_group
(
ParallelMode
.
PIPELINE
),
)
state_dict
=
(
OrderedDict
(
chain
.
from_iterable
(
state
.
items
()
for
state
in
gathered_states
))
if
gpc
.
get_local_rank
(
ParallelMode
.
PIPELINE
)
==
0
else
OrderedDict
()
)
state_dict
=
(
OrderedDict
(
chain
.
from_iterable
(
state
.
items
()
for
state
in
gathered_states
))
if
gpc
.
get_local_rank
(
ParallelMode
.
PIPELINE
)
==
0
else
OrderedDict
())
return
state_dict
def
save_checkpoint
(
file
,
epoch
:
int
,
model
:
torch
.
nn
.
Module
,
optimizer
:
torch
.
optim
.
Optimizer
=
None
,
lr_scheduler
:
torch
.
optim
.
lr_scheduler
.
_LRScheduler
=
None
,
**
kwargs
):
def
save_checkpoint
(
file
,
epoch
:
int
,
model
:
torch
.
nn
.
Module
,
optimizer
:
torch
.
optim
.
Optimizer
=
None
,
lr_scheduler
:
torch
.
optim
.
lr_scheduler
.
_LRScheduler
=
None
,
**
kwargs
):
"""Stores the checkpoint to disk. Saves all the training components' parameters or buffers, such as model, optimizer,
lr_scheduler etc. into a checkpoint dictionary.
...
...
@@ -162,8 +155,8 @@ def save_checkpoint(
epoch (int): Epoch number (indicates how many epochs have you trained this model).
model (:class:`torch.nn.Module`): Model to be saved.
optimizer (Union[:class:`torch.optim.Optimizer`, :class:`colossalai.nn.optimizer`]): Optimizer to be saved.
lr_scheduler (Union[:class:`torch.optim.lr_scheduler`,
:class:`colossalai.nn.lr_scheduler`], optional):
lr_scheduler to be saved, defaults to None.
lr_scheduler (Union[:class:`torch.optim.lr_scheduler`,
:class:`colossalai.nn.lr_scheduler`], optional):
lr_scheduler to be saved, defaults to None.
pickle_module: module used for pickling metadata and objects
pickle_protocol: can be specified to override the default protocol
"""
...
...
@@ -195,7 +188,7 @@ def load_checkpoint(
):
"""Loads training states from a checkpoint file.
Args:
Args:
file: a file-like object (has to implement read(), readline(), tell(), and seek()), or a string or os.PathLike
object containing a file name.
model (:class:`torch.nn.Module`): Model to load saved weights and buffers.
...
...
@@ -211,9 +204,8 @@ def load_checkpoint(
Raises:
RuntimeError: Raise error if the model/optimizer cannot successfully be recuperated
"""
state_dict
=
(
torch
.
load
(
file
,
map_location
=
torch
.
device
(
"cpu"
))
if
gpc
.
get_local_rank
(
ParallelMode
.
MODEL
)
==
0
else
None
)
state_dict
=
(
torch
.
load
(
file
,
map_location
=
torch
.
device
(
"cpu"
))
if
gpc
.
get_local_rank
(
ParallelMode
.
MODEL
)
==
0
else
None
)
# model states
model_state
=
state_dict
.
pop
(
"model"
)
if
state_dict
is
not
None
else
dict
()
...
...
@@ -231,11 +223,8 @@ def load_checkpoint(
dist
.
gather_object
(
error_msgs
,
all_error_msgs
,
dst
=
dst_rank
,
group
=
gpc
.
get_cpu_group
(
ParallelMode
.
MODEL
))
if
gpc
.
get_global_rank
()
==
0
:
all_error_msgs
=
list
(
chain
.
from_iterable
(
all_error_msgs
))
raise
RuntimeError
(
"Error(s) in loading state_dict for {}:
\n\t
{}"
.
format
(
model
.
__class__
.
__name__
,
"
\n\t
"
.
join
(
all_error_msgs
)
)
)
raise
RuntimeError
(
"Error(s) in loading state_dict for {}:
\n\t
{}"
.
format
(
model
.
__class__
.
__name__
,
"
\n\t
"
.
join
(
all_error_msgs
)))
else
:
raise
e
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment