Unverified Commit c77e95a6 authored by Ziyi Wu's avatar Ziyi Wu Committed by GitHub
Browse files

[Enhance] Add step momentum updater to support MMDet3D (#1011)

* add StepMomentumUpdaterHook

* add unit test

* fix typos

* refactor step updater

* replace stage with exp

* fix linting error

* use all() operation
parent 1a5bf762
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
import numbers import numbers
from math import cos, pi from math import cos, pi
import mmcv
from .hook import HOOKS, Hook from .hook import HOOKS, Hook
...@@ -177,10 +178,9 @@ class StepLrUpdaterHook(LrUpdaterHook): ...@@ -177,10 +178,9 @@ class StepLrUpdaterHook(LrUpdaterHook):
""" """
def __init__(self, step, gamma=0.1, min_lr=None, **kwargs): def __init__(self, step, gamma=0.1, min_lr=None, **kwargs):
assert isinstance(step, (list, int))
if isinstance(step, list): if isinstance(step, list):
for s in step: assert mmcv.is_list_of(step, int)
assert isinstance(s, int) and s > 0 assert all([s > 0 for s in step])
elif isinstance(step, int): elif isinstance(step, int):
assert step > 0 assert step > 0
else: else:
...@@ -193,19 +193,17 @@ class StepLrUpdaterHook(LrUpdaterHook): ...@@ -193,19 +193,17 @@ class StepLrUpdaterHook(LrUpdaterHook):
def get_lr(self, runner, base_lr): def get_lr(self, runner, base_lr):
progress = runner.epoch if self.by_epoch else runner.iter progress = runner.epoch if self.by_epoch else runner.iter
# calculate exponential term
if isinstance(self.step, int): if isinstance(self.step, int):
lr = base_lr * (self.gamma**(progress // self.step)) exp = progress // self.step
if self.min_lr is not None: else:
# clip to a minimum value exp = len(self.step)
lr = max(lr, self.min_lr) for i, s in enumerate(self.step):
return lr if progress < s:
exp = i
exp = len(self.step) break
for i, s in enumerate(self.step):
if progress < s: lr = base_lr * (self.gamma**exp)
exp = i
break
lr = base_lr * self.gamma**exp
if self.min_lr is not None: if self.min_lr is not None:
# clip to a minimum value # clip to a minimum value
lr = max(lr, self.min_lr) lr = max(lr, self.min_lr)
......
import mmcv
from .hook import HOOKS, Hook from .hook import HOOKS, Hook
from .lr_updater import annealing_cos, annealing_linear, format_param from .lr_updater import annealing_cos, annealing_linear, format_param
...@@ -148,6 +149,54 @@ class MomentumUpdaterHook(Hook): ...@@ -148,6 +149,54 @@ class MomentumUpdaterHook(Hook):
self._set_momentum(runner, warmup_momentum) self._set_momentum(runner, warmup_momentum)
@HOOKS.register_module()
class StepMomentumUpdaterHook(MomentumUpdaterHook):
"""Step momentum scheduler with min value clipping.
Args:
step (int | list[int]): Step to decay the momentum. If an int value is
given, regard it as the decay interval. If a list is given, decay
momentum at these steps.
gamma (float, optional): Decay momentum ratio. Default: 0.5.
min_momentum (float, optional): Minimum momentum value to keep. If
momentum after decay is lower than this value, it will be clipped
accordingly. If None is given, we don't perform lr clipping.
Default: None.
"""
def __init__(self, step, gamma=0.5, min_momentum=None, **kwargs):
if isinstance(step, list):
assert mmcv.is_list_of(step, int)
assert all([s > 0 for s in step])
elif isinstance(step, int):
assert step > 0
else:
raise TypeError('"step" must be a list or integer')
self.step = step
self.gamma = gamma
self.min_momentum = min_momentum
super(StepMomentumUpdaterHook, self).__init__(**kwargs)
def get_momentum(self, runner, base_momentum):
progress = runner.epoch if self.by_epoch else runner.iter
# calculate exponential term
if isinstance(self.step, int):
exp = progress // self.step
else:
exp = len(self.step)
for i, s in enumerate(self.step):
if progress < s:
exp = i
break
momentum = base_momentum * (self.gamma**exp)
if self.min_momentum is not None:
# clip to a minimum value
momentum = max(momentum, self.min_momentum)
return momentum
@HOOKS.register_module() @HOOKS.register_module()
class CosineAnnealingMomentumUpdaterHook(MomentumUpdaterHook): class CosineAnnealingMomentumUpdaterHook(MomentumUpdaterHook):
...@@ -419,11 +468,12 @@ class OneCycleMomentumUpdaterHook(MomentumUpdaterHook): ...@@ -419,11 +468,12 @@ class OneCycleMomentumUpdaterHook(MomentumUpdaterHook):
end_iter = phase['end_iter'] end_iter = phase['end_iter']
if curr_iter <= end_iter or i == len(self.momentum_phases) - 1: if curr_iter <= end_iter or i == len(self.momentum_phases) - 1:
pct = (curr_iter - start_iter) / (end_iter - start_iter) pct = (curr_iter - start_iter) / (end_iter - start_iter)
lr = self.anneal_func(param_group[phase['start_momentum']], momentum = self.anneal_func(
param_group[phase['end_momentum']], pct) param_group[phase['start_momentum']],
param_group[phase['end_momentum']], pct)
break break
start_iter = end_iter start_iter = end_iter
return lr return momentum
def get_regular_momentum(self, runner): def get_regular_momentum(self, runner):
if isinstance(runner.optimizer, dict): if isinstance(runner.optimizer, dict):
......
...@@ -540,7 +540,7 @@ def test_cosine_restart_lr_update_hook(multi_optimziers): ...@@ -540,7 +540,7 @@ def test_cosine_restart_lr_update_hook(multi_optimziers):
@pytest.mark.parametrize('multi_optimziers', (True, False)) @pytest.mark.parametrize('multi_optimziers', (True, False))
def test_step_lr_update_hook(multi_optimziers): def test_step_runner_hook(multi_optimziers):
"""Test StepLrUpdaterHook.""" """Test StepLrUpdaterHook."""
with pytest.raises(TypeError): with pytest.raises(TypeError):
# `step` should be specified # `step` should be specified
...@@ -557,6 +557,15 @@ def test_step_lr_update_hook(multi_optimziers): ...@@ -557,6 +557,15 @@ def test_step_lr_update_hook(multi_optimziers):
loader = DataLoader(torch.ones((30, 2))) loader = DataLoader(torch.ones((30, 2)))
runner = _build_demo_runner(multi_optimziers=multi_optimziers) runner = _build_demo_runner(multi_optimziers=multi_optimziers)
# add momentum scheduler
hook_cfg = dict(
type='StepMomentumUpdaterHook',
by_epoch=False,
step=5,
gamma=0.5,
min_momentum=0.05)
runner.register_hook_from_cfg(hook_cfg)
# add step LR scheduler # add step LR scheduler
hook = StepLrUpdaterHook(by_epoch=False, step=5, gamma=0.5, min_lr=1e-3) hook = StepLrUpdaterHook(by_epoch=False, step=5, gamma=0.5, min_lr=1e-3)
runner.register_hook(hook) runner.register_hook(hook)
...@@ -583,36 +592,36 @@ def test_step_lr_update_hook(multi_optimziers): ...@@ -583,36 +592,36 @@ def test_step_lr_update_hook(multi_optimziers):
'train', { 'train', {
'learning_rate/model1': 0.01, 'learning_rate/model1': 0.01,
'learning_rate/model2': 0.005, 'learning_rate/model2': 0.005,
'momentum/model1': 0.95, 'momentum/model1': 0.475,
'momentum/model2': 0.9 'momentum/model2': 0.45
}, 6), }, 6),
call( call(
'train', { 'train', {
'learning_rate/model1': 0.0025, 'learning_rate/model1': 0.0025,
'learning_rate/model2': 0.00125, 'learning_rate/model2': 0.00125,
'momentum/model1': 0.95, 'momentum/model1': 0.11875,
'momentum/model2': 0.9 'momentum/model2': 0.1125
}, 16), }, 16),
call( call(
'train', { 'train', {
'learning_rate/model1': 0.00125, 'learning_rate/model1': 0.00125,
'learning_rate/model2': 0.001, 'learning_rate/model2': 0.001,
'momentum/model1': 0.95, 'momentum/model1': 0.059375,
'momentum/model2': 0.9 'momentum/model2': 0.05625
}, 21), }, 21),
call( call(
'train', { 'train', {
'learning_rate/model1': 0.001, 'learning_rate/model1': 0.001,
'learning_rate/model2': 0.001, 'learning_rate/model2': 0.001,
'momentum/model1': 0.95, 'momentum/model1': 0.05,
'momentum/model2': 0.9 'momentum/model2': 0.05
}, 26), }, 26),
call( call(
'train', { 'train', {
'learning_rate/model1': 0.001, 'learning_rate/model1': 0.001,
'learning_rate/model2': 0.001, 'learning_rate/model2': 0.001,
'momentum/model1': 0.95, 'momentum/model1': 0.05,
'momentum/model2': 0.9 'momentum/model2': 0.05
}, 30) }, 30)
] ]
else: else:
...@@ -623,23 +632,23 @@ def test_step_lr_update_hook(multi_optimziers): ...@@ -623,23 +632,23 @@ def test_step_lr_update_hook(multi_optimziers):
}, 1), }, 1),
call('train', { call('train', {
'learning_rate': 0.01, 'learning_rate': 0.01,
'momentum': 0.95 'momentum': 0.475
}, 6), }, 6),
call('train', { call('train', {
'learning_rate': 0.0025, 'learning_rate': 0.0025,
'momentum': 0.95 'momentum': 0.11875
}, 16), }, 16),
call('train', { call('train', {
'learning_rate': 0.00125, 'learning_rate': 0.00125,
'momentum': 0.95 'momentum': 0.059375
}, 21), }, 21),
call('train', { call('train', {
'learning_rate': 0.001, 'learning_rate': 0.001,
'momentum': 0.95 'momentum': 0.05
}, 26), }, 26),
call('train', { call('train', {
'learning_rate': 0.001, 'learning_rate': 0.001,
'momentum': 0.95 'momentum': 0.05
}, 30) }, 30)
] ]
hook.writer.add_scalars.assert_has_calls(calls, any_order=True) hook.writer.add_scalars.assert_has_calls(calls, any_order=True)
...@@ -649,6 +658,14 @@ def test_step_lr_update_hook(multi_optimziers): ...@@ -649,6 +658,14 @@ def test_step_lr_update_hook(multi_optimziers):
loader = DataLoader(torch.ones((10, 2))) loader = DataLoader(torch.ones((10, 2)))
runner = _build_demo_runner(multi_optimziers=multi_optimziers) runner = _build_demo_runner(multi_optimziers=multi_optimziers)
# add momentum scheduler
hook_cfg = dict(
type='StepMomentumUpdaterHook',
by_epoch=False,
step=[4, 6, 8],
gamma=0.1)
runner.register_hook_from_cfg(hook_cfg)
# add step LR scheduler # add step LR scheduler
hook = StepLrUpdaterHook(by_epoch=False, step=[4, 6, 8], gamma=0.1) hook = StepLrUpdaterHook(by_epoch=False, step=[4, 6, 8], gamma=0.1)
runner.register_hook(hook) runner.register_hook(hook)
...@@ -675,22 +692,22 @@ def test_step_lr_update_hook(multi_optimziers): ...@@ -675,22 +692,22 @@ def test_step_lr_update_hook(multi_optimziers):
'train', { 'train', {
'learning_rate/model1': 0.002, 'learning_rate/model1': 0.002,
'learning_rate/model2': 0.001, 'learning_rate/model2': 0.001,
'momentum/model1': 0.95, 'momentum/model1': 9.5e-2,
'momentum/model2': 0.9 'momentum/model2': 9.000000000000001e-2
}, 5), }, 5),
call( call(
'train', { 'train', {
'learning_rate/model1': 2.0000000000000004e-4, 'learning_rate/model1': 2.0000000000000004e-4,
'learning_rate/model2': 1.0000000000000002e-4, 'learning_rate/model2': 1.0000000000000002e-4,
'momentum/model1': 0.95, 'momentum/model1': 9.500000000000001e-3,
'momentum/model2': 0.9 'momentum/model2': 9.000000000000003e-3
}, 7), }, 7),
call( call(
'train', { 'train', {
'learning_rate/model1': 2.0000000000000005e-05, 'learning_rate/model1': 2.0000000000000005e-05,
'learning_rate/model2': 1.0000000000000003e-05, 'learning_rate/model2': 1.0000000000000003e-05,
'momentum/model1': 0.95, 'momentum/model1': 9.500000000000002e-4,
'momentum/model2': 0.9 'momentum/model2': 9.000000000000002e-4
}, 9) }, 9)
] ]
else: else:
...@@ -701,16 +718,18 @@ def test_step_lr_update_hook(multi_optimziers): ...@@ -701,16 +718,18 @@ def test_step_lr_update_hook(multi_optimziers):
}, 1), }, 1),
call('train', { call('train', {
'learning_rate': 0.002, 'learning_rate': 0.002,
'momentum': 0.95 'momentum': 0.095
}, 5), }, 5),
call('train', { call(
'learning_rate': 2.0000000000000004e-4, 'train', {
'momentum': 0.95 'learning_rate': 2.0000000000000004e-4,
}, 7), 'momentum': 9.500000000000001e-3
call('train', { }, 7),
'learning_rate': 2.0000000000000005e-05, call(
'momentum': 0.95 'train', {
}, 9) 'learning_rate': 2.0000000000000005e-05,
'momentum': 9.500000000000002e-4
}, 9)
] ]
hook.writer.add_scalars.assert_has_calls(calls, any_order=True) hook.writer.add_scalars.assert_has_calls(calls, any_order=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment