Unverified Commit 67a26da9 authored by Harry's avatar Harry Committed by GitHub
Browse files

Add IterBasedRunner (#314)



* feat: add IterBasedRunner

* fix: unittest

* feat: more unittest

* fix: expose dataloader len

* minor updates of BaseRunner

* refactor: remove CosineRestartLrUpdaterHook

* style: add docstring

* refactor: update IterTextLoggerHook: fstring and exp_name

* fix: epoch_runner unittest

* refactor: remove IterBasedTextLogger

* fix: old IterTextLoggerHook issue

* refactor: remove __len__ of IterLoader

* feat: add IterBasedRunner to init

* feat: add __len__ to IterLoader

* fix some docstrings

* refactor: use is_parallel_module

* fix: import issue

* fix: runner unittest missing logger

* fix checkpoints

* feat: add by_epoch default value to IterBaseRunner regitering loggger_hook

* refactor: remove setting by_epoch in log_config

* minor refactoring

* docs: add docstring

* fix: remove unused doc

* update the log info for saving checkpoints
Co-authored-by: default avatarKai Chen <chenkaidev@gmail.com>
parent 61f9e91c
# Copyright (c) Open-MMLab. All rights reserved. # Copyright (c) Open-MMLab. All rights reserved.
from torch.nn.parallel import DistributedDataParallel import torch
from torch.nn.parallel.distributed import (DistributedDataParallel,
_find_tensors)
from .scatter_gather import scatter_kwargs from .scatter_gather import scatter_kwargs
class MMDistributedDataParallel(DistributedDataParallel): class MMDistributedDataParallel(DistributedDataParallel):
"""The DDP module that supports DataContainer.
MMDDP has two main differences with PyTorch DDP:
- It supports a custom type :class:`DataContainer` which allows more
flexible control of input data.
- It implement two APIs ``train_step()`` and ``val_step()``.
"""
def scatter(self, inputs, kwargs, device_ids): def scatter(self, inputs, kwargs, device_ids):
return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim) return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
def train_step(self, *inputs, **kwargs):
"""train_step() API for module wrapped by DistributedDataParallel.
This method is basically the same as
``DistributedDataParallel.forward()``, while replacing
``self.module.forward()`` with ``self.module.train_step()``.
It is compatible with PyTorch 1.1 - 1.5.
"""
if getattr(self, 'require_forward_param_sync', True):
self._sync_params()
if self.device_ids:
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
if len(self.device_ids) == 1:
output = self.module.train_step(*inputs[0], **kwargs[0])
else:
outputs = self.parallel_apply(
self._module_copies[:len(inputs)], inputs, kwargs)
output = self.gather(outputs, self.output_device)
else:
output = self.module.train_step(*inputs, **kwargs)
if torch.is_grad_enabled() and getattr(
self, 'require_backward_grad_sync', True):
if self.find_unused_parameters:
self.reducer.prepare_for_backward(list(_find_tensors(output)))
else:
self.reducer.prepare_for_backward([])
else:
if torch.__version__ > '1.2':
self.require_forward_param_sync = False
return output
def val_step(self, *inputs, **kwargs):
"""val_step() API for module wrapped by DistributedDataParallel.
This method is basically the same as
``DistributedDataParallel.forward()``, while replacing
``self.module.forward()`` with ``self.module.val_step()``.
It is compatible with PyTorch 1.1 - 1.5.
"""
if getattr(self, 'require_forward_param_sync', True):
self._sync_params()
if self.device_ids:
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
if len(self.device_ids) == 1:
output = self.module.val_step(*inputs[0], **kwargs[0])
else:
outputs = self.parallel_apply(
self._module_copies[:len(inputs)], inputs, kwargs)
output = self.gather(outputs, self.output_device)
else:
output = self.module.val_step(*inputs, **kwargs)
if torch.is_grad_enabled() and getattr(
self, 'require_backward_grad_sync', True):
if self.find_unused_parameters:
self.reducer.prepare_for_backward(list(_find_tensors(output)))
else:
self.reducer.prepare_for_backward([])
else:
if torch.__version__ > '1.2':
self.require_forward_param_sync = False
return output
...@@ -52,3 +52,15 @@ class MMDistributedDataParallel(nn.Module): ...@@ -52,3 +52,15 @@ class MMDistributedDataParallel(nn.Module):
inputs, kwargs = self.scatter(inputs, kwargs, inputs, kwargs = self.scatter(inputs, kwargs,
[torch.cuda.current_device()]) [torch.cuda.current_device()])
return self.module(*inputs[0], **kwargs[0]) return self.module(*inputs[0], **kwargs[0])
def train_step(self, *inputs, **kwargs):
inputs, kwargs = self.scatter(inputs, kwargs,
[torch.cuda.current_device()])
output = self.module.train_step(*inputs[0], **kwargs[0])
return output
def val_step(self, *inputs, **kwargs):
inputs, kwargs = self.scatter(inputs, kwargs,
[torch.cuda.current_device()])
output = self.module.val_step(*inputs[0], **kwargs[0])
return output
...@@ -8,6 +8,7 @@ from .hooks import (HOOKS, CheckpointHook, ClosureHook, DistSamplerSeedHook, ...@@ -8,6 +8,7 @@ from .hooks import (HOOKS, CheckpointHook, ClosureHook, DistSamplerSeedHook,
Hook, IterTimerHook, LoggerHook, LrUpdaterHook, Hook, IterTimerHook, LoggerHook, LrUpdaterHook,
MlflowLoggerHook, OptimizerHook, PaviLoggerHook, MlflowLoggerHook, OptimizerHook, PaviLoggerHook,
TensorboardLoggerHook, TextLoggerHook, WandbLoggerHook) TensorboardLoggerHook, TextLoggerHook, WandbLoggerHook)
from .iter_based_runner import IterBasedRunner, IterLoader
from .log_buffer import LogBuffer from .log_buffer import LogBuffer
from .optimizer import (OPTIMIZER_BUILDERS, OPTIMIZERS, from .optimizer import (OPTIMIZER_BUILDERS, OPTIMIZERS,
DefaultOptimizerConstructor, build_optimizer, DefaultOptimizerConstructor, build_optimizer,
...@@ -16,14 +17,15 @@ from .priority import Priority, get_priority ...@@ -16,14 +17,15 @@ from .priority import Priority, get_priority
from .utils import get_host_info, get_time_str, obj_from_dict from .utils import get_host_info, get_time_str, obj_from_dict
__all__ = [ __all__ = [
'BaseRunner', 'Runner', 'EpochBasedRunner', 'LogBuffer', 'HOOKS', 'Hook', 'BaseRunner', 'Runner', 'EpochBasedRunner', 'IterBasedRunner', 'LogBuffer',
'CheckpointHook', 'ClosureHook', 'LrUpdaterHook', 'OptimizerHook', 'HOOKS', 'Hook', 'CheckpointHook', 'ClosureHook', 'LrUpdaterHook',
'IterTimerHook', 'DistSamplerSeedHook', 'LoggerHook', 'PaviLoggerHook', 'OptimizerHook', 'IterTimerHook', 'DistSamplerSeedHook', 'LoggerHook',
'TextLoggerHook', 'TensorboardLoggerHook', 'WandbLoggerHook', 'PaviLoggerHook', 'TextLoggerHook', 'TensorboardLoggerHook',
'MlflowLoggerHook', '_load_checkpoint', 'load_state_dict', 'WandbLoggerHook', 'MlflowLoggerHook', '_load_checkpoint',
'load_checkpoint', 'weights_to_cpu', 'save_checkpoint', 'Priority', 'load_state_dict', 'load_checkpoint', 'weights_to_cpu', 'save_checkpoint',
'get_priority', 'get_host_info', 'get_time_str', 'obj_from_dict', 'Priority', 'get_priority', 'get_host_info', 'get_time_str',
'init_dist', 'get_dist_info', 'master_only', 'OPTIMIZER_BUILDERS', 'obj_from_dict', 'init_dist', 'get_dist_info', 'master_only',
'OPTIMIZERS', 'DefaultOptimizerConstructor', 'build_optimizer', 'OPTIMIZER_BUILDERS', 'OPTIMIZERS', 'DefaultOptimizerConstructor',
'build_optimizer_constructor' 'build_optimizer', 'build_optimizer_constructor', 'IterLoader',
'IterBasedRunner'
] ]
...@@ -5,6 +5,7 @@ import warnings ...@@ -5,6 +5,7 @@ import warnings
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
import torch import torch
from torch.optim import Optimizer
import mmcv import mmcv
from ..parallel import is_parallel_module from ..parallel import is_parallel_module
...@@ -31,12 +32,14 @@ class BaseRunner(metaclass=ABCMeta): ...@@ -31,12 +32,14 @@ class BaseRunner(metaclass=ABCMeta):
batch_processor (callable): A callable method that process a data batch_processor (callable): A callable method that process a data
batch. The interface of this method should be batch. The interface of this method should be
`batch_processor(model, data, train_mode) -> dict` `batch_processor(model, data, train_mode) -> dict`
optimizer (dict or :obj:`torch.optim.Optimizer`): If it is a dict, optimizer (dict or :obj:`torch.optim.Optimizer`): It can be either an
runner will construct an optimizer according to it. optimizer (in most cases) or a dict of optimizers (in models that
requires more than one optimizer, e.g., GAN).
work_dir (str, optional): The working directory to save checkpoints work_dir (str, optional): The working directory to save checkpoints
and logs. Defaults to None. and logs. Defaults to None.
logger (:obj:`logging.Logger`): Logger used during training. logger (:obj:`logging.Logger`): Logger used during training.
Defaults to None. Defaults to None. (The default value is just for backward
compatibility)
meta (dict | None): A dict records some import information such as meta (dict | None): A dict records some import information such as
environment info and seed, which will be logged in logger hook. environment info and seed, which will be logged in logger hook.
Defaults to None. Defaults to None.
...@@ -67,9 +70,34 @@ class BaseRunner(metaclass=ABCMeta): ...@@ -67,9 +70,34 @@ class BaseRunner(metaclass=ABCMeta):
'cannot be both available.') 'cannot be both available.')
else: else:
assert hasattr(model, 'train_step') assert hasattr(model, 'train_step')
# check the type of `optimizer`
if isinstance(optimizer, dict):
for name, optim in optimizer.items():
if not isinstance(optim, Optimizer):
raise TypeError(
f'optimizer must be a dict of torch.optim.Optimizers, '
f'but optimizer["{name}"] is a {type(optim)}')
elif not isinstance(optimizer, Optimizer) and optimizer is not None:
raise TypeError(
f'optimizer must be a torch.optim.Optimizer object '
f'or dict or None, but got {type(optimizer)}')
# check the type of `logger`
if not isinstance(logger, logging.Logger):
raise TypeError(f'logger must be a logging.Logger object, '
f'but got {type(logger)}')
# check the type of `meta`
if meta is not None and not isinstance(meta, dict):
raise TypeError(
f'meta must be a dict or None, but got {type(meta)}')
self.model = model self.model = model
self.batch_processor = batch_processor self.batch_processor = batch_processor
self.optimizer = optimizer self.optimizer = optimizer
self.logger = logger
self.meta = meta
# create work_dir # create work_dir
if mmcv.is_str(work_dir): if mmcv.is_str(work_dir):
...@@ -86,13 +114,6 @@ class BaseRunner(metaclass=ABCMeta): ...@@ -86,13 +114,6 @@ class BaseRunner(metaclass=ABCMeta):
else: else:
self._model_name = self.model.__class__.__name__ self._model_name = self.model.__class__.__name__
assert logging is not None
self.logger = logger
if meta is not None:
assert isinstance(meta, dict), '"meta" must be a dict or None'
self.meta = meta
self._rank, self._world_size = get_dist_info() self._rank, self._world_size = get_dist_info()
self.timestamp = get_time_str() self.timestamp = get_time_str()
self.mode = None self.mode = None
...@@ -176,30 +197,50 @@ class BaseRunner(metaclass=ABCMeta): ...@@ -176,30 +197,50 @@ class BaseRunner(metaclass=ABCMeta):
"""Get current learning rates. """Get current learning rates.
Returns: Returns:
list: Current learning rate of all param groups. list[float] | dict[str, list[float]]: Current learning rates of all
param groups. If the runner has a dict of optimizers, this
method will return a dict.
""" """
if self.optimizer is None: if isinstance(self.optimizer, torch.optim.Optimizer):
lr = [group['lr'] for group in self.optimizer.param_groups]
elif isinstance(self.optimizer, dict):
lr = dict()
for name, optim in self.optimizer.items():
lr[name] = [group['lr'] for group in optim.param_groups]
else:
raise RuntimeError( raise RuntimeError(
'lr is not applicable because optimizer does not exist.') 'lr is not applicable because optimizer does not exist.')
return [group['lr'] for group in self.optimizer.param_groups] return lr
def current_momentum(self): def current_momentum(self):
"""Get current momentums. """Get current momentums.
Returns: Returns:
list: Current momentum of all param groups. list[float] | dict[str, list[float]]: Current momentums of all
param groups. If the runner has a dict of optimizers, this
method will return a dict.
""" """
def _get_momentum(optimizer):
momentums = []
for group in optimizer.param_groups:
if 'momentum' in group.keys():
momentums.append(group['momentum'])
elif 'betas' in group.keys():
momentums.append(group['betas'][0])
else:
momentums.append(0)
return momentums
if self.optimizer is None: if self.optimizer is None:
raise RuntimeError( raise RuntimeError(
'momentum is not applicable because optimizer does not exist.') 'momentum is not applicable because optimizer does not exist.')
momentums = [] elif isinstance(self.optimizer, torch.optim.Optimizer):
for group in self.optimizer.param_groups: momentums = _get_momentum(self.optimizer)
if 'momentum' in group.keys(): elif isinstance(self.optimizer, dict):
momentums.append(group['momentum']) momentums = dict()
elif 'betas' in group.keys(): for name, optim in self.optimizer.items():
momentums.append(group['betas'][0]) momentums[name] = _get_momentum(optim)
else:
momentums.append(0)
return momentums return momentums
def register_hook(self, hook, priority='NORMAL'): def register_hook(self, hook, priority='NORMAL'):
......
...@@ -9,10 +9,12 @@ from importlib import import_module ...@@ -9,10 +9,12 @@ from importlib import import_module
import torch import torch
import torchvision import torchvision
from torch.optim import Optimizer
from torch.utils import model_zoo from torch.utils import model_zoo
import mmcv import mmcv
from ..fileio import load as load_file from ..fileio import load as load_file
from ..parallel import is_parallel_module
from ..utils import mkdir_or_exist from ..utils import mkdir_or_exist
from .dist_utils import get_dist_info from .dist_utils import get_dist_info
...@@ -59,6 +61,10 @@ def load_state_dict(module, state_dict, strict=False, logger=None): ...@@ -59,6 +61,10 @@ def load_state_dict(module, state_dict, strict=False, logger=None):
# use _load_from_state_dict to enable checkpoint version control # use _load_from_state_dict to enable checkpoint version control
def load(module, prefix=''): def load(module, prefix=''):
# recursively check parallel module in case that the model has a
# complicated structure, e.g., nn.Module(nn.Module(DDP))
if is_parallel_module(module):
module = module.module
local_metadata = {} if metadata is None else metadata.get( local_metadata = {} if metadata is None else metadata.get(
prefix[:-1], {}) prefix[:-1], {})
module._load_from_state_dict(state_dict, prefix, local_metadata, True, module._load_from_state_dict(state_dict, prefix, local_metadata, True,
...@@ -228,10 +234,7 @@ def load_checkpoint(model, ...@@ -228,10 +234,7 @@ def load_checkpoint(model,
if list(state_dict.keys())[0].startswith('module.'): if list(state_dict.keys())[0].startswith('module.'):
state_dict = {k[7:]: v for k, v in checkpoint['state_dict'].items()} state_dict = {k[7:]: v for k, v in checkpoint['state_dict'].items()}
# load state_dict # load state_dict
if hasattr(model, 'module'): load_state_dict(model, state_dict, strict, logger)
load_state_dict(model.module, state_dict, strict, logger)
else:
load_state_dict(model, state_dict, strict, logger)
return checkpoint return checkpoint
...@@ -269,15 +272,20 @@ def save_checkpoint(model, filename, optimizer=None, meta=None): ...@@ -269,15 +272,20 @@ def save_checkpoint(model, filename, optimizer=None, meta=None):
meta.update(mmcv_version=mmcv.__version__, time=time.asctime()) meta.update(mmcv_version=mmcv.__version__, time=time.asctime())
mmcv.mkdir_or_exist(osp.dirname(filename)) mmcv.mkdir_or_exist(osp.dirname(filename))
if hasattr(model, 'module'): if is_parallel_module(model):
model = model.module model = model.module
checkpoint = { checkpoint = {
'meta': meta, 'meta': meta,
'state_dict': weights_to_cpu(model.state_dict()) 'state_dict': weights_to_cpu(model.state_dict())
} }
if optimizer is not None: # save optimizer state dict in the checkpoint
if isinstance(optimizer, Optimizer):
checkpoint['optimizer'] = optimizer.state_dict() checkpoint['optimizer'] = optimizer.state_dict()
elif isinstance(optimizer, dict):
checkpoint['optimizer'] = {}
for name, optim in optimizer.items():
checkpoint['optimizer'][name] = optim.state_dict()
# immediately flush buffer # immediately flush buffer
with open(filename, 'wb') as f: with open(filename, 'wb') as f:
torch.save(checkpoint, f) torch.save(checkpoint, f)
......
...@@ -7,14 +7,34 @@ from .hook import HOOKS, Hook ...@@ -7,14 +7,34 @@ from .hook import HOOKS, Hook
@HOOKS.register_module() @HOOKS.register_module()
class CheckpointHook(Hook): class CheckpointHook(Hook):
"""Save checkpoints periodically.
Args:
interval (int): The saving period. If ``by_epoch=True``, interval
indicates epochs, otherwise it indicates iterations.
Default: -1, which means "never".
by_epoch (bool): Saving checkpoints by epoch or by iteration.
Default: True.
save_optimizer (bool): Whether to save optimizer state_dict in the
checkpoint. It is usually used for resuming experiments.
Default: True.
out_dir (str, optional): The directory to save checkpoints. If not
specified, ``runner.work_dir`` will be used by default.
max_keep_ckpts (int, optional): The maximum checkpoints to keep.
In some cases we want only the latest few checkpoints and would
like to delete old ones to save the disk space.
Default: -1, which means unlimited.
"""
def __init__(self, def __init__(self,
interval=-1, interval=-1,
by_epoch=True,
save_optimizer=True, save_optimizer=True,
out_dir=None, out_dir=None,
max_keep_ckpts=-1, max_keep_ckpts=-1,
**kwargs): **kwargs):
self.interval = interval self.interval = interval
self.by_epoch = by_epoch
self.save_optimizer = save_optimizer self.save_optimizer = save_optimizer
self.out_dir = out_dir self.out_dir = out_dir
self.max_keep_ckpts = max_keep_ckpts self.max_keep_ckpts = max_keep_ckpts
...@@ -22,9 +42,10 @@ class CheckpointHook(Hook): ...@@ -22,9 +42,10 @@ class CheckpointHook(Hook):
@master_only @master_only
def after_train_epoch(self, runner): def after_train_epoch(self, runner):
if not self.every_n_epochs(runner, self.interval): if not self.by_epoch or not self.every_n_epochs(runner, self.interval):
return return
runner.logger.info(f'Saving checkpoint at {runner.epoch + 1} epochs')
if not self.out_dir: if not self.out_dir:
self.out_dir = runner.work_dir self.out_dir = runner.work_dir
runner.save_checkpoint( runner.save_checkpoint(
...@@ -41,3 +62,29 @@ class CheckpointHook(Hook): ...@@ -41,3 +62,29 @@ class CheckpointHook(Hook):
os.remove(ckpt_path) os.remove(ckpt_path)
else: else:
break break
@master_only
def after_train_iter(self, runner):
if self.by_epoch or not self.every_n_iters(runner, self.interval):
return
runner.logger.info(
f'Saving checkpoint at {runner.iter + 1} iterations')
if not self.out_dir:
self.out_dir = runner.work_dir
runner.save_checkpoint(
self.out_dir, save_optimizer=self.save_optimizer, **self.args)
# remove other checkpoints
if self.max_keep_ckpts > 0:
filename_tmpl = self.args.get('filename_tmpl', 'iter_{}.pth')
current_iter = runner.iter + 1
for _iter in range(
current_iter - self.max_keep_ckpts * self.interval, 0,
-self.interval):
ckpt_path = os.path.join(self.out_dir,
filename_tmpl.format(_iter))
if os.path.exists(ckpt_path):
os.remove(ckpt_path)
else:
break
...@@ -71,8 +71,22 @@ class PaviLoggerHook(LoggerHook): ...@@ -71,8 +71,22 @@ class PaviLoggerHook(LoggerHook):
for tag, val in runner.log_buffer.output.items(): for tag, val in runner.log_buffer.output.items():
if tag not in ['time', 'data_time'] and is_scalar(val): if tag not in ['time', 'data_time'] and is_scalar(val):
tags[tag] = val tags[tag] = val
tags['learning_rate'] = runner.current_lr()[0] # add learning rate
tags['momentum'] = runner.current_momentum()[0] lrs = runner.current_lr()
if isinstance(lrs, dict):
for name, value in lrs.items():
tags[f'learning_rate/{name}'] = value[0]
else:
tags['learning_rate'] = lrs[0]
# add momentum
momentums = runner.current_momentum()
if isinstance(momentums, dict):
for name, value in momentums.items():
tags[f'momentum/{name}'] = value[0]
else:
tags['momentum'] = momentums[0]
if tags: if tags:
self.writer.add_scalars(runner.mode, tags, runner.iter) self.writer.add_scalars(runner.mode, tags, runner.iter)
......
...@@ -53,10 +53,22 @@ class TensorboardLoggerHook(LoggerHook): ...@@ -53,10 +53,22 @@ class TensorboardLoggerHook(LoggerHook):
else: else:
self.writer.add_scalar(tag, runner.log_buffer.output[var], self.writer.add_scalar(tag, runner.log_buffer.output[var],
runner.iter) runner.iter)
self.writer.add_scalar('learning_rate', # add learning rate
runner.current_lr()[0], runner.iter) lrs = runner.current_lr()
self.writer.add_scalar('momentum', if isinstance(lrs, dict):
runner.current_momentum()[0], runner.iter) for name, value in lrs.items():
self.writer.add_scalar(f'learning_rate/{name}', value[0],
runner.iter)
else:
self.writer.add_scalar('learning_rate', lrs[0], runner.iter)
# add momentum
momentums = runner.current_momentum()
if isinstance(momentums, dict):
for name, value in momentums.items():
self.writer.add_scalar(f'momentum/{name}', value[0],
runner.iter)
else:
self.writer.add_scalar('momentum', momentums[0], runner.iter)
@master_only @master_only
def after_run(self, runner): def after_run(self, runner):
......
...@@ -19,6 +19,7 @@ class TextLoggerHook(LoggerHook): ...@@ -19,6 +19,7 @@ class TextLoggerHook(LoggerHook):
saved in json file. saved in json file.
Args: Args:
by_epoch (bool): Whether EpochBasedRunner is used.
interval (int): Logging interval (every k iterations). interval (int): Logging interval (every k iterations).
ignore_last (bool): Ignore the log of last iterations in each epoch ignore_last (bool): Ignore the log of last iterations in each epoch
if less than `interval`. if less than `interval`.
...@@ -29,11 +30,13 @@ class TextLoggerHook(LoggerHook): ...@@ -29,11 +30,13 @@ class TextLoggerHook(LoggerHook):
""" """
def __init__(self, def __init__(self,
by_epoch=True,
interval=10, interval=10,
ignore_last=True, ignore_last=True,
reset_flag=False, reset_flag=False,
interval_exp_name=1000): interval_exp_name=1000):
super(TextLoggerHook, self).__init__(interval, ignore_last, reset_flag) super(TextLoggerHook, self).__init__(interval, ignore_last, reset_flag)
self.by_epoch = by_epoch
self.time_sec_tot = 0 self.time_sec_tot = 0
self.interval_exp_name = interval_exp_name self.interval_exp_name = interval_exp_name
...@@ -55,17 +58,32 @@ class TextLoggerHook(LoggerHook): ...@@ -55,17 +58,32 @@ class TextLoggerHook(LoggerHook):
return mem_mb.item() return mem_mb.item()
def _log_info(self, log_dict, runner): def _log_info(self, log_dict, runner):
# print exp name for users to distinguish experiments
# at every ``interval_exp_name`` iterations and the end of each epoch
if runner.meta is not None and 'exp_name' in runner.meta: if runner.meta is not None and 'exp_name' in runner.meta:
if (self.every_n_inner_iters( if (self.every_n_inner_iters(runner, self.interval_exp_name)) or (
runner, self.by_epoch and self.end_of_epoch(runner)):
self.interval_exp_name)) or self.end_of_epoch(runner): exp_info = f'Exp name: {runner.meta["exp_name"]}'
exp_info = f"Exp name: {runner.meta['exp_name']}"
runner.logger.info(exp_info) runner.logger.info(exp_info)
if runner.mode == 'train': if runner.mode == 'train':
log_str = f'Epoch [{log_dict["epoch"]}]' \ if isinstance(log_dict['lr'], dict):
f'[{log_dict["iter"]}/{len(runner.data_loader)}]\t' \ lr_str = []
f'lr: {log_dict["lr"]:.5f}, ' for k, val in log_dict['lr'].items():
lr_str.append(f'lr_{k}: {val:.3e}')
lr_str = ' '.join(lr_str)
else:
lr_str = f'lr: {log_dict["lr"]:.3e}'
# by epoch: Epoch [4][100/1000]
# by iter: Iter [100/100000]
if self.by_epoch:
log_str = f'Epoch [{log_dict["epoch"]}]' \
f'[{log_dict["iter"]}/{len(runner.data_loader)}]\t'
else:
log_str = f'Iter [{log_dict["iter"]}/{runner.max_iters}]\t'
log_str += f'{lr_str}, '
if 'time' in log_dict.keys(): if 'time' in log_dict.keys():
self.time_sec_tot += (log_dict['time'] * self.interval) self.time_sec_tot += (log_dict['time'] * self.interval)
time_sec_avg = self.time_sec_tot / ( time_sec_avg = self.time_sec_tot / (
...@@ -79,8 +97,12 @@ class TextLoggerHook(LoggerHook): ...@@ -79,8 +97,12 @@ class TextLoggerHook(LoggerHook):
if torch.cuda.is_available(): if torch.cuda.is_available():
log_str += f'memory: {log_dict["memory"]}, ' log_str += f'memory: {log_dict["memory"]}, '
else: else:
log_str = f'Epoch({log_dict["mode"]}) ' \ if self.by_epoch:
f'[{log_dict["epoch"] - 1}][{log_dict["iter"]}]\t' log_str = f'Epoch({log_dict["mode"]}) ' \
f'[{log_dict["epoch"] - 1}][{log_dict["iter"]}]\t'
else:
log_str = f'Iter({log_dict["mode"]}) [{log_dict["iter"]}]\t'
log_items = [] log_items = []
for name, val in log_dict.items(): for name, val in log_dict.items():
# TODO: resolve this hack # TODO: resolve this hack
...@@ -94,6 +116,7 @@ class TextLoggerHook(LoggerHook): ...@@ -94,6 +116,7 @@ class TextLoggerHook(LoggerHook):
val = f'{val:.4f}' val = f'{val:.4f}'
log_items.append(f'{name}: {val}') log_items.append(f'{name}: {val}')
log_str += ', '.join(log_items) log_str += ', '.join(log_items)
runner.logger.info(log_str) runner.logger.info(log_str)
def _dump_log(self, log_dict, runner): def _dump_log(self, log_dict, runner):
...@@ -123,7 +146,16 @@ class TextLoggerHook(LoggerHook): ...@@ -123,7 +146,16 @@ class TextLoggerHook(LoggerHook):
log_dict['epoch'] = runner.epoch + 1 log_dict['epoch'] = runner.epoch + 1
log_dict['iter'] = runner.inner_iter + 1 log_dict['iter'] = runner.inner_iter + 1
# only record lr of the first param group # only record lr of the first param group
log_dict['lr'] = runner.current_lr()[0] cur_lr = runner.current_lr()
if isinstance(cur_lr, list):
log_dict['lr'] = cur_lr[0]
else:
assert isinstance(cur_lr, dict)
log_dict['lr'] = {}
for k, lr_ in cur_lr.items():
assert isinstance(lr_, list)
log_dict['lr'].update({k: lr_[0]})
if mode == 'train': if mode == 'train':
log_dict['time'] = runner.log_buffer.output['time'] log_dict['time'] = runner.log_buffer.output['time']
log_dict['data_time'] = runner.log_buffer.output['data_time'] log_dict['data_time'] = runner.log_buffer.output['data_time']
......
...@@ -55,14 +55,31 @@ class LrUpdaterHook(Hook): ...@@ -55,14 +55,31 @@ class LrUpdaterHook(Hook):
self.regular_lr = [] # expected lr if no warming up is performed self.regular_lr = [] # expected lr if no warming up is performed
def _set_lr(self, runner, lr_groups): def _set_lr(self, runner, lr_groups):
for param_group, lr in zip(runner.optimizer.param_groups, lr_groups): if isinstance(runner.optimizer, dict):
param_group['lr'] = lr for k, optim in runner.optimizer.items():
for param_group, lr in zip(optim.param_groups, lr_groups[k]):
param_group['lr'] = lr
else:
for param_group, lr in zip(runner.optimizer.param_groups,
lr_groups):
param_group['lr'] = lr
def get_lr(self, runner, base_lr): def get_lr(self, runner, base_lr):
raise NotImplementedError raise NotImplementedError
def get_regular_lr(self, runner): def get_regular_lr(self, runner):
return [self.get_lr(runner, _base_lr) for _base_lr in self.base_lr] if isinstance(runner.optimizer, dict):
lr_groups = {}
for k in runner.optimizer.keys():
_lr_group = [
self.get_lr(runner, _base_lr)
for _base_lr in self.base_lr[k]
]
lr_groups.update({k: _lr_group})
return lr_groups
else:
return [self.get_lr(runner, _base_lr) for _base_lr in self.base_lr]
def get_warmup_lr(self, cur_iters): def get_warmup_lr(self, cur_iters):
if self.warmup == 'constant': if self.warmup == 'constant':
...@@ -78,11 +95,21 @@ class LrUpdaterHook(Hook): ...@@ -78,11 +95,21 @@ class LrUpdaterHook(Hook):
def before_run(self, runner): def before_run(self, runner):
# NOTE: when resuming from a checkpoint, if 'initial_lr' is not saved, # NOTE: when resuming from a checkpoint, if 'initial_lr' is not saved,
# it will be set according to the optimizer params # it will be set according to the optimizer params
for group in runner.optimizer.param_groups: if isinstance(runner.optimizer, dict):
group.setdefault('initial_lr', group['lr']) self.base_lr = {}
self.base_lr = [ for k, optim in runner.optimizer.items():
group['initial_lr'] for group in runner.optimizer.param_groups for group in optim.param_groups:
] group.setdefault('initial_lr', group['lr'])
_base_lr = [
group['initial_lr'] for group in optim.param_groups
]
self.base_lr.update({k: _base_lr})
else:
for group in runner.optimizer.param_groups:
group.setdefault('initial_lr', group['lr'])
self.base_lr = [
group['initial_lr'] for group in runner.optimizer.param_groups
]
def before_train_epoch(self, runner): def before_train_epoch(self, runner):
if not self.by_epoch: if not self.by_epoch:
...@@ -213,6 +240,7 @@ class CosineAnealingLrUpdaterHook(LrUpdaterHook): ...@@ -213,6 +240,7 @@ class CosineAnealingLrUpdaterHook(LrUpdaterHook):
else: else:
progress = runner.iter progress = runner.iter
max_progress = runner.max_iters max_progress = runner.max_iters
if self.min_lr_ratio is not None: if self.min_lr_ratio is not None:
target_lr = base_lr * self.min_lr_ratio target_lr = base_lr * self.min_lr_ratio
else: else:
...@@ -224,7 +252,7 @@ class CosineAnealingLrUpdaterHook(LrUpdaterHook): ...@@ -224,7 +252,7 @@ class CosineAnealingLrUpdaterHook(LrUpdaterHook):
class CyclicLrUpdaterHook(LrUpdaterHook): class CyclicLrUpdaterHook(LrUpdaterHook):
"""Cyclic LR Scheduler """Cyclic LR Scheduler
Implemet the cyclical learning rate policy (CLR) described in Implement the cyclical learning rate policy (CLR) described in
https://arxiv.org/pdf/1506.01186.pdf https://arxiv.org/pdf/1506.01186.pdf
Different from the original paper, we use cosine anealing rather than Different from the original paper, we use cosine anealing rather than
......
# Copyright (c) Open-MMLab. All rights reserved.
import os.path as osp
import time
import torch
from torch.optim import Optimizer
import mmcv
from .base_runner import BaseRunner
from .checkpoint import save_checkpoint
from .hooks import IterTimerHook
from .utils import get_host_info
class IterLoader:
def __init__(self, dataloader):
self._dataloader = dataloader
self.iter_loader = iter(self._dataloader)
self._epoch = 0
@property
def epoch(self):
return self._epoch
def __next__(self):
try:
data = next(self.iter_loader)
except StopIteration:
self._epoch += 1
if hasattr(self._dataloader.sampler, 'set_epoch'):
self._dataloader.sampler.set_epoch(self._epoch)
self.iter_loader = iter(self._dataloader)
data = next(self.iter_loader)
return data
def __len__(self):
return len(self._dataloader)
class IterBasedRunner(BaseRunner):
"""Iteration-based Runner.
This runner train models iteration by iteration.
"""
def train(self, data_loader, **kwargs):
self.model.train()
self.mode = 'train'
self.data_loader = data_loader
self._epoch = data_loader.epoch
self.call_hook('before_train_iter')
data_batch = next(data_loader)
outputs = self.model.train_step(data_batch, self.optimizer, **kwargs)
if not isinstance(outputs, dict):
raise TypeError('model.train_step() must return a dict')
if 'log_vars' in outputs:
self.log_buffer.update(outputs['log_vars'], outputs['num_samples'])
self.outputs = outputs
self.call_hook('after_train_iter')
self._inner_iter += 1
self._iter += 1
def val(self, data_loader, **kwargs):
self.model.eval()
self.mode = 'val'
self._inner_iter = 0
self.data_loader = data_loader
self.call_hook('before_val_iter')
data_batch = next(data_loader)
outputs = self.model.val_step(data_batch, self.optimizer, **kwargs)
if not isinstance(outputs, dict):
raise TypeError('model.val_step() must return a dict')
if 'log_vars' in outputs:
self.log_buffer.update(outputs['log_vars'], outputs['num_samples'])
self.outputs = outputs
self.call_hook('after_val_iter')
self._inner_iter += 1
def run(self, data_loaders, workflow, max_iters, **kwargs):
"""Start running.
Args:
data_loaders (list[:obj:`DataLoader`]): Dataloaders for training
and validation.
workflow (list[tuple]): A list of (phase, iters) to specify the
running order and iterations. E.g, [('train', 10000),
('val', 1000)] means running 10000 iterations for training and
1000 iterations for validation, iteratively.
max_iters (int): Total training iterations.
"""
assert isinstance(data_loaders, list)
assert mmcv.is_list_of(workflow, tuple)
assert len(data_loaders) == len(workflow)
self._max_iters = max_iters
work_dir = self.work_dir if self.work_dir is not None else 'NONE'
self.logger.info('Start running, host: %s, work_dir: %s',
get_host_info(), work_dir)
self.logger.info('workflow: %s, max: %d iters', workflow, max_iters)
self.call_hook('before_run')
iter_loaders = [IterLoader(x) for x in data_loaders]
self.call_hook('before_epoch')
while self.iter < max_iters:
for i, flow in enumerate(workflow):
mode, iters = flow
if not isinstance(mode, str) or not hasattr(self, mode):
raise ValueError(
'runner has no method named "{}" to run a workflow'.
format(mode))
iter_runner = getattr(self, mode)
for _ in range(iters):
if mode == 'train' and self.iter >= max_iters:
return
iter_runner(iter_loaders[i], **kwargs)
time.sleep(1) # wait for some hooks like loggers to finish
self.call_hook('after_epoch')
self.call_hook('after_run')
def resume(self,
checkpoint,
resume_optimizer=True,
map_location='default'):
"""Resume model from checkpoint.
Args:
checkpoint (str): Checkpoint to resume from.
resume_optimizer (bool, optional): Whether resume the optimizer(s)
if the checkpoint file includes optimizer(s). Default to True.
map_location (str, optional): Same as :func:`torch.load`.
Default to 'default'.
"""
if map_location == 'default':
device_id = torch.cuda.current_device()
checkpoint = self.load_checkpoint(
checkpoint,
map_location=lambda storage, loc: storage.cuda(device_id))
else:
checkpoint = self.load_checkpoint(
checkpoint, map_location=map_location)
self._epoch = checkpoint['meta']['epoch']
self._iter = checkpoint['meta']['iter']
self._inner_iter = checkpoint['meta']['iter']
if 'optimizer' in checkpoint and resume_optimizer:
if isinstance(self.optimizer, Optimizer):
self.optimizer.load_state_dict(checkpoint['optimizer'])
elif isinstance(self.optimizer, dict):
for k in self.optimizer.keys():
self.optimizer[k].load_state_dict(
checkpoint['optimizer'][k])
self.logger.info(f'resumed from epoch: {self.epoch}, iter {self.iter}')
def save_checkpoint(self,
out_dir,
filename_tmpl='iter_{}.pth',
meta=None,
save_optimizer=True,
create_symlink=True):
"""Save checkpoint to file.
Args:
out_dir (str): Directory to save checkpoint files.
filename_tmpl (str, optional): Checkpoint file template.
Defaults to 'iter_{}.pth'.
meta (dict, optional): Metadata to be saved in checkpoint.
Defaults to None.
save_optimizer (bool, optional): Whether save optimizer.
Defaults to True.
create_symlink (bool, optional): Whether create symlink to the
latest checkpoint file. Defaults to True.
"""
if meta is None:
meta = dict(iter=self.iter + 1, epoch=self.epoch + 1)
elif isinstance(meta, dict):
meta.update(iter=self.iter + 1, epoch=self.epoch + 1)
else:
raise TypeError(
f'meta should be a dict or None, but got {type(meta)}')
meta.update(self.meta)
filename = filename_tmpl.format(self.iter + 1)
filepath = osp.join(out_dir, filename)
optimizer = self.optimizer if save_optimizer else None
save_checkpoint(self.model, filepath, optimizer=optimizer, meta=meta)
# in some environments, `os.symlink` is not supported, you may need to
# set `create_symlink` to False
if create_symlink:
mmcv.symlink(filename, osp.join(out_dir, 'latest.pth'))
def register_training_hooks(self,
lr_config,
optimizer_config=None,
checkpoint_config=None,
log_config=None,
momentum_config=None):
"""Register default hooks for iter-based training.
Default hooks include:
- LrUpdaterHook
- MomentumUpdaterHook
- OptimizerStepperHook
- CheckpointSaverHook
- IterTimerHook
- LoggerHook(s)
"""
if checkpoint_config is not None:
checkpoint_config.setdefault('by_epoch', False)
if lr_config is not None:
lr_config.setdefault('by_epoch', False)
self.register_lr_hook(lr_config)
self.register_momentum_hook(momentum_config)
self.register_optimizer_hook(optimizer_config)
self.register_checkpoint_hook(checkpoint_config)
self.register_hook(IterTimerHook())
self.register_logger_hooks(log_config)
...@@ -39,22 +39,48 @@ def test_epoch_based_runner(): ...@@ -39,22 +39,48 @@ def test_epoch_based_runner():
def batch_processor(): def batch_processor():
pass pass
_ = EpochBasedRunner(model, batch_processor) _ = EpochBasedRunner(
model, batch_processor, logger=logging.getLogger())
with pytest.raises(TypeError): with pytest.raises(TypeError):
# batch_processor must be callable # batch_processor must be callable
model = OldStyleModel() model = OldStyleModel()
_ = EpochBasedRunner(model, batch_processor=0) _ = EpochBasedRunner(
model, batch_processor=0, logger=logging.getLogger())
with pytest.raises(TypeError):
# optimizer must be a optimizer or a dict of optimizers
model = Model()
optimizer = 'NotAOptimizer'
_ = EpochBasedRunner(
model, optimizer=optimizer, logger=logging.getLogger())
with pytest.raises(TypeError):
# optimizer must be a optimizer or a dict of optimizers
model = Model()
optimizers = dict(optim1=torch.optim.Adam(), optim2='NotAOptimizer')
_ = EpochBasedRunner(
model, optimizer=optimizers, logger=logging.getLogger())
with pytest.raises(TypeError):
# logger must be a logging.Logger
model = Model()
_ = EpochBasedRunner(model, logger=None)
with pytest.raises(TypeError):
# meta must be a dict or None
model = Model()
_ = EpochBasedRunner(model, logger=logging.getLogger(), meta=['list'])
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
# model must implement the method train_step() # model must implement the method train_step()
model = OldStyleModel() model = OldStyleModel()
_ = EpochBasedRunner(model) _ = EpochBasedRunner(model, logger=logging.getLogger())
with pytest.raises(TypeError): with pytest.raises(TypeError):
# work_dir must be a str or None # work_dir must be a str or None
model = Model() model = Model()
_ = EpochBasedRunner(model, work_dir=1) _ = EpochBasedRunner(model, work_dir=1, logger=logging.getLogger())
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
# batch_processor and train_step() cannot be both set # batch_processor and train_step() cannot be both set
...@@ -63,7 +89,8 @@ def test_epoch_based_runner(): ...@@ -63,7 +89,8 @@ def test_epoch_based_runner():
pass pass
model = Model() model = Model()
_ = EpochBasedRunner(model, batch_processor) _ = EpochBasedRunner(
model, batch_processor, logger=logging.getLogger())
# test work_dir # test work_dir
model = Model() model = Model()
...@@ -71,9 +98,9 @@ def test_epoch_based_runner(): ...@@ -71,9 +98,9 @@ def test_epoch_based_runner():
dir_name = ''.join( dir_name = ''.join(
[random.choice(string.ascii_letters) for _ in range(10)]) [random.choice(string.ascii_letters) for _ in range(10)])
work_dir = osp.join(temp_root, dir_name) work_dir = osp.join(temp_root, dir_name)
_ = EpochBasedRunner(model, work_dir=work_dir) _ = EpochBasedRunner(model, work_dir=work_dir, logger=logging.getLogger())
assert osp.isdir(work_dir) assert osp.isdir(work_dir)
_ = EpochBasedRunner(model, work_dir=work_dir) _ = EpochBasedRunner(model, work_dir=work_dir, logger=logging.getLogger())
assert osp.isdir(work_dir) assert osp.isdir(work_dir)
os.removedirs(work_dir) os.removedirs(work_dir)
...@@ -84,7 +111,7 @@ def test_runner_with_parallel(): ...@@ -84,7 +111,7 @@ def test_runner_with_parallel():
pass pass
model = MMDataParallel(OldStyleModel()) model = MMDataParallel(OldStyleModel())
_ = EpochBasedRunner(model, batch_processor) _ = EpochBasedRunner(model, batch_processor, logger=logging.getLogger())
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
# batch_processor and train_step() cannot be both set # batch_processor and train_step() cannot be both set
...@@ -93,7 +120,8 @@ def test_runner_with_parallel(): ...@@ -93,7 +120,8 @@ def test_runner_with_parallel():
pass pass
model = MMDataParallel(Model()) model = MMDataParallel(Model())
_ = EpochBasedRunner(model, batch_processor) _ = EpochBasedRunner(
model, batch_processor, logger=logging.getLogger())
def test_save_checkpoint(): def test_save_checkpoint():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment