Unverified Commit 6b52e9b5 authored by David de la Iglesia Castro's avatar David de la Iglesia Castro Committed by GitHub
Browse files

Add runner builder (#570)

* Add build_runner

* Parametrize test_runner

* Add imports to runner __init__

* Refactor max_iters and max_epochs from run to init

* Add assertion error messages

* Add test_builder

* Make change retro-compatible

* Raise ValueError if max_epochs and max_iters
parent 2bb1160e
# Copyright (c) Open-MMLab. All rights reserved. # Copyright (c) Open-MMLab. All rights reserved.
from .base_runner import BaseRunner from .base_runner import BaseRunner
from .builder import RUNNERS, build_runner
from .checkpoint import (_load_checkpoint, load_checkpoint, load_state_dict, from .checkpoint import (_load_checkpoint, load_checkpoint, load_state_dict,
save_checkpoint, weights_to_cpu) save_checkpoint, weights_to_cpu)
from .dist_utils import get_dist_info, init_dist, master_only from .dist_utils import get_dist_info, init_dist, master_only
...@@ -30,5 +31,6 @@ __all__ = [ ...@@ -30,5 +31,6 @@ __all__ = [
'OPTIMIZER_BUILDERS', 'OPTIMIZERS', 'DefaultOptimizerConstructor', 'OPTIMIZER_BUILDERS', 'OPTIMIZERS', 'DefaultOptimizerConstructor',
'build_optimizer', 'build_optimizer_constructor', 'IterLoader', 'build_optimizer', 'build_optimizer_constructor', 'IterLoader',
'set_random_seed', 'auto_fp16', 'force_fp32', 'wrap_fp16_model', 'set_random_seed', 'auto_fp16', 'force_fp32', 'wrap_fp16_model',
'Fp16OptimizerHook', 'SyncBuffersHook', 'EMAHook' 'Fp16OptimizerHook', 'SyncBuffersHook', 'EMAHook', 'build_runner',
'RUNNERS'
] ]
...@@ -43,6 +43,8 @@ class BaseRunner(metaclass=ABCMeta): ...@@ -43,6 +43,8 @@ class BaseRunner(metaclass=ABCMeta):
meta (dict | None): A dict records some import information such as meta (dict | None): A dict records some import information such as
environment info and seed, which will be logged in logger hook. environment info and seed, which will be logged in logger hook.
Defaults to None. Defaults to None.
max_epochs (int, optional): Total training epochs.
max_iters (int, optional): Total training iterations.
""" """
def __init__(self, def __init__(self,
...@@ -51,7 +53,9 @@ class BaseRunner(metaclass=ABCMeta): ...@@ -51,7 +53,9 @@ class BaseRunner(metaclass=ABCMeta):
optimizer=None, optimizer=None,
work_dir=None, work_dir=None,
logger=None, logger=None,
meta=None): meta=None,
max_iters=None,
max_epochs=None):
if batch_processor is not None: if batch_processor is not None:
if not callable(batch_processor): if not callable(batch_processor):
raise TypeError('batch_processor must be callable, ' raise TypeError('batch_processor must be callable, '
...@@ -121,8 +125,13 @@ class BaseRunner(metaclass=ABCMeta): ...@@ -121,8 +125,13 @@ class BaseRunner(metaclass=ABCMeta):
self._epoch = 0 self._epoch = 0
self._iter = 0 self._iter = 0
self._inner_iter = 0 self._inner_iter = 0
self._max_epochs = 0
self._max_iters = 0 if max_epochs is not None and max_iters is not None:
raise ValueError(
'Only one of `max_epochs` or `max_iters` can be set.')
self._max_epochs = max_epochs
self._max_iters = max_iters
# TODO: Redesign LogBuffer, it is not flexible and elegant enough # TODO: Redesign LogBuffer, it is not flexible and elegant enough
self.log_buffer = LogBuffer() self.log_buffer = LogBuffer()
......
from ..utils import Registry, build_from_cfg
RUNNERS = Registry('runner')
def build_runner(cfg, default_args=None):
return build_from_cfg(cfg, RUNNERS, default_args=default_args)
...@@ -9,10 +9,12 @@ import torch ...@@ -9,10 +9,12 @@ import torch
import mmcv import mmcv
from .base_runner import BaseRunner from .base_runner import BaseRunner
from .builder import RUNNERS
from .checkpoint import save_checkpoint from .checkpoint import save_checkpoint
from .utils import get_host_info from .utils import get_host_info
@RUNNERS.register_module()
class EpochBasedRunner(BaseRunner): class EpochBasedRunner(BaseRunner):
"""Epoch-based Runner. """Epoch-based Runner.
...@@ -67,7 +69,7 @@ class EpochBasedRunner(BaseRunner): ...@@ -67,7 +69,7 @@ class EpochBasedRunner(BaseRunner):
self.call_hook('after_val_epoch') self.call_hook('after_val_epoch')
def run(self, data_loaders, workflow, max_epochs, **kwargs): def run(self, data_loaders, workflow, max_epochs=None, **kwargs):
"""Start running. """Start running.
Args: Args:
...@@ -77,13 +79,19 @@ class EpochBasedRunner(BaseRunner): ...@@ -77,13 +79,19 @@ class EpochBasedRunner(BaseRunner):
running order and epochs. E.g, [('train', 2), ('val', 1)] means running order and epochs. E.g, [('train', 2), ('val', 1)] means
running 2 epochs for training and 1 epoch for validation, running 2 epochs for training and 1 epoch for validation,
iteratively. iteratively.
max_epochs (int): Total training epochs.
""" """
assert isinstance(data_loaders, list) assert isinstance(data_loaders, list)
assert mmcv.is_list_of(workflow, tuple) assert mmcv.is_list_of(workflow, tuple)
assert len(data_loaders) == len(workflow) assert len(data_loaders) == len(workflow)
if max_epochs is not None:
warnings.warn(
'setting max_epochs in run is deprecated, '
'please set max_epochs in runner_config', DeprecationWarning)
self._max_epochs = max_epochs
assert self._max_epochs is not None, (
'max_epochs must be specified during instantiation')
self._max_epochs = max_epochs
for i, flow in enumerate(workflow): for i, flow in enumerate(workflow):
mode, epochs = flow mode, epochs = flow
if mode == 'train': if mode == 'train':
...@@ -164,6 +172,7 @@ class EpochBasedRunner(BaseRunner): ...@@ -164,6 +172,7 @@ class EpochBasedRunner(BaseRunner):
shutil.copy(filename, dst_file) shutil.copy(filename, dst_file)
@RUNNERS.register_module()
class Runner(EpochBasedRunner): class Runner(EpochBasedRunner):
"""Deprecated name of EpochBasedRunner.""" """Deprecated name of EpochBasedRunner."""
......
...@@ -3,12 +3,14 @@ import os.path as osp ...@@ -3,12 +3,14 @@ import os.path as osp
import platform import platform
import shutil import shutil
import time import time
import warnings
import torch import torch
from torch.optim import Optimizer from torch.optim import Optimizer
import mmcv import mmcv
from .base_runner import BaseRunner from .base_runner import BaseRunner
from .builder import RUNNERS
from .checkpoint import save_checkpoint from .checkpoint import save_checkpoint
from .hooks import IterTimerHook from .hooks import IterTimerHook
from .utils import get_host_info from .utils import get_host_info
...@@ -41,6 +43,7 @@ class IterLoader: ...@@ -41,6 +43,7 @@ class IterLoader:
return len(self._dataloader) return len(self._dataloader)
@RUNNERS.register_module()
class IterBasedRunner(BaseRunner): class IterBasedRunner(BaseRunner):
"""Iteration-based Runner. """Iteration-based Runner.
...@@ -79,7 +82,7 @@ class IterBasedRunner(BaseRunner): ...@@ -79,7 +82,7 @@ class IterBasedRunner(BaseRunner):
self.call_hook('after_val_iter') self.call_hook('after_val_iter')
self._inner_iter += 1 self._inner_iter += 1
def run(self, data_loaders, workflow, max_iters, **kwargs): def run(self, data_loaders, workflow, max_iters=None, **kwargs):
"""Start running. """Start running.
Args: Args:
...@@ -89,24 +92,30 @@ class IterBasedRunner(BaseRunner): ...@@ -89,24 +92,30 @@ class IterBasedRunner(BaseRunner):
running order and iterations. E.g, [('train', 10000), running order and iterations. E.g, [('train', 10000),
('val', 1000)] means running 10000 iterations for training and ('val', 1000)] means running 10000 iterations for training and
1000 iterations for validation, iteratively. 1000 iterations for validation, iteratively.
max_iters (int): Total training iterations.
""" """
assert isinstance(data_loaders, list) assert isinstance(data_loaders, list)
assert mmcv.is_list_of(workflow, tuple) assert mmcv.is_list_of(workflow, tuple)
assert len(data_loaders) == len(workflow) assert len(data_loaders) == len(workflow)
if max_iters is not None:
warnings.warn(
'setting max_iters in run is deprecated, '
'please set max_iters in runner_config', DeprecationWarning)
self._max_iters = max_iters
assert self._max_iters is not None, (
'max_iters must be specified during instantiation')
self._max_iters = max_iters
work_dir = self.work_dir if self.work_dir is not None else 'NONE' work_dir = self.work_dir if self.work_dir is not None else 'NONE'
self.logger.info('Start running, host: %s, work_dir: %s', self.logger.info('Start running, host: %s, work_dir: %s',
get_host_info(), work_dir) get_host_info(), work_dir)
self.logger.info('workflow: %s, max: %d iters', workflow, max_iters) self.logger.info('workflow: %s, max: %d iters', workflow,
self._max_iters)
self.call_hook('before_run') self.call_hook('before_run')
iter_loaders = [IterLoader(x) for x in data_loaders] iter_loaders = [IterLoader(x) for x in data_loaders]
self.call_hook('before_epoch') self.call_hook('before_epoch')
while self.iter < max_iters: while self.iter < self._max_iters:
for i, flow in enumerate(workflow): for i, flow in enumerate(workflow):
self._inner_iter = 0 self._inner_iter = 0
mode, iters = flow mode, iters = flow
...@@ -116,7 +125,7 @@ class IterBasedRunner(BaseRunner): ...@@ -116,7 +125,7 @@ class IterBasedRunner(BaseRunner):
format(mode)) format(mode))
iter_runner = getattr(self, mode) iter_runner = getattr(self, mode)
for _ in range(iters): for _ in range(iters):
if mode == 'train' and self.iter >= max_iters: if mode == 'train' and self.iter >= self._max_iters:
break break
iter_runner(iter_loaders[i], **kwargs) iter_runner(iter_loaders[i], **kwargs)
......
...@@ -17,9 +17,9 @@ import torch.nn as nn ...@@ -17,9 +17,9 @@ import torch.nn as nn
from torch.nn.init import constant_ from torch.nn.init import constant_
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from mmcv.runner import (CheckpointHook, EMAHook, EpochBasedRunner, from mmcv.runner import (CheckpointHook, EMAHook, IterTimerHook,
IterTimerHook, MlflowLoggerHook, PaviLoggerHook, MlflowLoggerHook, PaviLoggerHook, WandbLoggerHook,
WandbLoggerHook) build_runner)
from mmcv.runner.hooks.lr_updater import CosineRestartLrUpdaterHook from mmcv.runner.hooks.lr_updater import CosineRestartLrUpdaterHook
...@@ -59,7 +59,7 @@ def test_ema_hook(): ...@@ -59,7 +59,7 @@ def test_ema_hook():
checkpointhook = CheckpointHook(interval=1, by_epoch=True) checkpointhook = CheckpointHook(interval=1, by_epoch=True)
runner.register_hook(emahook, priority='HIGHEST') runner.register_hook(emahook, priority='HIGHEST')
runner.register_hook(checkpointhook) runner.register_hook(checkpointhook)
runner.run([loader, loader], [('train', 1), ('val', 1)], 1) runner.run([loader, loader], [('train', 1), ('val', 1)])
checkpoint = torch.load(f'{runner.work_dir}/epoch_1.pth') checkpoint = torch.load(f'{runner.work_dir}/epoch_1.pth')
contain_ema_buffer = False contain_ema_buffer = False
for name, value in checkpoint['state_dict'].items(): for name, value in checkpoint['state_dict'].items():
...@@ -74,12 +74,12 @@ def test_ema_hook(): ...@@ -74,12 +74,12 @@ def test_ema_hook():
work_dir = runner.work_dir work_dir = runner.work_dir
resume_ema_hook = EMAHook( resume_ema_hook = EMAHook(
momentum=0.5, warm_up=0, resume_from=f'{work_dir}/epoch_1.pth') momentum=0.5, warm_up=0, resume_from=f'{work_dir}/epoch_1.pth')
runner = _build_demo_runner() runner = _build_demo_runner(max_epochs=2)
runner.model = demo_model runner.model = demo_model
runner.register_hook(resume_ema_hook, priority='HIGHEST') runner.register_hook(resume_ema_hook, priority='HIGHEST')
checkpointhook = CheckpointHook(interval=1, by_epoch=True) checkpointhook = CheckpointHook(interval=1, by_epoch=True)
runner.register_hook(checkpointhook) runner.register_hook(checkpointhook)
runner.run([loader, loader], [('train', 1), ('val', 1)], 2) runner.run([loader, loader], [('train', 1), ('val', 1)])
checkpoint = torch.load(f'{runner.work_dir}/epoch_2.pth') checkpoint = torch.load(f'{runner.work_dir}/epoch_2.pth')
contain_ema_buffer = False contain_ema_buffer = False
for name, value in checkpoint['state_dict'].items(): for name, value in checkpoint['state_dict'].items():
...@@ -101,7 +101,7 @@ def test_pavi_hook(): ...@@ -101,7 +101,7 @@ def test_pavi_hook():
runner.meta = dict(config_dict=dict(lr=0.02, gpu_ids=range(1))) runner.meta = dict(config_dict=dict(lr=0.02, gpu_ids=range(1)))
hook = PaviLoggerHook(add_graph=False, add_last_ckpt=True) hook = PaviLoggerHook(add_graph=False, add_last_ckpt=True)
runner.register_hook(hook) runner.register_hook(hook)
runner.run([loader, loader], [('train', 1), ('val', 1)], 1) runner.run([loader, loader], [('train', 1), ('val', 1)])
shutil.rmtree(runner.work_dir) shutil.rmtree(runner.work_dir)
assert hasattr(hook, 'writer') assert hasattr(hook, 'writer')
...@@ -119,7 +119,7 @@ def test_sync_buffers_hook(): ...@@ -119,7 +119,7 @@ def test_sync_buffers_hook():
loader = DataLoader(torch.ones((5, 2))) loader = DataLoader(torch.ones((5, 2)))
runner = _build_demo_runner() runner = _build_demo_runner()
runner.register_hook_from_cfg(dict(type='SyncBuffersHook')) runner.register_hook_from_cfg(dict(type='SyncBuffersHook'))
runner.run([loader, loader], [('train', 1), ('val', 1)], 1) runner.run([loader, loader], [('train', 1), ('val', 1)])
shutil.rmtree(runner.work_dir) shutil.rmtree(runner.work_dir)
...@@ -151,7 +151,7 @@ def test_momentum_runner_hook(): ...@@ -151,7 +151,7 @@ def test_momentum_runner_hook():
# add pavi hook # add pavi hook
hook = PaviLoggerHook(interval=1, add_graph=False, add_last_ckpt=True) hook = PaviLoggerHook(interval=1, add_graph=False, add_last_ckpt=True)
runner.register_hook(hook) runner.register_hook(hook)
runner.run([loader], [('train', 1)], 1) runner.run([loader], [('train', 1)])
shutil.rmtree(runner.work_dir) shutil.rmtree(runner.work_dir)
# TODO: use a more elegant way to check values # TODO: use a more elegant way to check values
...@@ -202,7 +202,7 @@ def test_cosine_runner_hook(): ...@@ -202,7 +202,7 @@ def test_cosine_runner_hook():
# add pavi hook # add pavi hook
hook = PaviLoggerHook(interval=1, add_graph=False, add_last_ckpt=True) hook = PaviLoggerHook(interval=1, add_graph=False, add_last_ckpt=True)
runner.register_hook(hook) runner.register_hook(hook)
runner.run([loader], [('train', 1)], 1) runner.run([loader], [('train', 1)])
shutil.rmtree(runner.work_dir) shutil.rmtree(runner.work_dir)
# TODO: use a more elegant way to check values # TODO: use a more elegant way to check values
...@@ -261,7 +261,7 @@ def test_cosine_restart_lr_update_hook(): ...@@ -261,7 +261,7 @@ def test_cosine_restart_lr_update_hook():
# add pavi hook # add pavi hook
hook = PaviLoggerHook(interval=1, add_graph=False, add_last_ckpt=True) hook = PaviLoggerHook(interval=1, add_graph=False, add_last_ckpt=True)
runner.register_hook(hook) runner.register_hook(hook)
runner.run([loader], [('train', 1)], 1) runner.run([loader], [('train', 1)])
shutil.rmtree(runner.work_dir) shutil.rmtree(runner.work_dir)
sys.modules['pavi'] = MagicMock() sys.modules['pavi'] = MagicMock()
...@@ -280,7 +280,7 @@ def test_cosine_restart_lr_update_hook(): ...@@ -280,7 +280,7 @@ def test_cosine_restart_lr_update_hook():
# add pavi hook # add pavi hook
hook = PaviLoggerHook(interval=1, add_graph=False, add_last_ckpt=True) hook = PaviLoggerHook(interval=1, add_graph=False, add_last_ckpt=True)
runner.register_hook(hook) runner.register_hook(hook)
runner.run([loader], [('train', 1)], 1) runner.run([loader], [('train', 1)])
shutil.rmtree(runner.work_dir) shutil.rmtree(runner.work_dir)
# TODO: use a more elegant way to check values # TODO: use a more elegant way to check values
...@@ -312,7 +312,7 @@ def test_mlflow_hook(log_model): ...@@ -312,7 +312,7 @@ def test_mlflow_hook(log_model):
hook = MlflowLoggerHook(exp_name='test', log_model=log_model) hook = MlflowLoggerHook(exp_name='test', log_model=log_model)
runner.register_hook(hook) runner.register_hook(hook)
runner.run([loader, loader], [('train', 1), ('val', 1)], 1) runner.run([loader, loader], [('train', 1), ('val', 1)])
shutil.rmtree(runner.work_dir) shutil.rmtree(runner.work_dir)
hook.mlflow.set_experiment.assert_called_with('test') hook.mlflow.set_experiment.assert_called_with('test')
...@@ -335,7 +335,7 @@ def test_wandb_hook(): ...@@ -335,7 +335,7 @@ def test_wandb_hook():
loader = DataLoader(torch.ones((5, 2))) loader = DataLoader(torch.ones((5, 2)))
runner.register_hook(hook) runner.register_hook(hook)
runner.run([loader, loader], [('train', 1), ('val', 1)], 1) runner.run([loader, loader], [('train', 1), ('val', 1)])
shutil.rmtree(runner.work_dir) shutil.rmtree(runner.work_dir)
hook.wandb.init.assert_called_with() hook.wandb.init.assert_called_with()
...@@ -347,7 +347,9 @@ def test_wandb_hook(): ...@@ -347,7 +347,9 @@ def test_wandb_hook():
hook.wandb.join.assert_called_with() hook.wandb.join.assert_called_with()
def _build_demo_runner(): def _build_demo_runner(runner_type='EpochBasedRunner',
max_epochs=1,
max_iters=None):
class Model(nn.Module): class Model(nn.Module):
...@@ -374,11 +376,15 @@ def _build_demo_runner(): ...@@ -374,11 +376,15 @@ def _build_demo_runner():
]) ])
tmp_dir = tempfile.mkdtemp() tmp_dir = tempfile.mkdtemp()
runner = EpochBasedRunner( runner = build_runner(
model=model, dict(type=runner_type),
work_dir=tmp_dir, default_args=dict(
optimizer=optimizer, model=model,
logger=logging.getLogger()) work_dir=tmp_dir,
optimizer=optimizer,
logger=logging.getLogger(),
max_epochs=max_epochs,
max_iters=max_iters))
runner.register_checkpoint_hook(dict(interval=1)) runner.register_checkpoint_hook(dict(interval=1))
runner.register_logger_hooks(log_config) runner.register_logger_hooks(log_config)
return runner return runner
...@@ -11,7 +11,8 @@ import torch ...@@ -11,7 +11,8 @@ import torch
import torch.nn as nn import torch.nn as nn
from mmcv.parallel import MMDataParallel from mmcv.parallel import MMDataParallel
from mmcv.runner import EpochBasedRunner from mmcv.runner import (RUNNERS, EpochBasedRunner, IterBasedRunner,
build_runner)
class OldStyleModel(nn.Module): class OldStyleModel(nn.Module):
...@@ -30,7 +31,29 @@ class Model(OldStyleModel): ...@@ -30,7 +31,29 @@ class Model(OldStyleModel):
pass pass
def test_epoch_based_runner(): def test_build_runner():
temp_root = tempfile.gettempdir()
dir_name = ''.join(
[random.choice(string.ascii_letters) for _ in range(10)])
default_args = dict(
model=Model(),
work_dir=osp.join(temp_root, dir_name),
logger=logging.getLogger())
cfg = dict(type='EpochBasedRunner', max_epochs=1)
runner = build_runner(cfg, default_args=default_args)
assert runner._max_epochs == 1
cfg = dict(type='IterBasedRunner', max_iters=1)
runner = build_runner(cfg, default_args=default_args)
assert runner._max_iters == 1
with pytest.raises(ValueError, match='Only one of'):
cfg = dict(type='IterBasedRunner', max_epochs=1, max_iters=1)
runner = build_runner(cfg, default_args=default_args)
@pytest.mark.parametrize('runner_class', RUNNERS.module_dict.values())
def test_epoch_based_runner(runner_class):
with pytest.warns(UserWarning): with pytest.warns(UserWarning):
# batch_processor is deprecated # batch_processor is deprecated
...@@ -39,48 +62,46 @@ def test_epoch_based_runner(): ...@@ -39,48 +62,46 @@ def test_epoch_based_runner():
def batch_processor(): def batch_processor():
pass pass
_ = EpochBasedRunner( _ = runner_class(model, batch_processor, logger=logging.getLogger())
model, batch_processor, logger=logging.getLogger())
with pytest.raises(TypeError): with pytest.raises(TypeError):
# batch_processor must be callable # batch_processor must be callable
model = OldStyleModel() model = OldStyleModel()
_ = EpochBasedRunner( _ = runner_class(model, batch_processor=0, logger=logging.getLogger())
model, batch_processor=0, logger=logging.getLogger())
with pytest.raises(TypeError): with pytest.raises(TypeError):
# optimizer must be a optimizer or a dict of optimizers # optimizer must be a optimizer or a dict of optimizers
model = Model() model = Model()
optimizer = 'NotAOptimizer' optimizer = 'NotAOptimizer'
_ = EpochBasedRunner( _ = runner_class(
model, optimizer=optimizer, logger=logging.getLogger()) model, optimizer=optimizer, logger=logging.getLogger())
with pytest.raises(TypeError): with pytest.raises(TypeError):
# optimizer must be a optimizer or a dict of optimizers # optimizer must be a optimizer or a dict of optimizers
model = Model() model = Model()
optimizers = dict(optim1=torch.optim.Adam(), optim2='NotAOptimizer') optimizers = dict(optim1=torch.optim.Adam(), optim2='NotAOptimizer')
_ = EpochBasedRunner( _ = runner_class(
model, optimizer=optimizers, logger=logging.getLogger()) model, optimizer=optimizers, logger=logging.getLogger())
with pytest.raises(TypeError): with pytest.raises(TypeError):
# logger must be a logging.Logger # logger must be a logging.Logger
model = Model() model = Model()
_ = EpochBasedRunner(model, logger=None) _ = runner_class(model, logger=None)
with pytest.raises(TypeError): with pytest.raises(TypeError):
# meta must be a dict or None # meta must be a dict or None
model = Model() model = Model()
_ = EpochBasedRunner(model, logger=logging.getLogger(), meta=['list']) _ = runner_class(model, logger=logging.getLogger(), meta=['list'])
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
# model must implement the method train_step() # model must implement the method train_step()
model = OldStyleModel() model = OldStyleModel()
_ = EpochBasedRunner(model, logger=logging.getLogger()) _ = runner_class(model, logger=logging.getLogger())
with pytest.raises(TypeError): with pytest.raises(TypeError):
# work_dir must be a str or None # work_dir must be a str or None
model = Model() model = Model()
_ = EpochBasedRunner(model, work_dir=1, logger=logging.getLogger()) _ = runner_class(model, work_dir=1, logger=logging.getLogger())
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
# batch_processor and train_step() cannot be both set # batch_processor and train_step() cannot be both set
...@@ -89,8 +110,7 @@ def test_epoch_based_runner(): ...@@ -89,8 +110,7 @@ def test_epoch_based_runner():
pass pass
model = Model() model = Model()
_ = EpochBasedRunner( _ = runner_class(model, batch_processor, logger=logging.getLogger())
model, batch_processor, logger=logging.getLogger())
# test work_dir # test work_dir
model = Model() model = Model()
...@@ -98,23 +118,24 @@ def test_epoch_based_runner(): ...@@ -98,23 +118,24 @@ def test_epoch_based_runner():
dir_name = ''.join( dir_name = ''.join(
[random.choice(string.ascii_letters) for _ in range(10)]) [random.choice(string.ascii_letters) for _ in range(10)])
work_dir = osp.join(temp_root, dir_name) work_dir = osp.join(temp_root, dir_name)
_ = EpochBasedRunner(model, work_dir=work_dir, logger=logging.getLogger()) _ = runner_class(model, work_dir=work_dir, logger=logging.getLogger())
assert osp.isdir(work_dir) assert osp.isdir(work_dir)
_ = EpochBasedRunner(model, work_dir=work_dir, logger=logging.getLogger()) _ = runner_class(model, work_dir=work_dir, logger=logging.getLogger())
assert osp.isdir(work_dir) assert osp.isdir(work_dir)
os.removedirs(work_dir) os.removedirs(work_dir)
def test_runner_with_parallel(): @pytest.mark.parametrize('runner_class', RUNNERS.module_dict.values())
def test_runner_with_parallel(runner_class):
def batch_processor(): def batch_processor():
pass pass
model = MMDataParallel(OldStyleModel()) model = MMDataParallel(OldStyleModel())
_ = EpochBasedRunner(model, batch_processor, logger=logging.getLogger()) _ = runner_class(model, batch_processor, logger=logging.getLogger())
model = MMDataParallel(Model()) model = MMDataParallel(Model())
_ = EpochBasedRunner(model, logger=logging.getLogger()) _ = runner_class(model, logger=logging.getLogger())
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
# batch_processor and train_step() cannot be both set # batch_processor and train_step() cannot be both set
...@@ -123,13 +144,13 @@ def test_runner_with_parallel(): ...@@ -123,13 +144,13 @@ def test_runner_with_parallel():
pass pass
model = MMDataParallel(Model()) model = MMDataParallel(Model())
_ = EpochBasedRunner( _ = runner_class(model, batch_processor, logger=logging.getLogger())
model, batch_processor, logger=logging.getLogger())
def test_save_checkpoint(): @pytest.mark.parametrize('runner_class', RUNNERS.module_dict.values())
def test_save_checkpoint(runner_class):
model = Model() model = Model()
runner = EpochBasedRunner(model=model, logger=logging.getLogger()) runner = runner_class(model=model, logger=logging.getLogger())
with pytest.raises(TypeError): with pytest.raises(TypeError):
# meta should be None or dict # meta should be None or dict
...@@ -139,18 +160,23 @@ def test_save_checkpoint(): ...@@ -139,18 +160,23 @@ def test_save_checkpoint():
runner.save_checkpoint(root) runner.save_checkpoint(root)
latest_path = osp.join(root, 'latest.pth') latest_path = osp.join(root, 'latest.pth')
epoch1_path = osp.join(root, 'epoch_1.pth')
assert osp.exists(latest_path) assert osp.exists(latest_path)
assert osp.exists(epoch1_path)
assert osp.realpath(latest_path) == osp.realpath(epoch1_path) if isinstance(runner, EpochBasedRunner):
first_ckp_path = osp.join(root, 'epoch_1.pth')
elif isinstance(runner, IterBasedRunner):
first_ckp_path = osp.join(root, 'iter_1.pth')
assert osp.exists(first_ckp_path)
assert osp.realpath(latest_path) == osp.realpath(first_ckp_path)
torch.load(latest_path) torch.load(latest_path)
def test_build_lr_momentum_hook(): @pytest.mark.parametrize('runner_class', RUNNERS.module_dict.values())
def test_build_lr_momentum_hook(runner_class):
model = Model() model = Model()
runner = EpochBasedRunner(model=model, logger=logging.getLogger()) runner = runner_class(model=model, logger=logging.getLogger())
# test policy that is already title # test policy that is already title
lr_config = dict( lr_config = dict(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment