Unverified Commit 580e374e authored by Mashiro's avatar Mashiro Committed by GitHub
Browse files

Add decay to CyclicLrUpdater (#1655)



* Add decay to cyclic LR

* Simplify tests

* Fix error in tests

* fix cyclic lr

* add weight decay CyclicLrUpdate test

* Update mmcv/runner/hooks/lr_updater.py

update docstring
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* update CyclicLrUpdater unit test

* add comments to CyclicLrUpdaterHook

* fix CyclicLrUpdater. Support peak_lr decay when target[0] < 1

* add momentum decay to CyclicMomentumUpdater

* update momentum unit test

* Fix CyclicMomentum comment

Fix CyclicMomentum comment
Co-authored-by: default avatarDmytro Panchenko <panchenkodmytry@gmail.com>
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>
parent 0448fcf9
...@@ -420,15 +420,19 @@ class CyclicLrUpdaterHook(LrUpdaterHook): ...@@ -420,15 +420,19 @@ class CyclicLrUpdaterHook(LrUpdaterHook):
3D detection area. 3D detection area.
Args: Args:
by_epoch (bool): Whether to update LR by epoch. by_epoch (bool, optional): Whether to update LR by epoch.
target_ratio (tuple[float]): Relative ratio of the highest LR and the target_ratio (tuple[float], optional): Relative ratio of the highest LR
lowest LR to the initial LR. and the lowest LR to the initial LR.
cyclic_times (int): Number of cycles during training cyclic_times (int, optional): Number of cycles during training
step_ratio_up (float): The ratio of the increasing process of LR in step_ratio_up (float, optional): The ratio of the increasing process of
the total cycle. LR in the total cycle.
anneal_strategy (str): {'cos', 'linear'} anneal_strategy (str, optional): {'cos', 'linear'}
Specifies the annealing strategy: 'cos' for cosine annealing, Specifies the annealing strategy: 'cos' for cosine annealing,
'linear' for linear annealing. Default: 'cos'. 'linear' for linear annealing. Default: 'cos'.
gamma (float, optional): Cycle decay ratio. Default: 1.
It takes values in the range (0, 1]. The difference between the
maximum learning rate and the minimum learning rate decreases
periodically when it is less than 1. `New in version 1.4.4.`
""" """
def __init__(self, def __init__(self,
...@@ -437,6 +441,7 @@ class CyclicLrUpdaterHook(LrUpdaterHook): ...@@ -437,6 +441,7 @@ class CyclicLrUpdaterHook(LrUpdaterHook):
cyclic_times=1, cyclic_times=1,
step_ratio_up=0.4, step_ratio_up=0.4,
anneal_strategy='cos', anneal_strategy='cos',
gamma=1,
**kwargs): **kwargs):
if isinstance(target_ratio, float): if isinstance(target_ratio, float):
target_ratio = (target_ratio, target_ratio / 1e5) target_ratio = (target_ratio, target_ratio / 1e5)
...@@ -451,10 +456,14 @@ class CyclicLrUpdaterHook(LrUpdaterHook): ...@@ -451,10 +456,14 @@ class CyclicLrUpdaterHook(LrUpdaterHook):
'"target_ratio" must be list or tuple of two floats' '"target_ratio" must be list or tuple of two floats'
assert 0 <= step_ratio_up < 1.0, \ assert 0 <= step_ratio_up < 1.0, \
'"step_ratio_up" must be in range [0,1)' '"step_ratio_up" must be in range [0,1)'
assert 0 < gamma <= 1, \
'"gamma" must be in range (0, 1]'
self.target_ratio = target_ratio self.target_ratio = target_ratio
self.cyclic_times = cyclic_times self.cyclic_times = cyclic_times
self.step_ratio_up = step_ratio_up self.step_ratio_up = step_ratio_up
self.gamma = gamma
self.max_iter_per_phase = None
self.lr_phases = [] # init lr_phases self.lr_phases = [] # init lr_phases
# validate anneal_strategy # validate anneal_strategy
if anneal_strategy not in ['cos', 'linear']: if anneal_strategy not in ['cos', 'linear']:
...@@ -473,21 +482,32 @@ class CyclicLrUpdaterHook(LrUpdaterHook): ...@@ -473,21 +482,32 @@ class CyclicLrUpdaterHook(LrUpdaterHook):
super(CyclicLrUpdaterHook, self).before_run(runner) super(CyclicLrUpdaterHook, self).before_run(runner)
# initiate lr_phases # initiate lr_phases
# total lr_phases are separated as up and down # total lr_phases are separated as up and down
max_iter_per_phase = runner.max_iters // self.cyclic_times self.max_iter_per_phase = runner.max_iters // self.cyclic_times
iter_up_phase = int(self.step_ratio_up * max_iter_per_phase) iter_up_phase = int(self.step_ratio_up * self.max_iter_per_phase)
self.lr_phases.append( self.lr_phases.append([0, iter_up_phase, 1, self.target_ratio[0]])
[0, iter_up_phase, max_iter_per_phase, 1, self.target_ratio[0]])
self.lr_phases.append([ self.lr_phases.append([
iter_up_phase, max_iter_per_phase, max_iter_per_phase, iter_up_phase, self.max_iter_per_phase, self.target_ratio[0],
self.target_ratio[0], self.target_ratio[1] self.target_ratio[1]
]) ])
def get_lr(self, runner, base_lr): def get_lr(self, runner, base_lr):
curr_iter = runner.iter curr_iter = runner.iter % self.max_iter_per_phase
for (start_iter, end_iter, max_iter_per_phase, start_ratio, curr_cycle = runner.iter // self.max_iter_per_phase
end_ratio) in self.lr_phases: # Update weight decay
curr_iter %= max_iter_per_phase scale = self.gamma**curr_cycle
for (start_iter, end_iter, start_ratio, end_ratio) in self.lr_phases:
if start_iter <= curr_iter < end_iter: if start_iter <= curr_iter < end_iter:
# Apply cycle scaling to gradually reduce the difference
# between max_lr and base lr. The target end_ratio can be
# expressed as:
# end_ratio = (base_lr + scale * (max_lr - base_lr)) / base_lr
# iteration: 0-iter_up_phase:
if start_iter == 0:
end_ratio = 1 - scale + end_ratio * scale
# iteration: iter_up_phase-self.max_iter_per_phase
else:
start_ratio = 1 - scale + start_ratio * scale
progress = curr_iter - start_iter progress = curr_iter - start_iter
return self.anneal_func(base_lr * start_ratio, return self.anneal_func(base_lr * start_ratio,
base_lr * end_ratio, base_lr * end_ratio,
......
...@@ -239,6 +239,13 @@ class CyclicMomentumUpdaterHook(MomentumUpdaterHook): ...@@ -239,6 +239,13 @@ class CyclicMomentumUpdaterHook(MomentumUpdaterHook):
step_ratio_up (float): The ratio of the increasing process of momentum step_ratio_up (float): The ratio of the increasing process of momentum
in the total cycle. in the total cycle.
by_epoch (bool): Whether to update momentum by epoch. by_epoch (bool): Whether to update momentum by epoch.
anneal_strategy (str, optional): {'cos', 'linear'}
Specifies the annealing strategy: 'cos' for cosine annealing,
'linear' for linear annealing. Default: 'cos'.
gamma (float, optional): Cycle decay ratio. Default: 1.
It takes values in the range (0, 1]. The difference between the
maximum learning rate and the minimum learning rate decreases
periodically when it is less than 1. `New in version 1.4.4.`
""" """
def __init__(self, def __init__(self,
...@@ -246,6 +253,8 @@ class CyclicMomentumUpdaterHook(MomentumUpdaterHook): ...@@ -246,6 +253,8 @@ class CyclicMomentumUpdaterHook(MomentumUpdaterHook):
target_ratio=(0.85 / 0.95, 1), target_ratio=(0.85 / 0.95, 1),
cyclic_times=1, cyclic_times=1,
step_ratio_up=0.4, step_ratio_up=0.4,
anneal_strategy='cos',
gamma=1,
**kwargs): **kwargs):
if isinstance(target_ratio, float): if isinstance(target_ratio, float):
target_ratio = (target_ratio, target_ratio / 1e5) target_ratio = (target_ratio, target_ratio / 1e5)
...@@ -264,7 +273,16 @@ class CyclicMomentumUpdaterHook(MomentumUpdaterHook): ...@@ -264,7 +273,16 @@ class CyclicMomentumUpdaterHook(MomentumUpdaterHook):
self.target_ratio = target_ratio self.target_ratio = target_ratio
self.cyclic_times = cyclic_times self.cyclic_times = cyclic_times
self.step_ratio_up = step_ratio_up self.step_ratio_up = step_ratio_up
self.gamma = gamma
self.momentum_phases = [] # init momentum_phases self.momentum_phases = [] # init momentum_phases
if anneal_strategy not in ['cos', 'linear']:
raise ValueError('anneal_strategy must be one of "cos" or '
f'"linear", instead got {anneal_strategy}')
elif anneal_strategy == 'cos':
self.anneal_func = annealing_cos
elif anneal_strategy == 'linear':
self.anneal_func = annealing_linear
# currently only support by_epoch=False # currently only support by_epoch=False
assert not by_epoch, \ assert not by_epoch, \
'currently only support "by_epoch" = False' 'currently only support "by_epoch" = False'
...@@ -276,21 +294,34 @@ class CyclicMomentumUpdaterHook(MomentumUpdaterHook): ...@@ -276,21 +294,34 @@ class CyclicMomentumUpdaterHook(MomentumUpdaterHook):
# total momentum_phases are separated as up and down # total momentum_phases are separated as up and down
max_iter_per_phase = runner.max_iters // self.cyclic_times max_iter_per_phase = runner.max_iters // self.cyclic_times
iter_up_phase = int(self.step_ratio_up * max_iter_per_phase) iter_up_phase = int(self.step_ratio_up * max_iter_per_phase)
self.max_iter_per_phase = max_iter_per_phase
self.momentum_phases.append( self.momentum_phases.append(
[0, iter_up_phase, max_iter_per_phase, 1, self.target_ratio[0]]) [0, iter_up_phase, 1, self.target_ratio[0]])
self.momentum_phases.append([ self.momentum_phases.append([
iter_up_phase, max_iter_per_phase, max_iter_per_phase, iter_up_phase, max_iter_per_phase, self.target_ratio[0],
self.target_ratio[0], self.target_ratio[1] self.target_ratio[1]
]) ])
def get_momentum(self, runner, base_momentum): def get_momentum(self, runner, base_momentum):
curr_iter = runner.iter curr_iter = runner.iter % self.max_iter_per_phase
for (start_iter, end_iter, max_iter_per_phase, start_ratio, curr_cycle = runner.iter // self.max_iter_per_phase
end_ratio) in self.momentum_phases: scale = self.gamma**curr_cycle
curr_iter %= max_iter_per_phase for (start_iter, end_iter, start_ratio, end_ratio) \
in self.momentum_phases:
if start_iter <= curr_iter < end_iter: if start_iter <= curr_iter < end_iter:
# Apply cycle scaling to gradually reduce the difference
# between max_momentum and base momentum. The target end_ratio
# can be expressed as:
# end_ratio = (base_momentum + scale * \
# (max_momentum - base_momentum)) / base_momentum
# iteration: 0-iter_up_phase:
if start_iter == 0:
end_ratio = 1 - scale + end_ratio * scale
# iteration: iter_up_phase-self.max_iter_per_phase
else:
start_ratio = 1 - scale + start_ratio * scale
progress = curr_iter - start_iter progress = curr_iter - start_iter
return annealing_cos(base_momentum * start_ratio, return self.anneal_func(base_momentum * start_ratio,
base_momentum * end_ratio, base_momentum * end_ratio,
progress / (end_iter - start_iter)) progress / (end_iter - start_iter))
......
...@@ -354,20 +354,23 @@ def test_sync_buffers_hook(): ...@@ -354,20 +354,23 @@ def test_sync_buffers_hook():
shutil.rmtree(runner.work_dir) shutil.rmtree(runner.work_dir)
@pytest.mark.parametrize('multi_optimziers', (True, False)) @pytest.mark.parametrize('multi_optimizers, max_iters, gamma, cyclic_times',
def test_momentum_runner_hook(multi_optimziers): [(True, 8, 1, 1), (False, 8, 0.5, 2)])
def test_momentum_runner_hook(multi_optimizers, max_iters, gamma,
cyclic_times):
"""xdoctest -m tests/test_hooks.py test_momentum_runner_hook.""" """xdoctest -m tests/test_hooks.py test_momentum_runner_hook."""
sys.modules['pavi'] = MagicMock() sys.modules['pavi'] = MagicMock()
loader = DataLoader(torch.ones((10, 2))) loader = DataLoader(torch.ones((10, 2)))
runner = _build_demo_runner(multi_optimziers=multi_optimziers) runner = _build_demo_runner(multi_optimziers=multi_optimizers)
# add momentum scheduler # add momentum scheduler
hook_cfg = dict( hook_cfg = dict(
type='CyclicMomentumUpdaterHook', type='CyclicMomentumUpdaterHook',
by_epoch=False, by_epoch=False,
target_ratio=(0.85 / 0.95, 1), target_ratio=(0.85 / 0.95, 1),
cyclic_times=1, cyclic_times=cyclic_times,
step_ratio_up=0.4) step_ratio_up=0.4,
gamma=gamma)
runner.register_hook_from_cfg(hook_cfg) runner.register_hook_from_cfg(hook_cfg)
# add momentum LR scheduler # add momentum LR scheduler
...@@ -388,7 +391,7 @@ def test_momentum_runner_hook(multi_optimziers): ...@@ -388,7 +391,7 @@ def test_momentum_runner_hook(multi_optimziers):
# TODO: use a more elegant way to check values # TODO: use a more elegant way to check values
assert hasattr(hook, 'writer') assert hasattr(hook, 'writer')
if multi_optimziers: if multi_optimizers:
calls = [ calls = [
call( call(
'train', { 'train', {
...@@ -419,13 +422,17 @@ def test_momentum_runner_hook(multi_optimziers): ...@@ -419,13 +422,17 @@ def test_momentum_runner_hook(multi_optimziers):
'momentum': 0.95 'momentum': 0.95
}, 1), }, 1),
call('train', { call('train', {
'learning_rate': 0.2, 'learning_rate': 0.11,
'momentum': 0.85 'momentum': 0.85
}, 5), }, 3),
call('train', { call('train', {
'learning_rate': 0.155, 'learning_rate': 0.1879422863405995,
'momentum': 0.875 'momentum': 0.95
}, 7), }, 6),
call('train', {
'learning_rate': 0.11000000000000001,
'momentum': 0.9
}, 8),
] ]
hook.writer.add_scalars.assert_has_calls(calls, any_order=True) hook.writer.add_scalars.assert_has_calls(calls, any_order=True)
...@@ -1267,26 +1274,31 @@ def test_step_runner_hook(multi_optimziers): ...@@ -1267,26 +1274,31 @@ def test_step_runner_hook(multi_optimziers):
hook.writer.add_scalars.assert_has_calls(calls, any_order=True) hook.writer.add_scalars.assert_has_calls(calls, any_order=True)
@pytest.mark.parametrize('multi_optimizers, max_iters', [(True, 8), @pytest.mark.parametrize('multi_optimizers, max_iters, gamma, cyclic_times',
(False, 8)]) [(True, 8, 1, 1), (False, 8, 0.5, 2)])
def test_cyclic_lr_update_hook(multi_optimizers, max_iters): def test_cyclic_lr_update_hook(multi_optimizers, max_iters, gamma,
cyclic_times):
"""Test CyclicLrUpdateHook.""" """Test CyclicLrUpdateHook."""
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
# by_epoch should be False # by_epoch should be False
CyclicLrUpdaterHook(by_epoch=True) CyclicLrUpdaterHook(by_epoch=True)
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
# target_ratio" must be either float or tuple/list of two floats # target_ratio must be either float or tuple/list of two floats
CyclicLrUpdaterHook(by_epoch=False, target_ratio=(10.0, 0.1, 0.2)) CyclicLrUpdaterHook(by_epoch=False, target_ratio=(10.0, 0.1, 0.2))
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
# step_ratio_up" must be in range [0,1) # step_ratio_up must be in range [0,1)
CyclicLrUpdaterHook(by_epoch=False, step_ratio_up=1.4) CyclicLrUpdaterHook(by_epoch=False, step_ratio_up=1.4)
with pytest.raises(ValueError): with pytest.raises(ValueError):
# anneal_strategy must be one of "cos" or "linear" # anneal_strategy must be one of "cos" or "linear"
CyclicLrUpdaterHook(by_epoch=False, anneal_strategy='sin') CyclicLrUpdaterHook(by_epoch=False, anneal_strategy='sin')
with pytest.raises(AssertionError):
# gamma must be in range (0, 1]
CyclicLrUpdaterHook(by_epoch=False, gamma=0)
sys.modules['pavi'] = MagicMock() sys.modules['pavi'] = MagicMock()
loader = DataLoader(torch.ones((10, 2))) loader = DataLoader(torch.ones((10, 2)))
runner = _build_demo_runner( runner = _build_demo_runner(
...@@ -1296,13 +1308,14 @@ def test_cyclic_lr_update_hook(multi_optimizers, max_iters): ...@@ -1296,13 +1308,14 @@ def test_cyclic_lr_update_hook(multi_optimizers, max_iters):
multi_optimziers=multi_optimizers) multi_optimziers=multi_optimizers)
# add cyclic LR scheduler # add cyclic LR scheduler
hook = CyclicLrUpdaterHook( schedule_hook = CyclicLrUpdaterHook(
by_epoch=False, by_epoch=False,
target_ratio=(10.0, 1.0), target_ratio=(10.0, 1.0),
cyclic_times=1, cyclic_times=cyclic_times,
step_ratio_up=0.5, step_ratio_up=0.5,
anneal_strategy='linear') anneal_strategy='linear',
runner.register_hook(hook) gamma=gamma)
runner.register_hook(schedule_hook)
runner.register_hook_from_cfg(dict(type='IterTimerHook')) runner.register_hook_from_cfg(dict(type='IterTimerHook'))
runner.register_hook(IterTimerHook()) runner.register_hook(IterTimerHook())
# add pavi hook # add pavi hook
...@@ -1343,13 +1356,17 @@ def test_cyclic_lr_update_hook(multi_optimizers, max_iters): ...@@ -1343,13 +1356,17 @@ def test_cyclic_lr_update_hook(multi_optimizers, max_iters):
'momentum': 0.95 'momentum': 0.95
}, 1), }, 1),
call('train', { call('train', {
'learning_rate': 0.155, 'learning_rate': 0.11,
'momentum': 0.95 'momentum': 0.95
}, 4), }, 4),
call('train', { call('train', {
'learning_rate': 0.155, 'learning_rate': 0.065,
'momentum': 0.95 'momentum': 0.95
}, 6), }, 6),
call('train', {
'learning_rate': 0.11,
'momentum': 0.95
}, 7),
] ]
hook.writer.add_scalars.assert_has_calls(calls, any_order=True) hook.writer.add_scalars.assert_has_calls(calls, any_order=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment