Unverified Commit f4550cd3 authored by Kai Chen's avatar Kai Chen Committed by GitHub
Browse files

Merge pull request #14 from open-mmlab/docs

Draft documentation
parents ad98e856 4cfd45f7
......@@ -10,11 +10,24 @@ from .log_buffer import LogBuffer
from .hooks import (Hook, LrUpdaterHook, CheckpointHook, IterTimerHook,
OptimizerHook, lr_updater)
from .checkpoint import load_checkpoint, save_checkpoint
from .priority import get_priority
from .utils import get_dist_info, get_host_info, get_time_str, obj_from_dict
class Runner(object):
"""A training helper for PyTorch."""
"""A training helper for PyTorch.
Args:
model (:obj:`torch.nn.Module`): The model to be run.
batch_processor (callable): A callable method that process a data
batch. The interface of this method should be
`batch_processor(model, data, train_mode) -> dict`
optimizer (dict or :obj:`torch.optim.Optimizer`): If it is a dict,
runner will construct an optimizer according to it.
work_dir (str, optional): The working directory to save checkpoints
and logs.
log_level (int): Logging level.
"""
def __init__(self,
model,
......@@ -154,8 +167,8 @@ class Runner(object):
logging.basicConfig(
format='%(asctime)s - %(levelname)s - %(message)s', level=level)
logger = logging.getLogger(__name__)
if log_dir:
filename = '{}_{}.log'.format(get_time_str(), self.rank)
if log_dir and self.rank == 0:
filename = '{}.log'.format(get_time_str())
log_file = osp.join(log_dir, filename)
self._add_file_handler(logger, log_file, level=level)
return logger
......@@ -171,17 +184,18 @@ class Runner(object):
'lr is not applicable because optimizer does not exist.')
return [group['lr'] for group in self.optimizer.param_groups]
def register_hook(self, hook, priority=50):
def register_hook(self, hook, priority='NORMAL'):
"""Register a hook into the hook list.
Args:
hook (:obj:`Hook`): The hook to be registered.
priority (int): Hook priority. Lower value means higher priority.
priority (int or str or :obj:`Priority`): Hook priority.
Lower value means higher priority.
"""
assert isinstance(hook, Hook)
assert isinstance(priority, int) and priority >= 0 and priority <= 100
if hasattr(hook, 'priority'):
raise ValueError('"priority" is a reserved attribute for hooks')
priority = get_priority(priority)
hook.priority = priority
# insert the hook to a sorted list
inserted = False
......@@ -292,6 +306,17 @@ class Runner(object):
self.logger.info('resumed epoch %d, iter %d', self.epoch, self.iter)
def run(self, data_loaders, workflow, max_epochs, **kwargs):
"""Start running.
Args:
data_loaders (list[:obj:`DataLoader`]): Dataloaders for training
and validation.
workflow (list[tuple]): A list of (phase, epochs) to specify the
running order and epochs. E.g, [('train', 2), ('val', 1)] means
running 2 epochs for training and 1 epoch for validation,
iteratively.
max_epochs (int): Total training epochs.
"""
assert isinstance(data_loaders, list)
assert mmcv.is_list_of(workflow, tuple)
assert len(data_loaders) == len(workflow)
......@@ -346,7 +371,7 @@ class Runner(object):
for info in log_config['hooks']:
logger_hook = obj_from_dict(
info, hooks, default_args=dict(interval=log_interval))
self.register_hook(logger_hook, priority=60)
self.register_hook(logger_hook, priority='VERY_LOW')
def register_training_hooks(self,
lr_config,
......@@ -356,11 +381,12 @@ class Runner(object):
"""Register default hooks for training.
Default hooks include:
- LrUpdaterHook
- OptimizerStepperHook
- CheckpointSaverHook
- IterTimerHook
- LoggerHook
- LoggerHook(s)
"""
if optimizer_config is None:
optimizer_config = {}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment