Unverified Commit f28a7c7e authored by Harry's avatar Harry Committed by GitHub
Browse files

Add CosineRestartLrUpdaterHook (#319)

* feat: add CosineRestartLrUpdaterHook

* style: rename period to periods

* fix: bug in period 0

* feat: rename eta_min to min_lr and add min_lr_ratio

* docs: fix docstring of restart lr updater

* refactor: use annealing_cos

* docs: add docstring to annealing_cos

* feat: cosine restart lr update hook

* refactor: modify code order for unittest
parent 9f04477f
...@@ -247,6 +247,82 @@ class CosineAnealingLrUpdaterHook(LrUpdaterHook): ...@@ -247,6 +247,82 @@ class CosineAnealingLrUpdaterHook(LrUpdaterHook):
return annealing_cos(base_lr, target_lr, progress / max_progress) return annealing_cos(base_lr, target_lr, progress / max_progress)
@HOOKS.register_module()
class CosineRestartLrUpdaterHook(LrUpdaterHook):
"""Cosine annealing with restarts learning rate scheme.
Args:
periods (list[int]): Periods for each cosine anneling cycle.
restart_weights (list[float], optional): Restart weights at each
restart iteration. Default: [1].
min_lr (float, optional): The minimum lr. Default: None.
min_lr_ratio (float, optional): The ratio of minimum lr to the base lr.
Either `min_lr` or `min_lr_ratio` should be specified.
Default: None.
"""
def __init__(self,
periods,
restart_weights=[1],
min_lr=None,
min_lr_ratio=None,
**kwargs):
assert (min_lr is None) ^ (min_lr_ratio is None)
self.periods = periods
self.min_lr = min_lr
self.min_lr_ratio = min_lr_ratio
self.restart_weights = restart_weights
assert (len(self.periods) == len(self.restart_weights)
), 'periods and restart_weights should have the same length.'
super(CosineRestartLrUpdaterHook, self).__init__(**kwargs)
self.cumulative_periods = [
sum(self.periods[0:i + 1]) for i in range(0, len(self.periods))
]
def get_lr(self, runner, base_lr):
if self.by_epoch:
progress = runner.epoch
else:
progress = runner.iter
if self.min_lr_ratio is not None:
target_lr = base_lr * self.min_lr_ratio
else:
target_lr = self.min_lr
idx = get_position_from_periods(progress, self.cumulative_periods)
current_weight = self.restart_weights[idx]
nearest_restart = 0 if idx == 0 else self.cumulative_periods[idx - 1]
current_periods = self.periods[idx]
alpha = min((progress - nearest_restart) / current_periods, 1)
return annealing_cos(base_lr, target_lr, alpha, current_weight)
def get_position_from_periods(iteration, cumulative_periods):
"""Get the position from a period list.
It will return the index of the right-closest number in the period list.
For example, the cumulative_periods = [100, 200, 300, 400],
if iteration == 50, return 0;
if iteration == 210, return 2;
if iteration == 300, return 2.
Args:
iteration (int): Current iteration.
cumulative_periods (list[int]): Cumulative period list.
Returns:
int: The position of the right-closest number in the period list.
"""
for i, period in enumerate(cumulative_periods):
if iteration <= period:
return i
raise ValueError(f'Current iteration {iteration} exceeds '
f'cumulative_periods {cumulative_periods}')
@HOOKS.register_module() @HOOKS.register_module()
class CyclicLrUpdaterHook(LrUpdaterHook): class CyclicLrUpdaterHook(LrUpdaterHook):
"""Cyclic LR Scheduler """Cyclic LR Scheduler
...@@ -322,7 +398,19 @@ class CyclicLrUpdaterHook(LrUpdaterHook): ...@@ -322,7 +398,19 @@ class CyclicLrUpdaterHook(LrUpdaterHook):
progress / (end_iter - start_iter)) progress / (end_iter - start_iter))
def annealing_cos(start, end, factor): def annealing_cos(start, end, factor, weight=1):
"""Cosine anneal from `start` to `end` as pct goes from 0.0 to 1.0.""" """Calculate annealing cos learning rate.
Cosine anneal from `weight * start + (1 - weight) * end` to `end` as
percentage goes from 0.0 to 1.0.
Args:
start (float): The starting learning rate of the cosine annealing.
end (float): The ending learing rate of the cosine annealing.
factor (float): The coefficient of `pi` when calculating the current
percentage. Range from 0.0 to 1.0.
weight (float, optional): The combination factor of `start` and `end`
when calculating the actual starting learning rate. Default to 1.
"""
cos_out = cos(pi * factor) + 1 cos_out = cos(pi * factor) + 1
return end + 0.5 * (start - end) * cos_out return end + 0.5 * weight * (start - end) * cos_out
...@@ -21,6 +21,7 @@ from torch.utils.data import DataLoader ...@@ -21,6 +21,7 @@ from torch.utils.data import DataLoader
from mmcv.runner import (EpochBasedRunner, IterTimerHook, MlflowLoggerHook, from mmcv.runner import (EpochBasedRunner, IterTimerHook, MlflowLoggerHook,
PaviLoggerHook, WandbLoggerHook) PaviLoggerHook, WandbLoggerHook)
from mmcv.runner.hooks.lr_updater import (CosineAnealingLrUpdaterHook, from mmcv.runner.hooks.lr_updater import (CosineAnealingLrUpdaterHook,
CosineRestartLrUpdaterHook,
CyclicLrUpdaterHook) CyclicLrUpdaterHook)
from mmcv.runner.hooks.momentum_updater import ( from mmcv.runner.hooks.momentum_updater import (
CosineAnealingMomentumUpdaterHook, CyclicMomentumUpdaterHook) CosineAnealingMomentumUpdaterHook, CyclicMomentumUpdaterHook)
...@@ -144,6 +145,84 @@ def test_cosine_runner_hook(): ...@@ -144,6 +145,84 @@ def test_cosine_runner_hook():
hook.writer.add_scalars.assert_has_calls(calls, any_order=True) hook.writer.add_scalars.assert_has_calls(calls, any_order=True)
def test_cosine_restart_lr_update_hook():
"""Test CosineRestartLrUpdaterHook."""
with pytest.raises(AssertionError):
# either `min_lr` or `min_lr_ratio` should be specified
CosineRestartLrUpdaterHook(
by_epoch=False,
periods=[2, 10],
restart_weights=[0.5, 0.5],
min_lr=0.1,
min_lr_ratio=0)
with pytest.raises(AssertionError):
# periods and restart_weights should have the same length
CosineRestartLrUpdaterHook(
by_epoch=False,
periods=[2, 10],
restart_weights=[0.5],
min_lr_ratio=0)
with pytest.raises(ValueError):
# the last cumulative_periods 7 (out of [5, 7]) should >= 10
sys.modules['pavi'] = MagicMock()
loader = DataLoader(torch.ones((10, 2)))
runner = _build_demo_runner()
# add cosine restart LR scheduler
hook = CosineRestartLrUpdaterHook(
by_epoch=False,
periods=[5, 2], # cumulative_periods [5, 7 (5 + 2)]
restart_weights=[0.5, 0.5],
min_lr=0.0001)
runner.register_hook(hook)
runner.register_hook(IterTimerHook())
# add pavi hook
hook = PaviLoggerHook(interval=1, add_graph=False, add_last_ckpt=True)
runner.register_hook(hook)
runner.run([loader], [('train', 1)], 1)
shutil.rmtree(runner.work_dir)
sys.modules['pavi'] = MagicMock()
loader = DataLoader(torch.ones((10, 2)))
runner = _build_demo_runner()
# add cosine restart LR scheduler
hook = CosineRestartLrUpdaterHook(
by_epoch=False,
periods=[5, 5],
restart_weights=[0.5, 0.5],
min_lr_ratio=0)
runner.register_hook(hook)
runner.register_hook(IterTimerHook())
# add pavi hook
hook = PaviLoggerHook(interval=1, add_graph=False, add_last_ckpt=True)
runner.register_hook(hook)
runner.run([loader], [('train', 1)], 1)
shutil.rmtree(runner.work_dir)
# TODO: use a more elegant way to check values
assert hasattr(hook, 'writer')
calls = [
call('train', {
'learning_rate': 0.01,
'momentum': 0.95
}, 0),
call('train', {
'learning_rate': 0.0,
'momentum': 0.95
}, 5),
call('train', {
'learning_rate': 0.0009549150281252633,
'momentum': 0.95
}, 9)
]
hook.writer.add_scalars.assert_has_calls(calls, any_order=True)
@pytest.mark.parametrize('log_model', (True, False)) @pytest.mark.parametrize('log_model', (True, False))
def test_mlflow_hook(log_model): def test_mlflow_hook(log_model):
sys.modules['mlflow'] = MagicMock() sys.modules['mlflow'] = MagicMock()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment