"...text-generation-inference.git" did not exist on "91d72675342e34c314a0d7cc9bb9ca9d8f5aa295"
Unverified Commit 35ba1528 authored by Kai Chen's avatar Kai Chen Committed by GitHub
Browse files

Add a BaseRunner and rename Runner to EpochBasedRunner (#290)

* add a BaseRunner and rename Runner to EpochBasedRunner

* fix the train/val step

* bug fix

* update unit tests

* fix unit tests

* raise an error if both batch_processor and train_step are set

* add a unit test
parent 6f21d8b5
# Copyright (c) Open-MMLab. All rights reserved. # Copyright (c) Open-MMLab. All rights reserved.
from .base_runner import BaseRunner
from .checkpoint import (_load_checkpoint, load_checkpoint, load_state_dict, from .checkpoint import (_load_checkpoint, load_checkpoint, load_state_dict,
save_checkpoint, weights_to_cpu) save_checkpoint, weights_to_cpu)
from .dist_utils import get_dist_info, init_dist, master_only from .dist_utils import get_dist_info, init_dist, master_only
from .epoch_based_runner import EpochBasedRunner, Runner
from .hooks import (HOOKS, CheckpointHook, ClosureHook, DistSamplerSeedHook, from .hooks import (HOOKS, CheckpointHook, ClosureHook, DistSamplerSeedHook,
Hook, IterTimerHook, LoggerHook, LrUpdaterHook, Hook, IterTimerHook, LoggerHook, LrUpdaterHook,
OptimizerHook, PaviLoggerHook, TensorboardLoggerHook, MlflowLoggerHook, OptimizerHook, PaviLoggerHook,
TextLoggerHook, WandbLoggerHook) TensorboardLoggerHook, TextLoggerHook, WandbLoggerHook)
from .log_buffer import LogBuffer from .log_buffer import LogBuffer
from .optimizer import (OPTIMIZER_BUILDERS, OPTIMIZERS, from .optimizer import (OPTIMIZER_BUILDERS, OPTIMIZERS,
DefaultOptimizerConstructor, build_optimizer, DefaultOptimizerConstructor, build_optimizer,
build_optimizer_constructor) build_optimizer_constructor)
from .priority import Priority, get_priority from .priority import Priority, get_priority
from .runner import Runner
from .utils import get_host_info, get_time_str, obj_from_dict from .utils import get_host_info, get_time_str, obj_from_dict
__all__ = [ __all__ = [
'Runner', 'LogBuffer', 'HOOKS', 'Hook', 'CheckpointHook', 'ClosureHook', 'BaseRunner', 'Runner', 'EpochBasedRunner', 'LogBuffer', 'HOOKS', 'Hook',
'LrUpdaterHook', 'OptimizerHook', 'IterTimerHook', 'DistSamplerSeedHook', 'CheckpointHook', 'ClosureHook', 'LrUpdaterHook', 'OptimizerHook',
'LoggerHook', 'PaviLoggerHook', 'TextLoggerHook', 'TensorboardLoggerHook', 'IterTimerHook', 'DistSamplerSeedHook', 'LoggerHook', 'PaviLoggerHook',
'WandbLoggerHook', '_load_checkpoint', 'load_state_dict', 'TextLoggerHook', 'TensorboardLoggerHook', 'WandbLoggerHook',
'MlflowLoggerHook', '_load_checkpoint', 'load_state_dict',
'load_checkpoint', 'weights_to_cpu', 'save_checkpoint', 'Priority', 'load_checkpoint', 'weights_to_cpu', 'save_checkpoint', 'Priority',
'get_priority', 'get_host_info', 'get_time_str', 'obj_from_dict', 'get_priority', 'get_host_info', 'get_time_str', 'obj_from_dict',
'init_dist', 'get_dist_info', 'master_only', 'OPTIMIZER_BUILDERS', 'init_dist', 'get_dist_info', 'master_only', 'OPTIMIZER_BUILDERS',
......
# Copyright (c) Open-MMLab. All rights reserved. # Copyright (c) Open-MMLab. All rights reserved.
import logging import logging
import os.path as osp import os.path as osp
import time import warnings
from abc import ABCMeta, abstractmethod
import torch import torch
import mmcv import mmcv
from .checkpoint import load_checkpoint, save_checkpoint from .checkpoint import load_checkpoint
from .dist_utils import get_dist_info from .dist_utils import get_dist_info
from .hooks import HOOKS, Hook, IterTimerHook from .hooks import HOOKS, Hook, IterTimerHook
from .log_buffer import LogBuffer from .log_buffer import LogBuffer
from .priority import get_priority from .priority import get_priority
from .utils import get_host_info, get_time_str, obj_from_dict from .utils import get_time_str
class Runner(object): class BaseRunner(metaclass=ABCMeta):
"""A training helper for PyTorch. """The base class of Runner, a training helper for PyTorch.
All subclasses should implement the following APIs:
- ``run()``
- ``train()``
- ``val()``
- ``save_checkpoint()``
Args: Args:
model (:obj:`torch.nn.Module`): The model to be run. model (:obj:`torch.nn.Module`): The model to be run.
...@@ -25,29 +33,38 @@ class Runner(object): ...@@ -25,29 +33,38 @@ class Runner(object):
optimizer (dict or :obj:`torch.optim.Optimizer`): If it is a dict, optimizer (dict or :obj:`torch.optim.Optimizer`): If it is a dict,
runner will construct an optimizer according to it. runner will construct an optimizer according to it.
work_dir (str, optional): The working directory to save checkpoints work_dir (str, optional): The working directory to save checkpoints
and logs. and logs. Defaults to None.
log_level (int): Logging level. logger (:obj:`logging.Logger`): Logger used during training.
logger (:obj:`logging.Logger`): Custom logger. If `None`, use the Defaults to None.
default logger.
meta (dict | None): A dict records some import information such as meta (dict | None): A dict records some import information such as
environment info and seed, which will be logged in logger hook. environment info and seed, which will be logged in logger hook.
Defaults to None.
""" """
def __init__(self, def __init__(self,
model, model,
batch_processor, batch_processor=None,
optimizer=None, optimizer=None,
work_dir=None, work_dir=None,
log_level=logging.INFO,
logger=None, logger=None,
meta=None): meta=None):
assert callable(batch_processor) if batch_processor is not None:
self.model = model if not callable(batch_processor):
if optimizer is not None: raise TypeError('batch_processor must be callable, '
self.optimizer = self.init_optimizer(optimizer) f'but got {type(batch_processor)}')
warnings.warn('batch_processor is deprecated, please implement '
'train_step() and val_step() in the model instead.')
# raise an error is `batch_processor` is not None and
# `model.train_step()` exists.
if hasattr(model, 'train_step') or hasattr(model, 'val_step'):
raise RuntimeError(
'batch_processor and model.train_step()/model.val_step() '
'cannot be both available.')
else: else:
self.optimizer = None assert hasattr(model, 'train_step')
self.model = model
self.batch_processor = batch_processor self.batch_processor = batch_processor
self.optimizer = optimizer
# create work_dir # create work_dir
if mmcv.is_str(work_dir): if mmcv.is_str(work_dir):
...@@ -64,18 +81,15 @@ class Runner(object): ...@@ -64,18 +81,15 @@ class Runner(object):
else: else:
self._model_name = self.model.__class__.__name__ self._model_name = self.model.__class__.__name__
self._rank, self._world_size = get_dist_info() assert logging is not None
self.timestamp = get_time_str() self.logger = logger
if logger is None:
self.logger = self.init_logger(work_dir, log_level)
else:
self.logger = logger
self.log_buffer = LogBuffer()
if meta is not None: if meta is not None:
assert isinstance(meta, dict), '"meta" must be a dict or None' assert isinstance(meta, dict), '"meta" must be a dict or None'
self.meta = meta self.meta = meta
self._rank, self._world_size = get_dist_info()
self.timestamp = get_time_str()
self.mode = None self.mode = None
self._hooks = [] self._hooks = []
self._epoch = 0 self._epoch = 0
...@@ -83,6 +97,8 @@ class Runner(object): ...@@ -83,6 +97,8 @@ class Runner(object):
self._inner_iter = 0 self._inner_iter = 0
self._max_epochs = 0 self._max_epochs = 0
self._max_iters = 0 self._max_iters = 0
# TODO: Redesign LogBuffer, it is not flexible and elegant enough
self.log_buffer = LogBuffer()
@property @property
def model_name(self): def model_name(self):
...@@ -130,62 +146,26 @@ class Runner(object): ...@@ -130,62 +146,26 @@ class Runner(object):
"""int: Maximum training iterations.""" """int: Maximum training iterations."""
return self._max_iters return self._max_iters
def init_optimizer(self, optimizer): @abstractmethod
"""Init the optimizer. def train(self):
pass
Args: @abstractmethod
optimizer (dict or :obj:`~torch.optim.Optimizer`): Either an def val(self):
optimizer object or a dict used for constructing the optimizer. pass
Returns:
:obj:`~torch.optim.Optimizer`: An optimizer object.
Examples:
>>> optimizer = dict(type='SGD', lr=0.01, momentum=0.9)
>>> type(runner.init_optimizer(optimizer))
<class 'torch.optim.sgd.SGD'>
"""
if isinstance(optimizer, dict):
optimizer = obj_from_dict(optimizer, torch.optim,
dict(params=self.model.parameters()))
elif not isinstance(optimizer, torch.optim.Optimizer):
raise TypeError(
'optimizer must be either an Optimizer object or a dict, '
f'but got {type(optimizer)}')
return optimizer
def _add_file_handler(self,
logger,
filename=None,
mode='w',
level=logging.INFO):
# TODO: move this method out of runner
file_handler = logging.FileHandler(filename, mode)
file_handler.setFormatter(
logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))
file_handler.setLevel(level)
logger.addHandler(file_handler)
return logger
def init_logger(self, log_dir=None, level=logging.INFO):
"""Init the logger.
Args: @abstractmethod
log_dir(str, optional): Log file directory. If not specified, no def run(self, data_loaders, workflow, **kwargs):
log file will be used. pass
level (int or str): See the built-in python logging module.
Returns: @abstractmethod
:obj:`~logging.Logger`: Python logger. def save_checkpoint(self,
""" out_dir,
logging.basicConfig( filename_tmpl,
format='%(asctime)s - %(levelname)s - %(message)s', level=level) save_optimizer=True,
logger = logging.getLogger(__name__) meta=None,
if log_dir and self.rank == 0: create_symlink=True):
filename = f'{self.timestamp}.log' pass
log_file = osp.join(log_dir, filename)
self._add_file_handler(logger, log_file, level=level)
return logger
def current_lr(self): def current_lr(self):
"""Get current learning rates. """Get current learning rates.
...@@ -220,6 +200,11 @@ class Runner(object): ...@@ -220,6 +200,11 @@ class Runner(object):
def register_hook(self, hook, priority='NORMAL'): def register_hook(self, hook, priority='NORMAL'):
"""Register a hook into the hook list. """Register a hook into the hook list.
The hook will be inserted into a priority queue, with the specified
priority (See :cls:`Priority` for details of priorities).
For hooks with the same priority, they will be triggered in the same
order as they are registered.
Args: Args:
hook (:obj:`Hook`): The hook to be registered. hook (:obj:`Hook`): The hook to be registered.
priority (int or str or :obj:`Priority`): Hook priority. priority (int or str or :obj:`Priority`): Hook priority.
...@@ -241,6 +226,12 @@ class Runner(object): ...@@ -241,6 +226,12 @@ class Runner(object):
self._hooks.insert(0, hook) self._hooks.insert(0, hook)
def call_hook(self, fn_name): def call_hook(self, fn_name):
"""Call all hooks.
Args:
fn_name (str): The function name in each hook to be called, such as
"before_train_epoch".
"""
for hook in self._hooks: for hook in self._hooks:
getattr(hook, fn_name)(self) getattr(hook, fn_name)(self)
...@@ -249,72 +240,6 @@ class Runner(object): ...@@ -249,72 +240,6 @@ class Runner(object):
return load_checkpoint(self.model, filename, map_location, strict, return load_checkpoint(self.model, filename, map_location, strict,
self.logger) self.logger)
def save_checkpoint(self,
out_dir,
filename_tmpl='epoch_{}.pth',
save_optimizer=True,
meta=None,
create_symlink=True):
if meta is None:
meta = dict(epoch=self.epoch + 1, iter=self.iter)
else:
meta.update(epoch=self.epoch + 1, iter=self.iter)
filename = filename_tmpl.format(self.epoch + 1)
filepath = osp.join(out_dir, filename)
optimizer = self.optimizer if save_optimizer else None
save_checkpoint(self.model, filepath, optimizer=optimizer, meta=meta)
# in some environments, `os.symlink` is not supported, you may need to
# set `create_symlink` to False
if create_symlink:
mmcv.symlink(filename, osp.join(out_dir, 'latest.pth'))
def train(self, data_loader, **kwargs):
self.model.train()
self.mode = 'train'
self.data_loader = data_loader
self.call_hook('before_train_epoch')
time.sleep(2) # Prevent possible deadlock during epoch transition
for i, data_batch in enumerate(data_loader):
self._inner_iter = i
self.call_hook('before_train_iter')
outputs = self.batch_processor(
self.model, data_batch, train_mode=True, **kwargs)
if not isinstance(outputs, dict):
raise TypeError('batch_processor() must return a dict')
if 'log_vars' in outputs:
self.log_buffer.update(outputs['log_vars'],
outputs['num_samples'])
self.outputs = outputs
self.call_hook('after_train_iter')
self._iter += 1
self.call_hook('after_train_epoch')
self._epoch += 1
def val(self, data_loader, **kwargs):
self.model.eval()
self.mode = 'val'
self.data_loader = data_loader
self.call_hook('before_val_epoch')
time.sleep(2) # Prevent possible deadlock during epoch transition
for i, data_batch in enumerate(data_loader):
self._inner_iter = i
self.call_hook('before_val_iter')
with torch.no_grad():
outputs = self.batch_processor(
self.model, data_batch, train_mode=False, **kwargs)
if not isinstance(outputs, dict):
raise TypeError('batch_processor() must return a dict')
if 'log_vars' in outputs:
self.log_buffer.update(outputs['log_vars'],
outputs['num_samples'])
self.outputs = outputs
self.call_hook('after_val_iter')
self.call_hook('after_val_epoch')
def resume(self, def resume(self,
checkpoint, checkpoint,
resume_optimizer=True, resume_optimizer=True,
...@@ -335,57 +260,6 @@ class Runner(object): ...@@ -335,57 +260,6 @@ class Runner(object):
self.logger.info('resumed epoch %d, iter %d', self.epoch, self.iter) self.logger.info('resumed epoch %d, iter %d', self.epoch, self.iter)
def run(self, data_loaders, workflow, max_epochs, **kwargs):
"""Start running.
Args:
data_loaders (list[:obj:`DataLoader`]): Dataloaders for training
and validation.
workflow (list[tuple]): A list of (phase, epochs) to specify the
running order and epochs. E.g, [('train', 2), ('val', 1)] means
running 2 epochs for training and 1 epoch for validation,
iteratively.
max_epochs (int): Total training epochs.
"""
assert isinstance(data_loaders, list)
assert mmcv.is_list_of(workflow, tuple)
assert len(data_loaders) == len(workflow)
self._max_epochs = max_epochs
for i, flow in enumerate(workflow):
mode, epochs = flow
if mode == 'train':
self._max_iters = self._max_epochs * len(data_loaders[i])
break
work_dir = self.work_dir if self.work_dir is not None else 'NONE'
self.logger.info('Start running, host: %s, work_dir: %s',
get_host_info(), work_dir)
self.logger.info('workflow: %s, max: %d epochs', workflow, max_epochs)
self.call_hook('before_run')
while self.epoch < max_epochs:
for i, flow in enumerate(workflow):
mode, epochs = flow
if isinstance(mode, str): # self.train()
if not hasattr(self, mode):
raise ValueError(
f'runner has no method named "{mode}" to run an '
'epoch')
epoch_runner = getattr(self, mode)
elif callable(mode): # custom train()
epoch_runner = mode
else:
raise TypeError('mode in workflow must be a str or '
f'callable function, not {type(mode)}')
for _ in range(epochs):
if mode == 'train' and self.epoch >= max_epochs:
return
epoch_runner(data_loaders[i], **kwargs)
time.sleep(1) # wait for some hooks like loggers to finish
self.call_hook('after_run')
def register_lr_hook(self, lr_config): def register_lr_hook(self, lr_config):
if isinstance(lr_config, dict): if isinstance(lr_config, dict):
assert 'policy' in lr_config assert 'policy' in lr_config
...@@ -404,26 +278,6 @@ class Runner(object): ...@@ -404,26 +278,6 @@ class Runner(object):
hook = lr_config hook = lr_config
self.register_hook(hook) self.register_hook(hook)
def register_optimizer_hook(self, optimizer_config):
if optimizer_config is None:
return
if isinstance(optimizer_config, dict):
optimizer_config.setdefault('type', 'OptimizerHook')
hook = mmcv.build_from_cfg(optimizer_config, HOOKS)
else:
hook = optimizer_config
self.register_hook(hook)
def register_checkpoint_hook(self, checkpoint_config):
if checkpoint_config is None:
return
if isinstance(checkpoint_config, dict):
checkpoint_config.setdefault('type', 'CheckpointHook')
hook = mmcv.build_from_cfg(checkpoint_config, HOOKS)
else:
hook = checkpoint_config
self.register_hook(hook)
def register_momentum_hook(self, momentum_config): def register_momentum_hook(self, momentum_config):
if momentum_config is None: if momentum_config is None:
return return
...@@ -444,6 +298,26 @@ class Runner(object): ...@@ -444,6 +298,26 @@ class Runner(object):
hook = momentum_config hook = momentum_config
self.register_hook(hook) self.register_hook(hook)
def register_optimizer_hook(self, optimizer_config):
if optimizer_config is None:
return
if isinstance(optimizer_config, dict):
optimizer_config.setdefault('type', 'OptimizerHook')
hook = mmcv.build_from_cfg(optimizer_config, HOOKS)
else:
hook = optimizer_config
self.register_hook(hook)
def register_checkpoint_hook(self, checkpoint_config):
if checkpoint_config is None:
return
if isinstance(checkpoint_config, dict):
checkpoint_config.setdefault('type', 'CheckpointHook')
hook = mmcv.build_from_cfg(checkpoint_config, HOOKS)
else:
hook = checkpoint_config
self.register_hook(hook)
def register_logger_hooks(self, log_config): def register_logger_hooks(self, log_config):
log_interval = log_config['interval'] log_interval = log_config['interval']
for info in log_config['hooks']: for info in log_config['hooks']:
......
# Copyright (c) Open-MMLab. All rights reserved.
import os.path as osp
import time
import warnings
import torch
import mmcv
from .base_runner import BaseRunner
from .checkpoint import save_checkpoint
from .utils import get_host_info
class EpochBasedRunner(BaseRunner):
"""Epoch-based Runner.
This runner train models epoch by epoch.
"""
def train(self, data_loader, **kwargs):
self.model.train()
self.mode = 'train'
self.data_loader = data_loader
self._max_iters = self._max_epochs * len(data_loader)
self.call_hook('before_train_epoch')
time.sleep(2) # Prevent possible deadlock during epoch transition
for i, data_batch in enumerate(data_loader):
self._inner_iter = i
self.call_hook('before_train_iter')
if self.batch_processor is None:
outputs = self.model.train_step(data_batch, self.optimizer,
**kwargs)
else:
outputs = self.batch_processor(
self.model, data_batch, train_mode=True, **kwargs)
if not isinstance(outputs, dict):
raise TypeError('"batch_processor()" or "model.train_step()"'
' must return a dict')
if 'log_vars' in outputs:
self.log_buffer.update(outputs['log_vars'],
outputs['num_samples'])
self.outputs = outputs
self.call_hook('after_train_iter')
self._iter += 1
self.call_hook('after_train_epoch')
self._epoch += 1
def val(self, data_loader, **kwargs):
self.model.eval()
self.mode = 'val'
self.data_loader = data_loader
self.call_hook('before_val_epoch')
time.sleep(2) # Prevent possible deadlock during epoch transition
for i, data_batch in enumerate(data_loader):
self._inner_iter = i
self.call_hook('before_val_iter')
with torch.no_grad():
if self.batch_processor is None:
outputs = self.model.val_step(data_batch, self.optimizer,
**kwargs)
else:
outputs = self.batch_processor(
self.model, data_batch, train_mode=False, **kwargs)
if not isinstance(outputs, dict):
raise TypeError('"batch_processor()" or "model.val_step()"'
' must return a dict')
if 'log_vars' in outputs:
self.log_buffer.update(outputs['log_vars'],
outputs['num_samples'])
self.outputs = outputs
self.call_hook('after_val_iter')
self.call_hook('after_val_epoch')
def run(self, data_loaders, workflow, max_epochs, **kwargs):
"""Start running.
Args:
data_loaders (list[:obj:`DataLoader`]): Dataloaders for training
and validation.
workflow (list[tuple]): A list of (phase, epochs) to specify the
running order and epochs. E.g, [('train', 2), ('val', 1)] means
running 2 epochs for training and 1 epoch for validation,
iteratively.
max_epochs (int): Total training epochs.
"""
assert isinstance(data_loaders, list)
assert mmcv.is_list_of(workflow, tuple)
assert len(data_loaders) == len(workflow)
self._max_epochs = max_epochs
for i, flow in enumerate(workflow):
mode, epochs = flow
if mode == 'train':
self._max_iters = self._max_epochs * len(data_loaders[i])
break
work_dir = self.work_dir if self.work_dir is not None else 'NONE'
self.logger.info('Start running, host: %s, work_dir: %s',
get_host_info(), work_dir)
self.logger.info('workflow: %s, max: %d epochs', workflow, max_epochs)
self.call_hook('before_run')
while self.epoch < max_epochs:
for i, flow in enumerate(workflow):
mode, epochs = flow
if isinstance(mode, str): # self.train()
if not hasattr(self, mode):
raise ValueError(
f'runner has no method named "{mode}" to run an '
'epoch')
epoch_runner = getattr(self, mode)
else:
raise TypeError(
'mode in workflow must be a str, but got {}'.format(
type(mode)))
for _ in range(epochs):
if mode == 'train' and self.epoch >= max_epochs:
return
epoch_runner(data_loaders[i], **kwargs)
time.sleep(1) # wait for some hooks like loggers to finish
self.call_hook('after_run')
def save_checkpoint(self,
out_dir,
filename_tmpl='epoch_{}.pth',
save_optimizer=True,
meta=None,
create_symlink=True):
"""Save the checkpoint.
Args:
out_dir (str): The directory that checkpoints are saved.
filename_tmpl (str, optional): The checkpoint filename template,
which contains a placeholder for the epoch number.
Defaults to 'epoch_{}.pth'.
save_optimizer (bool, optional): Whether to save the optimizer to
the checkpoint. Defaults to True.
meta (dict, optional): The meta information to be saved in the
checkpoint. Defaults to None.
create_symlink (bool, optional): Whether to create a symlink
"latest.pth" to point to the latest checkpoint.
Defaults to True.
"""
if meta is None:
meta = dict(epoch=self.epoch + 1, iter=self.iter)
else:
meta.update(epoch=self.epoch + 1, iter=self.iter)
filename = filename_tmpl.format(self.epoch + 1)
filepath = osp.join(out_dir, filename)
optimizer = self.optimizer if save_optimizer else None
save_checkpoint(self.model, filepath, optimizer=optimizer, meta=meta)
# in some environments, `os.symlink` is not supported, you may need to
# set `create_symlink` to False
if create_symlink:
mmcv.symlink(filename, osp.join(out_dir, 'latest.pth'))
class Runner(EpochBasedRunner):
"""Deprecated name of EpochBasedRunner"""
def __init__(self, *args, **kwargs):
warnings.warn(
'Runner was deprecated, please use DeprecationWarning instead')
super().__init__(*args, **kwargs)
# Copyright (c) Open-MMLab. All rights reserved. # Copyright (c) Open-MMLab. All rights reserved.
import numbers import numbers
from mmcv.runner import master_only from ...dist_utils import master_only
from ..hook import HOOKS from ..hook import HOOKS
from .base import LoggerHook from .base import LoggerHook
......
...@@ -5,7 +5,7 @@ import os.path as osp ...@@ -5,7 +5,7 @@ import os.path as osp
import numpy as np import numpy as np
import torch import torch
from mmcv.runner import master_only from ...dist_utils import master_only
from ..hook import HOOKS from ..hook import HOOKS
from .base import LoggerHook from .base import LoggerHook
......
...@@ -3,7 +3,7 @@ import os.path as osp ...@@ -3,7 +3,7 @@ import os.path as osp
import torch import torch
from mmcv.runner import master_only from ...dist_utils import master_only
from ..hook import HOOKS from ..hook import HOOKS
from .base import LoggerHook from .base import LoggerHook
......
# Copyright (c) Open-MMLab. All rights reserved. # Copyright (c) Open-MMLab. All rights reserved.
import numbers import numbers
from mmcv.runner import master_only from ...dist_utils import master_only
from ..hook import HOOKS from ..hook import HOOKS
from .base import LoggerHook from .base import LoggerHook
......
# Copyright (c) Open-MMLab. All rights reserved. # Copyright (c) Open-MMLab. All rights reserved.
from __future__ import division
from math import cos, pi from math import cos, pi
from .hook import HOOKS, Hook from .hook import HOOKS, Hook
......
...@@ -18,7 +18,12 @@ import torch ...@@ -18,7 +18,12 @@ import torch
import torch.nn as nn import torch.nn as nn
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
import mmcv.runner from mmcv.runner import (EpochBasedRunner, IterTimerHook, MlflowLoggerHook,
PaviLoggerHook, WandbLoggerHook)
from mmcv.runner.hooks.lr_updater import (CosineAnealingLrUpdaterHook,
CyclicLrUpdaterHook)
from mmcv.runner.hooks.momentum_updater import (
CosineAnealingMomentumUpdaterHook, CyclicMomentumUpdaterHook)
def test_pavi_hook(): def test_pavi_hook():
...@@ -26,8 +31,7 @@ def test_pavi_hook(): ...@@ -26,8 +31,7 @@ def test_pavi_hook():
loader = DataLoader(torch.ones((5, 2))) loader = DataLoader(torch.ones((5, 2)))
runner = _build_demo_runner() runner = _build_demo_runner()
hook = mmcv.runner.hooks.PaviLoggerHook( hook = PaviLoggerHook(add_graph=False, add_last_ckpt=True)
add_graph=False, add_last_ckpt=True)
runner.register_hook(hook) runner.register_hook(hook)
runner.run([loader, loader], [('train', 1), ('val', 1)], 1) runner.run([loader, loader], [('train', 1), ('val', 1)], 1)
shutil.rmtree(runner.work_dir) shutil.rmtree(runner.work_dir)
...@@ -52,7 +56,7 @@ def test_momentum_runner_hook(): ...@@ -52,7 +56,7 @@ def test_momentum_runner_hook():
runner = _build_demo_runner() runner = _build_demo_runner()
# add momentum scheduler # add momentum scheduler
hook = mmcv.runner.hooks.momentum_updater.CyclicMomentumUpdaterHook( hook = CyclicMomentumUpdaterHook(
by_epoch=False, by_epoch=False,
target_ratio=(0.85 / 0.95, 1), target_ratio=(0.85 / 0.95, 1),
cyclic_times=1, cyclic_times=1,
...@@ -60,17 +64,16 @@ def test_momentum_runner_hook(): ...@@ -60,17 +64,16 @@ def test_momentum_runner_hook():
runner.register_hook(hook) runner.register_hook(hook)
# add momentum LR scheduler # add momentum LR scheduler
hook = mmcv.runner.hooks.lr_updater.CyclicLrUpdaterHook( hook = CyclicLrUpdaterHook(
by_epoch=False, by_epoch=False,
target_ratio=(10, 1), target_ratio=(10, 1),
cyclic_times=1, cyclic_times=1,
step_ratio_up=0.4) step_ratio_up=0.4)
runner.register_hook(hook) runner.register_hook(hook)
runner.register_hook(mmcv.runner.hooks.IterTimerHook()) runner.register_hook(IterTimerHook())
# add pavi hook # add pavi hook
hook = mmcv.runner.hooks.PaviLoggerHook( hook = PaviLoggerHook(interval=1, add_graph=False, add_last_ckpt=True)
interval=1, add_graph=False, add_last_ckpt=True)
runner.register_hook(hook) runner.register_hook(hook)
runner.run([loader], [('train', 1)], 1) runner.run([loader], [('train', 1)], 1)
shutil.rmtree(runner.work_dir) shutil.rmtree(runner.work_dir)
...@@ -103,23 +106,21 @@ def test_cosine_runner_hook(): ...@@ -103,23 +106,21 @@ def test_cosine_runner_hook():
runner = _build_demo_runner() runner = _build_demo_runner()
# add momentum scheduler # add momentum scheduler
hook = mmcv.runner.hooks.momentum_updater \ hook = CosineAnealingMomentumUpdaterHook(
.CosineAnealingMomentumUpdaterHook( min_momentum_ratio=0.99 / 0.95,
min_momentum_ratio=0.99 / 0.95, by_epoch=False,
by_epoch=False, warmup_iters=2,
warmup_iters=2, warmup_ratio=0.9 / 0.95)
warmup_ratio=0.9 / 0.95)
runner.register_hook(hook) runner.register_hook(hook)
# add momentum LR scheduler # add momentum LR scheduler
hook = mmcv.runner.hooks.lr_updater.CosineAnealingLrUpdaterHook( hook = CosineAnealingLrUpdaterHook(
by_epoch=False, min_lr_ratio=0, warmup_iters=2, warmup_ratio=0.9) by_epoch=False, min_lr_ratio=0, warmup_iters=2, warmup_ratio=0.9)
runner.register_hook(hook) runner.register_hook(hook)
runner.register_hook(mmcv.runner.hooks.IterTimerHook()) runner.register_hook(IterTimerHook())
# add pavi hook # add pavi hook
hook = mmcv.runner.hooks.PaviLoggerHook( hook = PaviLoggerHook(interval=1, add_graph=False, add_last_ckpt=True)
interval=1, add_graph=False, add_last_ckpt=True)
runner.register_hook(hook) runner.register_hook(hook)
runner.run([loader], [('train', 1)], 1) runner.run([loader], [('train', 1)], 1)
shutil.rmtree(runner.work_dir) shutil.rmtree(runner.work_dir)
...@@ -151,8 +152,7 @@ def test_mlflow_hook(log_model): ...@@ -151,8 +152,7 @@ def test_mlflow_hook(log_model):
runner = _build_demo_runner() runner = _build_demo_runner()
loader = DataLoader(torch.ones((5, 2))) loader = DataLoader(torch.ones((5, 2)))
hook = mmcv.runner.hooks.MlflowLoggerHook( hook = MlflowLoggerHook(exp_name='test', log_model=log_model)
exp_name='test', log_model=log_model)
runner.register_hook(hook) runner.register_hook(hook)
runner.run([loader, loader], [('train', 1), ('val', 1)], 1) runner.run([loader, loader], [('train', 1), ('val', 1)], 1)
shutil.rmtree(runner.work_dir) shutil.rmtree(runner.work_dir)
...@@ -173,7 +173,7 @@ def test_mlflow_hook(log_model): ...@@ -173,7 +173,7 @@ def test_mlflow_hook(log_model):
def test_wandb_hook(): def test_wandb_hook():
sys.modules['wandb'] = MagicMock() sys.modules['wandb'] = MagicMock()
runner = _build_demo_runner() runner = _build_demo_runner()
hook = mmcv.runner.hooks.WandbLoggerHook() hook = WandbLoggerHook()
loader = DataLoader(torch.ones((5, 2))) loader = DataLoader(torch.ones((5, 2)))
runner.register_hook(hook) runner.register_hook(hook)
...@@ -190,7 +190,24 @@ def test_wandb_hook(): ...@@ -190,7 +190,24 @@ def test_wandb_hook():
def _build_demo_runner(): def _build_demo_runner():
model = nn.Linear(2, 1)
class Model(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(2, 1)
def forward(self, x):
return self.linear(x)
def train_step(self, x, optimizer, **kwargs):
return dict(loss=self(x))
def val_step(self, x, optimizer, **kwargs):
return dict(loss=self(x))
model = Model()
optimizer = torch.optim.SGD(model.parameters(), lr=0.02, momentum=0.95) optimizer = torch.optim.SGD(model.parameters(), lr=0.02, momentum=0.95)
log_config = dict( log_config = dict(
...@@ -199,10 +216,9 @@ def _build_demo_runner(): ...@@ -199,10 +216,9 @@ def _build_demo_runner():
]) ])
tmp_dir = tempfile.mkdtemp() tmp_dir = tempfile.mkdtemp()
runner = mmcv.runner.Runner( runner = EpochBasedRunner(
model=model, model=model,
work_dir=tmp_dir, work_dir=tmp_dir,
batch_processor=lambda model, x, **kwargs: {'loss': model(x) - 0},
optimizer=optimizer, optimizer=optimizer,
logger=logging.getLogger()) logger=logging.getLogger())
......
# Copyright (c) Open-MMLab. All rights reserved. # Copyright (c) Open-MMLab. All rights reserved.
import logging import logging
import os
import os.path as osp import os.path as osp
import random
import string
import tempfile import tempfile
import warnings
import pytest
import torch
import torch.nn as nn
def test_save_checkpoint(): from mmcv.runner import EpochBasedRunner
try:
import torch
from torch import nn class OldStyleModel(nn.Module):
except ImportError:
warnings.warn('Skipping test_save_checkpoint in the absense of torch') def __init__(self):
return super().__init__()
self.conv = nn.Conv2d(3, 3, 1)
class Model(OldStyleModel):
def train_step(self):
pass
def val_step(self):
pass
def test_epoch_based_runner():
with pytest.warns(UserWarning):
# batch_processor is deprecated
model = OldStyleModel()
import mmcv.runner def batch_processor():
pass
model = nn.Linear(1, 1) _ = EpochBasedRunner(model, batch_processor)
runner = mmcv.runner.Runner(
model=model, batch_processor=lambda x: x, logger=logging.getLogger()) with pytest.raises(TypeError):
# batch_processor must be callable
model = OldStyleModel()
_ = EpochBasedRunner(model, batch_processor=0)
with pytest.raises(AssertionError):
# model must implement the method train_step()
model = OldStyleModel()
_ = EpochBasedRunner(model)
with pytest.raises(TypeError):
# work_dir must be a str or None
model = Model()
_ = EpochBasedRunner(model, work_dir=1)
with pytest.raises(RuntimeError):
# batch_processor and train_step() cannot be both set
def batch_processor():
pass
model = Model()
_ = EpochBasedRunner(model, batch_processor)
# test work_dir
model = Model()
temp_root = tempfile.gettempdir()
dir_name = ''.join(
[random.choice(string.ascii_letters) for _ in range(10)])
work_dir = osp.join(temp_root, dir_name)
_ = EpochBasedRunner(model, work_dir=work_dir)
assert osp.isdir(work_dir)
_ = EpochBasedRunner(model, work_dir=work_dir)
assert osp.isdir(work_dir)
os.removedirs(work_dir)
def test_save_checkpoint():
model = Model()
runner = EpochBasedRunner(model=model, logger=logging.getLogger())
with tempfile.TemporaryDirectory() as root: with tempfile.TemporaryDirectory() as root:
runner.save_checkpoint(root) runner.save_checkpoint(root)
...@@ -33,15 +95,8 @@ def test_save_checkpoint(): ...@@ -33,15 +95,8 @@ def test_save_checkpoint():
def test_build_lr_momentum_hook(): def test_build_lr_momentum_hook():
try: model = Model()
from torch import nn runner = EpochBasedRunner(model=model, logger=logging.getLogger())
except ImportError:
warnings.warn('Skipping test_save_checkpoint in the absense of torch')
return
import mmcv.runner
model = nn.Linear(1, 1)
runner = mmcv.runner.Runner(
model=model, batch_processor=lambda x: x, logger=logging.getLogger())
# test policy that is already title # test policy that is already title
lr_config = dict( lr_config = dict(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment