Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
MMCV
Commits
0d5332a4
"tests/compute/test_basics.py" did not exist on "c42eac718fe09aae40b7f8dfd8c3466b08e0f0be"
Unverified
Commit
0d5332a4
authored
Feb 29, 2020
by
Kai Chen
Committed by
GitHub
Feb 29, 2020
Browse files
use registry to manage hooks (#199)
parent
c2c9fced
Changes
13
Show whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
70 additions
and
49 deletions
+70
-49
mmcv/runner/hooks/__init__.py
mmcv/runner/hooks/__init__.py
+4
-4
mmcv/runner/hooks/checkpoint.py
mmcv/runner/hooks/checkpoint.py
+2
-1
mmcv/runner/hooks/closure.py
mmcv/runner/hooks/closure.py
+2
-1
mmcv/runner/hooks/hook.py
mmcv/runner/hooks/hook.py
+5
-0
mmcv/runner/hooks/iter_timer.py
mmcv/runner/hooks/iter_timer.py
+2
-1
mmcv/runner/hooks/logger/tensorboard.py
mmcv/runner/hooks/logger/tensorboard.py
+3
-1
mmcv/runner/hooks/logger/text.py
mmcv/runner/hooks/logger/text.py
+2
-0
mmcv/runner/hooks/logger/wandb.py
mmcv/runner/hooks/logger/wandb.py
+3
-1
mmcv/runner/hooks/lr_updater.py
mmcv/runner/hooks/lr_updater.py
+7
-1
mmcv/runner/hooks/memory.py
mmcv/runner/hooks/memory.py
+2
-1
mmcv/runner/hooks/optimizer.py
mmcv/runner/hooks/optimizer.py
+2
-1
mmcv/runner/hooks/sampler_seed.py
mmcv/runner/hooks/sampler_seed.py
+2
-1
mmcv/runner/runner.py
mmcv/runner/runner.py
+34
-36
No files found.
mmcv/runner/hooks/__init__.py
View file @
0d5332a4
# Copyright (c) Open-MMLab. All rights reserved.
# Copyright (c) Open-MMLab. All rights reserved.
from
.checkpoint
import
CheckpointHook
from
.checkpoint
import
CheckpointHook
from
.closure
import
ClosureHook
from
.closure
import
ClosureHook
from
.hook
import
Hook
from
.hook
import
HOOKS
,
Hook
from
.iter_timer
import
IterTimerHook
from
.iter_timer
import
IterTimerHook
from
.logger
import
(
LoggerHook
,
TensorboardLoggerHook
,
TextLoggerHook
,
from
.logger
import
(
LoggerHook
,
TensorboardLoggerHook
,
TextLoggerHook
,
WandbLoggerHook
)
WandbLoggerHook
)
...
@@ -11,7 +11,7 @@ from .optimizer import OptimizerHook
...
@@ -11,7 +11,7 @@ from .optimizer import OptimizerHook
from
.sampler_seed
import
DistSamplerSeedHook
from
.sampler_seed
import
DistSamplerSeedHook
__all__
=
[
__all__
=
[
'Hook'
,
'CheckpointHook'
,
'ClosureHook'
,
'LrUpdaterHook'
,
'OptimizerHook'
,
'HOOKS'
,
'Hook'
,
'CheckpointHook'
,
'ClosureHook'
,
'LrUpdaterHook'
,
'IterTimerHook'
,
'DistSamplerSeedHook'
,
'EmptyCacheHook'
,
'LoggerHook'
,
'OptimizerHook'
,
'IterTimerHook'
,
'DistSamplerSeedHook'
,
'EmptyCacheHook'
,
'TextLoggerHook'
,
'TensorboardLoggerHook'
,
'WandbLoggerHook'
'LoggerHook'
,
'TextLoggerHook'
,
'TensorboardLoggerHook'
,
'WandbLoggerHook'
]
]
mmcv/runner/hooks/checkpoint.py
View file @
0d5332a4
# Copyright (c) Open-MMLab. All rights reserved.
# Copyright (c) Open-MMLab. All rights reserved.
from
..dist_utils
import
master_only
from
..dist_utils
import
master_only
from
.hook
import
Hook
from
.hook
import
HOOKS
,
Hook
@
HOOKS
.
register_module
class
CheckpointHook
(
Hook
):
class
CheckpointHook
(
Hook
):
def
__init__
(
self
,
def
__init__
(
self
,
...
...
mmcv/runner/hooks/closure.py
View file @
0d5332a4
# Copyright (c) Open-MMLab. All rights reserved.
# Copyright (c) Open-MMLab. All rights reserved.
from
.hook
import
Hook
from
.hook
import
HOOKS
,
Hook
@
HOOKS
.
register_module
class
ClosureHook
(
Hook
):
class
ClosureHook
(
Hook
):
def
__init__
(
self
,
fn_name
,
fn
):
def
__init__
(
self
,
fn_name
,
fn
):
...
...
mmcv/runner/hooks/hook.py
View file @
0d5332a4
# Copyright (c) Open-MMLab. All rights reserved.
# Copyright (c) Open-MMLab. All rights reserved.
from
mmcv.utils
import
Registry
HOOKS
=
Registry
(
'hook'
)
class
Hook
(
object
):
class
Hook
(
object
):
def
before_run
(
self
,
runner
):
def
before_run
(
self
,
runner
):
...
...
mmcv/runner/hooks/iter_timer.py
View file @
0d5332a4
# Copyright (c) Open-MMLab. All rights reserved.
# Copyright (c) Open-MMLab. All rights reserved.
import
time
import
time
from
.hook
import
Hook
from
.hook
import
HOOKS
,
Hook
@
HOOKS
.
register_module
class
IterTimerHook
(
Hook
):
class
IterTimerHook
(
Hook
):
def
before_epoch
(
self
,
runner
):
def
before_epoch
(
self
,
runner
):
...
...
mmcv/runner/hooks/logger/tensorboard.py
View file @
0d5332a4
...
@@ -3,10 +3,12 @@ import os.path as osp
...
@@ -3,10 +3,12 @@ import os.path as osp
import
torch
import
torch
from
...dist_utils
import
master_only
from
mmcv.runner
import
master_only
from
..hook
import
HOOKS
from
.base
import
LoggerHook
from
.base
import
LoggerHook
@
HOOKS
.
register_module
class
TensorboardLoggerHook
(
LoggerHook
):
class
TensorboardLoggerHook
(
LoggerHook
):
def
__init__
(
self
,
def
__init__
(
self
,
...
...
mmcv/runner/hooks/logger/text.py
View file @
0d5332a4
...
@@ -7,9 +7,11 @@ import torch
...
@@ -7,9 +7,11 @@ import torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
import
mmcv
import
mmcv
from
..hook
import
HOOKS
from
.base
import
LoggerHook
from
.base
import
LoggerHook
@
HOOKS
.
register_module
class
TextLoggerHook
(
LoggerHook
):
class
TextLoggerHook
(
LoggerHook
):
def
__init__
(
self
,
interval
=
10
,
ignore_last
=
True
,
reset_flag
=
False
):
def
__init__
(
self
,
interval
=
10
,
ignore_last
=
True
,
reset_flag
=
False
):
...
...
mmcv/runner/hooks/logger/wandb.py
View file @
0d5332a4
# Copyright (c) Open-MMLab. All rights reserved.
# Copyright (c) Open-MMLab. All rights reserved.
import
numbers
import
numbers
from
...dist_utils
import
master_only
from
mmcv.runner
import
master_only
from
..hook
import
HOOKS
from
.base
import
LoggerHook
from
.base
import
LoggerHook
@
HOOKS
.
register_module
class
WandbLoggerHook
(
LoggerHook
):
class
WandbLoggerHook
(
LoggerHook
):
def
__init__
(
self
,
def
__init__
(
self
,
...
...
mmcv/runner/hooks/lr_updater.py
View file @
0d5332a4
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
from
__future__
import
division
from
__future__
import
division
from
math
import
cos
,
pi
from
math
import
cos
,
pi
from
.hook
import
Hook
from
.hook
import
HOOKS
,
Hook
class
LrUpdaterHook
(
Hook
):
class
LrUpdaterHook
(
Hook
):
...
@@ -88,6 +88,7 @@ class LrUpdaterHook(Hook):
...
@@ -88,6 +88,7 @@ class LrUpdaterHook(Hook):
self
.
_set_lr
(
runner
,
warmup_lr
)
self
.
_set_lr
(
runner
,
warmup_lr
)
@
HOOKS
.
register_module
class
FixedLrUpdaterHook
(
LrUpdaterHook
):
class
FixedLrUpdaterHook
(
LrUpdaterHook
):
def
__init__
(
self
,
**
kwargs
):
def
__init__
(
self
,
**
kwargs
):
...
@@ -97,6 +98,7 @@ class FixedLrUpdaterHook(LrUpdaterHook):
...
@@ -97,6 +98,7 @@ class FixedLrUpdaterHook(LrUpdaterHook):
return
base_lr
return
base_lr
@
HOOKS
.
register_module
class
StepLrUpdaterHook
(
LrUpdaterHook
):
class
StepLrUpdaterHook
(
LrUpdaterHook
):
def
__init__
(
self
,
step
,
gamma
=
0.1
,
**
kwargs
):
def
__init__
(
self
,
step
,
gamma
=
0.1
,
**
kwargs
):
...
@@ -126,6 +128,7 @@ class StepLrUpdaterHook(LrUpdaterHook):
...
@@ -126,6 +128,7 @@ class StepLrUpdaterHook(LrUpdaterHook):
return
base_lr
*
self
.
gamma
**
exp
return
base_lr
*
self
.
gamma
**
exp
@
HOOKS
.
register_module
class
ExpLrUpdaterHook
(
LrUpdaterHook
):
class
ExpLrUpdaterHook
(
LrUpdaterHook
):
def
__init__
(
self
,
gamma
,
**
kwargs
):
def
__init__
(
self
,
gamma
,
**
kwargs
):
...
@@ -137,6 +140,7 @@ class ExpLrUpdaterHook(LrUpdaterHook):
...
@@ -137,6 +140,7 @@ class ExpLrUpdaterHook(LrUpdaterHook):
return
base_lr
*
self
.
gamma
**
progress
return
base_lr
*
self
.
gamma
**
progress
@
HOOKS
.
register_module
class
PolyLrUpdaterHook
(
LrUpdaterHook
):
class
PolyLrUpdaterHook
(
LrUpdaterHook
):
def
__init__
(
self
,
power
=
1.
,
min_lr
=
0.
,
**
kwargs
):
def
__init__
(
self
,
power
=
1.
,
min_lr
=
0.
,
**
kwargs
):
...
@@ -155,6 +159,7 @@ class PolyLrUpdaterHook(LrUpdaterHook):
...
@@ -155,6 +159,7 @@ class PolyLrUpdaterHook(LrUpdaterHook):
return
(
base_lr
-
self
.
min_lr
)
*
coeff
+
self
.
min_lr
return
(
base_lr
-
self
.
min_lr
)
*
coeff
+
self
.
min_lr
@
HOOKS
.
register_module
class
InvLrUpdaterHook
(
LrUpdaterHook
):
class
InvLrUpdaterHook
(
LrUpdaterHook
):
def
__init__
(
self
,
gamma
,
power
=
1.
,
**
kwargs
):
def
__init__
(
self
,
gamma
,
power
=
1.
,
**
kwargs
):
...
@@ -167,6 +172,7 @@ class InvLrUpdaterHook(LrUpdaterHook):
...
@@ -167,6 +172,7 @@ class InvLrUpdaterHook(LrUpdaterHook):
return
base_lr
*
(
1
+
self
.
gamma
*
progress
)
**
(
-
self
.
power
)
return
base_lr
*
(
1
+
self
.
gamma
*
progress
)
**
(
-
self
.
power
)
@
HOOKS
.
register_module
class
CosineLrUpdaterHook
(
LrUpdaterHook
):
class
CosineLrUpdaterHook
(
LrUpdaterHook
):
def
__init__
(
self
,
target_lr
=
0
,
**
kwargs
):
def
__init__
(
self
,
target_lr
=
0
,
**
kwargs
):
...
...
mmcv/runner/hooks/memory.py
View file @
0d5332a4
# Copyright (c) Open-MMLab. All rights reserved.
# Copyright (c) Open-MMLab. All rights reserved.
import
torch
import
torch
from
.hook
import
Hook
from
.hook
import
HOOKS
,
Hook
@
HOOKS
.
register_module
class
EmptyCacheHook
(
Hook
):
class
EmptyCacheHook
(
Hook
):
def
__init__
(
self
,
before_epoch
=
False
,
after_epoch
=
True
,
after_iter
=
False
):
def
__init__
(
self
,
before_epoch
=
False
,
after_epoch
=
True
,
after_iter
=
False
):
...
...
mmcv/runner/hooks/optimizer.py
View file @
0d5332a4
# Copyright (c) Open-MMLab. All rights reserved.
# Copyright (c) Open-MMLab. All rights reserved.
from
torch.nn.utils
import
clip_grad
from
torch.nn.utils
import
clip_grad
from
.hook
import
Hook
from
.hook
import
HOOKS
,
Hook
@
HOOKS
.
register_module
class
OptimizerHook
(
Hook
):
class
OptimizerHook
(
Hook
):
def
__init__
(
self
,
grad_clip
=
None
):
def
__init__
(
self
,
grad_clip
=
None
):
...
...
mmcv/runner/hooks/sampler_seed.py
View file @
0d5332a4
# Copyright (c) Open-MMLab. All rights reserved.
# Copyright (c) Open-MMLab. All rights reserved.
from
.hook
import
Hook
from
.hook
import
HOOKS
,
Hook
@
HOOKS
.
register_module
class
DistSamplerSeedHook
(
Hook
):
class
DistSamplerSeedHook
(
Hook
):
def
before_epoch
(
self
,
runner
):
def
before_epoch
(
self
,
runner
):
...
...
mmcv/runner/runner.py
View file @
0d5332a4
...
@@ -6,11 +6,9 @@ import time
...
@@ -6,11 +6,9 @@ import time
import
torch
import
torch
import
mmcv
import
mmcv
from
.
import
hooks
from
.checkpoint
import
load_checkpoint
,
save_checkpoint
from
.checkpoint
import
load_checkpoint
,
save_checkpoint
from
.dist_utils
import
get_dist_info
from
.dist_utils
import
get_dist_info
from
.hooks
import
(
CheckpointHook
,
Hook
,
IterTimerHook
,
LrUpdaterHook
,
from
.hooks
import
HOOKS
,
Hook
,
IterTimerHook
OptimizerHook
,
lr_updater
)
from
.log_buffer
import
LogBuffer
from
.log_buffer
import
LogBuffer
from
.priority
import
get_priority
from
.priority
import
get_priority
from
.utils
import
get_host_info
,
get_time_str
,
obj_from_dict
from
.utils
import
get_host_info
,
get_time_str
,
obj_from_dict
...
@@ -223,16 +221,6 @@ class Runner(object):
...
@@ -223,16 +221,6 @@ class Runner(object):
if
not
inserted
:
if
not
inserted
:
self
.
_hooks
.
insert
(
0
,
hook
)
self
.
_hooks
.
insert
(
0
,
hook
)
def
build_hook
(
self
,
args
,
hook_type
=
None
):
if
isinstance
(
args
,
Hook
):
return
args
elif
isinstance
(
args
,
dict
):
assert
issubclass
(
hook_type
,
Hook
)
return
hook_type
(
**
args
)
else
:
raise
TypeError
(
'"args" must be either a Hook object'
' or dict, not {}'
.
format
(
type
(
args
)))
def
call_hook
(
self
,
fn_name
):
def
call_hook
(
self
,
fn_name
):
for
hook
in
self
.
_hooks
:
for
hook
in
self
.
_hooks
:
getattr
(
hook
,
fn_name
)(
self
)
getattr
(
hook
,
fn_name
)(
self
)
...
@@ -373,26 +361,41 @@ class Runner(object):
...
@@ -373,26 +361,41 @@ class Runner(object):
time
.
sleep
(
1
)
# wait for some hooks like loggers to finish
time
.
sleep
(
1
)
# wait for some hooks like loggers to finish
self
.
call_hook
(
'after_run'
)
self
.
call_hook
(
'after_run'
)
def
register_lr_hooks
(
self
,
lr_config
):
def
register_lr_hook
(
self
,
lr_config
):
if
isinstance
(
lr_config
,
LrUpdaterHook
):
if
isinstance
(
lr_config
,
dict
):
self
.
register_hook
(
lr_config
)
elif
isinstance
(
lr_config
,
dict
):
assert
'policy'
in
lr_config
assert
'policy'
in
lr_config
# from .hooks import lr_updater
hook_type
=
lr_config
.
pop
(
'policy'
).
title
()
+
'LrUpdaterHook'
hook_name
=
lr_config
[
'policy'
].
title
()
+
'LrUpdaterHook'
lr_config
[
'type'
]
=
hook_type
if
not
hasattr
(
lr_updater
,
hook_name
):
hook
=
mmcv
.
build_from_cfg
(
lr_config
,
HOOKS
)
raise
ValueError
(
'"{}" does not exist'
.
format
(
hook_name
))
else
:
hook_cls
=
getattr
(
lr_updater
,
hook_name
)
hook
=
lr_config
self
.
register_hook
(
hook_cls
(
**
lr_config
))
self
.
register_hook
(
hook
)
def
register_optimizer_hook
(
self
,
optimizer_config
):
if
optimizer_config
is
None
:
return
if
isinstance
(
optimizer_config
,
dict
):
optimizer_config
.
setdefault
(
'type'
,
'OptimizerHook'
)
hook
=
mmcv
.
build_from_cfg
(
optimizer_config
,
HOOKS
)
else
:
else
:
raise
TypeError
(
'"lr_config" must be either a LrUpdaterHook object'
hook
=
optimizer_config
' or dict, not {}'
.
format
(
type
(
lr_config
)))
self
.
register_hook
(
hook
)
def
register_checkpoint_hook
(
self
,
checkpoint_config
):
if
checkpoint_config
is
None
:
return
if
isinstance
(
checkpoint_config
,
dict
):
checkpoint_config
.
setdefault
(
'type'
,
'CheckpointHook'
)
hook
=
mmcv
.
build_from_cfg
(
checkpoint_config
,
HOOKS
)
else
:
hook
=
checkpoint_config
self
.
register_hook
(
hook
)
def
register_logger_hooks
(
self
,
log_config
):
def
register_logger_hooks
(
self
,
log_config
):
log_interval
=
log_config
[
'interval'
]
log_interval
=
log_config
[
'interval'
]
for
info
in
log_config
[
'hooks'
]:
for
info
in
log_config
[
'hooks'
]:
logger_hook
=
obj
_from_
dict
(
logger_hook
=
mmcv
.
build
_from_
cfg
(
info
,
hooks
,
default_args
=
dict
(
interval
=
log_interval
))
info
,
HOOKS
,
default_args
=
dict
(
interval
=
log_interval
))
self
.
register_hook
(
logger_hook
,
priority
=
'VERY_LOW'
)
self
.
register_hook
(
logger_hook
,
priority
=
'VERY_LOW'
)
def
register_training_hooks
(
self
,
def
register_training_hooks
(
self
,
...
@@ -410,13 +413,8 @@ class Runner(object):
...
@@ -410,13 +413,8 @@ class Runner(object):
- IterTimerHook
- IterTimerHook
- LoggerHook(s)
- LoggerHook(s)
"""
"""
if
optimizer_config
is
None
:
self
.
register_lr_hook
(
lr_config
)
optimizer_config
=
{}
self
.
register_optimizer_hook
(
optimizer_config
)
if
checkpoint_config
is
None
:
self
.
register_checkpoint_hook
(
checkpoint_config
)
checkpoint_config
=
{}
self
.
register_lr_hooks
(
lr_config
)
self
.
register_hook
(
self
.
build_hook
(
optimizer_config
,
OptimizerHook
))
self
.
register_hook
(
self
.
build_hook
(
checkpoint_config
,
CheckpointHook
))
self
.
register_hook
(
IterTimerHook
())
self
.
register_hook
(
IterTimerHook
())
if
log_config
is
not
None
:
self
.
register_logger_hooks
(
log_config
)
self
.
register_logger_hooks
(
log_config
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment