Unverified Commit ba059611 authored by Wenwei Zhang's avatar Wenwei Zhang Committed by GitHub
Browse files

Momentum scheduler (#167)

* track progress of iter&enum

* restore

* add momentum scheduler

* fix small bug

* cyclic scheduler"

* fix bug

* fix second phase's bug

* reformat

* feature (cosine lr): use relative ratio for more flexible scheduler

* Fix (runner): fix bugs in runner

* Refactor (hook): refactor cosing/cyclic LR/momentum hook with unittest

* Clean unnecessary files and reformat

* Fix memory key error when GPU is not avaliable

* Resolve comments

* Do not print momentum in text log

* Change hook register order

* Refactor max_iter

* Fix max_iter bugs in runner

* Enforce target_ratio to be either tuple or float
parent 8ac858b1
exclude: ^tests/data/ exclude: ^tests/data/
repos: repos:
- repo: https://gitlab.com/pycqa/flake8 - repo: https://gitlab.com/pycqa/flake8.git
rev: 3.7.9 rev: 3.7.9
hooks: hooks:
- id: flake8 - id: flake8
- repo: https://github.com/asottile/seed-isort-config - repo: https://github.com/asottile/seed-isort-config
rev: v2.1.0 rev: v2.1.0
hooks: hooks:
- id: seed-isort-config - id: seed-isort-config
- repo: https://github.com/timothycrosley/isort - repo: https://github.com/timothycrosley/isort
rev: 4.3.21 rev: 4.3.21
hooks: hooks:
- id: isort - id: isort
- repo: https://github.com/pre-commit/mirrors-yapf - repo: https://github.com/pre-commit/mirrors-yapf
rev: v0.29.0 rev: v0.29.0
hooks: hooks:
- id: yapf - id: yapf
- repo: https://github.com/pre-commit/pre-commit-hooks - repo: https://github.com/pre-commit/pre-commit-hooks
rev: v2.5.0 rev: v2.5.0
hooks: hooks:
- id: trailing-whitespace - id: trailing-whitespace
- id: check-yaml - id: check-yaml
- id: end-of-file-fixer - id: end-of-file-fixer
- id: requirements-txt-fixer - id: requirements-txt-fixer
- id: double-quote-string-fixer - id: double-quote-string-fixer
- id: fix-encoding-pragma - id: fix-encoding-pragma
args: ["--remove"] args: ["--remove"]
- id: mixed-line-ending - id: mixed-line-ending
args: ["--fix=lf"] args: ["--fix=lf"]
...@@ -7,6 +7,7 @@ from .logger import (LoggerHook, MlflowLoggerHook, PaviLoggerHook, ...@@ -7,6 +7,7 @@ from .logger import (LoggerHook, MlflowLoggerHook, PaviLoggerHook,
TensorboardLoggerHook, TextLoggerHook, WandbLoggerHook) TensorboardLoggerHook, TextLoggerHook, WandbLoggerHook)
from .lr_updater import LrUpdaterHook from .lr_updater import LrUpdaterHook
from .memory import EmptyCacheHook from .memory import EmptyCacheHook
from .momentum_updater import MomentumUpdaterHook
from .optimizer import OptimizerHook from .optimizer import OptimizerHook
from .sampler_seed import DistSamplerSeedHook from .sampler_seed import DistSamplerSeedHook
...@@ -14,5 +15,5 @@ __all__ = [ ...@@ -14,5 +15,5 @@ __all__ = [
'HOOKS', 'Hook', 'CheckpointHook', 'ClosureHook', 'LrUpdaterHook', 'HOOKS', 'Hook', 'CheckpointHook', 'ClosureHook', 'LrUpdaterHook',
'OptimizerHook', 'IterTimerHook', 'DistSamplerSeedHook', 'EmptyCacheHook', 'OptimizerHook', 'IterTimerHook', 'DistSamplerSeedHook', 'EmptyCacheHook',
'LoggerHook', 'MlflowLoggerHook', 'PaviLoggerHook', 'TextLoggerHook', 'LoggerHook', 'MlflowLoggerHook', 'PaviLoggerHook', 'TextLoggerHook',
'TensorboardLoggerHook', 'WandbLoggerHook' 'TensorboardLoggerHook', 'WandbLoggerHook', 'MomentumUpdaterHook'
] ]
...@@ -70,6 +70,8 @@ class MlflowLoggerHook(LoggerHook): ...@@ -70,6 +70,8 @@ class MlflowLoggerHook(LoggerHook):
tag = '{}/{}'.format(var, runner.mode) tag = '{}/{}'.format(var, runner.mode)
if isinstance(val, numbers.Number): if isinstance(val, numbers.Number):
metrics[tag] = val metrics[tag] = val
metrics['learning_rate'] = runner.current_lr()[0]
metrics['momentum'] = runner.current_momentum()[0]
self.mlflow.log_metrics(metrics, step=runner.iter) self.mlflow.log_metrics(metrics, step=runner.iter)
@master_only @master_only
......
...@@ -71,6 +71,8 @@ class PaviLoggerHook(LoggerHook): ...@@ -71,6 +71,8 @@ class PaviLoggerHook(LoggerHook):
for tag, val in runner.log_buffer.output.items(): for tag, val in runner.log_buffer.output.items():
if tag not in ['time', 'data_time'] and is_scalar(val): if tag not in ['time', 'data_time'] and is_scalar(val):
tags[tag] = val tags[tag] = val
tags['learning_rate'] = runner.current_lr()[0]
tags['momentum'] = runner.current_momentum()[0]
if tags: if tags:
self.writer.add_scalars(runner.mode, tags, runner.iter) self.writer.add_scalars(runner.mode, tags, runner.iter)
......
...@@ -52,6 +52,10 @@ class TensorboardLoggerHook(LoggerHook): ...@@ -52,6 +52,10 @@ class TensorboardLoggerHook(LoggerHook):
else: else:
self.writer.add_scalar(tag, runner.log_buffer.output[var], self.writer.add_scalar(tag, runner.log_buffer.output[var],
runner.iter) runner.iter)
self.writer.add_scalar('learning_rate',
runner.current_lr()[0], runner.iter)
self.writer.add_scalar('momentum',
runner.current_momentum()[0], runner.iter)
@master_only @master_only
def after_run(self, runner): def after_run(self, runner):
......
...@@ -49,7 +49,9 @@ class TextLoggerHook(LoggerHook): ...@@ -49,7 +49,9 @@ class TextLoggerHook(LoggerHook):
log_str += 'eta: {}, '.format(eta_str) log_str += 'eta: {}, '.format(eta_str)
log_str += ('time: {:.3f}, data_time: {:.3f}, '.format( log_str += ('time: {:.3f}, data_time: {:.3f}, '.format(
log_dict['time'], log_dict['data_time'])) log_dict['time'], log_dict['data_time']))
log_str += 'memory: {}, '.format(log_dict['memory']) # statistic memory
if torch.cuda.is_available():
log_str += 'memory: {}, '.format(log_dict['memory'])
else: else:
log_str = 'Epoch({}) [{}][{}]\t'.format(log_dict['mode'], log_str = 'Epoch({}) [{}][{}]\t'.format(log_dict['mode'],
log_dict['epoch'] - 1, log_dict['epoch'] - 1,
...@@ -100,6 +102,7 @@ class TextLoggerHook(LoggerHook): ...@@ -100,6 +102,7 @@ class TextLoggerHook(LoggerHook):
if mode == 'train': if mode == 'train':
log_dict['time'] = runner.log_buffer.output['time'] log_dict['time'] = runner.log_buffer.output['time']
log_dict['data_time'] = runner.log_buffer.output['data_time'] log_dict['data_time'] = runner.log_buffer.output['data_time']
# statistic memory # statistic memory
if torch.cuda.is_available(): if torch.cuda.is_available():
log_dict['memory'] = self._get_max_memory(runner) log_dict['memory'] = self._get_max_memory(runner)
......
...@@ -45,6 +45,8 @@ class WandbLoggerHook(LoggerHook): ...@@ -45,6 +45,8 @@ class WandbLoggerHook(LoggerHook):
tag = '{}/{}'.format(var, runner.mode) tag = '{}/{}'.format(var, runner.mode)
if isinstance(val, numbers.Number): if isinstance(val, numbers.Number):
metrics[tag] = val metrics[tag] = val
metrics['learning_rate'] = runner.current_lr()[0]
metrics['momentum'] = runner.current_momentum()[0]
if metrics: if metrics:
self.wandb.log(metrics, step=runner.iter) self.wandb.log(metrics, step=runner.iter)
......
...@@ -199,11 +199,13 @@ class InvLrUpdaterHook(LrUpdaterHook): ...@@ -199,11 +199,13 @@ class InvLrUpdaterHook(LrUpdaterHook):
@HOOKS.register_module @HOOKS.register_module
class CosineLrUpdaterHook(LrUpdaterHook): class CosineAnealingLrUpdaterHook(LrUpdaterHook):
def __init__(self, target_lr=0, **kwargs): def __init__(self, min_lr=None, min_lr_ratio=None, **kwargs):
self.target_lr = target_lr assert (min_lr is None) ^ (min_lr_ratio is None)
super(CosineLrUpdaterHook, self).__init__(**kwargs) self.min_lr = min_lr
self.min_lr_ratio = min_lr_ratio
super(CosineAnealingLrUpdaterHook, self).__init__(**kwargs)
def get_lr(self, runner, base_lr): def get_lr(self, runner, base_lr):
if self.by_epoch: if self.by_epoch:
...@@ -212,5 +214,88 @@ class CosineLrUpdaterHook(LrUpdaterHook): ...@@ -212,5 +214,88 @@ class CosineLrUpdaterHook(LrUpdaterHook):
else: else:
progress = runner.iter progress = runner.iter
max_progress = runner.max_iters max_progress = runner.max_iters
return self.target_lr + 0.5 * (base_lr - self.target_lr) * \ if self.min_lr_ratio is not None:
(1 + cos(pi * (progress / max_progress))) target_lr = base_lr * self.min_lr_ratio
else:
target_lr = self.min_lr
return annealing_cos(base_lr, target_lr, progress / max_progress)
class CyclicLrUpdaterHook(LrUpdaterHook):
"""Cyclic LR Scheduler
Implemet the cyclical learning rate policy (CLR) described in
https://arxiv.org/pdf/1506.01186.pdf
Different from the original paper, we use cosine anealing rather than
triangular policy inside a cycle. This improves the performance in the
3D detection area.
Attributes:
target_ratio (tuple[float]): Relative ratio of the highest LR and the
lowest LR to the initial LR.
cyclic_times (int): Number of cycles during training
step_ratio_up (float): The ratio of the increasing process of LR in
the total cycle.
by_epoch (bool): Whether to update LR by epoch.
"""
def __init__(self,
by_epoch=False,
target_ratio=(10, 1e-4),
cyclic_times=1,
step_ratio_up=0.4,
**kwargs):
if isinstance(target_ratio, float):
target_ratio = (target_ratio, target_ratio / 1e5)
elif isinstance(target_ratio, tuple):
target_ratio = (target_ratio[0], target_ratio[0] / 1e5) \
if len(target_ratio) == 1 else target_ratio
else:
raise ValueError('target_ratio should be either float '
'or tuple, got {}'.format(type(target_ratio)))
assert len(target_ratio) == 2, \
'"target_ratio" must be list or tuple of two floats'
assert 0 <= step_ratio_up < 1.0, \
'"step_ratio_up" must be in range [0,1)'
self.target_ratio = target_ratio
self.cyclic_times = cyclic_times
self.step_ratio_up = step_ratio_up
self.lr_phases = [] # init lr_phases
assert not by_epoch, \
'currently only support "by_epoch" = False'
super(CyclicLrUpdaterHook, self).__init__(by_epoch, **kwargs)
def before_run(self, runner):
super(CyclicLrUpdaterHook, self).before_run(runner)
# initiate lr_phases
# total lr_phases are separated as up and down
max_iter_per_phase = runner.max_iters // self.cyclic_times
iter_up_phase = int(self.step_ratio_up * max_iter_per_phase)
self.lr_phases.append(
[0, iter_up_phase, max_iter_per_phase, 1, self.target_ratio[0]])
self.lr_phases.append([
iter_up_phase, max_iter_per_phase, max_iter_per_phase,
self.target_ratio[0], self.target_ratio[1]
])
def get_lr(self, runner, base_lr):
curr_iter = runner.iter
for (start_iter, end_iter, max_iter_per_phase, start_ratio,
end_ratio) in self.lr_phases:
curr_iter %= max_iter_per_phase
if start_iter <= curr_iter < end_iter:
progress = curr_iter - start_iter
return annealing_cos(base_lr * start_ratio,
base_lr * end_ratio,
progress / (end_iter - start_iter))
def annealing_cos(start, end, factor):
"""Cosine anneal from `start` to `end` as pct goes from 0.0 to 1.0."""
cos_out = cos(pi * factor) + 1
return end + 0.5 * (start - end) * cos_out
from .hook import HOOKS, Hook
from .lr_updater import annealing_cos
class MomentumUpdaterHook(Hook):
def __init__(self,
by_epoch=True,
warmup=None,
warmup_iters=0,
warmup_ratio=0.9,
**kwargs):
# validate the "warmup" argument
if warmup is not None:
if warmup not in ['constant', 'linear', 'exp']:
raise ValueError(
'"{}" is not a supported type for warming up, valid types'
' are "constant" and "linear"'.format(warmup))
if warmup is not None:
assert warmup_iters > 0, \
'"warmup_iters" must be a positive integer'
assert 0 < warmup_ratio <= 1.0, \
'"warmup_momentum" must be in range (0,1]'
self.by_epoch = by_epoch
self.warmup = warmup
self.warmup_iters = warmup_iters
self.warmup_ratio = warmup_ratio
self.base_momentum = [] # initial momentum for all param groups
self.regular_momentum = [
] # expected momentum if no warming up is performed
def _set_momentum(self, runner, momentum_groups):
for param_group, mom in zip(runner.optimizer.param_groups,
momentum_groups):
if 'momentum' in param_group.keys():
param_group['momentum'] = mom
elif 'betas' in param_group.keys():
param_group['betas'] = (mom, param_group['betas'][1])
def get_momentum(self, runner, base_momentum):
raise NotImplementedError
def get_regular_momentum(self, runner):
return [
self.get_momentum(runner, _base_momentum)
for _base_momentum in self.base_momentum
]
def get_warmup_momentum(self, cur_iters):
if self.warmup == 'constant':
warmup_momentum = [
_momentum / self.warmup_ratio
for _momentum in self.regular_momentum
]
elif self.warmup == 'linear':
k = (1 - cur_iters / self.warmup_iters) * (1 - self.warmup_ratio)
warmup_momentum = [
_momentum / (1 - k) for _momentum in self.regular_mom
]
elif self.warmup == 'exp':
k = self.warmup_ratio**(1 - cur_iters / self.warmup_iters)
warmup_momentum = [_momentum / k for _momentum in self.regular_mom]
return warmup_momentum
def before_run(self, runner):
# NOTE: when resuming from a checkpoint,
# if 'initial_momentum' is not saved,
# it will be set according to the optimizer params
for group in runner.optimizer.param_groups:
if 'momentum' in group.keys():
group.setdefault('initial_momentum', group['momentum'])
else:
group.setdefault('initial_momentum', group['betas'][0])
self.base_momentum = [
group['initial_momentum']
for group in runner.optimizer.param_groups
]
def before_train_epoch(self, runner):
if not self.by_epoch:
return
self.regular_mom = self.get_regular_momentum(runner)
self._set_momentum(runner, self.regular_mom)
def before_train_iter(self, runner):
cur_iter = runner.iter
if not self.by_epoch:
self.regular_mom = self.get_regular_momentum(runner)
if self.warmup is None or cur_iter >= self.warmup_iters:
self._set_momentum(runner, self.regular_mom)
else:
warmup_momentum = self.get_warmup_momentum(cur_iter)
self._set_momentum(runner, warmup_momentum)
elif self.by_epoch:
if self.warmup is None or cur_iter > self.warmup_iters:
return
elif cur_iter == self.warmup_iters:
self._set_momentum(runner, self.regular_mom)
else:
warmup_momentum = self.get_warmup_momentum(cur_iter)
self._set_momentum(runner, warmup_momentum)
@HOOKS.register_module
class CosineAnealingMomentumUpdaterHook(MomentumUpdaterHook):
def __init__(self, min_momentum=None, min_momentum_ratio=None, **kwargs):
assert (min_momentum is None) ^ (min_momentum_ratio is None)
self.min_momentum = min_momentum
self.min_momentum_ratio = min_momentum_ratio
super(CosineAnealingMomentumUpdaterHook, self).__init__(**kwargs)
def get_momentum(self, runner, base_momentum):
if self.by_epoch:
progress = runner.epoch
max_progress = runner.max_epochs
else:
progress = runner.iter
max_progress = runner.max_iters
if self.min_momentum_ratio is not None:
target_momentum = base_momentum * self.min_momentum_ratio
else:
target_momentum = self.min_momentum
return annealing_cos(base_momentum, target_momentum,
progress / max_progress)
@HOOKS.register_module
class CyclicMomentumUpdaterHook(MomentumUpdaterHook):
"""Cyclic momentum Scheduler
Implemet the cyclical momentum scheduler policy described in
https://arxiv.org/pdf/1708.07120.pdf
This momentum scheduler usually used together with the CyclicLRUpdater
to improve the performance in the 3D detection area.
Attributes:
target_ratio (tuple[float]): Relative ratio of the lowest momentum and
the highest momentum to the initial momentum.
cyclic_times (int): Number of cycles during training
step_ratio_up (float): The ratio of the increasing process of momentum
in the total cycle.
by_epoch (bool): Whether to update momentum by epoch.
"""
def __init__(self,
by_epoch=False,
target_ratio=(0.85 / 0.95, 1),
cyclic_times=1,
step_ratio_up=0.4,
**kwargs):
if isinstance(target_ratio, float):
target_ratio = (target_ratio, target_ratio / 1e5)
elif isinstance(target_ratio, tuple):
target_ratio = (target_ratio[0], target_ratio[0] / 1e5) \
if len(target_ratio) == 1 else target_ratio
else:
raise ValueError('target_ratio should be either float '
'or tuple, got {}'.format(type(target_ratio)))
assert len(target_ratio) == 2, \
'"target_ratio" must be list or tuple of two floats'
assert 0 <= step_ratio_up < 1.0, \
'"step_ratio_up" must be in range [0,1)'
self.target_ratio = target_ratio
self.cyclic_times = cyclic_times
self.step_ratio_up = step_ratio_up
self.momentum_phases = [] # init momentum_phases
# currently only support by_epoch=False
assert not by_epoch, \
'currently only support "by_epoch" = False'
super(CyclicMomentumUpdaterHook, self).__init__(by_epoch, **kwargs)
def before_run(self, runner):
super(CyclicMomentumUpdaterHook, self).before_run(runner)
# initiate momentum_phases
# total momentum_phases are separated as up and down
max_iter_per_phase = runner.max_iters // self.cyclic_times
iter_up_phase = int(self.step_ratio_up * max_iter_per_phase)
self.momentum_phases.append(
[0, iter_up_phase, max_iter_per_phase, 1, self.target_ratio[0]])
self.momentum_phases.append([
iter_up_phase, max_iter_per_phase, max_iter_per_phase,
self.target_ratio[0], self.target_ratio[1]
])
def get_momentum(self, runner, base_momentum):
curr_iter = runner.iter
for (start_iter, end_iter, max_iter_per_phase, start_ratio,
end_ratio) in self.momentum_phases:
curr_iter %= max_iter_per_phase
if start_iter <= curr_iter < end_iter:
progress = curr_iter - start_iter
return annealing_cos(base_momentum * start_ratio,
base_momentum * end_ratio,
progress / (end_iter - start_iter))
...@@ -198,6 +198,21 @@ class Runner(object): ...@@ -198,6 +198,21 @@ class Runner(object):
'lr is not applicable because optimizer does not exist.') 'lr is not applicable because optimizer does not exist.')
return [group['lr'] for group in self.optimizer.param_groups] return [group['lr'] for group in self.optimizer.param_groups]
def current_momentum(self):
"""Get current momentums.
Returns:
list: Current momentum of all param groups.
"""
if self.optimizer is None:
raise RuntimeError(
'lr is not applicable because optimizer does not exist.')
return [
group['momentum']
if 'momentum' in group.keys() else group['betas'][0]
for group in self.optimizer.param_groups
]
def register_hook(self, hook, priority='NORMAL'): def register_hook(self, hook, priority='NORMAL'):
"""Register a hook into the hook list. """Register a hook into the hook list.
...@@ -254,7 +269,7 @@ class Runner(object): ...@@ -254,7 +269,7 @@ class Runner(object):
self.model.train() self.model.train()
self.mode = 'train' self.mode = 'train'
self.data_loader = data_loader self.data_loader = data_loader
self._max_iters = self._max_epochs * len(data_loader)
self.call_hook('before_train_epoch') self.call_hook('before_train_epoch')
for i, data_batch in enumerate(data_loader): for i, data_batch in enumerate(data_loader):
self._inner_iter = i self._inner_iter = i
...@@ -332,6 +347,12 @@ class Runner(object): ...@@ -332,6 +347,12 @@ class Runner(object):
assert len(data_loaders) == len(workflow) assert len(data_loaders) == len(workflow)
self._max_epochs = max_epochs self._max_epochs = max_epochs
for i, flow in enumerate(workflow):
mode, epochs = flow
if mode == 'train':
self._max_iters = self._max_epochs * len(data_loaders[i])
break
work_dir = self.work_dir if self.work_dir is not None else 'NONE' work_dir = self.work_dir if self.work_dir is not None else 'NONE'
self.logger.info('Start running, host: %s, work_dir: %s', self.logger.info('Start running, host: %s, work_dir: %s',
get_host_info(), work_dir) get_host_info(), work_dir)
...@@ -391,6 +412,19 @@ class Runner(object): ...@@ -391,6 +412,19 @@ class Runner(object):
hook = checkpoint_config hook = checkpoint_config
self.register_hook(hook) self.register_hook(hook)
def register_momentum_hooks(self, momentum_config):
if momentum_config is None:
return
if isinstance(momentum_config, dict):
assert 'policy' in momentum_config
hook_type = momentum_config.pop(
'policy').title() + 'MomentumUpdaterHook'
momentum_config['type'] = hook_type
hook = mmcv.build_from_cfg(momentum_config, HOOKS)
else:
hook = momentum_config
self.register_hook(hook)
def register_logger_hooks(self, log_config): def register_logger_hooks(self, log_config):
log_interval = log_config['interval'] log_interval = log_config['interval']
for info in log_config['hooks']: for info in log_config['hooks']:
...@@ -402,18 +436,21 @@ class Runner(object): ...@@ -402,18 +436,21 @@ class Runner(object):
lr_config, lr_config,
optimizer_config=None, optimizer_config=None,
checkpoint_config=None, checkpoint_config=None,
log_config=None): log_config=None,
momentum_config=None):
"""Register default hooks for training. """Register default hooks for training.
Default hooks include: Default hooks include:
- LrUpdaterHook - LrUpdaterHook
- MomentumUpdaterHook
- OptimizerStepperHook - OptimizerStepperHook
- CheckpointSaverHook - CheckpointSaverHook
- IterTimerHook - IterTimerHook
- LoggerHook(s) - LoggerHook(s)
""" """
self.register_lr_hook(lr_config) self.register_lr_hook(lr_config)
self.register_momentum_hooks(momentum_config)
self.register_optimizer_hook(optimizer_config) self.register_optimizer_hook(optimizer_config)
self.register_checkpoint_hook(checkpoint_config) self.register_checkpoint_hook(checkpoint_config)
self.register_hook(IterTimerHook()) self.register_hook(IterTimerHook())
......
"""
Tests the hooks with runners.
CommandLine:
pytest tests/test_hooks.py
xdoctest tests/test_hooks.py zero
"""
import os.path as osp import os.path as osp
import sys import sys
from unittest.mock import MagicMock from unittest.mock import MagicMock, call
import pytest import pytest
import torch import torch
...@@ -13,49 +21,129 @@ import mmcv.runner ...@@ -13,49 +21,129 @@ import mmcv.runner
def test_pavi_hook(): def test_pavi_hook():
sys.modules['pavi'] = MagicMock() sys.modules['pavi'] = MagicMock()
model = nn.Linear(1, 1) loader = DataLoader(torch.ones((5, 2)))
loader = DataLoader(torch.ones((5, 5))) runner = _build_demo_runner()
work_dir = osp.join(osp.dirname(osp.abspath(__file__)), 'data')
runner = mmcv.runner.Runner(
model=model,
work_dir=work_dir,
batch_processor=lambda model, x, **kwargs: {
'log_vars': {
'loss': 2.333
},
'num_samples': 5
})
hook = mmcv.runner.hooks.PaviLoggerHook( hook = mmcv.runner.hooks.PaviLoggerHook(
add_graph=False, add_last_ckpt=True) add_graph=False, add_last_ckpt=True)
runner.register_hook(hook) runner.register_hook(hook)
runner.run([loader, loader], [('train', 1), ('val', 1)], 1) runner.run([loader, loader], [('train', 1), ('val', 1)], 1)
assert hasattr(hook, 'writer') assert hasattr(hook, 'writer')
hook.writer.add_scalars.assert_called_with('val', {'loss': 2.333}, 5) hook.writer.add_scalars.assert_called_with('val', {
'learning_rate': 0.02,
'momentum': 0.95
}, 5)
hook.writer.add_snapshot_file.assert_called_with( hook.writer.add_snapshot_file.assert_called_with(
tag='data', tag='data',
snapshot_file_path=osp.join(work_dir, 'latest.pth'), snapshot_file_path=osp.join(runner.work_dir, 'latest.pth'),
iteration=5) iteration=5)
def test_momentum_runner_hook():
"""
xdoctest -m tests/test_hooks.py test_momentum_runner_hook
"""
sys.modules['pavi'] = MagicMock()
loader = DataLoader(torch.ones((10, 2)))
runner = _build_demo_runner()
# add momentum scheduler
hook = mmcv.runner.hooks.momentum_updater.CyclicMomentumUpdaterHook(
by_epoch=False,
target_ratio=(0.85 / 0.95, 1),
cyclic_times=1,
step_ratio_up=0.4)
runner.register_hook(hook)
# add momentum LR scheduler
hook = mmcv.runner.hooks.lr_updater.CyclicLrUpdaterHook(
by_epoch=False,
target_ratio=(10, 1),
cyclic_times=1,
step_ratio_up=0.4)
runner.register_hook(hook)
runner.register_hook(mmcv.runner.hooks.IterTimerHook())
# add pavi hook
hook = mmcv.runner.hooks.PaviLoggerHook(
interval=1, add_graph=False, add_last_ckpt=True)
runner.register_hook(hook)
runner.run([loader], [('train', 1)], 1)
# TODO: use a more elegant way to check values
assert hasattr(hook, 'writer')
calls = [
call('train', {
'learning_rate': 0.01999999999999999,
'momentum': 0.95
}, 0),
call('train', {
'learning_rate': 0.2,
'momentum': 0.85
}, 4),
call('train', {
'learning_rate': 0.155,
'momentum': 0.875
}, 6),
]
hook.writer.add_scalars.assert_has_calls(calls, any_order=True)
def test_cosine_runner_hook():
"""
xdoctest -m tests/test_hooks.py test_cosine_runner_hook
"""
sys.modules['pavi'] = MagicMock()
loader = DataLoader(torch.ones((10, 2)))
runner = _build_demo_runner()
# add momentum scheduler
hook = mmcv.runner.hooks.momentum_updater \
.CosineAnealingMomentumUpdaterHook(
min_momentum_ratio=0.99 / 0.95,
by_epoch=False,
warmup_iters=2,
warmup_ratio=0.9 / 0.95)
runner.register_hook(hook)
# add momentum LR scheduler
hook = mmcv.runner.hooks.lr_updater.CosineAnealingLrUpdaterHook(
by_epoch=False, min_lr_ratio=0, warmup_iters=2, warmup_ratio=0.9)
runner.register_hook(hook)
runner.register_hook(mmcv.runner.hooks.IterTimerHook())
# add pavi hook
hook = mmcv.runner.hooks.PaviLoggerHook(
interval=1, add_graph=False, add_last_ckpt=True)
runner.register_hook(hook)
runner.run([loader], [('train', 1)], 1)
# TODO: use a more elegant way to check values
assert hasattr(hook, 'writer')
calls = [
call('train', {
'learning_rate': 0.02,
'momentum': 0.95
}, 0),
call('train', {
'learning_rate': 0.01,
'momentum': 0.97
}, 5),
call('train', {
'learning_rate': 0.0004894348370484647,
'momentum': 0.9890211303259032
}, 9)
]
hook.writer.add_scalars.assert_has_calls(calls, any_order=True)
@pytest.mark.parametrize('log_model', (True, False)) @pytest.mark.parametrize('log_model', (True, False))
def test_mlflow_hook(log_model): def test_mlflow_hook(log_model):
sys.modules['mlflow'] = MagicMock() sys.modules['mlflow'] = MagicMock()
sys.modules['mlflow.pytorch'] = MagicMock() sys.modules['mlflow.pytorch'] = MagicMock()
model = nn.Linear(1, 1) runner = _build_demo_runner()
loader = DataLoader(torch.ones((5, 5))) loader = DataLoader(torch.ones((5, 2)))
work_dir = osp.join(osp.dirname(osp.abspath(__file__)), 'data')
runner = mmcv.runner.Runner(
model=model,
work_dir=work_dir,
batch_processor=lambda model, x, **kwargs: {
'log_vars': {
'accuracy': 0.98
},
'num_samples': 5
})
hook = mmcv.runner.hooks.MlflowLoggerHook( hook = mmcv.runner.hooks.MlflowLoggerHook(
exp_name='test', log_model=log_model) exp_name='test', log_model=log_model)
...@@ -63,7 +151,11 @@ def test_mlflow_hook(log_model): ...@@ -63,7 +151,11 @@ def test_mlflow_hook(log_model):
runner.run([loader, loader], [('train', 1), ('val', 1)], 1) runner.run([loader, loader], [('train', 1), ('val', 1)], 1)
hook.mlflow.set_experiment.assert_called_with('test') hook.mlflow.set_experiment.assert_called_with('test')
hook.mlflow.log_metrics.assert_called_with({'accuracy/val': 0.98}, step=5) hook.mlflow.log_metrics.assert_called_with(
{
'learning_rate': 0.02,
'momentum': 0.95
}, step=5)
if log_model: if log_model:
hook.mlflow_pytorch.log_model.assert_called_with( hook.mlflow_pytorch.log_model.assert_called_with(
runner.model, 'models') runner.model, 'models')
...@@ -73,21 +165,36 @@ def test_mlflow_hook(log_model): ...@@ -73,21 +165,36 @@ def test_mlflow_hook(log_model):
def test_wandb_hook(): def test_wandb_hook():
sys.modules['wandb'] = MagicMock() sys.modules['wandb'] = MagicMock()
runner = _build_demo_runner()
hook = mmcv.runner.hooks.WandbLoggerHook() hook = mmcv.runner.hooks.WandbLoggerHook()
loader = DataLoader(torch.ones((5, 5))) loader = DataLoader(torch.ones((5, 2)))
model = nn.Linear(1, 1)
runner = mmcv.runner.Runner(
model=model,
batch_processor=lambda model, x, **kwargs: {
'log_vars': {
'accuracy': 0.98
},
'num_samples': 5
})
runner.register_hook(hook) runner.register_hook(hook)
runner.run([loader, loader], [('train', 1), ('val', 1)], 1) runner.run([loader, loader], [('train', 1), ('val', 1)], 1)
hook.wandb.init.assert_called_with() hook.wandb.init.assert_called_with()
hook.wandb.log.assert_called_with({'accuracy/val': 0.98}, step=5) hook.wandb.log.assert_called_with({
'learning_rate': 0.02,
'momentum': 0.95
},
step=5)
hook.wandb.join.assert_called_with() hook.wandb.join.assert_called_with()
def _build_demo_runner():
model = nn.Linear(2, 1)
work_dir = osp.join(osp.dirname(osp.abspath(__file__)), 'data')
optimizer = torch.optim.SGD(model.parameters(), lr=0.02, momentum=0.95)
log_config = dict(
interval=1, hooks=[
dict(type='TextLoggerHook'),
])
runner = mmcv.runner.Runner(
model=model,
work_dir=work_dir,
batch_processor=lambda model, x, **kwargs: {'loss': model(x) - 0},
optimizer=optimizer)
runner.register_logger_hooks(log_config)
return runner
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment