Unverified Commit 35ba1528 authored by Kai Chen's avatar Kai Chen Committed by GitHub
Browse files

Add a BaseRunner and rename Runner to EpochBasedRunner (#290)

* add a BaseRunner and rename Runner to EpochBasedRunner

* fix the train/val step

* bug fix

* update unit tests

* fix unit tests

* raise an error if both batch_processor and train_step are set

* add a unit test
parent 6f21d8b5
# Copyright (c) Open-MMLab. All rights reserved.
from .base_runner import BaseRunner
from .checkpoint import (_load_checkpoint, load_checkpoint, load_state_dict,
save_checkpoint, weights_to_cpu)
from .dist_utils import get_dist_info, init_dist, master_only
from .epoch_based_runner import EpochBasedRunner, Runner
from .hooks import (HOOKS, CheckpointHook, ClosureHook, DistSamplerSeedHook,
Hook, IterTimerHook, LoggerHook, LrUpdaterHook,
OptimizerHook, PaviLoggerHook, TensorboardLoggerHook,
TextLoggerHook, WandbLoggerHook)
MlflowLoggerHook, OptimizerHook, PaviLoggerHook,
TensorboardLoggerHook, TextLoggerHook, WandbLoggerHook)
from .log_buffer import LogBuffer
from .optimizer import (OPTIMIZER_BUILDERS, OPTIMIZERS,
DefaultOptimizerConstructor, build_optimizer,
build_optimizer_constructor)
from .priority import Priority, get_priority
from .runner import Runner
from .utils import get_host_info, get_time_str, obj_from_dict
__all__ = [
'Runner', 'LogBuffer', 'HOOKS', 'Hook', 'CheckpointHook', 'ClosureHook',
'LrUpdaterHook', 'OptimizerHook', 'IterTimerHook', 'DistSamplerSeedHook',
'LoggerHook', 'PaviLoggerHook', 'TextLoggerHook', 'TensorboardLoggerHook',
'WandbLoggerHook', '_load_checkpoint', 'load_state_dict',
'BaseRunner', 'Runner', 'EpochBasedRunner', 'LogBuffer', 'HOOKS', 'Hook',
'CheckpointHook', 'ClosureHook', 'LrUpdaterHook', 'OptimizerHook',
'IterTimerHook', 'DistSamplerSeedHook', 'LoggerHook', 'PaviLoggerHook',
'TextLoggerHook', 'TensorboardLoggerHook', 'WandbLoggerHook',
'MlflowLoggerHook', '_load_checkpoint', 'load_state_dict',
'load_checkpoint', 'weights_to_cpu', 'save_checkpoint', 'Priority',
'get_priority', 'get_host_info', 'get_time_str', 'obj_from_dict',
'init_dist', 'get_dist_info', 'master_only', 'OPTIMIZER_BUILDERS',
......
# Copyright (c) Open-MMLab. All rights reserved.
import logging
import os.path as osp
import time
import warnings
from abc import ABCMeta, abstractmethod
import torch
import mmcv
from .checkpoint import load_checkpoint, save_checkpoint
from .checkpoint import load_checkpoint
from .dist_utils import get_dist_info
from .hooks import HOOKS, Hook, IterTimerHook
from .log_buffer import LogBuffer
from .priority import get_priority
from .utils import get_host_info, get_time_str, obj_from_dict
from .utils import get_time_str
class Runner(object):
"""A training helper for PyTorch.
class BaseRunner(metaclass=ABCMeta):
"""The base class of Runner, a training helper for PyTorch.
All subclasses should implement the following APIs:
- ``run()``
- ``train()``
- ``val()``
- ``save_checkpoint()``
Args:
model (:obj:`torch.nn.Module`): The model to be run.
......@@ -25,29 +33,38 @@ class Runner(object):
optimizer (dict or :obj:`torch.optim.Optimizer`): If it is a dict,
runner will construct an optimizer according to it.
work_dir (str, optional): The working directory to save checkpoints
and logs.
log_level (int): Logging level.
logger (:obj:`logging.Logger`): Custom logger. If `None`, use the
default logger.
and logs. Defaults to None.
logger (:obj:`logging.Logger`): Logger used during training.
Defaults to None.
meta (dict | None): A dict records some import information such as
environment info and seed, which will be logged in logger hook.
Defaults to None.
"""
def __init__(self,
model,
batch_processor,
batch_processor=None,
optimizer=None,
work_dir=None,
log_level=logging.INFO,
logger=None,
meta=None):
assert callable(batch_processor)
self.model = model
if optimizer is not None:
self.optimizer = self.init_optimizer(optimizer)
if batch_processor is not None:
if not callable(batch_processor):
raise TypeError('batch_processor must be callable, '
f'but got {type(batch_processor)}')
warnings.warn('batch_processor is deprecated, please implement '
'train_step() and val_step() in the model instead.')
# raise an error is `batch_processor` is not None and
# `model.train_step()` exists.
if hasattr(model, 'train_step') or hasattr(model, 'val_step'):
raise RuntimeError(
'batch_processor and model.train_step()/model.val_step() '
'cannot be both available.')
else:
self.optimizer = None
assert hasattr(model, 'train_step')
self.model = model
self.batch_processor = batch_processor
self.optimizer = optimizer
# create work_dir
if mmcv.is_str(work_dir):
......@@ -64,18 +81,15 @@ class Runner(object):
else:
self._model_name = self.model.__class__.__name__
self._rank, self._world_size = get_dist_info()
self.timestamp = get_time_str()
if logger is None:
self.logger = self.init_logger(work_dir, log_level)
else:
self.logger = logger
self.log_buffer = LogBuffer()
assert logging is not None
self.logger = logger
if meta is not None:
assert isinstance(meta, dict), '"meta" must be a dict or None'
self.meta = meta
self._rank, self._world_size = get_dist_info()
self.timestamp = get_time_str()
self.mode = None
self._hooks = []
self._epoch = 0
......@@ -83,6 +97,8 @@ class Runner(object):
self._inner_iter = 0
self._max_epochs = 0
self._max_iters = 0
# TODO: Redesign LogBuffer, it is not flexible and elegant enough
self.log_buffer = LogBuffer()
@property
def model_name(self):
......@@ -130,62 +146,26 @@ class Runner(object):
"""int: Maximum training iterations."""
return self._max_iters
def init_optimizer(self, optimizer):
"""Init the optimizer.
@abstractmethod
def train(self):
pass
Args:
optimizer (dict or :obj:`~torch.optim.Optimizer`): Either an
optimizer object or a dict used for constructing the optimizer.
Returns:
:obj:`~torch.optim.Optimizer`: An optimizer object.
Examples:
>>> optimizer = dict(type='SGD', lr=0.01, momentum=0.9)
>>> type(runner.init_optimizer(optimizer))
<class 'torch.optim.sgd.SGD'>
"""
if isinstance(optimizer, dict):
optimizer = obj_from_dict(optimizer, torch.optim,
dict(params=self.model.parameters()))
elif not isinstance(optimizer, torch.optim.Optimizer):
raise TypeError(
'optimizer must be either an Optimizer object or a dict, '
f'but got {type(optimizer)}')
return optimizer
def _add_file_handler(self,
logger,
filename=None,
mode='w',
level=logging.INFO):
# TODO: move this method out of runner
file_handler = logging.FileHandler(filename, mode)
file_handler.setFormatter(
logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))
file_handler.setLevel(level)
logger.addHandler(file_handler)
return logger
def init_logger(self, log_dir=None, level=logging.INFO):
"""Init the logger.
@abstractmethod
def val(self):
pass
Args:
log_dir(str, optional): Log file directory. If not specified, no
log file will be used.
level (int or str): See the built-in python logging module.
@abstractmethod
def run(self, data_loaders, workflow, **kwargs):
pass
Returns:
:obj:`~logging.Logger`: Python logger.
"""
logging.basicConfig(
format='%(asctime)s - %(levelname)s - %(message)s', level=level)
logger = logging.getLogger(__name__)
if log_dir and self.rank == 0:
filename = f'{self.timestamp}.log'
log_file = osp.join(log_dir, filename)
self._add_file_handler(logger, log_file, level=level)
return logger
@abstractmethod
def save_checkpoint(self,
out_dir,
filename_tmpl,
save_optimizer=True,
meta=None,
create_symlink=True):
pass
def current_lr(self):
"""Get current learning rates.
......@@ -220,6 +200,11 @@ class Runner(object):
def register_hook(self, hook, priority='NORMAL'):
"""Register a hook into the hook list.
The hook will be inserted into a priority queue, with the specified
priority (See :cls:`Priority` for details of priorities).
For hooks with the same priority, they will be triggered in the same
order as they are registered.
Args:
hook (:obj:`Hook`): The hook to be registered.
priority (int or str or :obj:`Priority`): Hook priority.
......@@ -241,6 +226,12 @@ class Runner(object):
self._hooks.insert(0, hook)
def call_hook(self, fn_name):
"""Call all hooks.
Args:
fn_name (str): The function name in each hook to be called, such as
"before_train_epoch".
"""
for hook in self._hooks:
getattr(hook, fn_name)(self)
......@@ -249,72 +240,6 @@ class Runner(object):
return load_checkpoint(self.model, filename, map_location, strict,
self.logger)
def save_checkpoint(self,
out_dir,
filename_tmpl='epoch_{}.pth',
save_optimizer=True,
meta=None,
create_symlink=True):
if meta is None:
meta = dict(epoch=self.epoch + 1, iter=self.iter)
else:
meta.update(epoch=self.epoch + 1, iter=self.iter)
filename = filename_tmpl.format(self.epoch + 1)
filepath = osp.join(out_dir, filename)
optimizer = self.optimizer if save_optimizer else None
save_checkpoint(self.model, filepath, optimizer=optimizer, meta=meta)
# in some environments, `os.symlink` is not supported, you may need to
# set `create_symlink` to False
if create_symlink:
mmcv.symlink(filename, osp.join(out_dir, 'latest.pth'))
def train(self, data_loader, **kwargs):
self.model.train()
self.mode = 'train'
self.data_loader = data_loader
self.call_hook('before_train_epoch')
time.sleep(2) # Prevent possible deadlock during epoch transition
for i, data_batch in enumerate(data_loader):
self._inner_iter = i
self.call_hook('before_train_iter')
outputs = self.batch_processor(
self.model, data_batch, train_mode=True, **kwargs)
if not isinstance(outputs, dict):
raise TypeError('batch_processor() must return a dict')
if 'log_vars' in outputs:
self.log_buffer.update(outputs['log_vars'],
outputs['num_samples'])
self.outputs = outputs
self.call_hook('after_train_iter')
self._iter += 1
self.call_hook('after_train_epoch')
self._epoch += 1
def val(self, data_loader, **kwargs):
self.model.eval()
self.mode = 'val'
self.data_loader = data_loader
self.call_hook('before_val_epoch')
time.sleep(2) # Prevent possible deadlock during epoch transition
for i, data_batch in enumerate(data_loader):
self._inner_iter = i
self.call_hook('before_val_iter')
with torch.no_grad():
outputs = self.batch_processor(
self.model, data_batch, train_mode=False, **kwargs)
if not isinstance(outputs, dict):
raise TypeError('batch_processor() must return a dict')
if 'log_vars' in outputs:
self.log_buffer.update(outputs['log_vars'],
outputs['num_samples'])
self.outputs = outputs
self.call_hook('after_val_iter')
self.call_hook('after_val_epoch')
def resume(self,
checkpoint,
resume_optimizer=True,
......@@ -335,57 +260,6 @@ class Runner(object):
self.logger.info('resumed epoch %d, iter %d', self.epoch, self.iter)
def run(self, data_loaders, workflow, max_epochs, **kwargs):
"""Start running.
Args:
data_loaders (list[:obj:`DataLoader`]): Dataloaders for training
and validation.
workflow (list[tuple]): A list of (phase, epochs) to specify the
running order and epochs. E.g, [('train', 2), ('val', 1)] means
running 2 epochs for training and 1 epoch for validation,
iteratively.
max_epochs (int): Total training epochs.
"""
assert isinstance(data_loaders, list)
assert mmcv.is_list_of(workflow, tuple)
assert len(data_loaders) == len(workflow)
self._max_epochs = max_epochs
for i, flow in enumerate(workflow):
mode, epochs = flow
if mode == 'train':
self._max_iters = self._max_epochs * len(data_loaders[i])
break
work_dir = self.work_dir if self.work_dir is not None else 'NONE'
self.logger.info('Start running, host: %s, work_dir: %s',
get_host_info(), work_dir)
self.logger.info('workflow: %s, max: %d epochs', workflow, max_epochs)
self.call_hook('before_run')
while self.epoch < max_epochs:
for i, flow in enumerate(workflow):
mode, epochs = flow
if isinstance(mode, str): # self.train()
if not hasattr(self, mode):
raise ValueError(
f'runner has no method named "{mode}" to run an '
'epoch')
epoch_runner = getattr(self, mode)
elif callable(mode): # custom train()
epoch_runner = mode
else:
raise TypeError('mode in workflow must be a str or '
f'callable function, not {type(mode)}')
for _ in range(epochs):
if mode == 'train' and self.epoch >= max_epochs:
return
epoch_runner(data_loaders[i], **kwargs)
time.sleep(1) # wait for some hooks like loggers to finish
self.call_hook('after_run')
def register_lr_hook(self, lr_config):
if isinstance(lr_config, dict):
assert 'policy' in lr_config
......@@ -404,26 +278,6 @@ class Runner(object):
hook = lr_config
self.register_hook(hook)
def register_optimizer_hook(self, optimizer_config):
if optimizer_config is None:
return
if isinstance(optimizer_config, dict):
optimizer_config.setdefault('type', 'OptimizerHook')
hook = mmcv.build_from_cfg(optimizer_config, HOOKS)
else:
hook = optimizer_config
self.register_hook(hook)
def register_checkpoint_hook(self, checkpoint_config):
if checkpoint_config is None:
return
if isinstance(checkpoint_config, dict):
checkpoint_config.setdefault('type', 'CheckpointHook')
hook = mmcv.build_from_cfg(checkpoint_config, HOOKS)
else:
hook = checkpoint_config
self.register_hook(hook)
def register_momentum_hook(self, momentum_config):
if momentum_config is None:
return
......@@ -444,6 +298,26 @@ class Runner(object):
hook = momentum_config
self.register_hook(hook)
def register_optimizer_hook(self, optimizer_config):
if optimizer_config is None:
return
if isinstance(optimizer_config, dict):
optimizer_config.setdefault('type', 'OptimizerHook')
hook = mmcv.build_from_cfg(optimizer_config, HOOKS)
else:
hook = optimizer_config
self.register_hook(hook)
def register_checkpoint_hook(self, checkpoint_config):
if checkpoint_config is None:
return
if isinstance(checkpoint_config, dict):
checkpoint_config.setdefault('type', 'CheckpointHook')
hook = mmcv.build_from_cfg(checkpoint_config, HOOKS)
else:
hook = checkpoint_config
self.register_hook(hook)
def register_logger_hooks(self, log_config):
log_interval = log_config['interval']
for info in log_config['hooks']:
......
# Copyright (c) Open-MMLab. All rights reserved.
import os.path as osp
import time
import warnings
import torch
import mmcv
from .base_runner import BaseRunner
from .checkpoint import save_checkpoint
from .utils import get_host_info
class EpochBasedRunner(BaseRunner):
"""Epoch-based Runner.
This runner train models epoch by epoch.
"""
def train(self, data_loader, **kwargs):
self.model.train()
self.mode = 'train'
self.data_loader = data_loader
self._max_iters = self._max_epochs * len(data_loader)
self.call_hook('before_train_epoch')
time.sleep(2) # Prevent possible deadlock during epoch transition
for i, data_batch in enumerate(data_loader):
self._inner_iter = i
self.call_hook('before_train_iter')
if self.batch_processor is None:
outputs = self.model.train_step(data_batch, self.optimizer,
**kwargs)
else:
outputs = self.batch_processor(
self.model, data_batch, train_mode=True, **kwargs)
if not isinstance(outputs, dict):
raise TypeError('"batch_processor()" or "model.train_step()"'
' must return a dict')
if 'log_vars' in outputs:
self.log_buffer.update(outputs['log_vars'],
outputs['num_samples'])
self.outputs = outputs
self.call_hook('after_train_iter')
self._iter += 1
self.call_hook('after_train_epoch')
self._epoch += 1
def val(self, data_loader, **kwargs):
self.model.eval()
self.mode = 'val'
self.data_loader = data_loader
self.call_hook('before_val_epoch')
time.sleep(2) # Prevent possible deadlock during epoch transition
for i, data_batch in enumerate(data_loader):
self._inner_iter = i
self.call_hook('before_val_iter')
with torch.no_grad():
if self.batch_processor is None:
outputs = self.model.val_step(data_batch, self.optimizer,
**kwargs)
else:
outputs = self.batch_processor(
self.model, data_batch, train_mode=False, **kwargs)
if not isinstance(outputs, dict):
raise TypeError('"batch_processor()" or "model.val_step()"'
' must return a dict')
if 'log_vars' in outputs:
self.log_buffer.update(outputs['log_vars'],
outputs['num_samples'])
self.outputs = outputs
self.call_hook('after_val_iter')
self.call_hook('after_val_epoch')
def run(self, data_loaders, workflow, max_epochs, **kwargs):
"""Start running.
Args:
data_loaders (list[:obj:`DataLoader`]): Dataloaders for training
and validation.
workflow (list[tuple]): A list of (phase, epochs) to specify the
running order and epochs. E.g, [('train', 2), ('val', 1)] means
running 2 epochs for training and 1 epoch for validation,
iteratively.
max_epochs (int): Total training epochs.
"""
assert isinstance(data_loaders, list)
assert mmcv.is_list_of(workflow, tuple)
assert len(data_loaders) == len(workflow)
self._max_epochs = max_epochs
for i, flow in enumerate(workflow):
mode, epochs = flow
if mode == 'train':
self._max_iters = self._max_epochs * len(data_loaders[i])
break
work_dir = self.work_dir if self.work_dir is not None else 'NONE'
self.logger.info('Start running, host: %s, work_dir: %s',
get_host_info(), work_dir)
self.logger.info('workflow: %s, max: %d epochs', workflow, max_epochs)
self.call_hook('before_run')
while self.epoch < max_epochs:
for i, flow in enumerate(workflow):
mode, epochs = flow
if isinstance(mode, str): # self.train()
if not hasattr(self, mode):
raise ValueError(
f'runner has no method named "{mode}" to run an '
'epoch')
epoch_runner = getattr(self, mode)
else:
raise TypeError(
'mode in workflow must be a str, but got {}'.format(
type(mode)))
for _ in range(epochs):
if mode == 'train' and self.epoch >= max_epochs:
return
epoch_runner(data_loaders[i], **kwargs)
time.sleep(1) # wait for some hooks like loggers to finish
self.call_hook('after_run')
def save_checkpoint(self,
out_dir,
filename_tmpl='epoch_{}.pth',
save_optimizer=True,
meta=None,
create_symlink=True):
"""Save the checkpoint.
Args:
out_dir (str): The directory that checkpoints are saved.
filename_tmpl (str, optional): The checkpoint filename template,
which contains a placeholder for the epoch number.
Defaults to 'epoch_{}.pth'.
save_optimizer (bool, optional): Whether to save the optimizer to
the checkpoint. Defaults to True.
meta (dict, optional): The meta information to be saved in the
checkpoint. Defaults to None.
create_symlink (bool, optional): Whether to create a symlink
"latest.pth" to point to the latest checkpoint.
Defaults to True.
"""
if meta is None:
meta = dict(epoch=self.epoch + 1, iter=self.iter)
else:
meta.update(epoch=self.epoch + 1, iter=self.iter)
filename = filename_tmpl.format(self.epoch + 1)
filepath = osp.join(out_dir, filename)
optimizer = self.optimizer if save_optimizer else None
save_checkpoint(self.model, filepath, optimizer=optimizer, meta=meta)
# in some environments, `os.symlink` is not supported, you may need to
# set `create_symlink` to False
if create_symlink:
mmcv.symlink(filename, osp.join(out_dir, 'latest.pth'))
class Runner(EpochBasedRunner):
"""Deprecated name of EpochBasedRunner"""
def __init__(self, *args, **kwargs):
warnings.warn(
'Runner was deprecated, please use DeprecationWarning instead')
super().__init__(*args, **kwargs)
# Copyright (c) Open-MMLab. All rights reserved.
import numbers
from mmcv.runner import master_only
from ...dist_utils import master_only
from ..hook import HOOKS
from .base import LoggerHook
......
......@@ -5,7 +5,7 @@ import os.path as osp
import numpy as np
import torch
from mmcv.runner import master_only
from ...dist_utils import master_only
from ..hook import HOOKS
from .base import LoggerHook
......
......@@ -3,7 +3,7 @@ import os.path as osp
import torch
from mmcv.runner import master_only
from ...dist_utils import master_only
from ..hook import HOOKS
from .base import LoggerHook
......
# Copyright (c) Open-MMLab. All rights reserved.
import numbers
from mmcv.runner import master_only
from ...dist_utils import master_only
from ..hook import HOOKS
from .base import LoggerHook
......
# Copyright (c) Open-MMLab. All rights reserved.
from __future__ import division
from math import cos, pi
from .hook import HOOKS, Hook
......
......@@ -18,7 +18,12 @@ import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import mmcv.runner
from mmcv.runner import (EpochBasedRunner, IterTimerHook, MlflowLoggerHook,
PaviLoggerHook, WandbLoggerHook)
from mmcv.runner.hooks.lr_updater import (CosineAnealingLrUpdaterHook,
CyclicLrUpdaterHook)
from mmcv.runner.hooks.momentum_updater import (
CosineAnealingMomentumUpdaterHook, CyclicMomentumUpdaterHook)
def test_pavi_hook():
......@@ -26,8 +31,7 @@ def test_pavi_hook():
loader = DataLoader(torch.ones((5, 2)))
runner = _build_demo_runner()
hook = mmcv.runner.hooks.PaviLoggerHook(
add_graph=False, add_last_ckpt=True)
hook = PaviLoggerHook(add_graph=False, add_last_ckpt=True)
runner.register_hook(hook)
runner.run([loader, loader], [('train', 1), ('val', 1)], 1)
shutil.rmtree(runner.work_dir)
......@@ -52,7 +56,7 @@ def test_momentum_runner_hook():
runner = _build_demo_runner()
# add momentum scheduler
hook = mmcv.runner.hooks.momentum_updater.CyclicMomentumUpdaterHook(
hook = CyclicMomentumUpdaterHook(
by_epoch=False,
target_ratio=(0.85 / 0.95, 1),
cyclic_times=1,
......@@ -60,17 +64,16 @@ def test_momentum_runner_hook():
runner.register_hook(hook)
# add momentum LR scheduler
hook = mmcv.runner.hooks.lr_updater.CyclicLrUpdaterHook(
hook = CyclicLrUpdaterHook(
by_epoch=False,
target_ratio=(10, 1),
cyclic_times=1,
step_ratio_up=0.4)
runner.register_hook(hook)
runner.register_hook(mmcv.runner.hooks.IterTimerHook())
runner.register_hook(IterTimerHook())
# add pavi hook
hook = mmcv.runner.hooks.PaviLoggerHook(
interval=1, add_graph=False, add_last_ckpt=True)
hook = PaviLoggerHook(interval=1, add_graph=False, add_last_ckpt=True)
runner.register_hook(hook)
runner.run([loader], [('train', 1)], 1)
shutil.rmtree(runner.work_dir)
......@@ -103,23 +106,21 @@ def test_cosine_runner_hook():
runner = _build_demo_runner()
# add momentum scheduler
hook = mmcv.runner.hooks.momentum_updater \
.CosineAnealingMomentumUpdaterHook(
min_momentum_ratio=0.99 / 0.95,
by_epoch=False,
warmup_iters=2,
warmup_ratio=0.9 / 0.95)
hook = CosineAnealingMomentumUpdaterHook(
min_momentum_ratio=0.99 / 0.95,
by_epoch=False,
warmup_iters=2,
warmup_ratio=0.9 / 0.95)
runner.register_hook(hook)
# add momentum LR scheduler
hook = mmcv.runner.hooks.lr_updater.CosineAnealingLrUpdaterHook(
hook = CosineAnealingLrUpdaterHook(
by_epoch=False, min_lr_ratio=0, warmup_iters=2, warmup_ratio=0.9)
runner.register_hook(hook)
runner.register_hook(mmcv.runner.hooks.IterTimerHook())
runner.register_hook(IterTimerHook())
# add pavi hook
hook = mmcv.runner.hooks.PaviLoggerHook(
interval=1, add_graph=False, add_last_ckpt=True)
hook = PaviLoggerHook(interval=1, add_graph=False, add_last_ckpt=True)
runner.register_hook(hook)
runner.run([loader], [('train', 1)], 1)
shutil.rmtree(runner.work_dir)
......@@ -151,8 +152,7 @@ def test_mlflow_hook(log_model):
runner = _build_demo_runner()
loader = DataLoader(torch.ones((5, 2)))
hook = mmcv.runner.hooks.MlflowLoggerHook(
exp_name='test', log_model=log_model)
hook = MlflowLoggerHook(exp_name='test', log_model=log_model)
runner.register_hook(hook)
runner.run([loader, loader], [('train', 1), ('val', 1)], 1)
shutil.rmtree(runner.work_dir)
......@@ -173,7 +173,7 @@ def test_mlflow_hook(log_model):
def test_wandb_hook():
sys.modules['wandb'] = MagicMock()
runner = _build_demo_runner()
hook = mmcv.runner.hooks.WandbLoggerHook()
hook = WandbLoggerHook()
loader = DataLoader(torch.ones((5, 2)))
runner.register_hook(hook)
......@@ -190,7 +190,24 @@ def test_wandb_hook():
def _build_demo_runner():
model = nn.Linear(2, 1)
class Model(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(2, 1)
def forward(self, x):
return self.linear(x)
def train_step(self, x, optimizer, **kwargs):
return dict(loss=self(x))
def val_step(self, x, optimizer, **kwargs):
return dict(loss=self(x))
model = Model()
optimizer = torch.optim.SGD(model.parameters(), lr=0.02, momentum=0.95)
log_config = dict(
......@@ -199,10 +216,9 @@ def _build_demo_runner():
])
tmp_dir = tempfile.mkdtemp()
runner = mmcv.runner.Runner(
runner = EpochBasedRunner(
model=model,
work_dir=tmp_dir,
batch_processor=lambda model, x, **kwargs: {'loss': model(x) - 0},
optimizer=optimizer,
logger=logging.getLogger())
......
# Copyright (c) Open-MMLab. All rights reserved.
import logging
import os
import os.path as osp
import random
import string
import tempfile
import warnings
import pytest
import torch
import torch.nn as nn
def test_save_checkpoint():
try:
import torch
from torch import nn
except ImportError:
warnings.warn('Skipping test_save_checkpoint in the absense of torch')
return
from mmcv.runner import EpochBasedRunner
class OldStyleModel(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(3, 3, 1)
class Model(OldStyleModel):
def train_step(self):
pass
def val_step(self):
pass
def test_epoch_based_runner():
with pytest.warns(UserWarning):
# batch_processor is deprecated
model = OldStyleModel()
import mmcv.runner
def batch_processor():
pass
model = nn.Linear(1, 1)
runner = mmcv.runner.Runner(
model=model, batch_processor=lambda x: x, logger=logging.getLogger())
_ = EpochBasedRunner(model, batch_processor)
with pytest.raises(TypeError):
# batch_processor must be callable
model = OldStyleModel()
_ = EpochBasedRunner(model, batch_processor=0)
with pytest.raises(AssertionError):
# model must implement the method train_step()
model = OldStyleModel()
_ = EpochBasedRunner(model)
with pytest.raises(TypeError):
# work_dir must be a str or None
model = Model()
_ = EpochBasedRunner(model, work_dir=1)
with pytest.raises(RuntimeError):
# batch_processor and train_step() cannot be both set
def batch_processor():
pass
model = Model()
_ = EpochBasedRunner(model, batch_processor)
# test work_dir
model = Model()
temp_root = tempfile.gettempdir()
dir_name = ''.join(
[random.choice(string.ascii_letters) for _ in range(10)])
work_dir = osp.join(temp_root, dir_name)
_ = EpochBasedRunner(model, work_dir=work_dir)
assert osp.isdir(work_dir)
_ = EpochBasedRunner(model, work_dir=work_dir)
assert osp.isdir(work_dir)
os.removedirs(work_dir)
def test_save_checkpoint():
model = Model()
runner = EpochBasedRunner(model=model, logger=logging.getLogger())
with tempfile.TemporaryDirectory() as root:
runner.save_checkpoint(root)
......@@ -33,15 +95,8 @@ def test_save_checkpoint():
def test_build_lr_momentum_hook():
try:
from torch import nn
except ImportError:
warnings.warn('Skipping test_save_checkpoint in the absense of torch')
return
import mmcv.runner
model = nn.Linear(1, 1)
runner = mmcv.runner.Runner(
model=model, batch_processor=lambda x: x, logger=logging.getLogger())
model = Model()
runner = EpochBasedRunner(model=model, logger=logging.getLogger())
# test policy that is already title
lr_config = dict(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment