Unverified Commit 025ee9f5 authored by Gu Wang's avatar Gu Wang Committed by GitHub
Browse files

[Feature] Add flat cosine lr updater (#1066)

* add flat cosine lr updater

* add test

* add doc

* update doc

* reformat

* update unittest

* update test flat cos

* remove momentum hook test

* update test

* change assert to ValueError

* fix unittest

* add by_epoch=True unittest

* change to start_percent

* change to start_percent in test
parent f860d0ed
......@@ -278,6 +278,60 @@ class CosineAnnealingLrUpdaterHook(LrUpdaterHook):
return annealing_cos(base_lr, target_lr, progress / max_progress)
@HOOKS.register_module()
class FlatCosineAnnealingLrUpdaterHook(LrUpdaterHook):
"""Flat + Cosine lr schedule.
Modified from https://github.com/fastai/fastai/blob/master/fastai/callback/schedule.py#L128 # noqa: E501
Args:
start_percent (float): When to start annealing the learning rate
after the percentage of the total training steps.
The value should be in range [0, 1).
Default: 0.75
min_lr (float, optional): The minimum lr. Default: None.
min_lr_ratio (float, optional): The ratio of minimum lr to the base lr.
Either `min_lr` or `min_lr_ratio` should be specified.
Default: None.
"""
def __init__(self,
start_percent=0.75,
min_lr=None,
min_lr_ratio=None,
**kwargs):
assert (min_lr is None) ^ (min_lr_ratio is None)
if start_percent < 0 or start_percent > 1 or not isinstance(
start_percent, float):
raise ValueError(
'expected float between 0 and 1 start_percent, but '
f'got {start_percent}')
self.start_percent = start_percent
self.min_lr = min_lr
self.min_lr_ratio = min_lr_ratio
super(FlatCosineAnnealingLrUpdaterHook, self).__init__(**kwargs)
def get_lr(self, runner, base_lr):
if self.by_epoch:
start = round(runner.max_epochs * self.start_percent)
progress = runner.epoch - start
max_progress = runner.max_epochs - start
else:
start = round(runner.max_iters * self.start_percent)
progress = runner.iter - start
max_progress = runner.max_iters - start
if self.min_lr_ratio is not None:
target_lr = base_lr * self.min_lr_ratio
else:
target_lr = self.min_lr
if progress < 0:
return base_lr
else:
return annealing_cos(base_lr, target_lr, progress / max_progress)
@HOOKS.register_module()
class CosineRestartLrUpdaterHook(LrUpdaterHook):
"""Cosine annealing with restarts learning rate scheme.
......
......@@ -25,6 +25,7 @@ from mmcv.runner import (CheckpointHook, DvcliveLoggerHook, EMAHook,
from mmcv.runner.hooks.hook import HOOKS, Hook
from mmcv.runner.hooks.lr_updater import (CosineRestartLrUpdaterHook,
CyclicLrUpdaterHook,
FlatCosineAnnealingLrUpdaterHook,
OneCycleLrUpdaterHook,
StepLrUpdaterHook)
......@@ -375,6 +376,147 @@ def test_cosine_runner_hook(multi_optimziers):
hook.writer.add_scalars.assert_has_calls(calls, any_order=True)
@pytest.mark.parametrize('multi_optimziers, by_epoch', [(False, False),
(True, False),
(False, True),
(True, True)])
def test_flat_cosine_runner_hook(multi_optimziers, by_epoch):
"""xdoctest -m tests/test_hooks.py test_flat_cosine_runner_hook."""
sys.modules['pavi'] = MagicMock()
loader = DataLoader(torch.ones((10, 2)))
max_epochs = 10 if by_epoch else 1
runner = _build_demo_runner(
multi_optimziers=multi_optimziers, max_epochs=max_epochs)
with pytest.raises(ValueError):
# start_percent: expected float between 0 and 1
FlatCosineAnnealingLrUpdaterHook(start_percent=-0.1, min_lr_ratio=0)
# add LR scheduler
hook_cfg = dict(
type='FlatCosineAnnealingLrUpdaterHook',
by_epoch=by_epoch,
min_lr_ratio=0,
warmup='linear',
warmup_iters=10 if by_epoch else 2,
warmup_ratio=0.9,
start_percent=0.5)
runner.register_hook_from_cfg(hook_cfg)
runner.register_hook_from_cfg(dict(type='IterTimerHook'))
runner.register_hook(IterTimerHook())
# add pavi hook
hook = PaviLoggerHook(interval=1, add_graph=False, add_last_ckpt=True)
runner.register_hook(hook)
runner.run([loader], [('train', 1)])
shutil.rmtree(runner.work_dir)
# TODO: use a more elegant way to check values
assert hasattr(hook, 'writer')
if multi_optimziers:
if by_epoch:
calls = [
call(
'train', {
'learning_rate/model1': 0.018000000000000002,
'learning_rate/model2': 0.009000000000000001,
'momentum/model1': 0.95,
'momentum/model2': 0.9,
}, 1),
call(
'train', {
'learning_rate/model1': 0.02,
'learning_rate/model2': 0.01,
'momentum/model1': 0.95,
'momentum/model2': 0.9,
}, 11),
call(
'train', {
'learning_rate/model1': 0.018090169943749474,
'learning_rate/model2': 0.009045084971874737,
'momentum/model1': 0.95,
'momentum/model2': 0.9,
}, 61),
call(
'train', {
'learning_rate/model1': 0.0019098300562505265,
'learning_rate/model2': 0.0009549150281252633,
'momentum/model1': 0.95,
'momentum/model2': 0.9,
}, 100)
]
else:
calls = [
call(
'train', {
'learning_rate/model1': 0.018000000000000002,
'learning_rate/model2': 0.009000000000000001,
'momentum/model1': 0.95,
'momentum/model2': 0.9
}, 1),
call(
'train', {
'learning_rate/model1': 0.02,
'learning_rate/model2': 0.01,
'momentum/model1': 0.95,
'momentum/model2': 0.9
}, 6),
call(
'train', {
'learning_rate/model1': 0.018090169943749474,
'learning_rate/model2': 0.009045084971874737,
'momentum/model1': 0.95,
'momentum/model2': 0.9
}, 7),
call(
'train', {
'learning_rate/model1': 0.0019098300562505265,
'learning_rate/model2': 0.0009549150281252633,
'momentum/model1': 0.95,
'momentum/model2': 0.9
}, 10)
]
else:
if by_epoch:
calls = [
call('train', {
'learning_rate': 0.018000000000000002,
'momentum': 0.95
}, 1),
call('train', {
'learning_rate': 0.02,
'momentum': 0.95
}, 11),
call('train', {
'learning_rate': 0.018090169943749474,
'momentum': 0.95
}, 61),
call('train', {
'learning_rate': 0.0019098300562505265,
'momentum': 0.95
}, 100)
]
else:
calls = [
call('train', {
'learning_rate': 0.018000000000000002,
'momentum': 0.95
}, 1),
call('train', {
'learning_rate': 0.02,
'momentum': 0.95
}, 6),
call('train', {
'learning_rate': 0.018090169943749474,
'momentum': 0.95
}, 7),
call('train', {
'learning_rate': 0.0019098300562505265,
'momentum': 0.95
}, 10)
]
hook.writer.add_scalars.assert_has_calls(calls, any_order=True)
@pytest.mark.parametrize('multi_optimziers, max_iters', [(True, 10), (True, 2),
(False, 10),
(False, 2)])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment