Unverified Commit 04346fc7 authored by Ziyi Wu's avatar Ziyi Wu Committed by GitHub
Browse files

[Feature] Support clipping min_lr in StepLrUpdaterHook to help MMDet3D (#995)

* support clipping min_lr in StepLrUpdaterHook

* add docstring for StepLrUpdaterHook

* fix small bugs

* add unit test for StepLrUpdaterHook

* fix linting error
parent f46e5f8e
......@@ -32,7 +32,8 @@ def single_gpu_test(model, data_loader):
result = model(return_loss=False, **data)
results.extend(result)
# Assume result has the same length of batch_size, refer to https://github.com/open-mmlab/mmcv/issues/985
# Assume result has the same length of batch_size
# refer to https://github.com/open-mmlab/mmcv/issues/985
batch_size = len(result)
for _ in range(batch_size):
prog_bar.update()
......
......@@ -164,8 +164,19 @@ class FixedLrUpdaterHook(LrUpdaterHook):
@HOOKS.register_module()
class StepLrUpdaterHook(LrUpdaterHook):
"""Step LR scheduler with min_lr clipping.
def __init__(self, step, gamma=0.1, **kwargs):
Args:
step (int | list[int]): Step to decay the LR. If an int value is given,
regard it as the decay interval. If a list is given, decay LR at
these steps.
gamma (float, optional): Decay LR ratio. Default: 0.1.
min_lr (float, optional): Minimum LR value to keep. If LR after decay
is lower than `min_lr`, it will be clipped to this value. If None
is given, we don't perform lr clipping. Default: None.
"""
def __init__(self, step, gamma=0.1, min_lr=None, **kwargs):
assert isinstance(step, (list, int))
if isinstance(step, list):
for s in step:
......@@ -176,20 +187,29 @@ class StepLrUpdaterHook(LrUpdaterHook):
raise TypeError('"step" must be a list or integer')
self.step = step
self.gamma = gamma
self.min_lr = min_lr
super(StepLrUpdaterHook, self).__init__(**kwargs)
def get_lr(self, runner, base_lr):
progress = runner.epoch if self.by_epoch else runner.iter
if isinstance(self.step, int):
return base_lr * (self.gamma**(progress // self.step))
lr = base_lr * (self.gamma**(progress // self.step))
if self.min_lr is not None:
# clip to a minimum value
lr = max(lr, self.min_lr)
return lr
exp = len(self.step)
for i, s in enumerate(self.step):
if progress < s:
exp = i
break
return base_lr * self.gamma**exp
lr = base_lr * self.gamma**exp
if self.min_lr is not None:
# clip to a minimum value
lr = max(lr, self.min_lr)
return lr
@HOOKS.register_module()
......
......@@ -22,7 +22,8 @@ from mmcv.runner import (CheckpointHook, EMAHook, IterTimerHook,
MlflowLoggerHook, PaviLoggerHook, WandbLoggerHook,
build_runner)
from mmcv.runner.hooks.lr_updater import (CosineRestartLrUpdaterHook,
OneCycleLrUpdaterHook)
OneCycleLrUpdaterHook,
StepLrUpdaterHook)
def test_checkpoint_hook():
......@@ -537,6 +538,182 @@ def test_cosine_restart_lr_update_hook(multi_optimziers):
hook.writer.add_scalars.assert_has_calls(calls, any_order=True)
@pytest.mark.parametrize('multi_optimziers', (True, False))
def test_step_lr_update_hook(multi_optimziers):
"""Test StepLrUpdaterHook."""
with pytest.raises(TypeError):
# `step` should be specified
StepLrUpdaterHook()
with pytest.raises(AssertionError):
# if `step` is int, should be positive
StepLrUpdaterHook(-10)
with pytest.raises(AssertionError):
# if `step` is list of int, should all be positive
StepLrUpdaterHook([10, 16, -20])
# test StepLrUpdaterHook with int `step` value
sys.modules['pavi'] = MagicMock()
loader = DataLoader(torch.ones((30, 2)))
runner = _build_demo_runner(multi_optimziers=multi_optimziers)
# add step LR scheduler
hook = StepLrUpdaterHook(by_epoch=False, step=5, gamma=0.5, min_lr=1e-3)
runner.register_hook(hook)
runner.register_hook(IterTimerHook())
# add pavi hook
hook = PaviLoggerHook(interval=1, add_graph=False, add_last_ckpt=True)
runner.register_hook(hook)
runner.run([loader], [('train', 1)])
shutil.rmtree(runner.work_dir)
# TODO: use a more elegant way to check values
assert hasattr(hook, 'writer')
if multi_optimziers:
calls = [
call(
'train', {
'learning_rate/model1': 0.02,
'learning_rate/model2': 0.01,
'momentum/model1': 0.95,
'momentum/model2': 0.9
}, 1),
call(
'train', {
'learning_rate/model1': 0.01,
'learning_rate/model2': 0.005,
'momentum/model1': 0.95,
'momentum/model2': 0.9
}, 6),
call(
'train', {
'learning_rate/model1': 0.0025,
'learning_rate/model2': 0.00125,
'momentum/model1': 0.95,
'momentum/model2': 0.9
}, 16),
call(
'train', {
'learning_rate/model1': 0.00125,
'learning_rate/model2': 0.001,
'momentum/model1': 0.95,
'momentum/model2': 0.9
}, 21),
call(
'train', {
'learning_rate/model1': 0.001,
'learning_rate/model2': 0.001,
'momentum/model1': 0.95,
'momentum/model2': 0.9
}, 26),
call(
'train', {
'learning_rate/model1': 0.001,
'learning_rate/model2': 0.001,
'momentum/model1': 0.95,
'momentum/model2': 0.9
}, 30)
]
else:
calls = [
call('train', {
'learning_rate': 0.02,
'momentum': 0.95
}, 1),
call('train', {
'learning_rate': 0.01,
'momentum': 0.95
}, 6),
call('train', {
'learning_rate': 0.0025,
'momentum': 0.95
}, 16),
call('train', {
'learning_rate': 0.00125,
'momentum': 0.95
}, 21),
call('train', {
'learning_rate': 0.001,
'momentum': 0.95
}, 26),
call('train', {
'learning_rate': 0.001,
'momentum': 0.95
}, 30)
]
hook.writer.add_scalars.assert_has_calls(calls, any_order=True)
# test StepLrUpdaterHook with list[int] `step` value
sys.modules['pavi'] = MagicMock()
loader = DataLoader(torch.ones((10, 2)))
runner = _build_demo_runner(multi_optimziers=multi_optimziers)
# add step LR scheduler
hook = StepLrUpdaterHook(by_epoch=False, step=[4, 6, 8], gamma=0.1)
runner.register_hook(hook)
runner.register_hook(IterTimerHook())
# add pavi hook
hook = PaviLoggerHook(interval=1, add_graph=False, add_last_ckpt=True)
runner.register_hook(hook)
runner.run([loader], [('train', 1)])
shutil.rmtree(runner.work_dir)
# TODO: use a more elegant way to check values
assert hasattr(hook, 'writer')
if multi_optimziers:
calls = [
call(
'train', {
'learning_rate/model1': 0.02,
'learning_rate/model2': 0.01,
'momentum/model1': 0.95,
'momentum/model2': 0.9
}, 1),
call(
'train', {
'learning_rate/model1': 0.002,
'learning_rate/model2': 0.001,
'momentum/model1': 0.95,
'momentum/model2': 0.9
}, 5),
call(
'train', {
'learning_rate/model1': 2.0000000000000004e-4,
'learning_rate/model2': 1.0000000000000002e-4,
'momentum/model1': 0.95,
'momentum/model2': 0.9
}, 7),
call(
'train', {
'learning_rate/model1': 2.0000000000000005e-05,
'learning_rate/model2': 1.0000000000000003e-05,
'momentum/model1': 0.95,
'momentum/model2': 0.9
}, 9)
]
else:
calls = [
call('train', {
'learning_rate': 0.02,
'momentum': 0.95
}, 1),
call('train', {
'learning_rate': 0.002,
'momentum': 0.95
}, 5),
call('train', {
'learning_rate': 2.0000000000000004e-4,
'momentum': 0.95
}, 7),
call('train', {
'learning_rate': 2.0000000000000005e-05,
'momentum': 0.95
}, 9)
]
hook.writer.add_scalars.assert_has_calls(calls, any_order=True)
@pytest.mark.parametrize('log_model', (True, False))
def test_mlflow_hook(log_model):
sys.modules['mlflow'] = MagicMock()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment