Unverified Commit 1a5bf762 authored by gengenkai's avatar gengenkai Committed by GitHub
Browse files

[Enhancement] LR triangular (#1001)

* [Fix] lr add triangular

* [Fix] lr_triangular

* lr add docstring

* lr add docstring

* add unittest of triangular lr update
parent 47962467
...@@ -367,13 +367,16 @@ class CyclicLrUpdaterHook(LrUpdaterHook): ...@@ -367,13 +367,16 @@ class CyclicLrUpdaterHook(LrUpdaterHook):
triangular policy inside a cycle. This improves the performance in the triangular policy inside a cycle. This improves the performance in the
3D detection area. 3D detection area.
Attributes: Args:
by_epoch (bool): Whether to update LR by epoch.
target_ratio (tuple[float]): Relative ratio of the highest LR and the target_ratio (tuple[float]): Relative ratio of the highest LR and the
lowest LR to the initial LR. lowest LR to the initial LR.
cyclic_times (int): Number of cycles during training cyclic_times (int): Number of cycles during training
step_ratio_up (float): The ratio of the increasing process of LR in step_ratio_up (float): The ratio of the increasing process of LR in
the total cycle. the total cycle.
by_epoch (bool): Whether to update LR by epoch. anneal_strategy (str): {'cos', 'linear'}
Specifies the annealing strategy: 'cos' for cosine annealing,
'linear' for linear annealing. Default: 'cos'.
""" """
def __init__(self, def __init__(self,
...@@ -381,6 +384,7 @@ class CyclicLrUpdaterHook(LrUpdaterHook): ...@@ -381,6 +384,7 @@ class CyclicLrUpdaterHook(LrUpdaterHook):
target_ratio=(10, 1e-4), target_ratio=(10, 1e-4),
cyclic_times=1, cyclic_times=1,
step_ratio_up=0.4, step_ratio_up=0.4,
anneal_strategy='cos',
**kwargs): **kwargs):
if isinstance(target_ratio, float): if isinstance(target_ratio, float):
target_ratio = (target_ratio, target_ratio / 1e5) target_ratio = (target_ratio, target_ratio / 1e5)
...@@ -400,6 +404,14 @@ class CyclicLrUpdaterHook(LrUpdaterHook): ...@@ -400,6 +404,14 @@ class CyclicLrUpdaterHook(LrUpdaterHook):
self.cyclic_times = cyclic_times self.cyclic_times = cyclic_times
self.step_ratio_up = step_ratio_up self.step_ratio_up = step_ratio_up
self.lr_phases = [] # init lr_phases self.lr_phases = [] # init lr_phases
# validate anneal_strategy
if anneal_strategy not in ['cos', 'linear']:
raise ValueError('anneal_strategy must be one of "cos" or '
f'"linear", instead got {anneal_strategy}')
elif anneal_strategy == 'cos':
self.anneal_func = annealing_cos
elif anneal_strategy == 'linear':
self.anneal_func = annealing_linear
assert not by_epoch, \ assert not by_epoch, \
'currently only support "by_epoch" = False' 'currently only support "by_epoch" = False'
...@@ -425,9 +437,9 @@ class CyclicLrUpdaterHook(LrUpdaterHook): ...@@ -425,9 +437,9 @@ class CyclicLrUpdaterHook(LrUpdaterHook):
curr_iter %= max_iter_per_phase curr_iter %= max_iter_per_phase
if start_iter <= curr_iter < end_iter: if start_iter <= curr_iter < end_iter:
progress = curr_iter - start_iter progress = curr_iter - start_iter
return annealing_cos(base_lr * start_ratio, return self.anneal_func(base_lr * start_ratio,
base_lr * end_ratio, base_lr * end_ratio,
progress / (end_iter - start_iter)) progress / (end_iter - start_iter))
@HOOKS.register_module() @HOOKS.register_module()
......
...@@ -22,6 +22,7 @@ from mmcv.runner import (CheckpointHook, EMAHook, IterTimerHook, ...@@ -22,6 +22,7 @@ from mmcv.runner import (CheckpointHook, EMAHook, IterTimerHook,
MlflowLoggerHook, PaviLoggerHook, WandbLoggerHook, MlflowLoggerHook, PaviLoggerHook, WandbLoggerHook,
build_runner) build_runner)
from mmcv.runner.hooks.lr_updater import (CosineRestartLrUpdaterHook, from mmcv.runner.hooks.lr_updater import (CosineRestartLrUpdaterHook,
CyclicLrUpdaterHook,
OneCycleLrUpdaterHook, OneCycleLrUpdaterHook,
StepLrUpdaterHook) StepLrUpdaterHook)
...@@ -714,6 +715,93 @@ def test_step_lr_update_hook(multi_optimziers): ...@@ -714,6 +715,93 @@ def test_step_lr_update_hook(multi_optimziers):
hook.writer.add_scalars.assert_has_calls(calls, any_order=True) hook.writer.add_scalars.assert_has_calls(calls, any_order=True)
@pytest.mark.parametrize('multi_optimizers, max_iters', [(True, 8),
(False, 8)])
def test_cyclic_lr_update_hook(multi_optimizers, max_iters):
"""Test CyclicLrUpdateHook."""
with pytest.raises(AssertionError):
# by_epoch should be False
CyclicLrUpdaterHook(by_epoch=True)
with pytest.raises(AssertionError):
# target_ratio" must be either float or tuple/list of two floats
CyclicLrUpdaterHook(by_epoch=False, target_ratio=(10.0, 0.1, 0.2))
with pytest.raises(AssertionError):
# step_ratio_up" must be in range [0,1)
CyclicLrUpdaterHook(by_epoch=False, step_ratio_up=1.4)
with pytest.raises(ValueError):
# anneal_strategy must be one of "cos" or "linear"
CyclicLrUpdaterHook(by_epoch=False, anneal_strategy='sin')
sys.modules['pavi'] = MagicMock()
loader = DataLoader(torch.ones((10, 2)))
runner = _build_demo_runner(
runner_type='IterBasedRunner',
max_epochs=None,
max_iters=max_iters,
multi_optimziers=multi_optimizers)
# add cyclic LR scheduler
hook = CyclicLrUpdaterHook(
by_epoch=False,
target_ratio=(10.0, 1.0),
cyclic_times=1,
step_ratio_up=0.5,
anneal_strategy='linear')
runner.register_hook(hook)
runner.register_hook_from_cfg(dict(type='IterTimerHook'))
runner.register_hook(IterTimerHook())
# add pavi hook
hook = PaviLoggerHook(interval=1, add_graph=False, add_last_ckpt=True)
runner.register_hook(hook)
runner.run([loader], [('train', 1)])
shutil.rmtree(runner.work_dir)
assert hasattr(hook, 'writer')
if multi_optimizers:
calls = [
call(
'train', {
'learning_rate/model1': 0.02,
'learning_rate/model2': 0.01,
'momentum/model1': 0.95,
'momentum/model2': 0.9,
}, 1),
call(
'train', {
'learning_rate/model1': 0.155,
'learning_rate/model2': 0.0775,
'momentum/model1': 0.95,
'momentum/model2': 0.9,
}, 4),
call(
'train', {
'learning_rate/model1': 0.155,
'learning_rate/model2': 0.0775,
'momentum/model1': 0.95,
'momentum/model2': 0.9,
}, 6)
]
else:
calls = [
call('train', {
'learning_rate': 0.02,
'momentum': 0.95
}, 1),
call('train', {
'learning_rate': 0.155,
'momentum': 0.95
}, 4),
call('train', {
'learning_rate': 0.155,
'momentum': 0.95
}, 6),
]
hook.writer.add_scalars.assert_has_calls(calls, any_order=True)
@pytest.mark.parametrize('log_model', (True, False)) @pytest.mark.parametrize('log_model', (True, False))
def test_mlflow_hook(log_model): def test_mlflow_hook(log_model):
sys.modules['mlflow'] = MagicMock() sys.modules['mlflow'] = MagicMock()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment