Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
MMCV
Commits
27c81690
Unverified
Commit
27c81690
authored
Sep 21, 2018
by
Kai Chen
Committed by
GitHub
Sep 21, 2018
Browse files
Merge pull request #7 from OceanPang/env
fix checkpoint & runner bugs
parents
923091b5
818c40c3
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
22 additions
and
13 deletions
+22
-13
mmcv/torchpack/hooks/__init__.py
mmcv/torchpack/hooks/__init__.py
+2
-2
mmcv/torchpack/hooks/checkpoint_saver.py
mmcv/torchpack/hooks/checkpoint_saver.py
+1
-1
mmcv/torchpack/hooks/optimizer_stepper.py
mmcv/torchpack/hooks/optimizer_stepper.py
+1
-1
mmcv/torchpack/runner/runner.py
mmcv/torchpack/runner/runner.py
+18
-9
No files found.
mmcv/torchpack/hooks/__init__.py
View file @
27c81690
from
.hook
import
Hook
from
.checkpoint_saver
import
Checkpoint
Saver
Hook
from
.checkpoint_saver
import
CheckpointHook
from
.closure
import
ClosureHook
from
.lr_updater
import
LrUpdaterHook
from
.optimizer_stepper
import
Optimizer
Stepper
Hook
from
.optimizer_stepper
import
OptimizerHook
from
.iter_timer
import
IterTimerHook
from
.logger
import
*
mmcv/torchpack/hooks/checkpoint_saver.py
View file @
27c81690
...
...
@@ -2,7 +2,7 @@ from .hook import Hook
from
..utils
import
master_only
class
Checkpoint
Saver
Hook
(
Hook
):
class
CheckpointHook
(
Hook
):
def
__init__
(
self
,
interval
=-
1
,
...
...
mmcv/torchpack/hooks/optimizer_stepper.py
View file @
27c81690
...
...
@@ -3,7 +3,7 @@ from torch.nn.utils import clip_grad
from
.hook
import
Hook
class
Optimizer
Stepper
Hook
(
Hook
):
class
OptimizerHook
(
Hook
):
def
__init__
(
self
,
grad_clip
=
False
,
max_norm
=
35
,
norm_type
=
2
):
self
.
grad_clip
=
grad_clip
...
...
mmcv/torchpack/runner/runner.py
View file @
27c81690
...
...
@@ -8,8 +8,8 @@ from torch.nn.parallel import DataParallel, DistributedDataParallel
from
.log_buffer
import
LogBuffer
from
..
import
hooks
from
..hooks
import
(
Hook
,
LrUpdaterHook
,
Checkpoint
Saver
Hook
,
IterTimerHook
,
Optimizer
Stepper
Hook
)
from
..hooks
import
(
Hook
,
LrUpdaterHook
,
CheckpointHook
,
IterTimerHook
,
OptimizerHook
)
from
..io
import
load_checkpoint
,
save_checkpoint
from
..utils
import
(
get_dist_info
,
get_host_info
,
get_time_str
,
add_file_handler
,
obj_from_dict
)
...
...
@@ -182,6 +182,16 @@ class Runner(object):
if
not
inserted
:
self
.
_hooks
.
insert
(
0
,
hook
)
def
build_hook
(
self
,
args
,
hook_type
=
None
):
if
isinstance
(
args
,
Hook
):
return
args
elif
isinstance
(
args
,
dict
):
assert
issubclass
(
hook_type
,
Hook
)
return
hook_type
(
**
args
)
else
:
raise
TypeError
(
'"args" must be either a Hook object'
' or dict, not {}'
.
format
(
type
(
args
)))
def
call_hook
(
self
,
fn_name
):
for
hook
in
self
.
_hooks
:
getattr
(
hook
,
fn_name
)(
self
)
...
...
@@ -201,7 +211,7 @@ class Runner(object):
else
:
meta
.
update
(
epoch
=
self
.
epoch
+
1
,
iter
=
self
.
iter
)
filename
=
osp
.
join
(
out_dir
,
filename_tmpl
.
format
(
self
.
epoch
))
filename
=
osp
.
join
(
out_dir
,
filename_tmpl
.
format
(
self
.
epoch
+
1
))
linkname
=
osp
.
join
(
out_dir
,
'latest.pth'
)
optimizer
=
self
.
optimizer
if
save_optimizer
else
None
save_checkpoint
(
self
.
model
,
filename
,
optimizer
=
optimizer
,
meta
=
meta
)
...
...
@@ -213,7 +223,6 @@ class Runner(object):
self
.
data_loader
=
data_loader
self
.
_max_iters
=
self
.
_max_epochs
*
len
(
data_loader
)
self
.
call_hook
(
'before_train_epoch'
)
for
i
,
data_batch
in
enumerate
(
data_loader
):
self
.
_inner_iter
=
i
self
.
call_hook
(
'before_train_iter'
)
...
...
@@ -330,7 +339,7 @@ class Runner(object):
def
register_training_hooks
(
self
,
lr_config
,
grad_clip
_config
=
None
,
optimizer
_config
=
None
,
checkpoint_config
=
None
,
log_config
=
None
):
"""Register default hooks for training.
...
...
@@ -342,13 +351,13 @@ class Runner(object):
- IterTimerHook
- LoggerHook
"""
if
grad_clip
_config
is
None
:
grad_clip
_config
=
{}
if
optimizer
_config
is
None
:
optimizer
_config
=
{}
if
checkpoint_config
is
None
:
checkpoint_config
=
{}
self
.
register_lr_hooks
(
lr_config
)
self
.
register_hook
(
OptimizerStepperHook
(
**
grad_clip_config
))
self
.
register_hook
(
CheckpointSaverH
ook
(
**
checkpoint_config
))
self
.
register_hook
(
self
.
build_hook
(
optimizer_config
,
OptimizerHook
))
self
.
register_hook
(
self
.
build_h
ook
(
checkpoint_config
,
CheckpointHook
))
self
.
register_hook
(
IterTimerHook
())
if
log_config
is
not
None
:
self
.
register_logger_hooks
(
log_config
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment