Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
MMCV
Commits
f4550cd3
Unverified
Commit
f4550cd3
authored
Oct 06, 2018
by
Kai Chen
Committed by
GitHub
Oct 06, 2018
Browse files
Merge pull request #14 from open-mmlab/docs
Draft documentation
parents
ad98e856
4cfd45f7
Changes
21
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
34 additions
and
8 deletions
+34
-8
mmcv/runner/runner.py
mmcv/runner/runner.py
+34
-8
No files found.
mmcv/runner/runner.py
View file @
f4550cd3
...
...
@@ -10,11 +10,24 @@ from .log_buffer import LogBuffer
from
.hooks
import
(
Hook
,
LrUpdaterHook
,
CheckpointHook
,
IterTimerHook
,
OptimizerHook
,
lr_updater
)
from
.checkpoint
import
load_checkpoint
,
save_checkpoint
from
.priority
import
get_priority
from
.utils
import
get_dist_info
,
get_host_info
,
get_time_str
,
obj_from_dict
class
Runner
(
object
):
"""A training helper for PyTorch."""
"""A training helper for PyTorch.
Args:
model (:obj:`torch.nn.Module`): The model to be run.
batch_processor (callable): A callable method that process a data
batch. The interface of this method should be
`batch_processor(model, data, train_mode) -> dict`
optimizer (dict or :obj:`torch.optim.Optimizer`): If it is a dict,
runner will construct an optimizer according to it.
work_dir (str, optional): The working directory to save checkpoints
and logs.
log_level (int): Logging level.
"""
def
__init__
(
self
,
model
,
...
...
@@ -154,8 +167,8 @@ class Runner(object):
logging
.
basicConfig
(
format
=
'%(asctime)s - %(levelname)s - %(message)s'
,
level
=
level
)
logger
=
logging
.
getLogger
(
__name__
)
if
log_dir
:
filename
=
'{}
_{}
.log'
.
format
(
get_time_str
()
,
self
.
rank
)
if
log_dir
and
self
.
rank
==
0
:
filename
=
'{}.log'
.
format
(
get_time_str
())
log_file
=
osp
.
join
(
log_dir
,
filename
)
self
.
_add_file_handler
(
logger
,
log_file
,
level
=
level
)
return
logger
...
...
@@ -171,17 +184,18 @@ class Runner(object):
'lr is not applicable because optimizer does not exist.'
)
return
[
group
[
'lr'
]
for
group
in
self
.
optimizer
.
param_groups
]
def
register_hook
(
self
,
hook
,
priority
=
50
):
def
register_hook
(
self
,
hook
,
priority
=
'NORMAL'
):
"""Register a hook into the hook list.
Args:
hook (:obj:`Hook`): The hook to be registered.
priority (int): Hook priority. Lower value means higher priority.
priority (int or str or :obj:`Priority`): Hook priority.
Lower value means higher priority.
"""
assert
isinstance
(
hook
,
Hook
)
assert
isinstance
(
priority
,
int
)
and
priority
>=
0
and
priority
<=
100
if
hasattr
(
hook
,
'priority'
):
raise
ValueError
(
'"priority" is a reserved attribute for hooks'
)
priority
=
get_priority
(
priority
)
hook
.
priority
=
priority
# insert the hook to a sorted list
inserted
=
False
...
...
@@ -292,6 +306,17 @@ class Runner(object):
self
.
logger
.
info
(
'resumed epoch %d, iter %d'
,
self
.
epoch
,
self
.
iter
)
def
run
(
self
,
data_loaders
,
workflow
,
max_epochs
,
**
kwargs
):
"""Start running.
Args:
data_loaders (list[:obj:`DataLoader`]): Dataloaders for training
and validation.
workflow (list[tuple]): A list of (phase, epochs) to specify the
running order and epochs. E.g, [('train', 2), ('val', 1)] means
running 2 epochs for training and 1 epoch for validation,
iteratively.
max_epochs (int): Total training epochs.
"""
assert
isinstance
(
data_loaders
,
list
)
assert
mmcv
.
is_list_of
(
workflow
,
tuple
)
assert
len
(
data_loaders
)
==
len
(
workflow
)
...
...
@@ -346,7 +371,7 @@ class Runner(object):
for
info
in
log_config
[
'hooks'
]:
logger_hook
=
obj_from_dict
(
info
,
hooks
,
default_args
=
dict
(
interval
=
log_interval
))
self
.
register_hook
(
logger_hook
,
priority
=
60
)
self
.
register_hook
(
logger_hook
,
priority
=
'VERY_LOW'
)
def
register_training_hooks
(
self
,
lr_config
,
...
...
@@ -356,11 +381,12 @@ class Runner(object):
"""Register default hooks for training.
Default hooks include:
- LrUpdaterHook
- OptimizerStepperHook
- CheckpointSaverHook
- IterTimerHook
- LoggerHook
- LoggerHook
(s)
"""
if
optimizer_config
is
None
:
optimizer_config
=
{}
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment