Unverified Commit f0c43fdc authored by zhouzaida's avatar zhouzaida Committed by GitHub
Browse files

[Feature] Add OneCycleLrUpdaterHook (#906)

* [Feature] Add OneCycleLrUpdaterHook

* fix docstring

* fix docstring

* Remove redundant code
parent 3ae1b257
# Copyright (c) Open-MMLab. All rights reserved. # Copyright (c) Open-MMLab. All rights reserved.
import numbers
from math import cos, pi from math import cos, pi
from .hook import HOOKS, Hook from .hook import HOOKS, Hook
...@@ -398,6 +399,124 @@ class CyclicLrUpdaterHook(LrUpdaterHook): ...@@ -398,6 +399,124 @@ class CyclicLrUpdaterHook(LrUpdaterHook):
progress / (end_iter - start_iter)) progress / (end_iter - start_iter))
@HOOKS.register_module()
class OneCycleLrUpdaterHook(LrUpdaterHook):
"""One Cycle LR Scheduler.
The 1cycle learning rate policy changes the learning rate after every
batch. The one cycle learning rate policy is described in
https://arxiv.org/pdf/1708.07120.pdf
Args:
max_lr (float or list): Upper learning rate boundaries in the cycle
for each parameter group.
pct_start (float): The percentage of the cycle (in number of steps)
spent increasing the learning rate.
Default: 0.3
anneal_strategy (str): {'cos', 'linear'}
Specifies the annealing strategy: 'cos' for cosine annealing,
'linear' for linear annealing.
Default: 'cos'
div_factor (float): Determines the initial learning rate via
initial_lr = max_lr/div_factor
Default: 25
final_div_factor (float): Determines the minimum learning rate via
min_lr = initial_lr/final_div_factor
Default: 1e4
three_phase (bool): If three_phase is True, use a third phase of the
schedule to annihilate the learning rate according to
final_div_factor instead of modifying the second phase (the first
two phases will be symmetrical about the step indicated by
pct_start).
Default: False
"""
def __init__(self,
max_lr,
pct_start=0.3,
anneal_strategy='cos',
div_factor=25,
final_div_factor=1e4,
three_phase=False,
**kwargs):
# validate by_epoch, currently only support by_epoch = False
if 'by_epoch' not in kwargs:
kwargs['by_epoch'] = False
else:
assert not kwargs['by_epoch'], \
'currently only support "by_epoch" = False'
if not isinstance(max_lr, (numbers.Number, list, dict)):
raise ValueError('the type of max_lr must be the one of list or '
f'dict, but got {type(max_lr)}')
self._max_lr = max_lr
# validate pct_start
if pct_start < 0 or pct_start > 1 or not isinstance(pct_start, float):
raise ValueError('expected float between 0 and 1 pct_start, but '
f'got {pct_start}')
self.pct_start = pct_start
# validate anneal_strategy
if anneal_strategy not in ['cos', 'linear']:
raise ValueError('anneal_strategy must be one of "cos" or '
f'"linear", instead got {anneal_strategy}')
elif anneal_strategy == 'cos':
self.anneal_func = annealing_cos
elif anneal_strategy == 'linear':
self.anneal_func = annealing_linear
self.div_factor = div_factor
self.final_div_factor = final_div_factor
self.three_phase = three_phase
self.lr_phases = [] # init lr_phases
super(OneCycleLrUpdaterHook, self).__init__(**kwargs)
def before_run(self, runner):
if isinstance(runner.optimizer, dict):
self.base_lr = {}
for k, optim in runner.optimizer.items():
_max_lr = format_param(k, optim, self._max_lr)
self.base_lr[k] = [lr / self.div_factor for lr in _max_lr]
for group, lr in zip(optim.param_groups, self.base_lr[k]):
group.setdefault('initial_lr', lr)
else:
k = type(runner.optimizer).__name__
_max_lr = format_param(k, runner.optimizer, self._max_lr)
self.base_lr = [lr / self.div_factor for lr in _max_lr]
for group, lr in zip(runner.optimizer.param_groups, self.base_lr):
group.setdefault('initial_lr', lr)
if self.three_phase:
self.lr_phases.append([
float(self.pct_start * runner.max_iters) - 1, 1,
self.div_factor
])
self.lr_phases.append([
float(2 * self.pct_start * runner.max_iters) - 2,
self.div_factor, 1
])
self.lr_phases.append(
[runner.max_iters - 1, 1, 1 / self.final_div_factor])
else:
self.lr_phases.append([
float(self.pct_start * runner.max_iters) - 1, 1,
self.div_factor
])
self.lr_phases.append([
runner.max_iters - 1, self.div_factor,
1 / self.final_div_factor
])
def get_lr(self, runner, base_lr):
curr_iter = runner.iter
start_iter = 0
for i, (end_iter, start_lr, end_lr) in enumerate(self.lr_phases):
if curr_iter <= end_iter:
pct = (curr_iter - start_iter) / (end_iter - start_iter)
lr = self.anneal_func(base_lr * start_lr, base_lr * end_lr,
pct)
break
start_iter = end_iter
return lr
def annealing_cos(start, end, factor, weight=1): def annealing_cos(start, end, factor, weight=1):
"""Calculate annealing cos learning rate. """Calculate annealing cos learning rate.
...@@ -414,3 +533,31 @@ def annealing_cos(start, end, factor, weight=1): ...@@ -414,3 +533,31 @@ def annealing_cos(start, end, factor, weight=1):
""" """
cos_out = cos(pi * factor) + 1 cos_out = cos(pi * factor) + 1
return end + 0.5 * weight * (start - end) * cos_out return end + 0.5 * weight * (start - end) * cos_out
def annealing_linear(start, end, factor):
"""Calculate annealing linear learning rate.
Linear anneal from `start` to `end` as percentage goes from 0.0 to 1.0.
Args:
start (float): The starting learning rate of the linear annealing.
end (float): The ending learing rate of the linear annealing.
factor (float): The coefficient of `pi` when calculating the current
percentage. Range from 0.0 to 1.0.
"""
return start + (end - start) * factor
def format_param(name, optim, param):
if isinstance(param, numbers.Number):
return [param] * len(optim.param_groups)
elif isinstance(param, (list, tuple)): # multi param groups
if len(param) != len(optim.param_groups):
raise ValueError(f'expected {len(optim.param_groups)} '
f'values for {name}, got {len(param)}')
return param
else: # multi optimizers
if name not in param:
raise KeyError(f'{name} is not found in {param.keys()}')
return param[name]
from .hook import HOOKS, Hook from .hook import HOOKS, Hook
from .lr_updater import annealing_cos from .lr_updater import annealing_cos, annealing_linear, format_param
class MomentumUpdaterHook(Hook): class MomentumUpdaterHook(Hook):
...@@ -130,7 +130,7 @@ class CosineAnnealingMomentumUpdaterHook(MomentumUpdaterHook): ...@@ -130,7 +130,7 @@ class CosineAnnealingMomentumUpdaterHook(MomentumUpdaterHook):
class CyclicMomentumUpdaterHook(MomentumUpdaterHook): class CyclicMomentumUpdaterHook(MomentumUpdaterHook):
"""Cyclic momentum Scheduler. """Cyclic momentum Scheduler.
Implemet the cyclical momentum scheduler policy described in Implement the cyclical momentum scheduler policy described in
https://arxiv.org/pdf/1708.07120.pdf https://arxiv.org/pdf/1708.07120.pdf
This momentum scheduler usually used together with the CyclicLRUpdater This momentum scheduler usually used together with the CyclicLRUpdater
...@@ -197,3 +197,198 @@ class CyclicMomentumUpdaterHook(MomentumUpdaterHook): ...@@ -197,3 +197,198 @@ class CyclicMomentumUpdaterHook(MomentumUpdaterHook):
return annealing_cos(base_momentum * start_ratio, return annealing_cos(base_momentum * start_ratio,
base_momentum * end_ratio, base_momentum * end_ratio,
progress / (end_iter - start_iter)) progress / (end_iter - start_iter))
@HOOKS.register_module()
class OneCycleMomentumUpdaterHook(MomentumUpdaterHook):
"""OneCycle momentum Scheduler.
This momentum scheduler usually used together with the OneCycleLrUpdater
to improve the performance.
Args:
base_momentum (float or list): Lower momentum boundaries in the cycle
for each parameter group. Note that momentum is cycled inversely
to learning rate; at the peak of a cycle, momentum is
'base_momentum' and learning rate is 'max_lr'.
Default: 0.85
max_momentum (float or list): Upper momentum boundaries in the cycle
for each parameter group. Functionally,
it defines the cycle amplitude (max_momentum - base_momentum).
Note that momentum is cycled inversely
to learning rate; at the start of a cycle, momentum is
'max_momentum' and learning rate is 'base_lr'
Default: 0.95
pct_start (float): The percentage of the cycle (in number of steps)
spent increasing the learning rate.
Default: 0.3
anneal_strategy (str): {'cos', 'linear'}
Specifies the annealing strategy: 'cos' for cosine annealing,
'linear' for linear annealing.
Default: 'cos'
three_phase (bool): If three_phase is True, use a third phase of the
schedule to annihilate the learning rate according to
final_div_factor instead of modifying the second phase (the first
two phases will be symmetrical about the step indicated by
pct_start).
Default: False
"""
def __init__(self,
base_momentum=0.85,
max_momentum=0.95,
pct_start=0.3,
anneal_strategy='cos',
three_phase=False,
**kwargs):
# validate by_epoch, currently only support by_epoch=False
if 'by_epoch' not in kwargs:
kwargs['by_epoch'] = False
else:
assert not kwargs['by_epoch'], \
'currently only support "by_epoch" = False'
if not isinstance(base_momentum, (float, list, dict)):
raise ValueError('base_momentum must be the type among of float,'
'list or dict.')
self._base_momentum = base_momentum
if not isinstance(max_momentum, (float, list, dict)):
raise ValueError('max_momentum must be the type among of float,'
'list or dict.')
self._max_momentum = max_momentum
# validate pct_start
if pct_start < 0 or pct_start > 1 or not isinstance(pct_start, float):
raise ValueError('Expected float between 0 and 1 pct_start, but '
f'got {pct_start}')
self.pct_start = pct_start
# validate anneal_strategy
if anneal_strategy not in ['cos', 'linear']:
raise ValueError('anneal_strategy must by one of "cos" or '
f'"linear", instead got {anneal_strategy}')
elif anneal_strategy == 'cos':
self.anneal_func = annealing_cos
elif anneal_strategy == 'linear':
self.anneal_func = annealing_linear
self.three_phase = three_phase
self.momentum_phases = [] # init momentum_phases
super(OneCycleMomentumUpdaterHook, self).__init__(**kwargs)
def before_run(self, runner):
if isinstance(runner.optimizer, dict):
for k, optim in runner.optimizer.items():
if ('momentum' not in optim.defaults
and 'betas' not in optim.defaults):
raise ValueError('optimizer must support momentum with'
'option enabled')
self.use_beta1 = 'betas' in optim.defaults
_base_momentum = format_param(k, optim, self._base_momentum)
_max_momentum = format_param(k, optim, self._max_momentum)
for group, b_momentum, m_momentum in zip(
optim.param_groups, _base_momentum, _max_momentum):
if self.use_beta1:
_, beta2 = group['betas']
group['betas'] = (m_momentum, beta2)
else:
group['momentum'] = m_momentum
group['base_momentum'] = b_momentum
group['max_momentum'] = m_momentum
else:
optim = runner.optimizer
if ('momentum' not in optim.defaults
and 'betas' not in optim.defaults):
raise ValueError('optimizer must support momentum with'
'option enabled')
self.use_beta1 = 'betas' in optim.defaults
k = type(optim).__name__
_base_momentum = format_param(k, optim, self._base_momentum)
_max_momentum = format_param(k, optim, self._max_momentum)
for group, b_momentum, m_momentum in zip(optim.param_groups,
_base_momentum,
_max_momentum):
if self.use_beta1:
_, beta2 = group['betas']
group['betas'] = (m_momentum, beta2)
else:
group['momentum'] = m_momentum
group['base_momentum'] = b_momentum
group['max_momentum'] = m_momentum
if self.three_phase:
self.momentum_phases.append({
'end_iter':
float(self.pct_start * runner.max_iters) - 1,
'start_momentum':
'max_momentum',
'end_momentum':
'base_momentum'
})
self.momentum_phases.append({
'end_iter':
float(2 * self.pct_start * runner.max_iters) - 2,
'start_momentum':
'base_momentum',
'end_momentum':
'max_momentum'
})
self.momentum_phases.append({
'end_iter': runner.max_iters - 1,
'start_momentum': 'max_momentum',
'end_momentum': 'max_momentum'
})
else:
self.momentum_phases.append({
'end_iter':
float(self.pct_start * runner.max_iters) - 1,
'start_momentum':
'max_momentum',
'end_momentum':
'base_momentum'
})
self.momentum_phases.append({
'end_iter': runner.max_iters - 1,
'start_momentum': 'base_momentum',
'end_momentum': 'max_momentum'
})
def _set_momentum(self, runner, momentum_groups):
if isinstance(runner.optimizer, dict):
for k, optim in runner.optimizer.items():
for param_group, mom in zip(optim.param_groups,
momentum_groups[k]):
if 'momentum' in param_group.keys():
param_group['momentum'] = mom
elif 'betas' in param_group.keys():
param_group['betas'] = (mom, param_group['betas'][1])
else:
for param_group, mom in zip(runner.optimizer.param_groups,
momentum_groups):
if 'momentum' in param_group.keys():
param_group['momentum'] = mom
elif 'betas' in param_group.keys():
param_group['betas'] = (mom, param_group['betas'][1])
def get_momentum(self, runner, param_group):
curr_iter = runner.iter
start_iter = 0
for i, phase in enumerate(self.momentum_phases):
end_iter = phase['end_iter']
if curr_iter <= end_iter or i == len(self.momentum_phases) - 1:
pct = (curr_iter - start_iter) / (end_iter - start_iter)
lr = self.anneal_func(param_group[phase['start_momentum']],
param_group[phase['end_momentum']], pct)
break
start_iter = end_iter
return lr
def get_regular_momentum(self, runner):
if isinstance(runner.optimizer, dict):
momentum_groups = {}
for k, optim in runner.optimizer.items():
for param_group in optim.param_groups:
momentum_groups[k].append(
self.get_momentum(runner, param_group))
return momentum_groups
else:
momentum_groups = []
for param_group in runner.optimizer.param_groups:
momentum_groups.append(self.get_momentum(runner, param_group))
return momentum_groups
"""Tests the hooks with runners. """Tests the hooks with runners.
CommandLine: CommandLine:
pytest tests/test_hooks.py pytest tests/test_runner/test_hooks.py
xdoctest tests/test_hooks.py zero xdoctest tests/test_hooks.py zero
""" """
import logging import logging
...@@ -21,7 +21,8 @@ from torch.utils.data import DataLoader ...@@ -21,7 +21,8 @@ from torch.utils.data import DataLoader
from mmcv.runner import (CheckpointHook, EMAHook, IterTimerHook, from mmcv.runner import (CheckpointHook, EMAHook, IterTimerHook,
MlflowLoggerHook, PaviLoggerHook, WandbLoggerHook, MlflowLoggerHook, PaviLoggerHook, WandbLoggerHook,
build_runner) build_runner)
from mmcv.runner.hooks.lr_updater import CosineRestartLrUpdaterHook from mmcv.runner.hooks.lr_updater import (CosineRestartLrUpdaterHook,
OneCycleLrUpdaterHook)
def test_checkpoint_hook(): def test_checkpoint_hook():
...@@ -251,6 +252,71 @@ def test_cosine_runner_hook(): ...@@ -251,6 +252,71 @@ def test_cosine_runner_hook():
hook.writer.add_scalars.assert_has_calls(calls, any_order=True) hook.writer.add_scalars.assert_has_calls(calls, any_order=True)
def test_one_cycle_runner_hook():
"""Test OneCycleLrUpdaterHook and OneCycleMomentumUpdaterHook."""
with pytest.raises(AssertionError):
# by_epoch should be False
OneCycleLrUpdaterHook(max_lr=0.1, by_epoch=True)
with pytest.raises(ValueError):
# expected float between 0 and 1
OneCycleLrUpdaterHook(max_lr=0.1, pct_start=-0.1)
with pytest.raises(ValueError):
# anneal_strategy should be either 'cos' or 'linear'
OneCycleLrUpdaterHook(max_lr=0.1, anneal_strategy='sin')
sys.modules['pavi'] = MagicMock()
loader = DataLoader(torch.ones((10, 2)))
runner = _build_demo_runner()
# add momentum scheduler
hook_cfg = dict(
type='OneCycleMomentumUpdaterHook',
base_momentum=0.85,
max_momentum=0.95,
pct_start=0.5,
anneal_strategy='cos',
three_phase=False)
runner.register_hook_from_cfg(hook_cfg)
# add momentum LR scheduler
hook_cfg = dict(
type='OneCycleLrUpdaterHook',
max_lr=0.01,
pct_start=0.5,
anneal_strategy='cos',
div_factor=25,
final_div_factor=1e4,
three_phase=False)
runner.register_hook_from_cfg(hook_cfg)
runner.register_hook_from_cfg(dict(type='IterTimerHook'))
runner.register_hook(IterTimerHook())
# add pavi hook
hook = PaviLoggerHook(interval=1, add_graph=False, add_last_ckpt=True)
runner.register_hook(hook)
runner.run([loader], [('train', 1)])
shutil.rmtree(runner.work_dir)
# TODO: use a more elegant way to check values
assert hasattr(hook, 'writer')
calls = [
call('train', {
'learning_rate': 0.0003999999999999993,
'momentum': 0.95
}, 1),
call('train', {
'learning_rate': 0.00904508879153485,
'momentum': 0.8595491502812526
}, 6),
call('train', {
'learning_rate': 4e-08,
'momentum': 0.95
}, 10)
]
hook.writer.add_scalars.assert_has_calls(calls, any_order=True)
def test_cosine_restart_lr_update_hook(): def test_cosine_restart_lr_update_hook():
"""Test CosineRestartLrUpdaterHook.""" """Test CosineRestartLrUpdaterHook."""
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment