Unverified Commit 89efc607 authored by Miao Zheng's avatar Miao Zheng Committed by GitHub
Browse files

[Fix] OneCycleLrUpdaterHook adds total_steps (#929)

* [Fix] OneCycleLrUpdaterHook interface

* revise according to comments

* revise according to comments

* add test

* fix lint

* revise according to comments

* minors

* add pytest param

* fix lint

* ci
parent f959a34d
......@@ -421,6 +421,9 @@ class OneCycleLrUpdaterHook(LrUpdaterHook):
Args:
max_lr (float or list): Upper learning rate boundaries in the cycle
for each parameter group.
total_steps (int, optional): The total number of steps in the cycle.
Note that if a value is not provided here, it will be the max_iter
of runner. Default: None.
pct_start (float): The percentage of the cycle (in number of steps)
spent increasing the learning rate.
Default: 0.3
......@@ -444,6 +447,7 @@ class OneCycleLrUpdaterHook(LrUpdaterHook):
def __init__(self,
max_lr,
total_steps=None,
pct_start=0.3,
anneal_strategy='cos',
div_factor=25,
......@@ -460,6 +464,11 @@ class OneCycleLrUpdaterHook(LrUpdaterHook):
raise ValueError('the type of max_lr must be the one of list or '
f'dict, but got {type(max_lr)}')
self._max_lr = max_lr
if total_steps is not None:
if not isinstance(total_steps, int):
raise ValueError('the type of total_steps must be int, but'
f'got {type(total_steps)}')
self.total_steps = total_steps
# validate pct_start
if pct_start < 0 or pct_start > 1 or not isinstance(pct_start, float):
raise ValueError('expected float between 0 and 1 pct_start, but '
......@@ -480,6 +489,16 @@ class OneCycleLrUpdaterHook(LrUpdaterHook):
super(OneCycleLrUpdaterHook, self).__init__(**kwargs)
def before_run(self, runner):
if hasattr(self, 'total_steps'):
total_steps = self.total_steps
else:
total_steps = runner.max_iters
if total_steps < runner.max_iters:
raise ValueError(
'The total steps must be greater than or equal to max '
f'iterations {runner.max_iters} of runner, but total steps '
f'is {total_steps}.')
if isinstance(runner.optimizer, dict):
self.base_lr = {}
for k, optim in runner.optimizer.items():
......@@ -495,25 +514,18 @@ class OneCycleLrUpdaterHook(LrUpdaterHook):
group.setdefault('initial_lr', lr)
if self.three_phase:
self.lr_phases.append(
[float(self.pct_start * total_steps) - 1, 1, self.div_factor])
self.lr_phases.append([
float(self.pct_start * runner.max_iters) - 1, 1,
self.div_factor
])
self.lr_phases.append([
float(2 * self.pct_start * runner.max_iters) - 2,
self.div_factor, 1
float(2 * self.pct_start * total_steps) - 2, self.div_factor, 1
])
self.lr_phases.append(
[runner.max_iters - 1, 1, 1 / self.final_div_factor])
[total_steps - 1, 1, 1 / self.final_div_factor])
else:
self.lr_phases.append([
float(self.pct_start * runner.max_iters) - 1, 1,
self.div_factor
])
self.lr_phases.append([
runner.max_iters - 1, self.div_factor,
1 / self.final_div_factor
])
self.lr_phases.append(
[float(self.pct_start * total_steps) - 1, 1, self.div_factor])
self.lr_phases.append(
[total_steps - 1, self.div_factor, 1 / self.final_div_factor])
def get_lr(self, runner, base_lr):
curr_iter = runner.iter
......
......@@ -305,8 +305,10 @@ def test_cosine_runner_hook(multi_optimziers):
hook.writer.add_scalars.assert_has_calls(calls, any_order=True)
@pytest.mark.parametrize('multi_optimziers', (True, False))
def test_one_cycle_runner_hook(multi_optimziers):
@pytest.mark.parametrize('multi_optimziers, max_iters', [(True, 10), (True, 2),
(False, 10),
(False, 2)])
def test_one_cycle_runner_hook(multi_optimziers, max_iters):
"""Test OneCycleLrUpdaterHook and OneCycleMomentumUpdaterHook."""
with pytest.raises(AssertionError):
# by_epoch should be False
......@@ -396,6 +398,40 @@ def test_one_cycle_runner_hook(multi_optimziers):
]
hook.writer.add_scalars.assert_has_calls(calls, any_order=True)
# Test OneCycleLrUpdaterHook
sys.modules['pavi'] = MagicMock()
loader = DataLoader(torch.ones((10, 2)))
runner = _build_demo_runner(
runner_type='IterBasedRunner', max_epochs=None, max_iters=max_iters)
args = dict(
max_lr=0.01,
total_steps=5,
pct_start=0.5,
anneal_strategy='linear',
div_factor=25,
final_div_factor=1e4,
)
hook = OneCycleLrUpdaterHook(**args)
runner.register_hook(hook)
if max_iters == 10:
# test total_steps < max_iters
with pytest.raises(ValueError):
runner.run([loader], [('train', 1)])
else:
# test total_steps > max_iters
runner.run([loader], [('train', 1)])
lr_last = runner.current_lr()
t = torch.tensor([0.0], requires_grad=True)
optim = torch.optim.SGD([t], lr=0.01)
lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(optim, **args)
lr_target = []
for _ in range(max_iters):
optim.step()
lr_target.append(optim.param_groups[0]['lr'])
lr_scheduler.step()
assert lr_target[-1] == lr_last[0]
@pytest.mark.parametrize('multi_optimziers', (True, False))
def test_cosine_restart_lr_update_hook(multi_optimziers):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment