Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
MMCV
Commits
523f861b
Commit
523f861b
authored
Oct 05, 2018
by
Kai Chen
Browse files
add priority enum
parent
f573c11c
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
70 additions
and
10 deletions
+70
-10
mmcv/runner/__init__.py
mmcv/runner/__init__.py
+4
-2
mmcv/runner/priority.py
mmcv/runner/priority.py
+33
-0
mmcv/runner/runner.py
mmcv/runner/runner.py
+33
-8
No files found.
mmcv/runner/__init__.py
View file @
523f861b
...
@@ -7,6 +7,7 @@ from .hooks import (Hook, CheckpointHook, ClosureHook, LrUpdaterHook,
...
@@ -7,6 +7,7 @@ from .hooks import (Hook, CheckpointHook, ClosureHook, LrUpdaterHook,
from
.checkpoint
import
(
load_state_dict
,
load_checkpoint
,
weights_to_cpu
,
from
.checkpoint
import
(
load_state_dict
,
load_checkpoint
,
weights_to_cpu
,
save_checkpoint
)
save_checkpoint
)
from
.parallel
import
parallel_test
,
worker_func
from
.parallel
import
parallel_test
,
worker_func
from
.priority
import
Priority
,
get_priority
from
.utils
import
(
get_host_info
,
get_dist_info
,
master_only
,
get_time_str
,
from
.utils
import
(
get_host_info
,
get_dist_info
,
master_only
,
get_time_str
,
obj_from_dict
)
obj_from_dict
)
...
@@ -15,6 +16,7 @@ __all__ = [
...
@@ -15,6 +16,7 @@ __all__ = [
'LrUpdaterHook'
,
'OptimizerHook'
,
'IterTimerHook'
,
'DistSamplerSeedHook'
,
'LrUpdaterHook'
,
'OptimizerHook'
,
'IterTimerHook'
,
'DistSamplerSeedHook'
,
'LoggerHook'
,
'TextLoggerHook'
,
'PaviLoggerHook'
,
'TensorboardLoggerHook'
,
'LoggerHook'
,
'TextLoggerHook'
,
'PaviLoggerHook'
,
'TensorboardLoggerHook'
,
'load_state_dict'
,
'load_checkpoint'
,
'weights_to_cpu'
,
'save_checkpoint'
,
'load_state_dict'
,
'load_checkpoint'
,
'weights_to_cpu'
,
'save_checkpoint'
,
'parallel_test'
,
'worker_func'
,
'get_host_info'
,
'get_dist_info'
,
'parallel_test'
,
'worker_func'
,
'Priority'
,
'get_priority'
,
'master_only'
,
'get_time_str'
,
'obj_from_dict'
'get_host_info'
,
'get_dist_info'
,
'master_only'
,
'get_time_str'
,
'obj_from_dict'
]
]
mmcv/runner/priority.py
0 → 100644
View file @
523f861b
from
enum
import
Enum
class
Priority
(
Enum
):
HIGHEST
=
0
VERY_HIGH
=
20
HIGH
=
40
NORMAL
=
50
LOW
=
60
VERY_LOW
=
80
LOWEST
=
100
def
get_priority
(
priority
):
"""Get priority value.
Args:
priority (int or str or :obj:`Priority`): Priority.
Returns:
int: The priority value.
"""
if
isinstance
(
priority
,
int
):
if
priority
<
0
or
priority
>
100
:
raise
ValueError
(
'priority must be between 0 and 100'
)
return
priority
elif
isinstance
(
priority
,
Priority
):
return
priority
.
value
elif
isinstance
(
priority
,
str
):
return
Priority
[
priority
.
upper
()].
value
else
:
raise
TypeError
(
'priority must be an integer or Priority enum value'
)
mmcv/runner/runner.py
View file @
523f861b
...
@@ -10,11 +10,24 @@ from .log_buffer import LogBuffer
...
@@ -10,11 +10,24 @@ from .log_buffer import LogBuffer
from
.hooks
import
(
Hook
,
LrUpdaterHook
,
CheckpointHook
,
IterTimerHook
,
from
.hooks
import
(
Hook
,
LrUpdaterHook
,
CheckpointHook
,
IterTimerHook
,
OptimizerHook
,
lr_updater
)
OptimizerHook
,
lr_updater
)
from
.checkpoint
import
load_checkpoint
,
save_checkpoint
from
.checkpoint
import
load_checkpoint
,
save_checkpoint
from
.priority
import
get_priority
from
.utils
import
get_dist_info
,
get_host_info
,
get_time_str
,
obj_from_dict
from
.utils
import
get_dist_info
,
get_host_info
,
get_time_str
,
obj_from_dict
class
Runner
(
object
):
class
Runner
(
object
):
"""A training helper for PyTorch."""
"""A training helper for PyTorch.
Args:
model (:obj:`torch.nn.Module`): The model to be run.
batch_processor (callable): A callable method that process a data
batch. The interface of this method should be
`batch_processor(model, data, train_mode) -> dict`
optimizer (dict or :obj:`torch.optim.Optimizer`): If it is a dict,
runner will construct an optimizer according to it.
work_dir (str, optional): The working directory to save checkpoints
and logs.
log_level (int): Logging level.
"""
def
__init__
(
self
,
def
__init__
(
self
,
model
,
model
,
...
@@ -154,8 +167,8 @@ class Runner(object):
...
@@ -154,8 +167,8 @@ class Runner(object):
logging
.
basicConfig
(
logging
.
basicConfig
(
format
=
'%(asctime)s - %(levelname)s - %(message)s'
,
level
=
level
)
format
=
'%(asctime)s - %(levelname)s - %(message)s'
,
level
=
level
)
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
if
log_dir
:
if
log_dir
and
self
.
rank
==
0
:
filename
=
'{}
_{}
.log'
.
format
(
get_time_str
()
,
self
.
rank
)
filename
=
'{}.log'
.
format
(
get_time_str
())
log_file
=
osp
.
join
(
log_dir
,
filename
)
log_file
=
osp
.
join
(
log_dir
,
filename
)
self
.
_add_file_handler
(
logger
,
log_file
,
level
=
level
)
self
.
_add_file_handler
(
logger
,
log_file
,
level
=
level
)
return
logger
return
logger
...
@@ -171,17 +184,18 @@ class Runner(object):
...
@@ -171,17 +184,18 @@ class Runner(object):
'lr is not applicable because optimizer does not exist.'
)
'lr is not applicable because optimizer does not exist.'
)
return
[
group
[
'lr'
]
for
group
in
self
.
optimizer
.
param_groups
]
return
[
group
[
'lr'
]
for
group
in
self
.
optimizer
.
param_groups
]
def
register_hook
(
self
,
hook
,
priority
=
50
):
def
register_hook
(
self
,
hook
,
priority
=
'NORMAL'
):
"""Register a hook into the hook list.
"""Register a hook into the hook list.
Args:
Args:
hook (:obj:`Hook`): The hook to be registered.
hook (:obj:`Hook`): The hook to be registered.
priority (int): Hook priority. Lower value means higher priority.
priority (int or str or :obj:`Priority`): Hook priority.
Lower value means higher priority.
"""
"""
assert
isinstance
(
hook
,
Hook
)
assert
isinstance
(
hook
,
Hook
)
assert
isinstance
(
priority
,
int
)
and
priority
>=
0
and
priority
<=
100
if
hasattr
(
hook
,
'priority'
):
if
hasattr
(
hook
,
'priority'
):
raise
ValueError
(
'"priority" is a reserved attribute for hooks'
)
raise
ValueError
(
'"priority" is a reserved attribute for hooks'
)
priority
=
get_priority
(
priority
)
hook
.
priority
=
priority
hook
.
priority
=
priority
# insert the hook to a sorted list
# insert the hook to a sorted list
inserted
=
False
inserted
=
False
...
@@ -292,6 +306,17 @@ class Runner(object):
...
@@ -292,6 +306,17 @@ class Runner(object):
self
.
logger
.
info
(
'resumed epoch %d, iter %d'
,
self
.
epoch
,
self
.
iter
)
self
.
logger
.
info
(
'resumed epoch %d, iter %d'
,
self
.
epoch
,
self
.
iter
)
def
run
(
self
,
data_loaders
,
workflow
,
max_epochs
,
**
kwargs
):
def
run
(
self
,
data_loaders
,
workflow
,
max_epochs
,
**
kwargs
):
"""Start running.
Args:
data_loaders (list[:obj:`DataLoader`]): Dataloaders for training
and validation.
workflow (list[tuple]): A list of (phase, epochs) to specify the
running order and epochs. E.g, [('train', 2), ('val', 1)] means
running 2 epochs for training and 1 epoch for validation,
iteratively.
max_epochs (int): Total training epochs.
"""
assert
isinstance
(
data_loaders
,
list
)
assert
isinstance
(
data_loaders
,
list
)
assert
mmcv
.
is_list_of
(
workflow
,
tuple
)
assert
mmcv
.
is_list_of
(
workflow
,
tuple
)
assert
len
(
data_loaders
)
==
len
(
workflow
)
assert
len
(
data_loaders
)
==
len
(
workflow
)
...
@@ -346,7 +371,7 @@ class Runner(object):
...
@@ -346,7 +371,7 @@ class Runner(object):
for
info
in
log_config
[
'hooks'
]:
for
info
in
log_config
[
'hooks'
]:
logger_hook
=
obj_from_dict
(
logger_hook
=
obj_from_dict
(
info
,
hooks
,
default_args
=
dict
(
interval
=
log_interval
))
info
,
hooks
,
default_args
=
dict
(
interval
=
log_interval
))
self
.
register_hook
(
logger_hook
,
priority
=
60
)
self
.
register_hook
(
logger_hook
,
priority
=
'VERY_LOW'
)
def
register_training_hooks
(
self
,
def
register_training_hooks
(
self
,
lr_config
,
lr_config
,
...
@@ -360,7 +385,7 @@ class Runner(object):
...
@@ -360,7 +385,7 @@ class Runner(object):
- OptimizerStepperHook
- OptimizerStepperHook
- CheckpointSaverHook
- CheckpointSaverHook
- IterTimerHook
- IterTimerHook
- LoggerHook
- LoggerHook
(s)
"""
"""
if
optimizer_config
is
None
:
if
optimizer_config
is
None
:
optimizer_config
=
{}
optimizer_config
=
{}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment