Unverified Commit 6b52e9b5 authored by David de la Iglesia Castro's avatar David de la Iglesia Castro Committed by GitHub
Browse files

Add runner builder (#570)

* Add build_runner

* Parametrize test_runner

* Add imports to runner __init__

* Refactor max_iters and max_epochs from run to init

* Add assertion error messages

* Add test_builder

* Make change retro-compatible

* Raise ValueError if max_epochs and max_iters
parent 2bb1160e
# Copyright (c) Open-MMLab. All rights reserved.
from .base_runner import BaseRunner
from .builder import RUNNERS, build_runner
from .checkpoint import (_load_checkpoint, load_checkpoint, load_state_dict,
save_checkpoint, weights_to_cpu)
from .dist_utils import get_dist_info, init_dist, master_only
......@@ -30,5 +31,6 @@ __all__ = [
'OPTIMIZER_BUILDERS', 'OPTIMIZERS', 'DefaultOptimizerConstructor',
'build_optimizer', 'build_optimizer_constructor', 'IterLoader',
'set_random_seed', 'auto_fp16', 'force_fp32', 'wrap_fp16_model',
'Fp16OptimizerHook', 'SyncBuffersHook', 'EMAHook'
'Fp16OptimizerHook', 'SyncBuffersHook', 'EMAHook', 'build_runner',
'RUNNERS'
]
......@@ -43,6 +43,8 @@ class BaseRunner(metaclass=ABCMeta):
meta (dict | None): A dict records some import information such as
environment info and seed, which will be logged in logger hook.
Defaults to None.
max_epochs (int, optional): Total training epochs.
max_iters (int, optional): Total training iterations.
"""
def __init__(self,
......@@ -51,7 +53,9 @@ class BaseRunner(metaclass=ABCMeta):
optimizer=None,
work_dir=None,
logger=None,
meta=None):
meta=None,
max_iters=None,
max_epochs=None):
if batch_processor is not None:
if not callable(batch_processor):
raise TypeError('batch_processor must be callable, '
......@@ -121,8 +125,13 @@ class BaseRunner(metaclass=ABCMeta):
self._epoch = 0
self._iter = 0
self._inner_iter = 0
self._max_epochs = 0
self._max_iters = 0
if max_epochs is not None and max_iters is not None:
raise ValueError(
'Only one of `max_epochs` or `max_iters` can be set.')
self._max_epochs = max_epochs
self._max_iters = max_iters
# TODO: Redesign LogBuffer, it is not flexible and elegant enough
self.log_buffer = LogBuffer()
......
from ..utils import Registry, build_from_cfg
RUNNERS = Registry('runner')
def build_runner(cfg, default_args=None):
return build_from_cfg(cfg, RUNNERS, default_args=default_args)
......@@ -9,10 +9,12 @@ import torch
import mmcv
from .base_runner import BaseRunner
from .builder import RUNNERS
from .checkpoint import save_checkpoint
from .utils import get_host_info
@RUNNERS.register_module()
class EpochBasedRunner(BaseRunner):
"""Epoch-based Runner.
......@@ -67,7 +69,7 @@ class EpochBasedRunner(BaseRunner):
self.call_hook('after_val_epoch')
def run(self, data_loaders, workflow, max_epochs, **kwargs):
def run(self, data_loaders, workflow, max_epochs=None, **kwargs):
"""Start running.
Args:
......@@ -77,13 +79,19 @@ class EpochBasedRunner(BaseRunner):
running order and epochs. E.g, [('train', 2), ('val', 1)] means
running 2 epochs for training and 1 epoch for validation,
iteratively.
max_epochs (int): Total training epochs.
"""
assert isinstance(data_loaders, list)
assert mmcv.is_list_of(workflow, tuple)
assert len(data_loaders) == len(workflow)
if max_epochs is not None:
warnings.warn(
'setting max_epochs in run is deprecated, '
'please set max_epochs in runner_config', DeprecationWarning)
self._max_epochs = max_epochs
assert self._max_epochs is not None, (
'max_epochs must be specified during instantiation')
self._max_epochs = max_epochs
for i, flow in enumerate(workflow):
mode, epochs = flow
if mode == 'train':
......@@ -164,6 +172,7 @@ class EpochBasedRunner(BaseRunner):
shutil.copy(filename, dst_file)
@RUNNERS.register_module()
class Runner(EpochBasedRunner):
"""Deprecated name of EpochBasedRunner."""
......
......@@ -3,12 +3,14 @@ import os.path as osp
import platform
import shutil
import time
import warnings
import torch
from torch.optim import Optimizer
import mmcv
from .base_runner import BaseRunner
from .builder import RUNNERS
from .checkpoint import save_checkpoint
from .hooks import IterTimerHook
from .utils import get_host_info
......@@ -41,6 +43,7 @@ class IterLoader:
return len(self._dataloader)
@RUNNERS.register_module()
class IterBasedRunner(BaseRunner):
"""Iteration-based Runner.
......@@ -79,7 +82,7 @@ class IterBasedRunner(BaseRunner):
self.call_hook('after_val_iter')
self._inner_iter += 1
def run(self, data_loaders, workflow, max_iters, **kwargs):
def run(self, data_loaders, workflow, max_iters=None, **kwargs):
"""Start running.
Args:
......@@ -89,24 +92,30 @@ class IterBasedRunner(BaseRunner):
running order and iterations. E.g, [('train', 10000),
('val', 1000)] means running 10000 iterations for training and
1000 iterations for validation, iteratively.
max_iters (int): Total training iterations.
"""
assert isinstance(data_loaders, list)
assert mmcv.is_list_of(workflow, tuple)
assert len(data_loaders) == len(workflow)
if max_iters is not None:
warnings.warn(
'setting max_iters in run is deprecated, '
'please set max_iters in runner_config', DeprecationWarning)
self._max_iters = max_iters
assert self._max_iters is not None, (
'max_iters must be specified during instantiation')
self._max_iters = max_iters
work_dir = self.work_dir if self.work_dir is not None else 'NONE'
self.logger.info('Start running, host: %s, work_dir: %s',
get_host_info(), work_dir)
self.logger.info('workflow: %s, max: %d iters', workflow, max_iters)
self.logger.info('workflow: %s, max: %d iters', workflow,
self._max_iters)
self.call_hook('before_run')
iter_loaders = [IterLoader(x) for x in data_loaders]
self.call_hook('before_epoch')
while self.iter < max_iters:
while self.iter < self._max_iters:
for i, flow in enumerate(workflow):
self._inner_iter = 0
mode, iters = flow
......@@ -116,7 +125,7 @@ class IterBasedRunner(BaseRunner):
format(mode))
iter_runner = getattr(self, mode)
for _ in range(iters):
if mode == 'train' and self.iter >= max_iters:
if mode == 'train' and self.iter >= self._max_iters:
break
iter_runner(iter_loaders[i], **kwargs)
......
......@@ -17,9 +17,9 @@ import torch.nn as nn
from torch.nn.init import constant_
from torch.utils.data import DataLoader
from mmcv.runner import (CheckpointHook, EMAHook, EpochBasedRunner,
IterTimerHook, MlflowLoggerHook, PaviLoggerHook,
WandbLoggerHook)
from mmcv.runner import (CheckpointHook, EMAHook, IterTimerHook,
MlflowLoggerHook, PaviLoggerHook, WandbLoggerHook,
build_runner)
from mmcv.runner.hooks.lr_updater import CosineRestartLrUpdaterHook
......@@ -59,7 +59,7 @@ def test_ema_hook():
checkpointhook = CheckpointHook(interval=1, by_epoch=True)
runner.register_hook(emahook, priority='HIGHEST')
runner.register_hook(checkpointhook)
runner.run([loader, loader], [('train', 1), ('val', 1)], 1)
runner.run([loader, loader], [('train', 1), ('val', 1)])
checkpoint = torch.load(f'{runner.work_dir}/epoch_1.pth')
contain_ema_buffer = False
for name, value in checkpoint['state_dict'].items():
......@@ -74,12 +74,12 @@ def test_ema_hook():
work_dir = runner.work_dir
resume_ema_hook = EMAHook(
momentum=0.5, warm_up=0, resume_from=f'{work_dir}/epoch_1.pth')
runner = _build_demo_runner()
runner = _build_demo_runner(max_epochs=2)
runner.model = demo_model
runner.register_hook(resume_ema_hook, priority='HIGHEST')
checkpointhook = CheckpointHook(interval=1, by_epoch=True)
runner.register_hook(checkpointhook)
runner.run([loader, loader], [('train', 1), ('val', 1)], 2)
runner.run([loader, loader], [('train', 1), ('val', 1)])
checkpoint = torch.load(f'{runner.work_dir}/epoch_2.pth')
contain_ema_buffer = False
for name, value in checkpoint['state_dict'].items():
......@@ -101,7 +101,7 @@ def test_pavi_hook():
runner.meta = dict(config_dict=dict(lr=0.02, gpu_ids=range(1)))
hook = PaviLoggerHook(add_graph=False, add_last_ckpt=True)
runner.register_hook(hook)
runner.run([loader, loader], [('train', 1), ('val', 1)], 1)
runner.run([loader, loader], [('train', 1), ('val', 1)])
shutil.rmtree(runner.work_dir)
assert hasattr(hook, 'writer')
......@@ -119,7 +119,7 @@ def test_sync_buffers_hook():
loader = DataLoader(torch.ones((5, 2)))
runner = _build_demo_runner()
runner.register_hook_from_cfg(dict(type='SyncBuffersHook'))
runner.run([loader, loader], [('train', 1), ('val', 1)], 1)
runner.run([loader, loader], [('train', 1), ('val', 1)])
shutil.rmtree(runner.work_dir)
......@@ -151,7 +151,7 @@ def test_momentum_runner_hook():
# add pavi hook
hook = PaviLoggerHook(interval=1, add_graph=False, add_last_ckpt=True)
runner.register_hook(hook)
runner.run([loader], [('train', 1)], 1)
runner.run([loader], [('train', 1)])
shutil.rmtree(runner.work_dir)
# TODO: use a more elegant way to check values
......@@ -202,7 +202,7 @@ def test_cosine_runner_hook():
# add pavi hook
hook = PaviLoggerHook(interval=1, add_graph=False, add_last_ckpt=True)
runner.register_hook(hook)
runner.run([loader], [('train', 1)], 1)
runner.run([loader], [('train', 1)])
shutil.rmtree(runner.work_dir)
# TODO: use a more elegant way to check values
......@@ -261,7 +261,7 @@ def test_cosine_restart_lr_update_hook():
# add pavi hook
hook = PaviLoggerHook(interval=1, add_graph=False, add_last_ckpt=True)
runner.register_hook(hook)
runner.run([loader], [('train', 1)], 1)
runner.run([loader], [('train', 1)])
shutil.rmtree(runner.work_dir)
sys.modules['pavi'] = MagicMock()
......@@ -280,7 +280,7 @@ def test_cosine_restart_lr_update_hook():
# add pavi hook
hook = PaviLoggerHook(interval=1, add_graph=False, add_last_ckpt=True)
runner.register_hook(hook)
runner.run([loader], [('train', 1)], 1)
runner.run([loader], [('train', 1)])
shutil.rmtree(runner.work_dir)
# TODO: use a more elegant way to check values
......@@ -312,7 +312,7 @@ def test_mlflow_hook(log_model):
hook = MlflowLoggerHook(exp_name='test', log_model=log_model)
runner.register_hook(hook)
runner.run([loader, loader], [('train', 1), ('val', 1)], 1)
runner.run([loader, loader], [('train', 1), ('val', 1)])
shutil.rmtree(runner.work_dir)
hook.mlflow.set_experiment.assert_called_with('test')
......@@ -335,7 +335,7 @@ def test_wandb_hook():
loader = DataLoader(torch.ones((5, 2)))
runner.register_hook(hook)
runner.run([loader, loader], [('train', 1), ('val', 1)], 1)
runner.run([loader, loader], [('train', 1), ('val', 1)])
shutil.rmtree(runner.work_dir)
hook.wandb.init.assert_called_with()
......@@ -347,7 +347,9 @@ def test_wandb_hook():
hook.wandb.join.assert_called_with()
def _build_demo_runner():
def _build_demo_runner(runner_type='EpochBasedRunner',
max_epochs=1,
max_iters=None):
class Model(nn.Module):
......@@ -374,11 +376,15 @@ def _build_demo_runner():
])
tmp_dir = tempfile.mkdtemp()
runner = EpochBasedRunner(
model=model,
work_dir=tmp_dir,
optimizer=optimizer,
logger=logging.getLogger())
runner = build_runner(
dict(type=runner_type),
default_args=dict(
model=model,
work_dir=tmp_dir,
optimizer=optimizer,
logger=logging.getLogger(),
max_epochs=max_epochs,
max_iters=max_iters))
runner.register_checkpoint_hook(dict(interval=1))
runner.register_logger_hooks(log_config)
return runner
......@@ -11,7 +11,8 @@ import torch
import torch.nn as nn
from mmcv.parallel import MMDataParallel
from mmcv.runner import EpochBasedRunner
from mmcv.runner import (RUNNERS, EpochBasedRunner, IterBasedRunner,
build_runner)
class OldStyleModel(nn.Module):
......@@ -30,7 +31,29 @@ class Model(OldStyleModel):
pass
def test_epoch_based_runner():
def test_build_runner():
temp_root = tempfile.gettempdir()
dir_name = ''.join(
[random.choice(string.ascii_letters) for _ in range(10)])
default_args = dict(
model=Model(),
work_dir=osp.join(temp_root, dir_name),
logger=logging.getLogger())
cfg = dict(type='EpochBasedRunner', max_epochs=1)
runner = build_runner(cfg, default_args=default_args)
assert runner._max_epochs == 1
cfg = dict(type='IterBasedRunner', max_iters=1)
runner = build_runner(cfg, default_args=default_args)
assert runner._max_iters == 1
with pytest.raises(ValueError, match='Only one of'):
cfg = dict(type='IterBasedRunner', max_epochs=1, max_iters=1)
runner = build_runner(cfg, default_args=default_args)
@pytest.mark.parametrize('runner_class', RUNNERS.module_dict.values())
def test_epoch_based_runner(runner_class):
with pytest.warns(UserWarning):
# batch_processor is deprecated
......@@ -39,48 +62,46 @@ def test_epoch_based_runner():
def batch_processor():
pass
_ = EpochBasedRunner(
model, batch_processor, logger=logging.getLogger())
_ = runner_class(model, batch_processor, logger=logging.getLogger())
with pytest.raises(TypeError):
# batch_processor must be callable
model = OldStyleModel()
_ = EpochBasedRunner(
model, batch_processor=0, logger=logging.getLogger())
_ = runner_class(model, batch_processor=0, logger=logging.getLogger())
with pytest.raises(TypeError):
# optimizer must be a optimizer or a dict of optimizers
model = Model()
optimizer = 'NotAOptimizer'
_ = EpochBasedRunner(
_ = runner_class(
model, optimizer=optimizer, logger=logging.getLogger())
with pytest.raises(TypeError):
# optimizer must be a optimizer or a dict of optimizers
model = Model()
optimizers = dict(optim1=torch.optim.Adam(), optim2='NotAOptimizer')
_ = EpochBasedRunner(
_ = runner_class(
model, optimizer=optimizers, logger=logging.getLogger())
with pytest.raises(TypeError):
# logger must be a logging.Logger
model = Model()
_ = EpochBasedRunner(model, logger=None)
_ = runner_class(model, logger=None)
with pytest.raises(TypeError):
# meta must be a dict or None
model = Model()
_ = EpochBasedRunner(model, logger=logging.getLogger(), meta=['list'])
_ = runner_class(model, logger=logging.getLogger(), meta=['list'])
with pytest.raises(AssertionError):
# model must implement the method train_step()
model = OldStyleModel()
_ = EpochBasedRunner(model, logger=logging.getLogger())
_ = runner_class(model, logger=logging.getLogger())
with pytest.raises(TypeError):
# work_dir must be a str or None
model = Model()
_ = EpochBasedRunner(model, work_dir=1, logger=logging.getLogger())
_ = runner_class(model, work_dir=1, logger=logging.getLogger())
with pytest.raises(RuntimeError):
# batch_processor and train_step() cannot be both set
......@@ -89,8 +110,7 @@ def test_epoch_based_runner():
pass
model = Model()
_ = EpochBasedRunner(
model, batch_processor, logger=logging.getLogger())
_ = runner_class(model, batch_processor, logger=logging.getLogger())
# test work_dir
model = Model()
......@@ -98,23 +118,24 @@ def test_epoch_based_runner():
dir_name = ''.join(
[random.choice(string.ascii_letters) for _ in range(10)])
work_dir = osp.join(temp_root, dir_name)
_ = EpochBasedRunner(model, work_dir=work_dir, logger=logging.getLogger())
_ = runner_class(model, work_dir=work_dir, logger=logging.getLogger())
assert osp.isdir(work_dir)
_ = EpochBasedRunner(model, work_dir=work_dir, logger=logging.getLogger())
_ = runner_class(model, work_dir=work_dir, logger=logging.getLogger())
assert osp.isdir(work_dir)
os.removedirs(work_dir)
def test_runner_with_parallel():
@pytest.mark.parametrize('runner_class', RUNNERS.module_dict.values())
def test_runner_with_parallel(runner_class):
def batch_processor():
pass
model = MMDataParallel(OldStyleModel())
_ = EpochBasedRunner(model, batch_processor, logger=logging.getLogger())
_ = runner_class(model, batch_processor, logger=logging.getLogger())
model = MMDataParallel(Model())
_ = EpochBasedRunner(model, logger=logging.getLogger())
_ = runner_class(model, logger=logging.getLogger())
with pytest.raises(RuntimeError):
# batch_processor and train_step() cannot be both set
......@@ -123,13 +144,13 @@ def test_runner_with_parallel():
pass
model = MMDataParallel(Model())
_ = EpochBasedRunner(
model, batch_processor, logger=logging.getLogger())
_ = runner_class(model, batch_processor, logger=logging.getLogger())
def test_save_checkpoint():
@pytest.mark.parametrize('runner_class', RUNNERS.module_dict.values())
def test_save_checkpoint(runner_class):
model = Model()
runner = EpochBasedRunner(model=model, logger=logging.getLogger())
runner = runner_class(model=model, logger=logging.getLogger())
with pytest.raises(TypeError):
# meta should be None or dict
......@@ -139,18 +160,23 @@ def test_save_checkpoint():
runner.save_checkpoint(root)
latest_path = osp.join(root, 'latest.pth')
epoch1_path = osp.join(root, 'epoch_1.pth')
assert osp.exists(latest_path)
assert osp.exists(epoch1_path)
assert osp.realpath(latest_path) == osp.realpath(epoch1_path)
if isinstance(runner, EpochBasedRunner):
first_ckp_path = osp.join(root, 'epoch_1.pth')
elif isinstance(runner, IterBasedRunner):
first_ckp_path = osp.join(root, 'iter_1.pth')
assert osp.exists(first_ckp_path)
assert osp.realpath(latest_path) == osp.realpath(first_ckp_path)
torch.load(latest_path)
def test_build_lr_momentum_hook():
@pytest.mark.parametrize('runner_class', RUNNERS.module_dict.values())
def test_build_lr_momentum_hook(runner_class):
model = Model()
runner = EpochBasedRunner(model=model, logger=logging.getLogger())
runner = runner_class(model=model, logger=logging.getLogger())
# test policy that is already title
lr_config = dict(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment