Unverified Commit ba059611 authored by Wenwei Zhang's avatar Wenwei Zhang Committed by GitHub
Browse files

Momentum scheduler (#167)

* track progress of iter&enum

* restore

* add momentum scheduler

* fix small bug

* cyclic scheduler"

* fix bug

* fix second phase's bug

* reformat

* feature (cosine lr): use relative ratio for more flexible scheduler

* Fix (runner): fix bugs in runner

* Refactor (hook): refactor cosing/cyclic LR/momentum hook with unittest

* Clean unnecessary files and reformat

* Fix memory key error when GPU is not avaliable

* Resolve comments

* Do not print momentum in text log

* Change hook register order

* Refactor max_iter

* Fix max_iter bugs in runner

* Enforce target_ratio to be either tuple or float
parent 8ac858b1
exclude: ^tests/data/
repos:
- repo: https://gitlab.com/pycqa/flake8
rev: 3.7.9
hooks:
- id: flake8
- repo: https://github.com/asottile/seed-isort-config
rev: v2.1.0
hooks:
- id: seed-isort-config
- repo: https://github.com/timothycrosley/isort
rev: 4.3.21
hooks:
- repo: https://gitlab.com/pycqa/flake8.git
rev: 3.7.9
hooks:
- id: flake8
- repo: https://github.com/asottile/seed-isort-config
rev: v2.1.0
hooks:
- id: seed-isort-config
- repo: https://github.com/timothycrosley/isort
rev: 4.3.21
hooks:
- id: isort
- repo: https://github.com/pre-commit/mirrors-yapf
rev: v0.29.0
hooks:
- id: yapf
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v2.5.0
hooks:
- id: trailing-whitespace
- id: check-yaml
- id: end-of-file-fixer
- id: requirements-txt-fixer
- id: double-quote-string-fixer
- id: fix-encoding-pragma
args: ["--remove"]
- id: mixed-line-ending
args: ["--fix=lf"]
- repo: https://github.com/pre-commit/mirrors-yapf
rev: v0.29.0
hooks:
- id: yapf
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v2.5.0
hooks:
- id: trailing-whitespace
- id: check-yaml
- id: end-of-file-fixer
- id: requirements-txt-fixer
- id: double-quote-string-fixer
- id: fix-encoding-pragma
args: ["--remove"]
- id: mixed-line-ending
args: ["--fix=lf"]
......@@ -7,6 +7,7 @@ from .logger import (LoggerHook, MlflowLoggerHook, PaviLoggerHook,
TensorboardLoggerHook, TextLoggerHook, WandbLoggerHook)
from .lr_updater import LrUpdaterHook
from .memory import EmptyCacheHook
from .momentum_updater import MomentumUpdaterHook
from .optimizer import OptimizerHook
from .sampler_seed import DistSamplerSeedHook
......@@ -14,5 +15,5 @@ __all__ = [
'HOOKS', 'Hook', 'CheckpointHook', 'ClosureHook', 'LrUpdaterHook',
'OptimizerHook', 'IterTimerHook', 'DistSamplerSeedHook', 'EmptyCacheHook',
'LoggerHook', 'MlflowLoggerHook', 'PaviLoggerHook', 'TextLoggerHook',
'TensorboardLoggerHook', 'WandbLoggerHook'
'TensorboardLoggerHook', 'WandbLoggerHook', 'MomentumUpdaterHook'
]
......@@ -70,6 +70,8 @@ class MlflowLoggerHook(LoggerHook):
tag = '{}/{}'.format(var, runner.mode)
if isinstance(val, numbers.Number):
metrics[tag] = val
metrics['learning_rate'] = runner.current_lr()[0]
metrics['momentum'] = runner.current_momentum()[0]
self.mlflow.log_metrics(metrics, step=runner.iter)
@master_only
......
......@@ -71,6 +71,8 @@ class PaviLoggerHook(LoggerHook):
for tag, val in runner.log_buffer.output.items():
if tag not in ['time', 'data_time'] and is_scalar(val):
tags[tag] = val
tags['learning_rate'] = runner.current_lr()[0]
tags['momentum'] = runner.current_momentum()[0]
if tags:
self.writer.add_scalars(runner.mode, tags, runner.iter)
......
......@@ -52,6 +52,10 @@ class TensorboardLoggerHook(LoggerHook):
else:
self.writer.add_scalar(tag, runner.log_buffer.output[var],
runner.iter)
self.writer.add_scalar('learning_rate',
runner.current_lr()[0], runner.iter)
self.writer.add_scalar('momentum',
runner.current_momentum()[0], runner.iter)
@master_only
def after_run(self, runner):
......
......@@ -49,7 +49,9 @@ class TextLoggerHook(LoggerHook):
log_str += 'eta: {}, '.format(eta_str)
log_str += ('time: {:.3f}, data_time: {:.3f}, '.format(
log_dict['time'], log_dict['data_time']))
log_str += 'memory: {}, '.format(log_dict['memory'])
# statistic memory
if torch.cuda.is_available():
log_str += 'memory: {}, '.format(log_dict['memory'])
else:
log_str = 'Epoch({}) [{}][{}]\t'.format(log_dict['mode'],
log_dict['epoch'] - 1,
......@@ -100,6 +102,7 @@ class TextLoggerHook(LoggerHook):
if mode == 'train':
log_dict['time'] = runner.log_buffer.output['time']
log_dict['data_time'] = runner.log_buffer.output['data_time']
# statistic memory
if torch.cuda.is_available():
log_dict['memory'] = self._get_max_memory(runner)
......
......@@ -45,6 +45,8 @@ class WandbLoggerHook(LoggerHook):
tag = '{}/{}'.format(var, runner.mode)
if isinstance(val, numbers.Number):
metrics[tag] = val
metrics['learning_rate'] = runner.current_lr()[0]
metrics['momentum'] = runner.current_momentum()[0]
if metrics:
self.wandb.log(metrics, step=runner.iter)
......
......@@ -199,11 +199,13 @@ class InvLrUpdaterHook(LrUpdaterHook):
@HOOKS.register_module
class CosineLrUpdaterHook(LrUpdaterHook):
class CosineAnealingLrUpdaterHook(LrUpdaterHook):
def __init__(self, target_lr=0, **kwargs):
self.target_lr = target_lr
super(CosineLrUpdaterHook, self).__init__(**kwargs)
def __init__(self, min_lr=None, min_lr_ratio=None, **kwargs):
assert (min_lr is None) ^ (min_lr_ratio is None)
self.min_lr = min_lr
self.min_lr_ratio = min_lr_ratio
super(CosineAnealingLrUpdaterHook, self).__init__(**kwargs)
def get_lr(self, runner, base_lr):
if self.by_epoch:
......@@ -212,5 +214,88 @@ class CosineLrUpdaterHook(LrUpdaterHook):
else:
progress = runner.iter
max_progress = runner.max_iters
return self.target_lr + 0.5 * (base_lr - self.target_lr) * \
(1 + cos(pi * (progress / max_progress)))
if self.min_lr_ratio is not None:
target_lr = base_lr * self.min_lr_ratio
else:
target_lr = self.min_lr
return annealing_cos(base_lr, target_lr, progress / max_progress)
class CyclicLrUpdaterHook(LrUpdaterHook):
"""Cyclic LR Scheduler
Implemet the cyclical learning rate policy (CLR) described in
https://arxiv.org/pdf/1506.01186.pdf
Different from the original paper, we use cosine anealing rather than
triangular policy inside a cycle. This improves the performance in the
3D detection area.
Attributes:
target_ratio (tuple[float]): Relative ratio of the highest LR and the
lowest LR to the initial LR.
cyclic_times (int): Number of cycles during training
step_ratio_up (float): The ratio of the increasing process of LR in
the total cycle.
by_epoch (bool): Whether to update LR by epoch.
"""
def __init__(self,
by_epoch=False,
target_ratio=(10, 1e-4),
cyclic_times=1,
step_ratio_up=0.4,
**kwargs):
if isinstance(target_ratio, float):
target_ratio = (target_ratio, target_ratio / 1e5)
elif isinstance(target_ratio, tuple):
target_ratio = (target_ratio[0], target_ratio[0] / 1e5) \
if len(target_ratio) == 1 else target_ratio
else:
raise ValueError('target_ratio should be either float '
'or tuple, got {}'.format(type(target_ratio)))
assert len(target_ratio) == 2, \
'"target_ratio" must be list or tuple of two floats'
assert 0 <= step_ratio_up < 1.0, \
'"step_ratio_up" must be in range [0,1)'
self.target_ratio = target_ratio
self.cyclic_times = cyclic_times
self.step_ratio_up = step_ratio_up
self.lr_phases = [] # init lr_phases
assert not by_epoch, \
'currently only support "by_epoch" = False'
super(CyclicLrUpdaterHook, self).__init__(by_epoch, **kwargs)
def before_run(self, runner):
super(CyclicLrUpdaterHook, self).before_run(runner)
# initiate lr_phases
# total lr_phases are separated as up and down
max_iter_per_phase = runner.max_iters // self.cyclic_times
iter_up_phase = int(self.step_ratio_up * max_iter_per_phase)
self.lr_phases.append(
[0, iter_up_phase, max_iter_per_phase, 1, self.target_ratio[0]])
self.lr_phases.append([
iter_up_phase, max_iter_per_phase, max_iter_per_phase,
self.target_ratio[0], self.target_ratio[1]
])
def get_lr(self, runner, base_lr):
curr_iter = runner.iter
for (start_iter, end_iter, max_iter_per_phase, start_ratio,
end_ratio) in self.lr_phases:
curr_iter %= max_iter_per_phase
if start_iter <= curr_iter < end_iter:
progress = curr_iter - start_iter
return annealing_cos(base_lr * start_ratio,
base_lr * end_ratio,
progress / (end_iter - start_iter))
def annealing_cos(start, end, factor):
"""Cosine anneal from `start` to `end` as pct goes from 0.0 to 1.0."""
cos_out = cos(pi * factor) + 1
return end + 0.5 * (start - end) * cos_out
from .hook import HOOKS, Hook
from .lr_updater import annealing_cos
class MomentumUpdaterHook(Hook):
def __init__(self,
by_epoch=True,
warmup=None,
warmup_iters=0,
warmup_ratio=0.9,
**kwargs):
# validate the "warmup" argument
if warmup is not None:
if warmup not in ['constant', 'linear', 'exp']:
raise ValueError(
'"{}" is not a supported type for warming up, valid types'
' are "constant" and "linear"'.format(warmup))
if warmup is not None:
assert warmup_iters > 0, \
'"warmup_iters" must be a positive integer'
assert 0 < warmup_ratio <= 1.0, \
'"warmup_momentum" must be in range (0,1]'
self.by_epoch = by_epoch
self.warmup = warmup
self.warmup_iters = warmup_iters
self.warmup_ratio = warmup_ratio
self.base_momentum = [] # initial momentum for all param groups
self.regular_momentum = [
] # expected momentum if no warming up is performed
def _set_momentum(self, runner, momentum_groups):
for param_group, mom in zip(runner.optimizer.param_groups,
momentum_groups):
if 'momentum' in param_group.keys():
param_group['momentum'] = mom
elif 'betas' in param_group.keys():
param_group['betas'] = (mom, param_group['betas'][1])
def get_momentum(self, runner, base_momentum):
raise NotImplementedError
def get_regular_momentum(self, runner):
return [
self.get_momentum(runner, _base_momentum)
for _base_momentum in self.base_momentum
]
def get_warmup_momentum(self, cur_iters):
if self.warmup == 'constant':
warmup_momentum = [
_momentum / self.warmup_ratio
for _momentum in self.regular_momentum
]
elif self.warmup == 'linear':
k = (1 - cur_iters / self.warmup_iters) * (1 - self.warmup_ratio)
warmup_momentum = [
_momentum / (1 - k) for _momentum in self.regular_mom
]
elif self.warmup == 'exp':
k = self.warmup_ratio**(1 - cur_iters / self.warmup_iters)
warmup_momentum = [_momentum / k for _momentum in self.regular_mom]
return warmup_momentum
def before_run(self, runner):
# NOTE: when resuming from a checkpoint,
# if 'initial_momentum' is not saved,
# it will be set according to the optimizer params
for group in runner.optimizer.param_groups:
if 'momentum' in group.keys():
group.setdefault('initial_momentum', group['momentum'])
else:
group.setdefault('initial_momentum', group['betas'][0])
self.base_momentum = [
group['initial_momentum']
for group in runner.optimizer.param_groups
]
def before_train_epoch(self, runner):
if not self.by_epoch:
return
self.regular_mom = self.get_regular_momentum(runner)
self._set_momentum(runner, self.regular_mom)
def before_train_iter(self, runner):
cur_iter = runner.iter
if not self.by_epoch:
self.regular_mom = self.get_regular_momentum(runner)
if self.warmup is None or cur_iter >= self.warmup_iters:
self._set_momentum(runner, self.regular_mom)
else:
warmup_momentum = self.get_warmup_momentum(cur_iter)
self._set_momentum(runner, warmup_momentum)
elif self.by_epoch:
if self.warmup is None or cur_iter > self.warmup_iters:
return
elif cur_iter == self.warmup_iters:
self._set_momentum(runner, self.regular_mom)
else:
warmup_momentum = self.get_warmup_momentum(cur_iter)
self._set_momentum(runner, warmup_momentum)
@HOOKS.register_module
class CosineAnealingMomentumUpdaterHook(MomentumUpdaterHook):
def __init__(self, min_momentum=None, min_momentum_ratio=None, **kwargs):
assert (min_momentum is None) ^ (min_momentum_ratio is None)
self.min_momentum = min_momentum
self.min_momentum_ratio = min_momentum_ratio
super(CosineAnealingMomentumUpdaterHook, self).__init__(**kwargs)
def get_momentum(self, runner, base_momentum):
if self.by_epoch:
progress = runner.epoch
max_progress = runner.max_epochs
else:
progress = runner.iter
max_progress = runner.max_iters
if self.min_momentum_ratio is not None:
target_momentum = base_momentum * self.min_momentum_ratio
else:
target_momentum = self.min_momentum
return annealing_cos(base_momentum, target_momentum,
progress / max_progress)
@HOOKS.register_module
class CyclicMomentumUpdaterHook(MomentumUpdaterHook):
"""Cyclic momentum Scheduler
Implemet the cyclical momentum scheduler policy described in
https://arxiv.org/pdf/1708.07120.pdf
This momentum scheduler usually used together with the CyclicLRUpdater
to improve the performance in the 3D detection area.
Attributes:
target_ratio (tuple[float]): Relative ratio of the lowest momentum and
the highest momentum to the initial momentum.
cyclic_times (int): Number of cycles during training
step_ratio_up (float): The ratio of the increasing process of momentum
in the total cycle.
by_epoch (bool): Whether to update momentum by epoch.
"""
def __init__(self,
by_epoch=False,
target_ratio=(0.85 / 0.95, 1),
cyclic_times=1,
step_ratio_up=0.4,
**kwargs):
if isinstance(target_ratio, float):
target_ratio = (target_ratio, target_ratio / 1e5)
elif isinstance(target_ratio, tuple):
target_ratio = (target_ratio[0], target_ratio[0] / 1e5) \
if len(target_ratio) == 1 else target_ratio
else:
raise ValueError('target_ratio should be either float '
'or tuple, got {}'.format(type(target_ratio)))
assert len(target_ratio) == 2, \
'"target_ratio" must be list or tuple of two floats'
assert 0 <= step_ratio_up < 1.0, \
'"step_ratio_up" must be in range [0,1)'
self.target_ratio = target_ratio
self.cyclic_times = cyclic_times
self.step_ratio_up = step_ratio_up
self.momentum_phases = [] # init momentum_phases
# currently only support by_epoch=False
assert not by_epoch, \
'currently only support "by_epoch" = False'
super(CyclicMomentumUpdaterHook, self).__init__(by_epoch, **kwargs)
def before_run(self, runner):
super(CyclicMomentumUpdaterHook, self).before_run(runner)
# initiate momentum_phases
# total momentum_phases are separated as up and down
max_iter_per_phase = runner.max_iters // self.cyclic_times
iter_up_phase = int(self.step_ratio_up * max_iter_per_phase)
self.momentum_phases.append(
[0, iter_up_phase, max_iter_per_phase, 1, self.target_ratio[0]])
self.momentum_phases.append([
iter_up_phase, max_iter_per_phase, max_iter_per_phase,
self.target_ratio[0], self.target_ratio[1]
])
def get_momentum(self, runner, base_momentum):
curr_iter = runner.iter
for (start_iter, end_iter, max_iter_per_phase, start_ratio,
end_ratio) in self.momentum_phases:
curr_iter %= max_iter_per_phase
if start_iter <= curr_iter < end_iter:
progress = curr_iter - start_iter
return annealing_cos(base_momentum * start_ratio,
base_momentum * end_ratio,
progress / (end_iter - start_iter))
......@@ -198,6 +198,21 @@ class Runner(object):
'lr is not applicable because optimizer does not exist.')
return [group['lr'] for group in self.optimizer.param_groups]
def current_momentum(self):
"""Get current momentums.
Returns:
list: Current momentum of all param groups.
"""
if self.optimizer is None:
raise RuntimeError(
'lr is not applicable because optimizer does not exist.')
return [
group['momentum']
if 'momentum' in group.keys() else group['betas'][0]
for group in self.optimizer.param_groups
]
def register_hook(self, hook, priority='NORMAL'):
"""Register a hook into the hook list.
......@@ -254,7 +269,7 @@ class Runner(object):
self.model.train()
self.mode = 'train'
self.data_loader = data_loader
self._max_iters = self._max_epochs * len(data_loader)
self.call_hook('before_train_epoch')
for i, data_batch in enumerate(data_loader):
self._inner_iter = i
......@@ -332,6 +347,12 @@ class Runner(object):
assert len(data_loaders) == len(workflow)
self._max_epochs = max_epochs
for i, flow in enumerate(workflow):
mode, epochs = flow
if mode == 'train':
self._max_iters = self._max_epochs * len(data_loaders[i])
break
work_dir = self.work_dir if self.work_dir is not None else 'NONE'
self.logger.info('Start running, host: %s, work_dir: %s',
get_host_info(), work_dir)
......@@ -391,6 +412,19 @@ class Runner(object):
hook = checkpoint_config
self.register_hook(hook)
def register_momentum_hooks(self, momentum_config):
if momentum_config is None:
return
if isinstance(momentum_config, dict):
assert 'policy' in momentum_config
hook_type = momentum_config.pop(
'policy').title() + 'MomentumUpdaterHook'
momentum_config['type'] = hook_type
hook = mmcv.build_from_cfg(momentum_config, HOOKS)
else:
hook = momentum_config
self.register_hook(hook)
def register_logger_hooks(self, log_config):
log_interval = log_config['interval']
for info in log_config['hooks']:
......@@ -402,18 +436,21 @@ class Runner(object):
lr_config,
optimizer_config=None,
checkpoint_config=None,
log_config=None):
log_config=None,
momentum_config=None):
"""Register default hooks for training.
Default hooks include:
- LrUpdaterHook
- MomentumUpdaterHook
- OptimizerStepperHook
- CheckpointSaverHook
- IterTimerHook
- LoggerHook(s)
"""
self.register_lr_hook(lr_config)
self.register_momentum_hooks(momentum_config)
self.register_optimizer_hook(optimizer_config)
self.register_checkpoint_hook(checkpoint_config)
self.register_hook(IterTimerHook())
......
"""
Tests the hooks with runners.
CommandLine:
pytest tests/test_hooks.py
xdoctest tests/test_hooks.py zero
"""
import os.path as osp
import sys
from unittest.mock import MagicMock
from unittest.mock import MagicMock, call
import pytest
import torch
......@@ -13,49 +21,129 @@ import mmcv.runner
def test_pavi_hook():
sys.modules['pavi'] = MagicMock()
model = nn.Linear(1, 1)
loader = DataLoader(torch.ones((5, 5)))
work_dir = osp.join(osp.dirname(osp.abspath(__file__)), 'data')
runner = mmcv.runner.Runner(
model=model,
work_dir=work_dir,
batch_processor=lambda model, x, **kwargs: {
'log_vars': {
'loss': 2.333
},
'num_samples': 5
})
loader = DataLoader(torch.ones((5, 2)))
runner = _build_demo_runner()
hook = mmcv.runner.hooks.PaviLoggerHook(
add_graph=False, add_last_ckpt=True)
runner.register_hook(hook)
runner.run([loader, loader], [('train', 1), ('val', 1)], 1)
assert hasattr(hook, 'writer')
hook.writer.add_scalars.assert_called_with('val', {'loss': 2.333}, 5)
hook.writer.add_scalars.assert_called_with('val', {
'learning_rate': 0.02,
'momentum': 0.95
}, 5)
hook.writer.add_snapshot_file.assert_called_with(
tag='data',
snapshot_file_path=osp.join(work_dir, 'latest.pth'),
snapshot_file_path=osp.join(runner.work_dir, 'latest.pth'),
iteration=5)
def test_momentum_runner_hook():
"""
xdoctest -m tests/test_hooks.py test_momentum_runner_hook
"""
sys.modules['pavi'] = MagicMock()
loader = DataLoader(torch.ones((10, 2)))
runner = _build_demo_runner()
# add momentum scheduler
hook = mmcv.runner.hooks.momentum_updater.CyclicMomentumUpdaterHook(
by_epoch=False,
target_ratio=(0.85 / 0.95, 1),
cyclic_times=1,
step_ratio_up=0.4)
runner.register_hook(hook)
# add momentum LR scheduler
hook = mmcv.runner.hooks.lr_updater.CyclicLrUpdaterHook(
by_epoch=False,
target_ratio=(10, 1),
cyclic_times=1,
step_ratio_up=0.4)
runner.register_hook(hook)
runner.register_hook(mmcv.runner.hooks.IterTimerHook())
# add pavi hook
hook = mmcv.runner.hooks.PaviLoggerHook(
interval=1, add_graph=False, add_last_ckpt=True)
runner.register_hook(hook)
runner.run([loader], [('train', 1)], 1)
# TODO: use a more elegant way to check values
assert hasattr(hook, 'writer')
calls = [
call('train', {
'learning_rate': 0.01999999999999999,
'momentum': 0.95
}, 0),
call('train', {
'learning_rate': 0.2,
'momentum': 0.85
}, 4),
call('train', {
'learning_rate': 0.155,
'momentum': 0.875
}, 6),
]
hook.writer.add_scalars.assert_has_calls(calls, any_order=True)
def test_cosine_runner_hook():
"""
xdoctest -m tests/test_hooks.py test_cosine_runner_hook
"""
sys.modules['pavi'] = MagicMock()
loader = DataLoader(torch.ones((10, 2)))
runner = _build_demo_runner()
# add momentum scheduler
hook = mmcv.runner.hooks.momentum_updater \
.CosineAnealingMomentumUpdaterHook(
min_momentum_ratio=0.99 / 0.95,
by_epoch=False,
warmup_iters=2,
warmup_ratio=0.9 / 0.95)
runner.register_hook(hook)
# add momentum LR scheduler
hook = mmcv.runner.hooks.lr_updater.CosineAnealingLrUpdaterHook(
by_epoch=False, min_lr_ratio=0, warmup_iters=2, warmup_ratio=0.9)
runner.register_hook(hook)
runner.register_hook(mmcv.runner.hooks.IterTimerHook())
# add pavi hook
hook = mmcv.runner.hooks.PaviLoggerHook(
interval=1, add_graph=False, add_last_ckpt=True)
runner.register_hook(hook)
runner.run([loader], [('train', 1)], 1)
# TODO: use a more elegant way to check values
assert hasattr(hook, 'writer')
calls = [
call('train', {
'learning_rate': 0.02,
'momentum': 0.95
}, 0),
call('train', {
'learning_rate': 0.01,
'momentum': 0.97
}, 5),
call('train', {
'learning_rate': 0.0004894348370484647,
'momentum': 0.9890211303259032
}, 9)
]
hook.writer.add_scalars.assert_has_calls(calls, any_order=True)
@pytest.mark.parametrize('log_model', (True, False))
def test_mlflow_hook(log_model):
sys.modules['mlflow'] = MagicMock()
sys.modules['mlflow.pytorch'] = MagicMock()
model = nn.Linear(1, 1)
loader = DataLoader(torch.ones((5, 5)))
work_dir = osp.join(osp.dirname(osp.abspath(__file__)), 'data')
runner = mmcv.runner.Runner(
model=model,
work_dir=work_dir,
batch_processor=lambda model, x, **kwargs: {
'log_vars': {
'accuracy': 0.98
},
'num_samples': 5
})
runner = _build_demo_runner()
loader = DataLoader(torch.ones((5, 2)))
hook = mmcv.runner.hooks.MlflowLoggerHook(
exp_name='test', log_model=log_model)
......@@ -63,7 +151,11 @@ def test_mlflow_hook(log_model):
runner.run([loader, loader], [('train', 1), ('val', 1)], 1)
hook.mlflow.set_experiment.assert_called_with('test')
hook.mlflow.log_metrics.assert_called_with({'accuracy/val': 0.98}, step=5)
hook.mlflow.log_metrics.assert_called_with(
{
'learning_rate': 0.02,
'momentum': 0.95
}, step=5)
if log_model:
hook.mlflow_pytorch.log_model.assert_called_with(
runner.model, 'models')
......@@ -73,21 +165,36 @@ def test_mlflow_hook(log_model):
def test_wandb_hook():
sys.modules['wandb'] = MagicMock()
runner = _build_demo_runner()
hook = mmcv.runner.hooks.WandbLoggerHook()
loader = DataLoader(torch.ones((5, 5)))
loader = DataLoader(torch.ones((5, 2)))
model = nn.Linear(1, 1)
runner = mmcv.runner.Runner(
model=model,
batch_processor=lambda model, x, **kwargs: {
'log_vars': {
'accuracy': 0.98
},
'num_samples': 5
})
runner.register_hook(hook)
runner.run([loader, loader], [('train', 1), ('val', 1)], 1)
hook.wandb.init.assert_called_with()
hook.wandb.log.assert_called_with({'accuracy/val': 0.98}, step=5)
hook.wandb.log.assert_called_with({
'learning_rate': 0.02,
'momentum': 0.95
},
step=5)
hook.wandb.join.assert_called_with()
def _build_demo_runner():
model = nn.Linear(2, 1)
work_dir = osp.join(osp.dirname(osp.abspath(__file__)), 'data')
optimizer = torch.optim.SGD(model.parameters(), lr=0.02, momentum=0.95)
log_config = dict(
interval=1, hooks=[
dict(type='TextLoggerHook'),
])
runner = mmcv.runner.Runner(
model=model,
work_dir=work_dir,
batch_processor=lambda model, x, **kwargs: {'loss': model(x) - 0},
optimizer=optimizer)
runner.register_logger_hooks(log_config)
return runner
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment