Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
80e37eec
Unverified
Commit
80e37eec
authored
Apr 14, 2022
by
LuGY
Committed by
GitHub
Apr 14, 2022
Browse files
fix the ckpt bugs when using DDP (#769)
parent
1f698f44
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
29 additions
and
2 deletions
+29
-2
colossalai/trainer/hooks/_checkpoint_hook.py
colossalai/trainer/hooks/_checkpoint_hook.py
+29
-2
No files found.
colossalai/trainer/hooks/_checkpoint_hook.py
View file @
80e37eec
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import
torch
from
colossalai.logging
import
get_dist_logger
from
colossalai.registry
import
HOOKS
...
...
@@ -15,7 +15,12 @@ class SaveCheckpointHook(BaseHook):
Args:
interval (int, optional): Number of epochs between saving the checkpoint, defaults to 1.
if save_by_iter is True, this arg refers to the number of iters between saving.
checkpoint_dir (str, optional): File name to save the checkpoint, defaults to None.
model (torch.nn.Module, Optional): The model to save, defaults to None. When not passing,
'trainer.engine.model' will be used. We encourage you to pass the model in it to avoid some
unexpected bugs, especially when using **DDP**.
save_by_iter (bool, optional): Whether saving the checkpoint by iter, default to False.
priority (int, optional): Priority in the printing, hooks with small priority will be printed in front
defaults to 10. If different hooks share same priority, the order of printing would
depend on the hooks order in the hook list.
...
...
@@ -24,10 +29,14 @@ class SaveCheckpointHook(BaseHook):
def
__init__
(
self
,
interval
:
int
=
1
,
checkpoint_dir
:
str
=
None
,
model
:
torch
.
nn
.
Module
=
None
,
save_by_iter
:
bool
=
False
,
priority
:
int
=
10
):
super
().
__init__
(
priority
=
priority
)
self
.
interval
=
interval
self
.
checkpoint_dir
=
checkpoint_dir
self
.
model
=
model
self
.
save_by_iter
=
save_by_iter
self
.
logger
=
get_dist_logger
()
# get lr scheduler from the LRSchedulerHook before train
...
...
@@ -39,6 +48,24 @@ class SaveCheckpointHook(BaseHook):
if
isinstance
(
hook
,
LRSchedulerHook
):
self
.
_lr_scheduler
=
hook
.
lr_scheduler
break
self
.
model
=
self
.
model
if
self
.
model
is
not
None
else
trainer
.
engine
.
model
def
after_train_iter
(
self
,
trainer
,
output
,
label
,
loss
):
"""Saves the model after a training iter.
"""
# save by interval
if
self
.
save_by_iter
and
trainer
.
cur_step
%
self
.
interval
==
0
:
save_checkpoint
(
self
.
checkpoint_dir
,
trainer
.
cur_epoch
,
self
.
model
,
trainer
.
engine
.
optimizer
,
self
.
_lr_scheduler
)
self
.
logger
.
info
(
f
'checkpoint for iteration
{
trainer
.
cur_step
}
is saved to
{
self
.
checkpoint_dir
}
'
,
ranks
=
[
0
])
else
:
pass
def
after_train_epoch
(
self
,
trainer
):
"""Saves the model after a training epoch.
...
...
@@ -47,7 +74,7 @@ class SaveCheckpointHook(BaseHook):
if
trainer
.
cur_epoch
%
self
.
interval
==
0
:
save_checkpoint
(
self
.
checkpoint_dir
,
trainer
.
cur_epoch
,
trainer
.
engine
.
model
,
self
.
model
,
trainer
.
engine
.
optimizer
,
self
.
_lr_scheduler
)
self
.
logger
.
info
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment