Unverified Commit d525cfde authored by zhouzaida's avatar zhouzaida Committed by GitHub
Browse files

[Fix] Fix bug of lr updater hook (#907)

* [Fix] fix bug of lr update hook

* [Fix] fix bug of lr update hook

* [Fix] Fix bug of lr updater hook
parent 03a2e3a1
...@@ -82,16 +82,27 @@ class LrUpdaterHook(Hook): ...@@ -82,16 +82,27 @@ class LrUpdaterHook(Hook):
return [self.get_lr(runner, _base_lr) for _base_lr in self.base_lr] return [self.get_lr(runner, _base_lr) for _base_lr in self.base_lr]
def get_warmup_lr(self, cur_iters): def get_warmup_lr(self, cur_iters):
def _get_warmup_lr(cur_iters, regular_lr):
if self.warmup == 'constant': if self.warmup == 'constant':
warmup_lr = [_lr * self.warmup_ratio for _lr in self.regular_lr] warmup_lr = [_lr * self.warmup_ratio for _lr in regular_lr]
elif self.warmup == 'linear': elif self.warmup == 'linear':
k = (1 - cur_iters / self.warmup_iters) * (1 - self.warmup_ratio) k = (1 - cur_iters / self.warmup_iters) * (1 -
warmup_lr = [_lr * (1 - k) for _lr in self.regular_lr] self.warmup_ratio)
warmup_lr = [_lr * (1 - k) for _lr in regular_lr]
elif self.warmup == 'exp': elif self.warmup == 'exp':
k = self.warmup_ratio**(1 - cur_iters / self.warmup_iters) k = self.warmup_ratio**(1 - cur_iters / self.warmup_iters)
warmup_lr = [_lr * k for _lr in self.regular_lr] warmup_lr = [_lr * k for _lr in regular_lr]
return warmup_lr return warmup_lr
if isinstance(self.regular_lr, dict):
lr_groups = {}
for key, regular_lr in self.regular_lr.items():
lr_groups[key] = _get_warmup_lr(cur_iters, regular_lr)
return lr_groups
else:
return _get_warmup_lr(cur_iters, self.regular_lr)
def before_run(self, runner): def before_run(self, runner):
# NOTE: when resuming from a checkpoint, if 'initial_lr' is not saved, # NOTE: when resuming from a checkpoint, if 'initial_lr' is not saved,
# it will be set according to the optimizer params # it will be set according to the optimizer params
......
...@@ -31,6 +31,15 @@ class MomentumUpdaterHook(Hook): ...@@ -31,6 +31,15 @@ class MomentumUpdaterHook(Hook):
] # expected momentum if no warming up is performed ] # expected momentum if no warming up is performed
def _set_momentum(self, runner, momentum_groups): def _set_momentum(self, runner, momentum_groups):
if isinstance(runner.optimizer, dict):
for k, optim in runner.optimizer.items():
for param_group, mom in zip(optim.param_groups,
momentum_groups[k]):
if 'momentum' in param_group.keys():
param_group['momentum'] = mom
elif 'betas' in param_group.keys():
param_group['betas'] = (mom, param_group['betas'][1])
else:
for param_group, mom in zip(runner.optimizer.param_groups, for param_group, mom in zip(runner.optimizer.param_groups,
momentum_groups): momentum_groups):
if 'momentum' in param_group.keys(): if 'momentum' in param_group.keys():
...@@ -42,31 +51,68 @@ class MomentumUpdaterHook(Hook): ...@@ -42,31 +51,68 @@ class MomentumUpdaterHook(Hook):
raise NotImplementedError raise NotImplementedError
def get_regular_momentum(self, runner): def get_regular_momentum(self, runner):
if isinstance(runner.optimizer, dict):
momentum_groups = {}
for k in runner.optimizer.keys():
_momentum_group = [
self.get_momentum(runner, _base_momentum)
for _base_momentum in self.base_momentum[k]
]
momentum_groups.update({k: _momentum_group})
return momentum_groups
else:
return [ return [
self.get_momentum(runner, _base_momentum) self.get_momentum(runner, _base_momentum)
for _base_momentum in self.base_momentum for _base_momentum in self.base_momentum
] ]
def get_warmup_momentum(self, cur_iters): def get_warmup_momentum(self, cur_iters):
def _get_warmup_momentum(cur_iters, regular_momentum):
if self.warmup == 'constant': if self.warmup == 'constant':
warmup_momentum = [ warmup_momentum = [
_momentum / self.warmup_ratio _momentum / self.warmup_ratio
for _momentum in self.regular_momentum for _momentum in self.regular_momentum
] ]
elif self.warmup == 'linear': elif self.warmup == 'linear':
k = (1 - cur_iters / self.warmup_iters) * (1 - self.warmup_ratio) k = (1 - cur_iters / self.warmup_iters) * (1 -
self.warmup_ratio)
warmup_momentum = [ warmup_momentum = [
_momentum / (1 - k) for _momentum in self.regular_mom _momentum / (1 - k) for _momentum in self.regular_mom
] ]
elif self.warmup == 'exp': elif self.warmup == 'exp':
k = self.warmup_ratio**(1 - cur_iters / self.warmup_iters) k = self.warmup_ratio**(1 - cur_iters / self.warmup_iters)
warmup_momentum = [_momentum / k for _momentum in self.regular_mom] warmup_momentum = [
_momentum / k for _momentum in self.regular_mom
]
return warmup_momentum return warmup_momentum
if isinstance(self.regular_momentum, dict):
momentum_groups = {}
for key, regular_momentum in self.regular_momentum.items():
momentum_groups[key] = _get_warmup_momentum(
cur_iters, regular_momentum)
return momentum_groups
else:
return _get_warmup_momentum(cur_iters, self.regular_momentum)
def before_run(self, runner): def before_run(self, runner):
# NOTE: when resuming from a checkpoint, # NOTE: when resuming from a checkpoint,
# if 'initial_momentum' is not saved, # if 'initial_momentum' is not saved,
# it will be set according to the optimizer params # it will be set according to the optimizer params
if isinstance(runner.optimizer, dict):
self.base_momentum = {}
for k, optim in runner.optimizer.items():
for group in optim.param_groups:
if 'momentum' in group.keys():
group.setdefault('initial_momentum', group['momentum'])
else:
group.setdefault('initial_momentum', group['betas'][0])
_base_momentum = [
group['initial_momentum'] for group in optim.param_groups
]
self.base_momentum.update({k: _base_momentum})
else:
for group in runner.optimizer.param_groups: for group in runner.optimizer.param_groups:
if 'momentum' in group.keys(): if 'momentum' in group.keys():
group.setdefault('initial_momentum', group['momentum']) group.setdefault('initial_momentum', group['momentum'])
...@@ -383,9 +429,11 @@ class OneCycleMomentumUpdaterHook(MomentumUpdaterHook): ...@@ -383,9 +429,11 @@ class OneCycleMomentumUpdaterHook(MomentumUpdaterHook):
if isinstance(runner.optimizer, dict): if isinstance(runner.optimizer, dict):
momentum_groups = {} momentum_groups = {}
for k, optim in runner.optimizer.items(): for k, optim in runner.optimizer.items():
for param_group in optim.param_groups: _momentum_group = [
momentum_groups[k].append( self.get_momentum(runner, param_group)
self.get_momentum(runner, param_group)) for param_group in optim.param_groups
]
momentum_groups.update({k: _momentum_group})
return momentum_groups return momentum_groups
else: else:
momentum_groups = [] momentum_groups = []
......
...@@ -151,11 +151,12 @@ def test_sync_buffers_hook(): ...@@ -151,11 +151,12 @@ def test_sync_buffers_hook():
shutil.rmtree(runner.work_dir) shutil.rmtree(runner.work_dir)
def test_momentum_runner_hook(): @pytest.mark.parametrize('multi_optimziers', (True, False))
def test_momentum_runner_hook(multi_optimziers):
"""xdoctest -m tests/test_hooks.py test_momentum_runner_hook.""" """xdoctest -m tests/test_hooks.py test_momentum_runner_hook."""
sys.modules['pavi'] = MagicMock() sys.modules['pavi'] = MagicMock()
loader = DataLoader(torch.ones((10, 2))) loader = DataLoader(torch.ones((10, 2)))
runner = _build_demo_runner() runner = _build_demo_runner(multi_optimziers=multi_optimziers)
# add momentum scheduler # add momentum scheduler
hook_cfg = dict( hook_cfg = dict(
...@@ -184,6 +185,31 @@ def test_momentum_runner_hook(): ...@@ -184,6 +185,31 @@ def test_momentum_runner_hook():
# TODO: use a more elegant way to check values # TODO: use a more elegant way to check values
assert hasattr(hook, 'writer') assert hasattr(hook, 'writer')
if multi_optimziers:
calls = [
call(
'train', {
'learning_rate/model1': 0.01999999999999999,
'learning_rate/model2': 0.009999999999999995,
'momentum/model1': 0.95,
'momentum/model2': 0.9,
}, 1),
call(
'train', {
'learning_rate/model1': 0.2,
'learning_rate/model2': 0.1,
'momentum/model1': 0.85,
'momentum/model2': 0.8052631578947369,
}, 5),
call(
'train', {
'learning_rate/model1': 0.155,
'learning_rate/model2': 0.0775,
'momentum/model1': 0.875,
'momentum/model2': 0.8289473684210527,
}, 7)
]
else:
calls = [ calls = [
call('train', { call('train', {
'learning_rate': 0.01999999999999999, 'learning_rate': 0.01999999999999999,
...@@ -201,11 +227,12 @@ def test_momentum_runner_hook(): ...@@ -201,11 +227,12 @@ def test_momentum_runner_hook():
hook.writer.add_scalars.assert_has_calls(calls, any_order=True) hook.writer.add_scalars.assert_has_calls(calls, any_order=True)
def test_cosine_runner_hook(): @pytest.mark.parametrize('multi_optimziers', (True, False))
def test_cosine_runner_hook(multi_optimziers):
"""xdoctest -m tests/test_hooks.py test_cosine_runner_hook.""" """xdoctest -m tests/test_hooks.py test_cosine_runner_hook."""
sys.modules['pavi'] = MagicMock() sys.modules['pavi'] = MagicMock()
loader = DataLoader(torch.ones((10, 2))) loader = DataLoader(torch.ones((10, 2)))
runner = _build_demo_runner() runner = _build_demo_runner(multi_optimziers=multi_optimziers)
# add momentum scheduler # add momentum scheduler
...@@ -235,6 +262,31 @@ def test_cosine_runner_hook(): ...@@ -235,6 +262,31 @@ def test_cosine_runner_hook():
# TODO: use a more elegant way to check values # TODO: use a more elegant way to check values
assert hasattr(hook, 'writer') assert hasattr(hook, 'writer')
if multi_optimziers:
calls = [
call(
'train', {
'learning_rate/model1': 0.02,
'learning_rate/model2': 0.01,
'momentum/model1': 0.95,
'momentum/model2': 0.9,
}, 1),
call(
'train', {
'learning_rate/model1': 0.01,
'learning_rate/model2': 0.005,
'momentum/model1': 0.97,
'momentum/model2': 0.9189473684210527,
}, 6),
call(
'train', {
'learning_rate/model1': 0.0004894348370484647,
'learning_rate/model2': 0.00024471741852423234,
'momentum/model1': 0.9890211303259032,
'momentum/model2': 0.9369673866245399,
}, 10)
]
else:
calls = [ calls = [
call('train', { call('train', {
'learning_rate': 0.02, 'learning_rate': 0.02,
...@@ -244,7 +296,8 @@ def test_cosine_runner_hook(): ...@@ -244,7 +296,8 @@ def test_cosine_runner_hook():
'learning_rate': 0.01, 'learning_rate': 0.01,
'momentum': 0.97 'momentum': 0.97
}, 6), }, 6),
call('train', { call(
'train', {
'learning_rate': 0.0004894348370484647, 'learning_rate': 0.0004894348370484647,
'momentum': 0.9890211303259032 'momentum': 0.9890211303259032
}, 10) }, 10)
...@@ -252,7 +305,8 @@ def test_cosine_runner_hook(): ...@@ -252,7 +305,8 @@ def test_cosine_runner_hook():
hook.writer.add_scalars.assert_has_calls(calls, any_order=True) hook.writer.add_scalars.assert_has_calls(calls, any_order=True)
def test_one_cycle_runner_hook(): @pytest.mark.parametrize('multi_optimziers', (True, False))
def test_one_cycle_runner_hook(multi_optimziers):
"""Test OneCycleLrUpdaterHook and OneCycleMomentumUpdaterHook.""" """Test OneCycleLrUpdaterHook and OneCycleMomentumUpdaterHook."""
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
# by_epoch should be False # by_epoch should be False
...@@ -268,7 +322,7 @@ def test_one_cycle_runner_hook(): ...@@ -268,7 +322,7 @@ def test_one_cycle_runner_hook():
sys.modules['pavi'] = MagicMock() sys.modules['pavi'] = MagicMock()
loader = DataLoader(torch.ones((10, 2))) loader = DataLoader(torch.ones((10, 2)))
runner = _build_demo_runner() runner = _build_demo_runner(multi_optimziers=multi_optimziers)
# add momentum scheduler # add momentum scheduler
hook_cfg = dict( hook_cfg = dict(
...@@ -280,7 +334,7 @@ def test_one_cycle_runner_hook(): ...@@ -280,7 +334,7 @@ def test_one_cycle_runner_hook():
three_phase=False) three_phase=False)
runner.register_hook_from_cfg(hook_cfg) runner.register_hook_from_cfg(hook_cfg)
# add momentum LR scheduler # add LR scheduler
hook_cfg = dict( hook_cfg = dict(
type='OneCycleLrUpdaterHook', type='OneCycleLrUpdaterHook',
max_lr=0.01, max_lr=0.01,
...@@ -300,12 +354,38 @@ def test_one_cycle_runner_hook(): ...@@ -300,12 +354,38 @@ def test_one_cycle_runner_hook():
# TODO: use a more elegant way to check values # TODO: use a more elegant way to check values
assert hasattr(hook, 'writer') assert hasattr(hook, 'writer')
if multi_optimziers:
calls = [
call(
'train', {
'learning_rate/model1': 0.0003999999999999993,
'learning_rate/model2': 0.0003999999999999993,
'momentum/model1': 0.95,
'momentum/model2': 0.95,
}, 1),
call(
'train', {
'learning_rate/model1': 0.00904508879153485,
'learning_rate/model2': 0.00904508879153485,
'momentum/model1': 0.8595491502812526,
'momentum/model2': 0.8595491502812526,
}, 6),
call(
'train', {
'learning_rate/model1': 4e-08,
'learning_rate/model2': 4e-08,
'momentum/model1': 0.95,
'momentum/model2': 0.95,
}, 10)
]
else:
calls = [ calls = [
call('train', { call('train', {
'learning_rate': 0.0003999999999999993, 'learning_rate': 0.0003999999999999993,
'momentum': 0.95 'momentum': 0.95
}, 1), }, 1),
call('train', { call(
'train', {
'learning_rate': 0.00904508879153485, 'learning_rate': 0.00904508879153485,
'momentum': 0.8595491502812526 'momentum': 0.8595491502812526
}, 6), }, 6),
...@@ -317,7 +397,8 @@ def test_one_cycle_runner_hook(): ...@@ -317,7 +397,8 @@ def test_one_cycle_runner_hook():
hook.writer.add_scalars.assert_has_calls(calls, any_order=True) hook.writer.add_scalars.assert_has_calls(calls, any_order=True)
def test_cosine_restart_lr_update_hook(): @pytest.mark.parametrize('multi_optimziers', (True, False))
def test_cosine_restart_lr_update_hook(multi_optimziers):
"""Test CosineRestartLrUpdaterHook.""" """Test CosineRestartLrUpdaterHook."""
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
# either `min_lr` or `min_lr_ratio` should be specified # either `min_lr` or `min_lr_ratio` should be specified
...@@ -359,7 +440,7 @@ def test_cosine_restart_lr_update_hook(): ...@@ -359,7 +440,7 @@ def test_cosine_restart_lr_update_hook():
sys.modules['pavi'] = MagicMock() sys.modules['pavi'] = MagicMock()
loader = DataLoader(torch.ones((10, 2))) loader = DataLoader(torch.ones((10, 2)))
runner = _build_demo_runner() runner = _build_demo_runner(multi_optimziers=multi_optimziers)
# add cosine restart LR scheduler # add cosine restart LR scheduler
hook = CosineRestartLrUpdaterHook( hook = CosineRestartLrUpdaterHook(
...@@ -378,6 +459,31 @@ def test_cosine_restart_lr_update_hook(): ...@@ -378,6 +459,31 @@ def test_cosine_restart_lr_update_hook():
# TODO: use a more elegant way to check values # TODO: use a more elegant way to check values
assert hasattr(hook, 'writer') assert hasattr(hook, 'writer')
if multi_optimziers:
calls = [
call(
'train', {
'learning_rate/model1': 0.01,
'learning_rate/model2': 0.005,
'momentum/model1': 0.95,
'momentum/model2': 0.9,
}, 1),
call(
'train', {
'learning_rate/model1': 0.01,
'learning_rate/model2': 0.005,
'momentum/model1': 0.95,
'momentum/model2': 0.9,
}, 6),
call(
'train', {
'learning_rate/model1': 0.0009549150281252633,
'learning_rate/model2': 0.00047745751406263163,
'momentum/model1': 0.95,
'momentum/model2': 0.9,
}, 10)
]
else:
calls = [ calls = [
call('train', { call('train', {
'learning_rate': 0.01, 'learning_rate': 0.01,
...@@ -443,13 +549,15 @@ def test_wandb_hook(): ...@@ -443,13 +549,15 @@ def test_wandb_hook():
def _build_demo_runner(runner_type='EpochBasedRunner', def _build_demo_runner(runner_type='EpochBasedRunner',
max_epochs=1, max_epochs=1,
max_iters=None): max_iters=None,
multi_optimziers=False):
class Model(nn.Module): class Model(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.linear = nn.Linear(2, 1) self.linear = nn.Linear(2, 1)
self.conv = nn.Conv2d(3, 3, 3)
def forward(self, x): def forward(self, x):
return self.linear(x) return self.linear(x)
...@@ -462,6 +570,14 @@ def _build_demo_runner(runner_type='EpochBasedRunner', ...@@ -462,6 +570,14 @@ def _build_demo_runner(runner_type='EpochBasedRunner',
model = Model() model = Model()
if multi_optimziers:
optimizer = {
'model1':
torch.optim.SGD(model.linear.parameters(), lr=0.02, momentum=0.95),
'model2':
torch.optim.SGD(model.conv.parameters(), lr=0.01, momentum=0.9),
}
else:
optimizer = torch.optim.SGD(model.parameters(), lr=0.02, momentum=0.95) optimizer = torch.optim.SGD(model.parameters(), lr=0.02, momentum=0.95)
log_config = dict( log_config = dict(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment