Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
28b515d6
Unverified
Commit
28b515d6
authored
Apr 01, 2022
by
アマデウス
Committed by
GitHub
Apr 01, 2022
Browse files
[model checkpoint] updated checkpoint hook (#598)
parent
77ad24bf
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
14 additions
and
92 deletions
+14
-92
colossalai/trainer/hooks/__init__.py
colossalai/trainer/hooks/__init__.py
+4
-4
colossalai/trainer/hooks/_checkpoint_hook.py
colossalai/trainer/hooks/_checkpoint_hook.py
+10
-88
No files found.
colossalai/trainer/hooks/__init__.py
View file @
28b515d6
from
._base_hook
import
BaseHook
from
._checkpoint_hook
import
LoadCheckpointHook
,
SaveCheckpointHook
from
._checkpoint_hook
import
SaveCheckpointHook
from
._log_hook
import
(
LogMemoryByEpochHook
,
LogMetricByEpochHook
,
LogMetricByStepHook
,
LogTimingByEpochHook
,
TensorboardHook
)
from
._lr_scheduler_hook
import
LRSchedulerHook
from
._metric_hook
import
AccuracyHook
,
LossHook
,
MetricHook
,
ThroughputHook
__all__
=
[
'BaseHook'
,
'MetricHook'
,
'Lo
adCheckpointHook'
,
'SaveCheckpointHook'
,
'LossHook'
,
'Accuracy
Hook'
,
'LogMetricByEpochHook'
,
'TensorboardHook'
,
'LogTimingByEpochHook'
,
'LogMemoryByEpochHook'
,
'LRSchedulerHook'
,
'
ThroughputHook'
,
'LogMetricByStep
Hook'
'BaseHook'
,
'MetricHook'
,
'Lo
ssHook'
,
'AccuracyHook'
,
'LogMetricByEpochHook'
,
'Tensorboard
Hook'
,
'LogTimingByEpochHook'
,
'LogMemoryByEpochHook'
,
'LRSchedulerHook'
,
'ThroughputHook'
,
'LogMetricByStepHook'
,
'
SaveCheckpoint
Hook'
]
colossalai/trainer/hooks/_checkpoint_hook.py
View file @
28b515d6
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import
os.path
as
osp
from
colossalai.logging
import
get_dist_logger
from
colossalai.registry
import
HOOKS
from
colossalai.trainer.hooks
import
BaseHook
from
colossalai.utils
import
is_dp_rank_0
from
colossalai.utils.checkpointing
import
get_latest_checkpoint_path
,
get_checkpoint_path
from
colossalai.utils.checkpointing
import
save_checkpoint
,
load_checkpoint
from
colossalai.utils.checkpointing
import
save_checkpoint
from
._lr_scheduler_hook
import
LRSchedulerHook
...
...
@@ -17,9 +14,8 @@ class SaveCheckpointHook(BaseHook):
"""Saves the model by interval in training process.
Args:
interval (int, optional): Saving interval, defaults to 1.
checkpoint_dir (str, optional): Directory of saving checkpoint, defaults to None.
suffix (str, optional): Saving suffix of the file, defaults to ''.
interval (int, optional): Number of epochs between saving the checkpoint, defaults to 1.
checkpoint_dir (str, optional): File name to save the checkpoint, defaults to None.
priority (int, optional): Priority in the printing, hooks with small priority will be printed in front
defaults to 10. If different hooks share same priority, the order of printing would
depend on the hooks order in the hook list.
...
...
@@ -28,19 +24,17 @@ class SaveCheckpointHook(BaseHook):
def
__init__
(
self
,
interval
:
int
=
1
,
checkpoint_dir
:
str
=
None
,
suffix
:
str
=
''
,
priority
:
int
=
10
):
super
().
__init__
(
priority
=
priority
)
self
.
interval
=
interval
self
.
checkpoint_dir
=
checkpoint_dir
self
.
suffix
=
suffix
self
.
logger
=
get_dist_logger
()
# get lr scheduler from the LRSchedulerHook before train
self
.
_lr_scheduler
=
None
def
after_hook_is_attached
(
self
,
trainer
):
#
check if
lr scheduler i
s present in LRSchedulerHook
#
get
lr scheduler i
f exists
for
hook
in
trainer
.
hooks
:
if
isinstance
(
hook
,
LRSchedulerHook
):
self
.
_lr_scheduler
=
hook
.
lr_scheduler
...
...
@@ -51,82 +45,10 @@ class SaveCheckpointHook(BaseHook):
"""
# save by interval
if
trainer
.
cur_epoch
%
self
.
interval
==
0
:
# only gpus with data parallel rank equals to 0 write to the disk
if
is_dp_rank_0
():
save_path
=
get_checkpoint_path
(
self
.
checkpoint_dir
,
trainer
.
cur_epoch
,
suffix
=
self
.
suffix
)
save_checkpoint
(
save_path
,
trainer
.
cur_epoch
,
trainer
.
engine
.
model
,
trainer
.
engine
.
optimizer
,
self
.
_lr_scheduler
)
self
.
logger
.
info
(
f
'checkpoint for epoch
{
trainer
.
cur_epoch
}
is saved to
{
self
.
checkpoint_dir
}
'
,
ranks
=
[
0
])
@
HOOKS
.
register_module
class
LoadCheckpointHook
(
BaseHook
):
"""Loads the model before training process.
Args:
checkpoint_dir (str, optional): Directory of saving checkpoint, defaults to None.
epoch (str, optional): Loading checkpoint of setting epoch numbers, defaults to -1.
Epoch equals to -1 means choosing the latest checkpoint.
finetune (bool, optional): Whether allows to load a part of the model, defaults to False.
strict (bool, optional): Whether to strictly enforce that the keys in :attr:`state_dict` of the checkpoint
match the names of parameters and buffers in model, defaults to False.
suffix (str, optional): Suffix of checkpoint file path, defaults to ''.
priority (int, optional): Priority in the printing, hooks with small priority will be printed in front,
defaults to 0. If different hooks share same priority, the order of printing would
depend on the hooks order in the hook list.
"""
def
__init__
(
self
,
checkpoint_dir
:
str
=
None
,
epoch
:
int
=
-
1
,
finetune
:
bool
=
False
,
strict
:
bool
=
False
,
suffix
:
str
=
''
,
priority
:
int
=
0
)
->
None
:
super
().
__init__
(
priority
=
priority
)
self
.
epoch
=
epoch
self
.
checkpoint_dir
=
checkpoint_dir
self
.
finetune
=
finetune
self
.
suffix
=
suffix
self
.
strict
=
strict
self
.
logger
=
get_dist_logger
()
def
before_train
(
self
,
trainer
):
"""Loads parameters to the model before training.
"""
# check if lr scheduler is present in LRSchedulerHook
lr_scheduler
=
None
for
hook
in
trainer
.
hooks
:
if
isinstance
(
hook
,
LRSchedulerHook
):
lr_scheduler
=
hook
.
lr_scheduler
break
# use latest checkpoint if epoch = -1
if
self
.
epoch
==
-
1
:
path
=
get_latest_checkpoint_path
(
self
.
checkpoint_dir
,
suffix
=
self
.
suffix
)
else
:
path
=
get_checkpoint_path
(
self
.
checkpoint_dir
,
epoch
=
self
.
epoch
,
suffix
=
self
.
suffix
)
if
osp
.
exists
(
path
):
last_epoch
,
_
=
load_checkpoint
(
path
,
trainer
.
engine
.
model
,
trainer
.
engine
.
optimizer
,
lr_scheduler
,
finetune
=
self
.
finetune
,
strict
=
self
.
strict
)
if
self
.
finetune
:
trainer
.
cur_epoch
=
0
else
:
trainer
.
cur_epoch
=
last_epoch
save_checkpoint
(
self
.
checkpoint_dir
,
trainer
.
cur_epoch
,
trainer
.
engine
.
model
,
trainer
.
engine
.
optimizer
,
self
.
_lr_scheduler
)
self
.
logger
.
info
(
f
'loaded checkpoint from
{
path
}
'
,
ranks
=
[
0
])
else
:
raise
FileNotFoundError
(
f
'checkpoint is not found at
{
path
}
'
)
f
'checkpoint for epoch
{
trainer
.
cur_epoch
}
is saved to
{
self
.
checkpoint_dir
}
'
,
ranks
=
[
0
])
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment