Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
MMCV
Commits
27c81690
Unverified
Commit
27c81690
authored
Sep 21, 2018
by
Kai Chen
Committed by
GitHub
Sep 21, 2018
Browse files
Merge pull request #7 from OceanPang/env
fix checkpoint & runner bugs
parents
923091b5
818c40c3
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
22 additions
and
13 deletions
+22
-13
mmcv/torchpack/hooks/__init__.py
mmcv/torchpack/hooks/__init__.py
+2
-2
mmcv/torchpack/hooks/checkpoint_saver.py
mmcv/torchpack/hooks/checkpoint_saver.py
+1
-1
mmcv/torchpack/hooks/optimizer_stepper.py
mmcv/torchpack/hooks/optimizer_stepper.py
+1
-1
mmcv/torchpack/runner/runner.py
mmcv/torchpack/runner/runner.py
+18
-9
No files found.
mmcv/torchpack/hooks/__init__.py
View file @
27c81690
from
.hook
import
Hook
from
.hook
import
Hook
from
.checkpoint_saver
import
Checkpoint
Saver
Hook
from
.checkpoint_saver
import
CheckpointHook
from
.closure
import
ClosureHook
from
.closure
import
ClosureHook
from
.lr_updater
import
LrUpdaterHook
from
.lr_updater
import
LrUpdaterHook
from
.optimizer_stepper
import
Optimizer
Stepper
Hook
from
.optimizer_stepper
import
OptimizerHook
from
.iter_timer
import
IterTimerHook
from
.iter_timer
import
IterTimerHook
from
.logger
import
*
from
.logger
import
*
mmcv/torchpack/hooks/checkpoint_saver.py
View file @
27c81690
...
@@ -2,7 +2,7 @@ from .hook import Hook
...
@@ -2,7 +2,7 @@ from .hook import Hook
from
..utils
import
master_only
from
..utils
import
master_only
class
Checkpoint
Saver
Hook
(
Hook
):
class
CheckpointHook
(
Hook
):
def
__init__
(
self
,
def
__init__
(
self
,
interval
=-
1
,
interval
=-
1
,
...
...
mmcv/torchpack/hooks/optimizer_stepper.py
View file @
27c81690
...
@@ -3,7 +3,7 @@ from torch.nn.utils import clip_grad
...
@@ -3,7 +3,7 @@ from torch.nn.utils import clip_grad
from
.hook
import
Hook
from
.hook
import
Hook
class
Optimizer
Stepper
Hook
(
Hook
):
class
OptimizerHook
(
Hook
):
def
__init__
(
self
,
grad_clip
=
False
,
max_norm
=
35
,
norm_type
=
2
):
def
__init__
(
self
,
grad_clip
=
False
,
max_norm
=
35
,
norm_type
=
2
):
self
.
grad_clip
=
grad_clip
self
.
grad_clip
=
grad_clip
...
...
mmcv/torchpack/runner/runner.py
View file @
27c81690
...
@@ -8,8 +8,8 @@ from torch.nn.parallel import DataParallel, DistributedDataParallel
...
@@ -8,8 +8,8 @@ from torch.nn.parallel import DataParallel, DistributedDataParallel
from
.log_buffer
import
LogBuffer
from
.log_buffer
import
LogBuffer
from
..
import
hooks
from
..
import
hooks
from
..hooks
import
(
Hook
,
LrUpdaterHook
,
Checkpoint
Saver
Hook
,
IterTimerHook
,
from
..hooks
import
(
Hook
,
LrUpdaterHook
,
CheckpointHook
,
IterTimerHook
,
Optimizer
Stepper
Hook
)
OptimizerHook
)
from
..io
import
load_checkpoint
,
save_checkpoint
from
..io
import
load_checkpoint
,
save_checkpoint
from
..utils
import
(
get_dist_info
,
get_host_info
,
get_time_str
,
from
..utils
import
(
get_dist_info
,
get_host_info
,
get_time_str
,
add_file_handler
,
obj_from_dict
)
add_file_handler
,
obj_from_dict
)
...
@@ -182,6 +182,16 @@ class Runner(object):
...
@@ -182,6 +182,16 @@ class Runner(object):
if
not
inserted
:
if
not
inserted
:
self
.
_hooks
.
insert
(
0
,
hook
)
self
.
_hooks
.
insert
(
0
,
hook
)
def
build_hook
(
self
,
args
,
hook_type
=
None
):
if
isinstance
(
args
,
Hook
):
return
args
elif
isinstance
(
args
,
dict
):
assert
issubclass
(
hook_type
,
Hook
)
return
hook_type
(
**
args
)
else
:
raise
TypeError
(
'"args" must be either a Hook object'
' or dict, not {}'
.
format
(
type
(
args
)))
def
call_hook
(
self
,
fn_name
):
def
call_hook
(
self
,
fn_name
):
for
hook
in
self
.
_hooks
:
for
hook
in
self
.
_hooks
:
getattr
(
hook
,
fn_name
)(
self
)
getattr
(
hook
,
fn_name
)(
self
)
...
@@ -201,7 +211,7 @@ class Runner(object):
...
@@ -201,7 +211,7 @@ class Runner(object):
else
:
else
:
meta
.
update
(
epoch
=
self
.
epoch
+
1
,
iter
=
self
.
iter
)
meta
.
update
(
epoch
=
self
.
epoch
+
1
,
iter
=
self
.
iter
)
filename
=
osp
.
join
(
out_dir
,
filename_tmpl
.
format
(
self
.
epoch
))
filename
=
osp
.
join
(
out_dir
,
filename_tmpl
.
format
(
self
.
epoch
+
1
))
linkname
=
osp
.
join
(
out_dir
,
'latest.pth'
)
linkname
=
osp
.
join
(
out_dir
,
'latest.pth'
)
optimizer
=
self
.
optimizer
if
save_optimizer
else
None
optimizer
=
self
.
optimizer
if
save_optimizer
else
None
save_checkpoint
(
self
.
model
,
filename
,
optimizer
=
optimizer
,
meta
=
meta
)
save_checkpoint
(
self
.
model
,
filename
,
optimizer
=
optimizer
,
meta
=
meta
)
...
@@ -213,7 +223,6 @@ class Runner(object):
...
@@ -213,7 +223,6 @@ class Runner(object):
self
.
data_loader
=
data_loader
self
.
data_loader
=
data_loader
self
.
_max_iters
=
self
.
_max_epochs
*
len
(
data_loader
)
self
.
_max_iters
=
self
.
_max_epochs
*
len
(
data_loader
)
self
.
call_hook
(
'before_train_epoch'
)
self
.
call_hook
(
'before_train_epoch'
)
for
i
,
data_batch
in
enumerate
(
data_loader
):
for
i
,
data_batch
in
enumerate
(
data_loader
):
self
.
_inner_iter
=
i
self
.
_inner_iter
=
i
self
.
call_hook
(
'before_train_iter'
)
self
.
call_hook
(
'before_train_iter'
)
...
@@ -330,7 +339,7 @@ class Runner(object):
...
@@ -330,7 +339,7 @@ class Runner(object):
def
register_training_hooks
(
self
,
def
register_training_hooks
(
self
,
lr_config
,
lr_config
,
grad_clip
_config
=
None
,
optimizer
_config
=
None
,
checkpoint_config
=
None
,
checkpoint_config
=
None
,
log_config
=
None
):
log_config
=
None
):
"""Register default hooks for training.
"""Register default hooks for training.
...
@@ -342,13 +351,13 @@ class Runner(object):
...
@@ -342,13 +351,13 @@ class Runner(object):
- IterTimerHook
- IterTimerHook
- LoggerHook
- LoggerHook
"""
"""
if
grad_clip
_config
is
None
:
if
optimizer
_config
is
None
:
grad_clip
_config
=
{}
optimizer
_config
=
{}
if
checkpoint_config
is
None
:
if
checkpoint_config
is
None
:
checkpoint_config
=
{}
checkpoint_config
=
{}
self
.
register_lr_hooks
(
lr_config
)
self
.
register_lr_hooks
(
lr_config
)
self
.
register_hook
(
OptimizerStepperHook
(
**
grad_clip_config
))
self
.
register_hook
(
self
.
build_hook
(
optimizer_config
,
OptimizerHook
))
self
.
register_hook
(
CheckpointSaverH
ook
(
**
checkpoint_config
))
self
.
register_hook
(
self
.
build_h
ook
(
checkpoint_config
,
CheckpointHook
))
self
.
register_hook
(
IterTimerHook
())
self
.
register_hook
(
IterTimerHook
())
if
log_config
is
not
None
:
if
log_config
is
not
None
:
self
.
register_logger_hooks
(
log_config
)
self
.
register_logger_hooks
(
log_config
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment