Unverified Commit 37121bf4 authored by imyhxy's avatar imyhxy Committed by GitHub
Browse files

[Fix] Fix the bug of momentum updater with warmup policy (#1626)

* fix momentum updater typo bug

* Add unit test for 'constant', 'linear' and 'exp' momentum warmup policy
parent 88d802ce
...@@ -74,18 +74,18 @@ class MomentumUpdaterHook(Hook): ...@@ -74,18 +74,18 @@ class MomentumUpdaterHook(Hook):
if self.warmup == 'constant': if self.warmup == 'constant':
warmup_momentum = [ warmup_momentum = [
_momentum / self.warmup_ratio _momentum / self.warmup_ratio
for _momentum in self.regular_momentum for _momentum in regular_momentum
] ]
elif self.warmup == 'linear': elif self.warmup == 'linear':
k = (1 - cur_iters / self.warmup_iters) * (1 - k = (1 - cur_iters / self.warmup_iters) * (1 -
self.warmup_ratio) self.warmup_ratio)
warmup_momentum = [ warmup_momentum = [
_momentum / (1 - k) for _momentum in self.regular_mom _momentum / (1 - k) for _momentum in regular_momentum
] ]
elif self.warmup == 'exp': elif self.warmup == 'exp':
k = self.warmup_ratio**(1 - cur_iters / self.warmup_iters) k = self.warmup_ratio**(1 - cur_iters / self.warmup_iters)
warmup_momentum = [ warmup_momentum = [
_momentum / k for _momentum in self.regular_mom _momentum / k for _momentum in regular_momentum
] ]
return warmup_momentum return warmup_momentum
...@@ -128,15 +128,15 @@ class MomentumUpdaterHook(Hook): ...@@ -128,15 +128,15 @@ class MomentumUpdaterHook(Hook):
def before_train_epoch(self, runner): def before_train_epoch(self, runner):
if not self.by_epoch: if not self.by_epoch:
return return
self.regular_mom = self.get_regular_momentum(runner) self.regular_momentum = self.get_regular_momentum(runner)
self._set_momentum(runner, self.regular_mom) self._set_momentum(runner, self.regular_momentum)
def before_train_iter(self, runner): def before_train_iter(self, runner):
cur_iter = runner.iter cur_iter = runner.iter
if not self.by_epoch: if not self.by_epoch:
self.regular_mom = self.get_regular_momentum(runner) self.regular_momentum = self.get_regular_momentum(runner)
if self.warmup is None or cur_iter >= self.warmup_iters: if self.warmup is None or cur_iter >= self.warmup_iters:
self._set_momentum(runner, self.regular_mom) self._set_momentum(runner, self.regular_momentum)
else: else:
warmup_momentum = self.get_warmup_momentum(cur_iter) warmup_momentum = self.get_warmup_momentum(cur_iter)
self._set_momentum(runner, warmup_momentum) self._set_momentum(runner, warmup_momentum)
...@@ -144,7 +144,7 @@ class MomentumUpdaterHook(Hook): ...@@ -144,7 +144,7 @@ class MomentumUpdaterHook(Hook):
if self.warmup is None or cur_iter > self.warmup_iters: if self.warmup is None or cur_iter > self.warmup_iters:
return return
elif cur_iter == self.warmup_iters: elif cur_iter == self.warmup_iters:
self._set_momentum(runner, self.regular_mom) self._set_momentum(runner, self.regular_momentum)
else: else:
warmup_momentum = self.get_warmup_momentum(cur_iter) warmup_momentum = self.get_warmup_momentum(cur_iter)
self._set_momentum(runner, warmup_momentum) self._set_momentum(runner, warmup_momentum)
......
...@@ -335,7 +335,7 @@ def test_pavi_hook(): ...@@ -335,7 +335,7 @@ def test_pavi_hook():
'learning_rate': 0.02, 'learning_rate': 0.02,
'momentum': 0.95 'momentum': 0.95
}, 1) }, 1)
# in windows environment, the latest checkpoint is copied from epoch_1.pth # in Windows environment, the latest checkpoint is copied from epoch_1.pth
if platform.system() == 'Windows': if platform.system() == 'Windows':
snapshot_file_path = osp.join(runner.work_dir, 'latest.pth') snapshot_file_path = osp.join(runner.work_dir, 'latest.pth')
else: else:
...@@ -429,6 +429,198 @@ def test_momentum_runner_hook(multi_optimziers): ...@@ -429,6 +429,198 @@ def test_momentum_runner_hook(multi_optimziers):
] ]
hook.writer.add_scalars.assert_has_calls(calls, any_order=True) hook.writer.add_scalars.assert_has_calls(calls, any_order=True)
# test constant momentum warmup
sys.modules['pavi'] = MagicMock()
runner = _build_demo_runner(multi_optimziers=multi_optimziers)
# add momentum scheduler
hook_cfg = dict(
type='StepMomentumUpdaterHook',
by_epoch=False,
warmup='constant',
warmup_iters=5,
warmup_ratio=0.5,
step=[10],
)
runner.register_hook_from_cfg(hook_cfg)
runner.register_hook_from_cfg(dict(type='IterTimerHook'))
hook = PaviLoggerHook(interval=1, add_graph=False, add_last_ckpt=True)
runner.register_hook(hook)
runner.run([loader], [('train', 1)])
shutil.rmtree(runner.work_dir)
assert hasattr(hook, 'writer')
if multi_optimziers:
calls = [
call(
'train', {
'learning_rate/model1': 0.02,
'learning_rate/model2': 0.01,
'momentum/model1': 1.9,
'momentum/model2': 1.8,
}, 1),
call(
'train', {
'learning_rate/model1': 0.02,
'learning_rate/model2': 0.01,
'momentum/model1': 1.9,
'momentum/model2': 1.8,
}, 5),
call(
'train', {
'learning_rate/model1': 0.02,
'learning_rate/model2': 0.01,
'momentum/model1': 0.95,
'momentum/model2': 0.9,
}, 10),
]
else:
calls = [
call('train', {
'learning_rate': 0.02,
'momentum': 1.9
}, 1),
call('train', {
'learning_rate': 0.02,
'momentum': 1.9
}, 5),
call('train', {
'learning_rate': 0.02,
'momentum': 0.95
}, 10),
]
hook.writer.add_scalars.assert_has_calls(calls, any_order=True)
# test linear momentum warmup
sys.modules['pavi'] = MagicMock()
runner = _build_demo_runner(multi_optimziers=multi_optimziers)
# add momentum scheduler
hook_cfg = dict(
type='StepMomentumUpdaterHook',
by_epoch=False,
warmup='linear',
warmup_iters=5,
warmup_ratio=0.5,
step=[10],
)
runner.register_hook_from_cfg(hook_cfg)
runner.register_hook_from_cfg(dict(type='IterTimerHook'))
hook = PaviLoggerHook(interval=1, add_graph=False, add_last_ckpt=True)
runner.register_hook(hook)
runner.run([loader], [('train', 1)])
shutil.rmtree(runner.work_dir)
assert hasattr(hook, 'writer')
if multi_optimziers:
calls = [
call(
'train', {
'learning_rate/model1': 0.02,
'learning_rate/model2': 0.01,
'momentum/model1': 1.9,
'momentum/model2': 1.8,
}, 1),
call(
'train', {
'learning_rate/model1': 0.02,
'learning_rate/model2': 0.01,
'momentum/model1': 1.3571428571428572,
'momentum/model2': 1.2857142857142858,
}, 3),
call(
'train', {
'learning_rate/model1': 0.02,
'learning_rate/model2': 0.01,
'momentum/model1': 0.95,
'momentum/model2': 0.9,
}, 10),
]
else:
calls = [
call('train', {
'learning_rate': 0.02,
'momentum': 1.9
}, 1),
call('train', {
'learning_rate': 0.02,
'momentum': 1.3571428571428572
}, 3),
call('train', {
'learning_rate': 0.02,
'momentum': 0.95
}, 10),
]
hook.writer.add_scalars.assert_has_calls(calls, any_order=True)
# test exponentially momentum warmup
sys.modules['pavi'] = MagicMock()
runner = _build_demo_runner(multi_optimziers=multi_optimziers)
# add momentum scheduler
hook_cfg = dict(
type='StepMomentumUpdaterHook',
by_epoch=False,
warmup='exp',
warmup_iters=5,
warmup_ratio=0.5,
step=[10],
)
runner.register_hook_from_cfg(hook_cfg)
runner.register_hook_from_cfg(dict(type='IterTimerHook'))
hook = PaviLoggerHook(interval=1, add_graph=False, add_last_ckpt=True)
runner.register_hook(hook)
runner.run([loader], [('train', 1)])
shutil.rmtree(runner.work_dir)
assert hasattr(hook, 'writer')
if multi_optimziers:
calls = [
call(
'train', {
'learning_rate/model1': 0.02,
'learning_rate/model2': 0.01,
'momentum/model1': 1.9,
'momentum/model2': 1.8,
}, 1),
call(
'train', {
'learning_rate/model1': 0.02,
'learning_rate/model2': 0.01,
'momentum/model1': 1.4399307381848783,
'momentum/model2': 1.3641449098593583,
}, 3),
call(
'train', {
'learning_rate/model1': 0.02,
'learning_rate/model2': 0.01,
'momentum/model1': 0.95,
'momentum/model2': 0.9,
}, 10),
]
else:
calls = [
call('train', {
'learning_rate': 0.02,
'momentum': 1.9
}, 1),
call('train', {
'learning_rate': 0.02,
'momentum': 1.4399307381848783
}, 3),
call('train', {
'learning_rate': 0.02,
'momentum': 0.95
}, 10),
]
hook.writer.add_scalars.assert_has_calls(calls, any_order=True)
@pytest.mark.parametrize('multi_optimziers', (True, False)) @pytest.mark.parametrize('multi_optimziers', (True, False))
def test_cosine_runner_hook(multi_optimziers): def test_cosine_runner_hook(multi_optimziers):
...@@ -438,7 +630,6 @@ def test_cosine_runner_hook(multi_optimziers): ...@@ -438,7 +630,6 @@ def test_cosine_runner_hook(multi_optimziers):
runner = _build_demo_runner(multi_optimziers=multi_optimziers) runner = _build_demo_runner(multi_optimziers=multi_optimziers)
# add momentum scheduler # add momentum scheduler
hook_cfg = dict( hook_cfg = dict(
type='CosineAnnealingMomentumUpdaterHook', type='CosineAnnealingMomentumUpdaterHook',
min_momentum_ratio=0.99 / 0.95, min_momentum_ratio=0.99 / 0.95,
...@@ -1311,7 +1502,6 @@ def _build_demo_runner(runner_type='EpochBasedRunner', ...@@ -1311,7 +1502,6 @@ def _build_demo_runner(runner_type='EpochBasedRunner',
max_epochs=1, max_epochs=1,
max_iters=None, max_iters=None,
multi_optimziers=False): multi_optimziers=False):
log_config = dict( log_config = dict(
interval=1, hooks=[ interval=1, hooks=[
dict(type='TextLoggerHook'), dict(type='TextLoggerHook'),
...@@ -1326,7 +1516,6 @@ def _build_demo_runner(runner_type='EpochBasedRunner', ...@@ -1326,7 +1516,6 @@ def _build_demo_runner(runner_type='EpochBasedRunner',
def test_runner_with_revise_keys(): def test_runner_with_revise_keys():
import os import os
class Model(nn.Module): class Model(nn.Module):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment