Unverified Commit 48089881 authored by Harry's avatar Harry Committed by GitHub
Browse files

Fix potential bug in DefaultOptimizerConstructor (#355)

* fix: fix custom_keys whenweight_decay is not specified

* docs: small fix

* feat: add type check

* refactor: move if outside for
parent dcc20f3a
...@@ -87,11 +87,15 @@ class DefaultOptimizerConstructor: ...@@ -87,11 +87,15 @@ class DefaultOptimizerConstructor:
raise TypeError('paramwise_cfg should be None or a dict, ' raise TypeError('paramwise_cfg should be None or a dict, '
f'but got {type(self.paramwise_cfg)}') f'but got {type(self.paramwise_cfg)}')
if ('custom_keys' in self.paramwise_cfg if 'custom_keys' in self.paramwise_cfg:
and not isinstance(self.paramwise_cfg['custom_keys'], dict)): if not isinstance(self.paramwise_cfg['custom_keys'], dict):
raise TypeError( raise TypeError(
'If specified, custom_keys must be a dict, ' 'If specified, custom_keys must be a dict, '
f'but got {type(self.paramwise_cfg["custom_keys"])}') f'but got {type(self.paramwise_cfg["custom_keys"])}')
if self.base_wd is None:
for key in self.paramwise_cfg['custom_keys']:
if 'decay_mult' in self.paramwise_cfg['custom_keys'][key]:
raise ValueError('base_wd should not be None')
# get base lr and weight decay # get base lr and weight decay
# weight_decay must be explicitly specified if mult is specified # weight_decay must be explicitly specified if mult is specified
...@@ -154,11 +158,11 @@ class DefaultOptimizerConstructor: ...@@ -154,11 +158,11 @@ class DefaultOptimizerConstructor:
for key in sorted_keys: for key in sorted_keys:
if key in f'{prefix}.{name}': if key in f'{prefix}.{name}':
is_custom = True is_custom = True
param_group['lr'] = self.base_lr * custom_keys[key].get( lr_mult = custom_keys[key].get('lr_mult', 1.)
'lr_mult', 1.) param_group['lr'] = self.base_lr * lr_mult
param_group[ if self.base_wd is not None:
'weight_decay'] = self.base_wd * custom_keys[key].get( decay_mult = custom_keys[key].get('decay_mult', 1.)
'decay_mult', 1.) param_group['weight_decay'] = self.base_wd * decay_mult
break break
if not is_custom: if not is_custom:
# bias_lr_mult affects all bias parameters except for norm.bias # bias_lr_mult affects all bias parameters except for norm.bias
......
...@@ -355,7 +355,7 @@ def test_default_optimizer_constructor(): ...@@ -355,7 +355,7 @@ def test_default_optimizer_constructor():
assert len(optimizer.param_groups) == len(model_parameters) == 11 assert len(optimizer.param_groups) == len(model_parameters) == 11
check_optimizer(optimizer, model, **paramwise_cfg) check_optimizer(optimizer, model, **paramwise_cfg)
# test DefaultOptimizerConstructor with custom_groups and ExampleModel # test DefaultOptimizerConstructor with custom_keys and ExampleModel
model = ExampleModel() model = ExampleModel()
optimizer_cfg = dict( optimizer_cfg = dict(
type='SGD', lr=base_lr, weight_decay=base_wd, momentum=momentum) type='SGD', lr=base_lr, weight_decay=base_wd, momentum=momentum)
...@@ -372,7 +372,16 @@ def test_default_optimizer_constructor(): ...@@ -372,7 +372,16 @@ def test_default_optimizer_constructor():
# custom_keys should be a dict # custom_keys should be a dict
paramwise_cfg_ = dict(custom_keys=[0.1, 0.0001]) paramwise_cfg_ = dict(custom_keys=[0.1, 0.0001])
optim_constructor = DefaultOptimizerConstructor( optim_constructor = DefaultOptimizerConstructor(
optim_constructor, paramwise_cfg_) optimizer_cfg, paramwise_cfg_)
optimizer = optim_constructor(model)
with pytest.raises(ValueError):
# if 'decay_mult' is specified in custom_keys, weight_decay should be
# specified
optimizer_cfg_ = dict(type='SGD', lr=0.01)
paramwise_cfg_ = dict(custom_keys={'.backbone': dict(decay_mult=0.5)})
optim_constructor = DefaultOptimizerConstructor(
optimizer_cfg_, paramwise_cfg_)
optimizer = optim_constructor(model) optimizer = optim_constructor(model)
optim_constructor = DefaultOptimizerConstructor(optimizer_cfg, optim_constructor = DefaultOptimizerConstructor(optimizer_cfg,
...@@ -435,6 +444,52 @@ def test_default_optimizer_constructor(): ...@@ -435,6 +444,52 @@ def test_default_optimizer_constructor():
assert param_groups[i][setting] == settings[ assert param_groups[i][setting] == settings[
setting], f'{name} {setting}' setting], f'{name} {setting}'
# test DefaultOptimizerConstructor with custom_keys and ExampleModel 2
model = ExampleModel()
optimizer_cfg = dict(type='SGD', lr=base_lr, momentum=momentum)
paramwise_cfg = dict(custom_keys={'param1': dict(lr_mult=10)})
optim_constructor = DefaultOptimizerConstructor(optimizer_cfg,
paramwise_cfg)
optimizer = optim_constructor(model)
# check optimizer type and default config
assert isinstance(optimizer, torch.optim.SGD)
assert optimizer.defaults['lr'] == base_lr
assert optimizer.defaults['momentum'] == momentum
assert optimizer.defaults['weight_decay'] == 0
# check params groups
param_groups = optimizer.param_groups
groups = []
group_settings = []
# group 1, matches of 'param1'
groups.append(['param1', 'sub.param1'])
group_settings.append({
'lr': base_lr * 10,
'momentum': momentum,
'weight_decay': 0,
})
# group 2, default group
groups.append([
'sub.conv1.weight', 'sub.conv1.bias', 'sub.gn.weight', 'sub.gn.bias',
'conv1.weight', 'conv2.weight', 'conv2.bias', 'bn.weight', 'bn.bias'
])
group_settings.append({
'lr': base_lr,
'momentum': momentum,
'weight_decay': 0
})
assert len(param_groups) == 11
for i, (name, param) in enumerate(model.named_parameters()):
assert torch.equal(param_groups[i]['params'][0], param)
for group, settings in zip(groups, group_settings):
if name in group:
for setting in settings:
assert param_groups[i][setting] == settings[
setting], f'{name} {setting}'
def test_torch_optimizers(): def test_torch_optimizers():
torch_optimizers = [ torch_optimizers = [
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment