Unverified Commit d525cfde authored by zhouzaida's avatar zhouzaida Committed by GitHub
Browse files

[Fix] Fix bug of lr updater hook (#907)

* [Fix] fix bug of lr update hook

* [Fix] fix bug of lr update hook

* [Fix] Fix bug of lr updater hook
parent 03a2e3a1
...@@ -82,15 +82,26 @@ class LrUpdaterHook(Hook): ...@@ -82,15 +82,26 @@ class LrUpdaterHook(Hook):
return [self.get_lr(runner, _base_lr) for _base_lr in self.base_lr] return [self.get_lr(runner, _base_lr) for _base_lr in self.base_lr]
def get_warmup_lr(self, cur_iters): def get_warmup_lr(self, cur_iters):
if self.warmup == 'constant':
warmup_lr = [_lr * self.warmup_ratio for _lr in self.regular_lr] def _get_warmup_lr(cur_iters, regular_lr):
elif self.warmup == 'linear': if self.warmup == 'constant':
k = (1 - cur_iters / self.warmup_iters) * (1 - self.warmup_ratio) warmup_lr = [_lr * self.warmup_ratio for _lr in regular_lr]
warmup_lr = [_lr * (1 - k) for _lr in self.regular_lr] elif self.warmup == 'linear':
elif self.warmup == 'exp': k = (1 - cur_iters / self.warmup_iters) * (1 -
k = self.warmup_ratio**(1 - cur_iters / self.warmup_iters) self.warmup_ratio)
warmup_lr = [_lr * k for _lr in self.regular_lr] warmup_lr = [_lr * (1 - k) for _lr in regular_lr]
return warmup_lr elif self.warmup == 'exp':
k = self.warmup_ratio**(1 - cur_iters / self.warmup_iters)
warmup_lr = [_lr * k for _lr in regular_lr]
return warmup_lr
if isinstance(self.regular_lr, dict):
lr_groups = {}
for key, regular_lr in self.regular_lr.items():
lr_groups[key] = _get_warmup_lr(cur_iters, regular_lr)
return lr_groups
else:
return _get_warmup_lr(cur_iters, self.regular_lr)
def before_run(self, runner): def before_run(self, runner):
# NOTE: when resuming from a checkpoint, if 'initial_lr' is not saved, # NOTE: when resuming from a checkpoint, if 'initial_lr' is not saved,
......
...@@ -31,51 +31,97 @@ class MomentumUpdaterHook(Hook): ...@@ -31,51 +31,97 @@ class MomentumUpdaterHook(Hook):
] # expected momentum if no warming up is performed ] # expected momentum if no warming up is performed
def _set_momentum(self, runner, momentum_groups): def _set_momentum(self, runner, momentum_groups):
for param_group, mom in zip(runner.optimizer.param_groups, if isinstance(runner.optimizer, dict):
momentum_groups): for k, optim in runner.optimizer.items():
if 'momentum' in param_group.keys(): for param_group, mom in zip(optim.param_groups,
param_group['momentum'] = mom momentum_groups[k]):
elif 'betas' in param_group.keys(): if 'momentum' in param_group.keys():
param_group['betas'] = (mom, param_group['betas'][1]) param_group['momentum'] = mom
elif 'betas' in param_group.keys():
param_group['betas'] = (mom, param_group['betas'][1])
else:
for param_group, mom in zip(runner.optimizer.param_groups,
momentum_groups):
if 'momentum' in param_group.keys():
param_group['momentum'] = mom
elif 'betas' in param_group.keys():
param_group['betas'] = (mom, param_group['betas'][1])
def get_momentum(self, runner, base_momentum): def get_momentum(self, runner, base_momentum):
raise NotImplementedError raise NotImplementedError
def get_regular_momentum(self, runner): def get_regular_momentum(self, runner):
return [ if isinstance(runner.optimizer, dict):
self.get_momentum(runner, _base_momentum) momentum_groups = {}
for _base_momentum in self.base_momentum for k in runner.optimizer.keys():
] _momentum_group = [
self.get_momentum(runner, _base_momentum)
for _base_momentum in self.base_momentum[k]
]
momentum_groups.update({k: _momentum_group})
return momentum_groups
else:
return [
self.get_momentum(runner, _base_momentum)
for _base_momentum in self.base_momentum
]
def get_warmup_momentum(self, cur_iters): def get_warmup_momentum(self, cur_iters):
if self.warmup == 'constant':
warmup_momentum = [ def _get_warmup_momentum(cur_iters, regular_momentum):
_momentum / self.warmup_ratio if self.warmup == 'constant':
for _momentum in self.regular_momentum warmup_momentum = [
] _momentum / self.warmup_ratio
elif self.warmup == 'linear': for _momentum in self.regular_momentum
k = (1 - cur_iters / self.warmup_iters) * (1 - self.warmup_ratio) ]
warmup_momentum = [ elif self.warmup == 'linear':
_momentum / (1 - k) for _momentum in self.regular_mom k = (1 - cur_iters / self.warmup_iters) * (1 -
] self.warmup_ratio)
elif self.warmup == 'exp': warmup_momentum = [
k = self.warmup_ratio**(1 - cur_iters / self.warmup_iters) _momentum / (1 - k) for _momentum in self.regular_mom
warmup_momentum = [_momentum / k for _momentum in self.regular_mom] ]
return warmup_momentum elif self.warmup == 'exp':
k = self.warmup_ratio**(1 - cur_iters / self.warmup_iters)
warmup_momentum = [
_momentum / k for _momentum in self.regular_mom
]
return warmup_momentum
if isinstance(self.regular_momentum, dict):
momentum_groups = {}
for key, regular_momentum in self.regular_momentum.items():
momentum_groups[key] = _get_warmup_momentum(
cur_iters, regular_momentum)
return momentum_groups
else:
return _get_warmup_momentum(cur_iters, self.regular_momentum)
def before_run(self, runner): def before_run(self, runner):
# NOTE: when resuming from a checkpoint, # NOTE: when resuming from a checkpoint,
# if 'initial_momentum' is not saved, # if 'initial_momentum' is not saved,
# it will be set according to the optimizer params # it will be set according to the optimizer params
for group in runner.optimizer.param_groups: if isinstance(runner.optimizer, dict):
if 'momentum' in group.keys(): self.base_momentum = {}
group.setdefault('initial_momentum', group['momentum']) for k, optim in runner.optimizer.items():
else: for group in optim.param_groups:
group.setdefault('initial_momentum', group['betas'][0]) if 'momentum' in group.keys():
self.base_momentum = [ group.setdefault('initial_momentum', group['momentum'])
group['initial_momentum'] else:
for group in runner.optimizer.param_groups group.setdefault('initial_momentum', group['betas'][0])
] _base_momentum = [
group['initial_momentum'] for group in optim.param_groups
]
self.base_momentum.update({k: _base_momentum})
else:
for group in runner.optimizer.param_groups:
if 'momentum' in group.keys():
group.setdefault('initial_momentum', group['momentum'])
else:
group.setdefault('initial_momentum', group['betas'][0])
self.base_momentum = [
group['initial_momentum']
for group in runner.optimizer.param_groups
]
def before_train_epoch(self, runner): def before_train_epoch(self, runner):
if not self.by_epoch: if not self.by_epoch:
...@@ -383,9 +429,11 @@ class OneCycleMomentumUpdaterHook(MomentumUpdaterHook): ...@@ -383,9 +429,11 @@ class OneCycleMomentumUpdaterHook(MomentumUpdaterHook):
if isinstance(runner.optimizer, dict): if isinstance(runner.optimizer, dict):
momentum_groups = {} momentum_groups = {}
for k, optim in runner.optimizer.items(): for k, optim in runner.optimizer.items():
for param_group in optim.param_groups: _momentum_group = [
momentum_groups[k].append( self.get_momentum(runner, param_group)
self.get_momentum(runner, param_group)) for param_group in optim.param_groups
]
momentum_groups.update({k: _momentum_group})
return momentum_groups return momentum_groups
else: else:
momentum_groups = [] momentum_groups = []
......
...@@ -151,11 +151,12 @@ def test_sync_buffers_hook(): ...@@ -151,11 +151,12 @@ def test_sync_buffers_hook():
shutil.rmtree(runner.work_dir) shutil.rmtree(runner.work_dir)
def test_momentum_runner_hook(): @pytest.mark.parametrize('multi_optimziers', (True, False))
def test_momentum_runner_hook(multi_optimziers):
"""xdoctest -m tests/test_hooks.py test_momentum_runner_hook.""" """xdoctest -m tests/test_hooks.py test_momentum_runner_hook."""
sys.modules['pavi'] = MagicMock() sys.modules['pavi'] = MagicMock()
loader = DataLoader(torch.ones((10, 2))) loader = DataLoader(torch.ones((10, 2)))
runner = _build_demo_runner() runner = _build_demo_runner(multi_optimziers=multi_optimziers)
# add momentum scheduler # add momentum scheduler
hook_cfg = dict( hook_cfg = dict(
...@@ -184,28 +185,54 @@ def test_momentum_runner_hook(): ...@@ -184,28 +185,54 @@ def test_momentum_runner_hook():
# TODO: use a more elegant way to check values # TODO: use a more elegant way to check values
assert hasattr(hook, 'writer') assert hasattr(hook, 'writer')
calls = [ if multi_optimziers:
call('train', { calls = [
'learning_rate': 0.01999999999999999, call(
'momentum': 0.95 'train', {
}, 1), 'learning_rate/model1': 0.01999999999999999,
call('train', { 'learning_rate/model2': 0.009999999999999995,
'learning_rate': 0.2, 'momentum/model1': 0.95,
'momentum': 0.85 'momentum/model2': 0.9,
}, 5), }, 1),
call('train', { call(
'learning_rate': 0.155, 'train', {
'momentum': 0.875 'learning_rate/model1': 0.2,
}, 7), 'learning_rate/model2': 0.1,
] 'momentum/model1': 0.85,
'momentum/model2': 0.8052631578947369,
}, 5),
call(
'train', {
'learning_rate/model1': 0.155,
'learning_rate/model2': 0.0775,
'momentum/model1': 0.875,
'momentum/model2': 0.8289473684210527,
}, 7)
]
else:
calls = [
call('train', {
'learning_rate': 0.01999999999999999,
'momentum': 0.95
}, 1),
call('train', {
'learning_rate': 0.2,
'momentum': 0.85
}, 5),
call('train', {
'learning_rate': 0.155,
'momentum': 0.875
}, 7),
]
hook.writer.add_scalars.assert_has_calls(calls, any_order=True) hook.writer.add_scalars.assert_has_calls(calls, any_order=True)
def test_cosine_runner_hook(): @pytest.mark.parametrize('multi_optimziers', (True, False))
def test_cosine_runner_hook(multi_optimziers):
"""xdoctest -m tests/test_hooks.py test_cosine_runner_hook.""" """xdoctest -m tests/test_hooks.py test_cosine_runner_hook."""
sys.modules['pavi'] = MagicMock() sys.modules['pavi'] = MagicMock()
loader = DataLoader(torch.ones((10, 2))) loader = DataLoader(torch.ones((10, 2)))
runner = _build_demo_runner() runner = _build_demo_runner(multi_optimziers=multi_optimziers)
# add momentum scheduler # add momentum scheduler
...@@ -235,24 +262,51 @@ def test_cosine_runner_hook(): ...@@ -235,24 +262,51 @@ def test_cosine_runner_hook():
# TODO: use a more elegant way to check values # TODO: use a more elegant way to check values
assert hasattr(hook, 'writer') assert hasattr(hook, 'writer')
calls = [ if multi_optimziers:
call('train', { calls = [
'learning_rate': 0.02, call(
'momentum': 0.95 'train', {
}, 1), 'learning_rate/model1': 0.02,
call('train', { 'learning_rate/model2': 0.01,
'learning_rate': 0.01, 'momentum/model1': 0.95,
'momentum': 0.97 'momentum/model2': 0.9,
}, 6), }, 1),
call('train', { call(
'learning_rate': 0.0004894348370484647, 'train', {
'momentum': 0.9890211303259032 'learning_rate/model1': 0.01,
}, 10) 'learning_rate/model2': 0.005,
] 'momentum/model1': 0.97,
'momentum/model2': 0.9189473684210527,
}, 6),
call(
'train', {
'learning_rate/model1': 0.0004894348370484647,
'learning_rate/model2': 0.00024471741852423234,
'momentum/model1': 0.9890211303259032,
'momentum/model2': 0.9369673866245399,
}, 10)
]
else:
calls = [
call('train', {
'learning_rate': 0.02,
'momentum': 0.95
}, 1),
call('train', {
'learning_rate': 0.01,
'momentum': 0.97
}, 6),
call(
'train', {
'learning_rate': 0.0004894348370484647,
'momentum': 0.9890211303259032
}, 10)
]
hook.writer.add_scalars.assert_has_calls(calls, any_order=True) hook.writer.add_scalars.assert_has_calls(calls, any_order=True)
def test_one_cycle_runner_hook(): @pytest.mark.parametrize('multi_optimziers', (True, False))
def test_one_cycle_runner_hook(multi_optimziers):
"""Test OneCycleLrUpdaterHook and OneCycleMomentumUpdaterHook.""" """Test OneCycleLrUpdaterHook and OneCycleMomentumUpdaterHook."""
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
# by_epoch should be False # by_epoch should be False
...@@ -268,7 +322,7 @@ def test_one_cycle_runner_hook(): ...@@ -268,7 +322,7 @@ def test_one_cycle_runner_hook():
sys.modules['pavi'] = MagicMock() sys.modules['pavi'] = MagicMock()
loader = DataLoader(torch.ones((10, 2))) loader = DataLoader(torch.ones((10, 2)))
runner = _build_demo_runner() runner = _build_demo_runner(multi_optimziers=multi_optimziers)
# add momentum scheduler # add momentum scheduler
hook_cfg = dict( hook_cfg = dict(
...@@ -280,7 +334,7 @@ def test_one_cycle_runner_hook(): ...@@ -280,7 +334,7 @@ def test_one_cycle_runner_hook():
three_phase=False) three_phase=False)
runner.register_hook_from_cfg(hook_cfg) runner.register_hook_from_cfg(hook_cfg)
# add momentum LR scheduler # add LR scheduler
hook_cfg = dict( hook_cfg = dict(
type='OneCycleLrUpdaterHook', type='OneCycleLrUpdaterHook',
max_lr=0.01, max_lr=0.01,
...@@ -300,24 +354,51 @@ def test_one_cycle_runner_hook(): ...@@ -300,24 +354,51 @@ def test_one_cycle_runner_hook():
# TODO: use a more elegant way to check values # TODO: use a more elegant way to check values
assert hasattr(hook, 'writer') assert hasattr(hook, 'writer')
calls = [ if multi_optimziers:
call('train', { calls = [
'learning_rate': 0.0003999999999999993, call(
'momentum': 0.95 'train', {
}, 1), 'learning_rate/model1': 0.0003999999999999993,
call('train', { 'learning_rate/model2': 0.0003999999999999993,
'learning_rate': 0.00904508879153485, 'momentum/model1': 0.95,
'momentum': 0.8595491502812526 'momentum/model2': 0.95,
}, 6), }, 1),
call('train', { call(
'learning_rate': 4e-08, 'train', {
'momentum': 0.95 'learning_rate/model1': 0.00904508879153485,
}, 10) 'learning_rate/model2': 0.00904508879153485,
] 'momentum/model1': 0.8595491502812526,
'momentum/model2': 0.8595491502812526,
}, 6),
call(
'train', {
'learning_rate/model1': 4e-08,
'learning_rate/model2': 4e-08,
'momentum/model1': 0.95,
'momentum/model2': 0.95,
}, 10)
]
else:
calls = [
call('train', {
'learning_rate': 0.0003999999999999993,
'momentum': 0.95
}, 1),
call(
'train', {
'learning_rate': 0.00904508879153485,
'momentum': 0.8595491502812526
}, 6),
call('train', {
'learning_rate': 4e-08,
'momentum': 0.95
}, 10)
]
hook.writer.add_scalars.assert_has_calls(calls, any_order=True) hook.writer.add_scalars.assert_has_calls(calls, any_order=True)
def test_cosine_restart_lr_update_hook(): @pytest.mark.parametrize('multi_optimziers', (True, False))
def test_cosine_restart_lr_update_hook(multi_optimziers):
"""Test CosineRestartLrUpdaterHook.""" """Test CosineRestartLrUpdaterHook."""
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
# either `min_lr` or `min_lr_ratio` should be specified # either `min_lr` or `min_lr_ratio` should be specified
...@@ -359,7 +440,7 @@ def test_cosine_restart_lr_update_hook(): ...@@ -359,7 +440,7 @@ def test_cosine_restart_lr_update_hook():
sys.modules['pavi'] = MagicMock() sys.modules['pavi'] = MagicMock()
loader = DataLoader(torch.ones((10, 2))) loader = DataLoader(torch.ones((10, 2)))
runner = _build_demo_runner() runner = _build_demo_runner(multi_optimziers=multi_optimziers)
# add cosine restart LR scheduler # add cosine restart LR scheduler
hook = CosineRestartLrUpdaterHook( hook = CosineRestartLrUpdaterHook(
...@@ -378,20 +459,45 @@ def test_cosine_restart_lr_update_hook(): ...@@ -378,20 +459,45 @@ def test_cosine_restart_lr_update_hook():
# TODO: use a more elegant way to check values # TODO: use a more elegant way to check values
assert hasattr(hook, 'writer') assert hasattr(hook, 'writer')
calls = [ if multi_optimziers:
call('train', { calls = [
'learning_rate': 0.01, call(
'momentum': 0.95 'train', {
}, 1), 'learning_rate/model1': 0.01,
call('train', { 'learning_rate/model2': 0.005,
'learning_rate': 0.01, 'momentum/model1': 0.95,
'momentum': 0.95 'momentum/model2': 0.9,
}, 6), }, 1),
call('train', { call(
'learning_rate': 0.0009549150281252633, 'train', {
'momentum': 0.95 'learning_rate/model1': 0.01,
}, 10) 'learning_rate/model2': 0.005,
] 'momentum/model1': 0.95,
'momentum/model2': 0.9,
}, 6),
call(
'train', {
'learning_rate/model1': 0.0009549150281252633,
'learning_rate/model2': 0.00047745751406263163,
'momentum/model1': 0.95,
'momentum/model2': 0.9,
}, 10)
]
else:
calls = [
call('train', {
'learning_rate': 0.01,
'momentum': 0.95
}, 1),
call('train', {
'learning_rate': 0.01,
'momentum': 0.95
}, 6),
call('train', {
'learning_rate': 0.0009549150281252633,
'momentum': 0.95
}, 10)
]
hook.writer.add_scalars.assert_has_calls(calls, any_order=True) hook.writer.add_scalars.assert_has_calls(calls, any_order=True)
...@@ -443,13 +549,15 @@ def test_wandb_hook(): ...@@ -443,13 +549,15 @@ def test_wandb_hook():
def _build_demo_runner(runner_type='EpochBasedRunner', def _build_demo_runner(runner_type='EpochBasedRunner',
max_epochs=1, max_epochs=1,
max_iters=None): max_iters=None,
multi_optimziers=False):
class Model(nn.Module): class Model(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.linear = nn.Linear(2, 1) self.linear = nn.Linear(2, 1)
self.conv = nn.Conv2d(3, 3, 3)
def forward(self, x): def forward(self, x):
return self.linear(x) return self.linear(x)
...@@ -462,7 +570,15 @@ def _build_demo_runner(runner_type='EpochBasedRunner', ...@@ -462,7 +570,15 @@ def _build_demo_runner(runner_type='EpochBasedRunner',
model = Model() model = Model()
optimizer = torch.optim.SGD(model.parameters(), lr=0.02, momentum=0.95) if multi_optimziers:
optimizer = {
'model1':
torch.optim.SGD(model.linear.parameters(), lr=0.02, momentum=0.95),
'model2':
torch.optim.SGD(model.conv.parameters(), lr=0.01, momentum=0.9),
}
else:
optimizer = torch.optim.SGD(model.parameters(), lr=0.02, momentum=0.95)
log_config = dict( log_config = dict(
interval=1, hooks=[ interval=1, hooks=[
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment