Unverified Commit f4550cd3 authored by Kai Chen's avatar Kai Chen Committed by GitHub
Browse files

Merge pull request #14 from open-mmlab/docs

Draft documentation
parents ad98e856 4cfd45f7
...@@ -10,11 +10,24 @@ from .log_buffer import LogBuffer ...@@ -10,11 +10,24 @@ from .log_buffer import LogBuffer
from .hooks import (Hook, LrUpdaterHook, CheckpointHook, IterTimerHook, from .hooks import (Hook, LrUpdaterHook, CheckpointHook, IterTimerHook,
OptimizerHook, lr_updater) OptimizerHook, lr_updater)
from .checkpoint import load_checkpoint, save_checkpoint from .checkpoint import load_checkpoint, save_checkpoint
from .priority import get_priority
from .utils import get_dist_info, get_host_info, get_time_str, obj_from_dict from .utils import get_dist_info, get_host_info, get_time_str, obj_from_dict
class Runner(object): class Runner(object):
"""A training helper for PyTorch.""" """A training helper for PyTorch.
Args:
model (:obj:`torch.nn.Module`): The model to be run.
batch_processor (callable): A callable method that process a data
batch. The interface of this method should be
`batch_processor(model, data, train_mode) -> dict`
optimizer (dict or :obj:`torch.optim.Optimizer`): If it is a dict,
runner will construct an optimizer according to it.
work_dir (str, optional): The working directory to save checkpoints
and logs.
log_level (int): Logging level.
"""
def __init__(self, def __init__(self,
model, model,
...@@ -154,8 +167,8 @@ class Runner(object): ...@@ -154,8 +167,8 @@ class Runner(object):
logging.basicConfig( logging.basicConfig(
format='%(asctime)s - %(levelname)s - %(message)s', level=level) format='%(asctime)s - %(levelname)s - %(message)s', level=level)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
if log_dir: if log_dir and self.rank == 0:
filename = '{}_{}.log'.format(get_time_str(), self.rank) filename = '{}.log'.format(get_time_str())
log_file = osp.join(log_dir, filename) log_file = osp.join(log_dir, filename)
self._add_file_handler(logger, log_file, level=level) self._add_file_handler(logger, log_file, level=level)
return logger return logger
...@@ -171,17 +184,18 @@ class Runner(object): ...@@ -171,17 +184,18 @@ class Runner(object):
'lr is not applicable because optimizer does not exist.') 'lr is not applicable because optimizer does not exist.')
return [group['lr'] for group in self.optimizer.param_groups] return [group['lr'] for group in self.optimizer.param_groups]
def register_hook(self, hook, priority=50): def register_hook(self, hook, priority='NORMAL'):
"""Register a hook into the hook list. """Register a hook into the hook list.
Args: Args:
hook (:obj:`Hook`): The hook to be registered. hook (:obj:`Hook`): The hook to be registered.
priority (int): Hook priority. Lower value means higher priority. priority (int or str or :obj:`Priority`): Hook priority.
Lower value means higher priority.
""" """
assert isinstance(hook, Hook) assert isinstance(hook, Hook)
assert isinstance(priority, int) and priority >= 0 and priority <= 100
if hasattr(hook, 'priority'): if hasattr(hook, 'priority'):
raise ValueError('"priority" is a reserved attribute for hooks') raise ValueError('"priority" is a reserved attribute for hooks')
priority = get_priority(priority)
hook.priority = priority hook.priority = priority
# insert the hook to a sorted list # insert the hook to a sorted list
inserted = False inserted = False
...@@ -292,6 +306,17 @@ class Runner(object): ...@@ -292,6 +306,17 @@ class Runner(object):
self.logger.info('resumed epoch %d, iter %d', self.epoch, self.iter) self.logger.info('resumed epoch %d, iter %d', self.epoch, self.iter)
def run(self, data_loaders, workflow, max_epochs, **kwargs): def run(self, data_loaders, workflow, max_epochs, **kwargs):
"""Start running.
Args:
data_loaders (list[:obj:`DataLoader`]): Dataloaders for training
and validation.
workflow (list[tuple]): A list of (phase, epochs) to specify the
running order and epochs. E.g, [('train', 2), ('val', 1)] means
running 2 epochs for training and 1 epoch for validation,
iteratively.
max_epochs (int): Total training epochs.
"""
assert isinstance(data_loaders, list) assert isinstance(data_loaders, list)
assert mmcv.is_list_of(workflow, tuple) assert mmcv.is_list_of(workflow, tuple)
assert len(data_loaders) == len(workflow) assert len(data_loaders) == len(workflow)
...@@ -346,7 +371,7 @@ class Runner(object): ...@@ -346,7 +371,7 @@ class Runner(object):
for info in log_config['hooks']: for info in log_config['hooks']:
logger_hook = obj_from_dict( logger_hook = obj_from_dict(
info, hooks, default_args=dict(interval=log_interval)) info, hooks, default_args=dict(interval=log_interval))
self.register_hook(logger_hook, priority=60) self.register_hook(logger_hook, priority='VERY_LOW')
def register_training_hooks(self, def register_training_hooks(self,
lr_config, lr_config,
...@@ -356,11 +381,12 @@ class Runner(object): ...@@ -356,11 +381,12 @@ class Runner(object):
"""Register default hooks for training. """Register default hooks for training.
Default hooks include: Default hooks include:
- LrUpdaterHook - LrUpdaterHook
- OptimizerStepperHook - OptimizerStepperHook
- CheckpointSaverHook - CheckpointSaverHook
- IterTimerHook - IterTimerHook
- LoggerHook - LoggerHook(s)
""" """
if optimizer_config is None: if optimizer_config is None:
optimizer_config = {} optimizer_config = {}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment