Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
MMCV
Commits
0d5332a4
Unverified
Commit
0d5332a4
authored
Feb 29, 2020
by
Kai Chen
Committed by
GitHub
Feb 29, 2020
Browse files
use registry to manage hooks (#199)
parent
c2c9fced
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
70 additions
and
49 deletions
+70
-49
mmcv/runner/hooks/__init__.py
mmcv/runner/hooks/__init__.py
+4
-4
mmcv/runner/hooks/checkpoint.py
mmcv/runner/hooks/checkpoint.py
+2
-1
mmcv/runner/hooks/closure.py
mmcv/runner/hooks/closure.py
+2
-1
mmcv/runner/hooks/hook.py
mmcv/runner/hooks/hook.py
+5
-0
mmcv/runner/hooks/iter_timer.py
mmcv/runner/hooks/iter_timer.py
+2
-1
mmcv/runner/hooks/logger/tensorboard.py
mmcv/runner/hooks/logger/tensorboard.py
+3
-1
mmcv/runner/hooks/logger/text.py
mmcv/runner/hooks/logger/text.py
+2
-0
mmcv/runner/hooks/logger/wandb.py
mmcv/runner/hooks/logger/wandb.py
+3
-1
mmcv/runner/hooks/lr_updater.py
mmcv/runner/hooks/lr_updater.py
+7
-1
mmcv/runner/hooks/memory.py
mmcv/runner/hooks/memory.py
+2
-1
mmcv/runner/hooks/optimizer.py
mmcv/runner/hooks/optimizer.py
+2
-1
mmcv/runner/hooks/sampler_seed.py
mmcv/runner/hooks/sampler_seed.py
+2
-1
mmcv/runner/runner.py
mmcv/runner/runner.py
+34
-36
No files found.
mmcv/runner/hooks/__init__.py
View file @
0d5332a4
# Copyright (c) Open-MMLab. All rights reserved.
# Copyright (c) Open-MMLab. All rights reserved.
from
.checkpoint
import
CheckpointHook
from
.checkpoint
import
CheckpointHook
from
.closure
import
ClosureHook
from
.closure
import
ClosureHook
from
.hook
import
Hook
from
.hook
import
HOOKS
,
Hook
from
.iter_timer
import
IterTimerHook
from
.iter_timer
import
IterTimerHook
from
.logger
import
(
LoggerHook
,
TensorboardLoggerHook
,
TextLoggerHook
,
from
.logger
import
(
LoggerHook
,
TensorboardLoggerHook
,
TextLoggerHook
,
WandbLoggerHook
)
WandbLoggerHook
)
...
@@ -11,7 +11,7 @@ from .optimizer import OptimizerHook
...
@@ -11,7 +11,7 @@ from .optimizer import OptimizerHook
from
.sampler_seed
import
DistSamplerSeedHook
from
.sampler_seed
import
DistSamplerSeedHook
__all__
=
[
__all__
=
[
'Hook'
,
'CheckpointHook'
,
'ClosureHook'
,
'LrUpdaterHook'
,
'OptimizerHook'
,
'HOOKS'
,
'Hook'
,
'CheckpointHook'
,
'ClosureHook'
,
'LrUpdaterHook'
,
'IterTimerHook'
,
'DistSamplerSeedHook'
,
'EmptyCacheHook'
,
'LoggerHook'
,
'OptimizerHook'
,
'IterTimerHook'
,
'DistSamplerSeedHook'
,
'EmptyCacheHook'
,
'TextLoggerHook'
,
'TensorboardLoggerHook'
,
'WandbLoggerHook'
'LoggerHook'
,
'TextLoggerHook'
,
'TensorboardLoggerHook'
,
'WandbLoggerHook'
]
]
mmcv/runner/hooks/checkpoint.py
View file @
0d5332a4
# Copyright (c) Open-MMLab. All rights reserved.
# Copyright (c) Open-MMLab. All rights reserved.
from
..dist_utils
import
master_only
from
..dist_utils
import
master_only
from
.hook
import
Hook
from
.hook
import
HOOKS
,
Hook
@
HOOKS
.
register_module
class
CheckpointHook
(
Hook
):
class
CheckpointHook
(
Hook
):
def
__init__
(
self
,
def
__init__
(
self
,
...
...
mmcv/runner/hooks/closure.py
View file @
0d5332a4
# Copyright (c) Open-MMLab. All rights reserved.
# Copyright (c) Open-MMLab. All rights reserved.
from
.hook
import
Hook
from
.hook
import
HOOKS
,
Hook
@
HOOKS
.
register_module
class
ClosureHook
(
Hook
):
class
ClosureHook
(
Hook
):
def
__init__
(
self
,
fn_name
,
fn
):
def
__init__
(
self
,
fn_name
,
fn
):
...
...
mmcv/runner/hooks/hook.py
View file @
0d5332a4
# Copyright (c) Open-MMLab. All rights reserved.
# Copyright (c) Open-MMLab. All rights reserved.
from
mmcv.utils
import
Registry
HOOKS
=
Registry
(
'hook'
)
class
Hook
(
object
):
class
Hook
(
object
):
def
before_run
(
self
,
runner
):
def
before_run
(
self
,
runner
):
...
...
mmcv/runner/hooks/iter_timer.py
View file @
0d5332a4
# Copyright (c) Open-MMLab. All rights reserved.
# Copyright (c) Open-MMLab. All rights reserved.
import
time
import
time
from
.hook
import
Hook
from
.hook
import
HOOKS
,
Hook
@
HOOKS
.
register_module
class
IterTimerHook
(
Hook
):
class
IterTimerHook
(
Hook
):
def
before_epoch
(
self
,
runner
):
def
before_epoch
(
self
,
runner
):
...
...
mmcv/runner/hooks/logger/tensorboard.py
View file @
0d5332a4
...
@@ -3,10 +3,12 @@ import os.path as osp
...
@@ -3,10 +3,12 @@ import os.path as osp
import
torch
import
torch
from
...dist_utils
import
master_only
from
mmcv.runner
import
master_only
from
..hook
import
HOOKS
from
.base
import
LoggerHook
from
.base
import
LoggerHook
@
HOOKS
.
register_module
class
TensorboardLoggerHook
(
LoggerHook
):
class
TensorboardLoggerHook
(
LoggerHook
):
def
__init__
(
self
,
def
__init__
(
self
,
...
...
mmcv/runner/hooks/logger/text.py
View file @
0d5332a4
...
@@ -7,9 +7,11 @@ import torch
...
@@ -7,9 +7,11 @@ import torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
import
mmcv
import
mmcv
from
..hook
import
HOOKS
from
.base
import
LoggerHook
from
.base
import
LoggerHook
@
HOOKS
.
register_module
class
TextLoggerHook
(
LoggerHook
):
class
TextLoggerHook
(
LoggerHook
):
def
__init__
(
self
,
interval
=
10
,
ignore_last
=
True
,
reset_flag
=
False
):
def
__init__
(
self
,
interval
=
10
,
ignore_last
=
True
,
reset_flag
=
False
):
...
...
mmcv/runner/hooks/logger/wandb.py
View file @
0d5332a4
# Copyright (c) Open-MMLab. All rights reserved.
# Copyright (c) Open-MMLab. All rights reserved.
import
numbers
import
numbers
from
...dist_utils
import
master_only
from
mmcv.runner
import
master_only
from
..hook
import
HOOKS
from
.base
import
LoggerHook
from
.base
import
LoggerHook
@
HOOKS
.
register_module
class
WandbLoggerHook
(
LoggerHook
):
class
WandbLoggerHook
(
LoggerHook
):
def
__init__
(
self
,
def
__init__
(
self
,
...
...
mmcv/runner/hooks/lr_updater.py
View file @
0d5332a4
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
from
__future__
import
division
from
__future__
import
division
from
math
import
cos
,
pi
from
math
import
cos
,
pi
from
.hook
import
Hook
from
.hook
import
HOOKS
,
Hook
class
LrUpdaterHook
(
Hook
):
class
LrUpdaterHook
(
Hook
):
...
@@ -88,6 +88,7 @@ class LrUpdaterHook(Hook):
...
@@ -88,6 +88,7 @@ class LrUpdaterHook(Hook):
self
.
_set_lr
(
runner
,
warmup_lr
)
self
.
_set_lr
(
runner
,
warmup_lr
)
@
HOOKS
.
register_module
class
FixedLrUpdaterHook
(
LrUpdaterHook
):
class
FixedLrUpdaterHook
(
LrUpdaterHook
):
def
__init__
(
self
,
**
kwargs
):
def
__init__
(
self
,
**
kwargs
):
...
@@ -97,6 +98,7 @@ class FixedLrUpdaterHook(LrUpdaterHook):
...
@@ -97,6 +98,7 @@ class FixedLrUpdaterHook(LrUpdaterHook):
return
base_lr
return
base_lr
@
HOOKS
.
register_module
class
StepLrUpdaterHook
(
LrUpdaterHook
):
class
StepLrUpdaterHook
(
LrUpdaterHook
):
def
__init__
(
self
,
step
,
gamma
=
0.1
,
**
kwargs
):
def
__init__
(
self
,
step
,
gamma
=
0.1
,
**
kwargs
):
...
@@ -126,6 +128,7 @@ class StepLrUpdaterHook(LrUpdaterHook):
...
@@ -126,6 +128,7 @@ class StepLrUpdaterHook(LrUpdaterHook):
return
base_lr
*
self
.
gamma
**
exp
return
base_lr
*
self
.
gamma
**
exp
@
HOOKS
.
register_module
class
ExpLrUpdaterHook
(
LrUpdaterHook
):
class
ExpLrUpdaterHook
(
LrUpdaterHook
):
def
__init__
(
self
,
gamma
,
**
kwargs
):
def
__init__
(
self
,
gamma
,
**
kwargs
):
...
@@ -137,6 +140,7 @@ class ExpLrUpdaterHook(LrUpdaterHook):
...
@@ -137,6 +140,7 @@ class ExpLrUpdaterHook(LrUpdaterHook):
return
base_lr
*
self
.
gamma
**
progress
return
base_lr
*
self
.
gamma
**
progress
@
HOOKS
.
register_module
class
PolyLrUpdaterHook
(
LrUpdaterHook
):
class
PolyLrUpdaterHook
(
LrUpdaterHook
):
def
__init__
(
self
,
power
=
1.
,
min_lr
=
0.
,
**
kwargs
):
def
__init__
(
self
,
power
=
1.
,
min_lr
=
0.
,
**
kwargs
):
...
@@ -155,6 +159,7 @@ class PolyLrUpdaterHook(LrUpdaterHook):
...
@@ -155,6 +159,7 @@ class PolyLrUpdaterHook(LrUpdaterHook):
return
(
base_lr
-
self
.
min_lr
)
*
coeff
+
self
.
min_lr
return
(
base_lr
-
self
.
min_lr
)
*
coeff
+
self
.
min_lr
@
HOOKS
.
register_module
class
InvLrUpdaterHook
(
LrUpdaterHook
):
class
InvLrUpdaterHook
(
LrUpdaterHook
):
def
__init__
(
self
,
gamma
,
power
=
1.
,
**
kwargs
):
def
__init__
(
self
,
gamma
,
power
=
1.
,
**
kwargs
):
...
@@ -167,6 +172,7 @@ class InvLrUpdaterHook(LrUpdaterHook):
...
@@ -167,6 +172,7 @@ class InvLrUpdaterHook(LrUpdaterHook):
return
base_lr
*
(
1
+
self
.
gamma
*
progress
)
**
(
-
self
.
power
)
return
base_lr
*
(
1
+
self
.
gamma
*
progress
)
**
(
-
self
.
power
)
@
HOOKS
.
register_module
class
CosineLrUpdaterHook
(
LrUpdaterHook
):
class
CosineLrUpdaterHook
(
LrUpdaterHook
):
def
__init__
(
self
,
target_lr
=
0
,
**
kwargs
):
def
__init__
(
self
,
target_lr
=
0
,
**
kwargs
):
...
...
mmcv/runner/hooks/memory.py
View file @
0d5332a4
# Copyright (c) Open-MMLab. All rights reserved.
# Copyright (c) Open-MMLab. All rights reserved.
import
torch
import
torch
from
.hook
import
Hook
from
.hook
import
HOOKS
,
Hook
@
HOOKS
.
register_module
class
EmptyCacheHook
(
Hook
):
class
EmptyCacheHook
(
Hook
):
def
__init__
(
self
,
before_epoch
=
False
,
after_epoch
=
True
,
after_iter
=
False
):
def
__init__
(
self
,
before_epoch
=
False
,
after_epoch
=
True
,
after_iter
=
False
):
...
...
mmcv/runner/hooks/optimizer.py
View file @
0d5332a4
# Copyright (c) Open-MMLab. All rights reserved.
# Copyright (c) Open-MMLab. All rights reserved.
from
torch.nn.utils
import
clip_grad
from
torch.nn.utils
import
clip_grad
from
.hook
import
Hook
from
.hook
import
HOOKS
,
Hook
@
HOOKS
.
register_module
class
OptimizerHook
(
Hook
):
class
OptimizerHook
(
Hook
):
def
__init__
(
self
,
grad_clip
=
None
):
def
__init__
(
self
,
grad_clip
=
None
):
...
...
mmcv/runner/hooks/sampler_seed.py
View file @
0d5332a4
# Copyright (c) Open-MMLab. All rights reserved.
# Copyright (c) Open-MMLab. All rights reserved.
from
.hook
import
Hook
from
.hook
import
HOOKS
,
Hook
@
HOOKS
.
register_module
class
DistSamplerSeedHook
(
Hook
):
class
DistSamplerSeedHook
(
Hook
):
def
before_epoch
(
self
,
runner
):
def
before_epoch
(
self
,
runner
):
...
...
mmcv/runner/runner.py
View file @
0d5332a4
...
@@ -6,11 +6,9 @@ import time
...
@@ -6,11 +6,9 @@ import time
import
torch
import
torch
import
mmcv
import
mmcv
from
.
import
hooks
from
.checkpoint
import
load_checkpoint
,
save_checkpoint
from
.checkpoint
import
load_checkpoint
,
save_checkpoint
from
.dist_utils
import
get_dist_info
from
.dist_utils
import
get_dist_info
from
.hooks
import
(
CheckpointHook
,
Hook
,
IterTimerHook
,
LrUpdaterHook
,
from
.hooks
import
HOOKS
,
Hook
,
IterTimerHook
OptimizerHook
,
lr_updater
)
from
.log_buffer
import
LogBuffer
from
.log_buffer
import
LogBuffer
from
.priority
import
get_priority
from
.priority
import
get_priority
from
.utils
import
get_host_info
,
get_time_str
,
obj_from_dict
from
.utils
import
get_host_info
,
get_time_str
,
obj_from_dict
...
@@ -223,16 +221,6 @@ class Runner(object):
...
@@ -223,16 +221,6 @@ class Runner(object):
if
not
inserted
:
if
not
inserted
:
self
.
_hooks
.
insert
(
0
,
hook
)
self
.
_hooks
.
insert
(
0
,
hook
)
def
build_hook
(
self
,
args
,
hook_type
=
None
):
if
isinstance
(
args
,
Hook
):
return
args
elif
isinstance
(
args
,
dict
):
assert
issubclass
(
hook_type
,
Hook
)
return
hook_type
(
**
args
)
else
:
raise
TypeError
(
'"args" must be either a Hook object'
' or dict, not {}'
.
format
(
type
(
args
)))
def
call_hook
(
self
,
fn_name
):
def
call_hook
(
self
,
fn_name
):
for
hook
in
self
.
_hooks
:
for
hook
in
self
.
_hooks
:
getattr
(
hook
,
fn_name
)(
self
)
getattr
(
hook
,
fn_name
)(
self
)
...
@@ -373,26 +361,41 @@ class Runner(object):
...
@@ -373,26 +361,41 @@ class Runner(object):
time
.
sleep
(
1
)
# wait for some hooks like loggers to finish
time
.
sleep
(
1
)
# wait for some hooks like loggers to finish
self
.
call_hook
(
'after_run'
)
self
.
call_hook
(
'after_run'
)
def
register_lr_hooks
(
self
,
lr_config
):
def
register_lr_hook
(
self
,
lr_config
):
if
isinstance
(
lr_config
,
LrUpdaterHook
):
if
isinstance
(
lr_config
,
dict
):
self
.
register_hook
(
lr_config
)
elif
isinstance
(
lr_config
,
dict
):
assert
'policy'
in
lr_config
assert
'policy'
in
lr_config
# from .hooks import lr_updater
hook_type
=
lr_config
.
pop
(
'policy'
).
title
()
+
'LrUpdaterHook'
hook_name
=
lr_config
[
'policy'
].
title
()
+
'LrUpdaterHook'
lr_config
[
'type'
]
=
hook_type
if
not
hasattr
(
lr_updater
,
hook_name
):
hook
=
mmcv
.
build_from_cfg
(
lr_config
,
HOOKS
)
raise
ValueError
(
'"{}" does not exist'
.
format
(
hook_name
))
else
:
hook_cls
=
getattr
(
lr_updater
,
hook_name
)
hook
=
lr_config
self
.
register_hook
(
hook_cls
(
**
lr_config
))
self
.
register_hook
(
hook
)
def
register_optimizer_hook
(
self
,
optimizer_config
):
if
optimizer_config
is
None
:
return
if
isinstance
(
optimizer_config
,
dict
):
optimizer_config
.
setdefault
(
'type'
,
'OptimizerHook'
)
hook
=
mmcv
.
build_from_cfg
(
optimizer_config
,
HOOKS
)
else
:
else
:
raise
TypeError
(
'"lr_config" must be either a LrUpdaterHook object'
hook
=
optimizer_config
' or dict, not {}'
.
format
(
type
(
lr_config
)))
self
.
register_hook
(
hook
)
def
register_checkpoint_hook
(
self
,
checkpoint_config
):
if
checkpoint_config
is
None
:
return
if
isinstance
(
checkpoint_config
,
dict
):
checkpoint_config
.
setdefault
(
'type'
,
'CheckpointHook'
)
hook
=
mmcv
.
build_from_cfg
(
checkpoint_config
,
HOOKS
)
else
:
hook
=
checkpoint_config
self
.
register_hook
(
hook
)
def
register_logger_hooks
(
self
,
log_config
):
def
register_logger_hooks
(
self
,
log_config
):
log_interval
=
log_config
[
'interval'
]
log_interval
=
log_config
[
'interval'
]
for
info
in
log_config
[
'hooks'
]:
for
info
in
log_config
[
'hooks'
]:
logger_hook
=
obj
_from_
dict
(
logger_hook
=
mmcv
.
build
_from_
cfg
(
info
,
hooks
,
default_args
=
dict
(
interval
=
log_interval
))
info
,
HOOKS
,
default_args
=
dict
(
interval
=
log_interval
))
self
.
register_hook
(
logger_hook
,
priority
=
'VERY_LOW'
)
self
.
register_hook
(
logger_hook
,
priority
=
'VERY_LOW'
)
def
register_training_hooks
(
self
,
def
register_training_hooks
(
self
,
...
@@ -410,13 +413,8 @@ class Runner(object):
...
@@ -410,13 +413,8 @@ class Runner(object):
- IterTimerHook
- IterTimerHook
- LoggerHook(s)
- LoggerHook(s)
"""
"""
if
optimizer_config
is
None
:
self
.
register_lr_hook
(
lr_config
)
optimizer_config
=
{}
self
.
register_optimizer_hook
(
optimizer_config
)
if
checkpoint_config
is
None
:
self
.
register_checkpoint_hook
(
checkpoint_config
)
checkpoint_config
=
{}
self
.
register_lr_hooks
(
lr_config
)
self
.
register_hook
(
self
.
build_hook
(
optimizer_config
,
OptimizerHook
))
self
.
register_hook
(
self
.
build_hook
(
checkpoint_config
,
CheckpointHook
))
self
.
register_hook
(
IterTimerHook
())
self
.
register_hook
(
IterTimerHook
())
if
log_config
is
not
None
:
self
.
register_logger_hooks
(
log_config
)
self
.
register_logger_hooks
(
log_config
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment