"docs/vscode:/vscode.git/clone" did not exist on "8024ded58f94a6ca0a1dc187ecb3e4a963ebb8fc"
Unverified Commit 37121bf4 authored by imyhxy's avatar imyhxy Committed by GitHub
Browse files

[Fix] Fix the bug of momentum updater with warmup policy (#1626)

* fix momentum updater typo bug

* Add unit test for 'constant', 'linear' and 'exp' momentum warmup policy
parent 88d802ce
......@@ -74,18 +74,18 @@ class MomentumUpdaterHook(Hook):
if self.warmup == 'constant':
warmup_momentum = [
_momentum / self.warmup_ratio
for _momentum in self.regular_momentum
for _momentum in regular_momentum
]
elif self.warmup == 'linear':
k = (1 - cur_iters / self.warmup_iters) * (1 -
self.warmup_ratio)
warmup_momentum = [
_momentum / (1 - k) for _momentum in self.regular_mom
_momentum / (1 - k) for _momentum in regular_momentum
]
elif self.warmup == 'exp':
k = self.warmup_ratio**(1 - cur_iters / self.warmup_iters)
warmup_momentum = [
_momentum / k for _momentum in self.regular_mom
_momentum / k for _momentum in regular_momentum
]
return warmup_momentum
......@@ -128,15 +128,15 @@ class MomentumUpdaterHook(Hook):
def before_train_epoch(self, runner):
if not self.by_epoch:
return
self.regular_mom = self.get_regular_momentum(runner)
self._set_momentum(runner, self.regular_mom)
self.regular_momentum = self.get_regular_momentum(runner)
self._set_momentum(runner, self.regular_momentum)
def before_train_iter(self, runner):
cur_iter = runner.iter
if not self.by_epoch:
self.regular_mom = self.get_regular_momentum(runner)
self.regular_momentum = self.get_regular_momentum(runner)
if self.warmup is None or cur_iter >= self.warmup_iters:
self._set_momentum(runner, self.regular_mom)
self._set_momentum(runner, self.regular_momentum)
else:
warmup_momentum = self.get_warmup_momentum(cur_iter)
self._set_momentum(runner, warmup_momentum)
......@@ -144,7 +144,7 @@ class MomentumUpdaterHook(Hook):
if self.warmup is None or cur_iter > self.warmup_iters:
return
elif cur_iter == self.warmup_iters:
self._set_momentum(runner, self.regular_mom)
self._set_momentum(runner, self.regular_momentum)
else:
warmup_momentum = self.get_warmup_momentum(cur_iter)
self._set_momentum(runner, warmup_momentum)
......
......@@ -137,15 +137,15 @@ def test_checkpoint_hook(tmp_path):
runner.meta = dict()
out_dir = 's3://user/data'
with patch.object(PetrelBackend, 'put') as mock_put, \
patch.object(PetrelBackend, 'remove') as mock_remove, \
patch.object(PetrelBackend, 'isfile') as mock_isfile:
patch.object(PetrelBackend, 'remove') as mock_remove, \
patch.object(PetrelBackend, 'isfile') as mock_isfile:
checkpointhook = CheckpointHook(
interval=1, out_dir=out_dir, by_epoch=True, max_keep_ckpts=2)
runner.register_hook(checkpointhook)
runner.run([loader], [('train', 1)])
basename = osp.basename(runner.work_dir.rstrip(osp.sep))
assert runner.meta['hook_msgs']['last_ckpt'] == \
'/'.join([out_dir, basename, 'epoch_4.pth'])
'/'.join([out_dir, basename, 'epoch_4.pth'])
mock_put.assert_called()
mock_remove.assert_called()
mock_isfile.assert_called()
......@@ -168,15 +168,15 @@ def test_checkpoint_hook(tmp_path):
runner.meta = dict()
out_dir = 's3://user/data'
with patch.object(PetrelBackend, 'put') as mock_put, \
patch.object(PetrelBackend, 'remove') as mock_remove, \
patch.object(PetrelBackend, 'isfile') as mock_isfile:
patch.object(PetrelBackend, 'remove') as mock_remove, \
patch.object(PetrelBackend, 'isfile') as mock_isfile:
checkpointhook = CheckpointHook(
interval=1, out_dir=out_dir, by_epoch=False, max_keep_ckpts=2)
runner.register_hook(checkpointhook)
runner.run([loader], [('train', 1)])
basename = osp.basename(runner.work_dir.rstrip(osp.sep))
assert runner.meta['hook_msgs']['last_ckpt'] == \
'/'.join([out_dir, basename, 'iter_4.pth'])
'/'.join([out_dir, basename, 'iter_4.pth'])
mock_put.assert_called()
mock_remove.assert_called()
mock_isfile.assert_called()
......@@ -335,7 +335,7 @@ def test_pavi_hook():
'learning_rate': 0.02,
'momentum': 0.95
}, 1)
# in windows environment, the latest checkpoint is copied from epoch_1.pth
# in Windows environment, the latest checkpoint is copied from epoch_1.pth
if platform.system() == 'Windows':
snapshot_file_path = osp.join(runner.work_dir, 'latest.pth')
else:
......@@ -429,6 +429,198 @@ def test_momentum_runner_hook(multi_optimziers):
]
hook.writer.add_scalars.assert_has_calls(calls, any_order=True)
# test constant momentum warmup
sys.modules['pavi'] = MagicMock()
runner = _build_demo_runner(multi_optimziers=multi_optimziers)
# add momentum scheduler
hook_cfg = dict(
type='StepMomentumUpdaterHook',
by_epoch=False,
warmup='constant',
warmup_iters=5,
warmup_ratio=0.5,
step=[10],
)
runner.register_hook_from_cfg(hook_cfg)
runner.register_hook_from_cfg(dict(type='IterTimerHook'))
hook = PaviLoggerHook(interval=1, add_graph=False, add_last_ckpt=True)
runner.register_hook(hook)
runner.run([loader], [('train', 1)])
shutil.rmtree(runner.work_dir)
assert hasattr(hook, 'writer')
if multi_optimziers:
calls = [
call(
'train', {
'learning_rate/model1': 0.02,
'learning_rate/model2': 0.01,
'momentum/model1': 1.9,
'momentum/model2': 1.8,
}, 1),
call(
'train', {
'learning_rate/model1': 0.02,
'learning_rate/model2': 0.01,
'momentum/model1': 1.9,
'momentum/model2': 1.8,
}, 5),
call(
'train', {
'learning_rate/model1': 0.02,
'learning_rate/model2': 0.01,
'momentum/model1': 0.95,
'momentum/model2': 0.9,
}, 10),
]
else:
calls = [
call('train', {
'learning_rate': 0.02,
'momentum': 1.9
}, 1),
call('train', {
'learning_rate': 0.02,
'momentum': 1.9
}, 5),
call('train', {
'learning_rate': 0.02,
'momentum': 0.95
}, 10),
]
hook.writer.add_scalars.assert_has_calls(calls, any_order=True)
# test linear momentum warmup
sys.modules['pavi'] = MagicMock()
runner = _build_demo_runner(multi_optimziers=multi_optimziers)
# add momentum scheduler
hook_cfg = dict(
type='StepMomentumUpdaterHook',
by_epoch=False,
warmup='linear',
warmup_iters=5,
warmup_ratio=0.5,
step=[10],
)
runner.register_hook_from_cfg(hook_cfg)
runner.register_hook_from_cfg(dict(type='IterTimerHook'))
hook = PaviLoggerHook(interval=1, add_graph=False, add_last_ckpt=True)
runner.register_hook(hook)
runner.run([loader], [('train', 1)])
shutil.rmtree(runner.work_dir)
assert hasattr(hook, 'writer')
if multi_optimziers:
calls = [
call(
'train', {
'learning_rate/model1': 0.02,
'learning_rate/model2': 0.01,
'momentum/model1': 1.9,
'momentum/model2': 1.8,
}, 1),
call(
'train', {
'learning_rate/model1': 0.02,
'learning_rate/model2': 0.01,
'momentum/model1': 1.3571428571428572,
'momentum/model2': 1.2857142857142858,
}, 3),
call(
'train', {
'learning_rate/model1': 0.02,
'learning_rate/model2': 0.01,
'momentum/model1': 0.95,
'momentum/model2': 0.9,
}, 10),
]
else:
calls = [
call('train', {
'learning_rate': 0.02,
'momentum': 1.9
}, 1),
call('train', {
'learning_rate': 0.02,
'momentum': 1.3571428571428572
}, 3),
call('train', {
'learning_rate': 0.02,
'momentum': 0.95
}, 10),
]
hook.writer.add_scalars.assert_has_calls(calls, any_order=True)
# test exponentially momentum warmup
sys.modules['pavi'] = MagicMock()
runner = _build_demo_runner(multi_optimziers=multi_optimziers)
# add momentum scheduler
hook_cfg = dict(
type='StepMomentumUpdaterHook',
by_epoch=False,
warmup='exp',
warmup_iters=5,
warmup_ratio=0.5,
step=[10],
)
runner.register_hook_from_cfg(hook_cfg)
runner.register_hook_from_cfg(dict(type='IterTimerHook'))
hook = PaviLoggerHook(interval=1, add_graph=False, add_last_ckpt=True)
runner.register_hook(hook)
runner.run([loader], [('train', 1)])
shutil.rmtree(runner.work_dir)
assert hasattr(hook, 'writer')
if multi_optimziers:
calls = [
call(
'train', {
'learning_rate/model1': 0.02,
'learning_rate/model2': 0.01,
'momentum/model1': 1.9,
'momentum/model2': 1.8,
}, 1),
call(
'train', {
'learning_rate/model1': 0.02,
'learning_rate/model2': 0.01,
'momentum/model1': 1.4399307381848783,
'momentum/model2': 1.3641449098593583,
}, 3),
call(
'train', {
'learning_rate/model1': 0.02,
'learning_rate/model2': 0.01,
'momentum/model1': 0.95,
'momentum/model2': 0.9,
}, 10),
]
else:
calls = [
call('train', {
'learning_rate': 0.02,
'momentum': 1.9
}, 1),
call('train', {
'learning_rate': 0.02,
'momentum': 1.4399307381848783
}, 3),
call('train', {
'learning_rate': 0.02,
'momentum': 0.95
}, 10),
]
hook.writer.add_scalars.assert_has_calls(calls, any_order=True)
@pytest.mark.parametrize('multi_optimziers', (True, False))
def test_cosine_runner_hook(multi_optimziers):
......@@ -438,7 +630,6 @@ def test_cosine_runner_hook(multi_optimziers):
runner = _build_demo_runner(multi_optimziers=multi_optimziers)
# add momentum scheduler
hook_cfg = dict(
type='CosineAnnealingMomentumUpdaterHook',
min_momentum_ratio=0.99 / 0.95,
......@@ -1311,7 +1502,6 @@ def _build_demo_runner(runner_type='EpochBasedRunner',
max_epochs=1,
max_iters=None,
multi_optimziers=False):
log_config = dict(
interval=1, hooks=[
dict(type='TextLoggerHook'),
......@@ -1326,7 +1516,6 @@ def _build_demo_runner(runner_type='EpochBasedRunner',
def test_runner_with_revise_keys():
import os
class Model(nn.Module):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment