Unverified Commit 6f21d8b5 authored by Harry's avatar Harry Committed by GitHub
Browse files

Add optimizer constructor from mmdetection (#313)

* feat: add optimizer constructor

* refactor: version
parent fe4f657f
......@@ -7,6 +7,9 @@ from .hooks import (HOOKS, CheckpointHook, ClosureHook, DistSamplerSeedHook,
OptimizerHook, PaviLoggerHook, TensorboardLoggerHook,
TextLoggerHook, WandbLoggerHook)
from .log_buffer import LogBuffer
from .optimizer import (OPTIMIZER_BUILDERS, OPTIMIZERS,
DefaultOptimizerConstructor, build_optimizer,
build_optimizer_constructor)
from .priority import Priority, get_priority
from .runner import Runner
from .utils import get_host_info, get_time_str, obj_from_dict
......@@ -18,5 +21,7 @@ __all__ = [
'WandbLoggerHook', '_load_checkpoint', 'load_state_dict',
'load_checkpoint', 'weights_to_cpu', 'save_checkpoint', 'Priority',
'get_priority', 'get_host_info', 'get_time_str', 'obj_from_dict',
'init_dist', 'get_dist_info', 'master_only'
'init_dist', 'get_dist_info', 'master_only', 'OPTIMIZER_BUILDERS',
'OPTIMIZERS', 'DefaultOptimizerConstructor', 'build_optimizer',
'build_optimizer_constructor'
]
from .builder import (OPTIMIZER_BUILDERS, OPTIMIZERS, build_optimizer,
build_optimizer_constructor)
from .default_constructor import DefaultOptimizerConstructor
__all__ = [
'OPTIMIZER_BUILDERS', 'OPTIMIZERS', 'DefaultOptimizerConstructor',
'build_optimizer', 'build_optimizer_constructor'
]
import copy
import inspect
import torch
from ...utils import Registry, build_from_cfg
OPTIMIZERS = Registry('optimizer')
OPTIMIZER_BUILDERS = Registry('optimizer builder')
def register_torch_optimizers():
torch_optimizers = []
for module_name in dir(torch.optim):
if module_name.startswith('__'):
continue
_optim = getattr(torch.optim, module_name)
if inspect.isclass(_optim) and issubclass(_optim,
torch.optim.Optimizer):
OPTIMIZERS.register_module()(_optim)
torch_optimizers.append(module_name)
return torch_optimizers
TORCH_OPTIMIZERS = register_torch_optimizers()
def build_optimizer_constructor(cfg):
return build_from_cfg(cfg, OPTIMIZER_BUILDERS)
def build_optimizer(model, cfg):
optimizer_cfg = copy.deepcopy(cfg)
constructor_type = optimizer_cfg.pop('constructor',
'DefaultOptimizerConstructor')
paramwise_cfg = optimizer_cfg.pop('paramwise_cfg', None)
optim_constructor = build_optimizer_constructor(
dict(
type=constructor_type,
optimizer_cfg=optimizer_cfg,
paramwise_cfg=paramwise_cfg))
optimizer = optim_constructor(model)
return optimizer
import warnings
import torch
from torch.nn import GroupNorm, LayerNorm
from torch.nn.modules.batchnorm import _BatchNorm
from torch.nn.modules.instancenorm import _InstanceNorm
from mmcv.utils import build_from_cfg, is_list_of
from .builder import OPTIMIZER_BUILDERS, OPTIMIZERS
@OPTIMIZER_BUILDERS.register_module()
class DefaultOptimizerConstructor(object):
"""Default constructor for optimizers.
By default each parameter share the same optimizer settings, and we
provide an argument ``paramwise_cfg`` to specify parameter-wise settings.
It is a dict and may contain the following fields:
- ``bias_lr_mult`` (float): It will be multiplied to the learning
rate for all bias parameters (except for those in normalization
layers).
- ``bias_decay_mult`` (float): It will be multiplied to the weight
decay for all bias parameters (except for those in
normalization layers and depthwise conv layers).
- ``norm_decay_mult`` (float): It will be multiplied to the weight
decay for all weight and bias parameters of normalization
layers.
- ``dwconv_decay_mult`` (float): It will be multiplied to the weight
decay for all weight and bias parameters of depthwise conv
layers.
- ``bypass_duplicate`` (bool): If true, the duplicate parameters
would not be added into optimizer. Default: False
Args:
model (:obj:`nn.Module`): The model with parameters to be optimized.
optimizer_cfg (dict): The config dict of the optimizer.
Positional fields are
- `type`: class name of the optimizer.
Optional fields are
- any arguments of the corresponding optimizer type, e.g.,
lr, weight_decay, momentum, etc.
paramwise_cfg (dict, optional): Parameter-wise options.
Example:
>>> model = torch.nn.modules.Conv1d(1, 1, 1)
>>> optimizer_cfg = dict(type='SGD', lr=0.01, momentum=0.9,
>>> weight_decay=0.0001)
>>> paramwise_cfg = dict(norm_decay_mult=0.)
>>> optim_builder = DefaultOptimizerConstructor(
>>> optimizer_cfg, paramwise_cfg)
>>> optimizer = optim_builder(model)
"""
def __init__(self, optimizer_cfg, paramwise_cfg=None):
if not isinstance(optimizer_cfg, dict):
raise TypeError('optimizer_cfg should be a dict',
f'but got {type(optimizer_cfg)}')
self.optimizer_cfg = optimizer_cfg
self.paramwise_cfg = {} if paramwise_cfg is None else paramwise_cfg
self.base_lr = optimizer_cfg.get('lr', None)
self.base_wd = optimizer_cfg.get('weight_decay', None)
self._validate_cfg()
def _validate_cfg(self):
if not isinstance(self.paramwise_cfg, dict):
raise TypeError('paramwise_cfg should be None or a dict, '
f'but got {type(self.paramwise_cfg)}')
# get base lr and weight decay
# weight_decay must be explicitly specified if mult is specified
if ('bias_decay_mult' in self.paramwise_cfg
or 'norm_decay_mult' in self.paramwise_cfg
or 'dwconv_decay_mult' in self.paramwise_cfg):
if self.base_wd is None:
raise ValueError('base_wd should not be None')
def _is_in(self, param_group, param_group_list):
assert is_list_of(param_group_list, dict)
param = set(param_group['params'])
param_set = set()
for group in param_group_list:
param_set.update(set(group['params']))
return not param.isdisjoint(param_set)
def add_params(self, params, module, prefix=''):
"""Add all parameters of module to the params list.
The parameters of the given module will be added to the list of param
groups, with specific rules defined by paramwise_cfg.
Args:
params (list[dict]): A list of param groups, it will be modified
in place.
module (nn.Module): The module to be added.
prefix (str): The prefix of the module
"""
# get param-wise options
bias_lr_mult = self.paramwise_cfg.get('bias_lr_mult', 1.)
bias_decay_mult = self.paramwise_cfg.get('bias_decay_mult', 1.)
norm_decay_mult = self.paramwise_cfg.get('norm_decay_mult', 1.)
dwconv_decay_mult = self.paramwise_cfg.get('dwconv_decay_mult', 1.)
bypass_duplicate = self.paramwise_cfg.get('bypass_duplicate', False)
# special rules for norm layers and depth-wise conv layers
is_norm = isinstance(module,
(_BatchNorm, _InstanceNorm, GroupNorm, LayerNorm))
is_dwconv = (
isinstance(module, torch.nn.Conv2d)
and module.in_channels == module.groups)
for name, param in module.named_parameters(recurse=False):
param_group = {'params': [param]}
if not param.requires_grad:
params.append(param_group)
continue
if bypass_duplicate and self._is_in(param_group, params):
warnings.warn(f'{prefix} is duplicate. It is skipped since '
f'bypass_duplicate={bypass_duplicate}')
continue
# bias_lr_mult affects all bias parameters except for norm.bias
if name == 'bias' and not is_norm:
param_group['lr'] = self.base_lr * bias_lr_mult
# apply weight decay policies
if self.base_wd is not None:
# norm decay
if is_norm:
param_group[
'weight_decay'] = self.base_wd * norm_decay_mult
# depth-wise conv
elif is_dwconv:
param_group[
'weight_decay'] = self.base_wd * dwconv_decay_mult
# bias lr and decay
elif name == 'bias':
param_group[
'weight_decay'] = self.base_wd * bias_decay_mult
params.append(param_group)
for child_name, child_mod in module.named_children():
child_prefix = f'{prefix}.{child_name}' if prefix else child_name
self.add_params(params, child_mod, prefix=child_prefix)
def __call__(self, model):
if hasattr(model, 'module'):
model = model.module
optimizer_cfg = self.optimizer_cfg.copy()
# if no paramwise option is specified, just use the global setting
if not self.paramwise_cfg:
optimizer_cfg['params'] = model.parameters()
return build_from_cfg(optimizer_cfg, OPTIMIZERS)
# set param-wise lr and weight decay recursively
params = []
self.add_params(params, model)
optimizer_cfg['params'] = params
return build_from_cfg(optimizer_cfg, OPTIMIZERS)
# Copyright (c) Open-MMLab. All rights reserved.
__version__ = '0.5.8'
__version__ = '0.5.9'
import warnings
import pytest
import torch
import torch.nn as nn
from mmcv.runner import OPTIMIZER_BUILDERS, DefaultOptimizerConstructor
from mmcv.runner.optimizer import build_optimizer, build_optimizer_constructor
from mmcv.runner.optimizer.builder import TORCH_OPTIMIZERS
class SubModel(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(2, 2, kernel_size=1, groups=2)
self.gn = nn.GroupNorm(2, 2)
self.param1 = nn.Parameter(torch.ones(1))
def forward(self, x):
return x
class ExampleModel(nn.Module):
def __init__(self):
super().__init__()
self.param1 = nn.Parameter(torch.ones(1))
self.conv1 = nn.Conv2d(3, 4, kernel_size=1, bias=False)
self.conv2 = nn.Conv2d(4, 2, kernel_size=1)
self.bn = nn.BatchNorm2d(2)
self.sub = SubModel()
def forward(self, x):
return x
class ExampleDuplicateModel(nn.Module):
def __init__(self):
super().__init__()
self.param1 = nn.Parameter(torch.ones(1))
self.conv1 = nn.Sequential(nn.Conv2d(3, 4, kernel_size=1, bias=False))
self.conv2 = nn.Sequential(nn.Conv2d(4, 2, kernel_size=1))
self.bn = nn.BatchNorm2d(2)
self.sub = SubModel()
self.conv3 = nn.Sequential(nn.Conv2d(3, 4, kernel_size=1, bias=False))
self.conv3[0] = self.conv1[0]
def forward(self, x):
return x
class PseudoDataParallel(nn.Module):
def __init__(self):
super().__init__()
self.module = ExampleModel()
def forward(self, x):
return x
base_lr = 0.01
base_wd = 0.0001
momentum = 0.9
def check_default_optimizer(optimizer, model, prefix=''):
assert isinstance(optimizer, torch.optim.SGD)
assert optimizer.defaults['lr'] == base_lr
assert optimizer.defaults['momentum'] == momentum
assert optimizer.defaults['weight_decay'] == base_wd
param_groups = optimizer.param_groups[0]
param_names = [
'param1', 'conv1.weight', 'conv2.weight', 'conv2.bias', 'bn.weight',
'bn.bias', 'sub.param1', 'sub.conv1.weight', 'sub.conv1.bias',
'sub.gn.weight', 'sub.gn.bias'
]
param_dict = dict(model.named_parameters())
assert len(param_groups['params']) == len(param_names)
for i in range(len(param_groups['params'])):
assert torch.equal(param_groups['params'][i],
param_dict[prefix + param_names[i]])
def check_optimizer(optimizer,
model,
prefix='',
bias_lr_mult=1,
bias_decay_mult=1,
norm_decay_mult=1,
dwconv_decay_mult=1,
bypass_duplicate=False):
param_groups = optimizer.param_groups
assert isinstance(optimizer, torch.optim.SGD)
assert optimizer.defaults['lr'] == base_lr
assert optimizer.defaults['momentum'] == momentum
assert optimizer.defaults['weight_decay'] == base_wd
model_parameters = list(model.parameters())
assert len(param_groups) == len(model_parameters)
for i, param in enumerate(model_parameters):
param_group = param_groups[i]
assert torch.equal(param_group['params'][0], param)
assert param_group['momentum'] == momentum
# param1
param1 = param_groups[0]
assert param1['lr'] == base_lr
assert param1['weight_decay'] == base_wd
# conv1.weight
conv1_weight = param_groups[1]
assert conv1_weight['lr'] == base_lr
assert conv1_weight['weight_decay'] == base_wd
# conv2.weight
conv2_weight = param_groups[2]
assert conv2_weight['lr'] == base_lr
assert conv2_weight['weight_decay'] == base_wd
# conv2.bias
conv2_bias = param_groups[3]
assert conv2_bias['lr'] == base_lr * bias_lr_mult
assert conv2_bias['weight_decay'] == base_wd * bias_decay_mult
# bn.weight
bn_weight = param_groups[4]
assert bn_weight['lr'] == base_lr
assert bn_weight['weight_decay'] == base_wd * norm_decay_mult
# bn.bias
bn_bias = param_groups[5]
assert bn_bias['lr'] == base_lr
assert bn_bias['weight_decay'] == base_wd * norm_decay_mult
# sub.param1
sub_param1 = param_groups[6]
assert sub_param1['lr'] == base_lr
assert sub_param1['weight_decay'] == base_wd
# sub.conv1.weight
sub_conv1_weight = param_groups[7]
assert sub_conv1_weight['lr'] == base_lr
assert sub_conv1_weight['weight_decay'] == base_wd * dwconv_decay_mult
# sub.conv1.bias
sub_conv1_bias = param_groups[8]
assert sub_conv1_bias['lr'] == base_lr * bias_lr_mult
assert sub_conv1_bias['weight_decay'] == base_wd * dwconv_decay_mult
# sub.gn.weight
sub_gn_weight = param_groups[9]
assert sub_gn_weight['lr'] == base_lr
assert sub_gn_weight['weight_decay'] == base_wd * norm_decay_mult
# sub.gn.bias
sub_gn_bias = param_groups[10]
assert sub_gn_bias['lr'] == base_lr
assert sub_gn_bias['weight_decay'] == base_wd * norm_decay_mult
def test_default_optimizer_constructor():
model = ExampleModel()
with pytest.raises(TypeError):
# optimizer_cfg must be a dict
optimizer_cfg = []
optim_constructor = DefaultOptimizerConstructor(optimizer_cfg)
optim_constructor(model)
with pytest.raises(TypeError):
# paramwise_cfg must be a dict or None
optimizer_cfg = dict(lr=0.0001)
paramwise_cfg = ['error']
optim_constructor = DefaultOptimizerConstructor(
optimizer_cfg, paramwise_cfg)
optim_constructor(model)
with pytest.raises(ValueError):
# bias_decay_mult/norm_decay_mult is specified but weight_decay is None
optimizer_cfg = dict(lr=0.0001, weight_decay=None)
paramwise_cfg = dict(bias_decay_mult=1, norm_decay_mult=1)
optim_constructor = DefaultOptimizerConstructor(
optimizer_cfg, paramwise_cfg)
optim_constructor(model)
# basic config with ExampleModel
optimizer_cfg = dict(
type='SGD', lr=base_lr, weight_decay=base_wd, momentum=momentum)
optim_constructor = DefaultOptimizerConstructor(optimizer_cfg)
optimizer = optim_constructor(model)
check_default_optimizer(optimizer, model)
# basic config with pseudo data parallel
model = PseudoDataParallel()
optimizer_cfg = dict(
type='SGD', lr=base_lr, weight_decay=base_wd, momentum=momentum)
paramwise_cfg = None
optim_constructor = DefaultOptimizerConstructor(optimizer_cfg)
optimizer = optim_constructor(model)
check_default_optimizer(optimizer, model, prefix='module.')
# basic config with DataParallel
if torch.cuda.is_available():
model = torch.nn.DataParallel(ExampleModel())
optimizer_cfg = dict(
type='SGD', lr=base_lr, weight_decay=base_wd, momentum=momentum)
paramwise_cfg = None
optim_constructor = DefaultOptimizerConstructor(optimizer_cfg)
optimizer = optim_constructor(model)
check_default_optimizer(optimizer, model, prefix='module.')
# Empty paramwise_cfg with ExampleModel
model = ExampleModel()
optimizer_cfg = dict(
type='SGD', lr=base_lr, weight_decay=base_wd, momentum=momentum)
paramwise_cfg = dict()
optim_constructor = DefaultOptimizerConstructor(optimizer_cfg,
paramwise_cfg)
optimizer = optim_constructor(model)
check_default_optimizer(optimizer, model)
# Empty paramwise_cfg with ExampleModel and no grad
model = ExampleModel()
for param in model.parameters():
param.requires_grad = False
optimizer_cfg = dict(
type='SGD', lr=base_lr, weight_decay=base_wd, momentum=momentum)
paramwise_cfg = dict()
optim_constructor = DefaultOptimizerConstructor(optimizer_cfg)
optimizer = optim_constructor(model)
check_default_optimizer(optimizer, model)
# paramwise_cfg with ExampleModel
model = ExampleModel()
optimizer_cfg = dict(
type='SGD', lr=base_lr, weight_decay=base_wd, momentum=momentum)
paramwise_cfg = dict(
bias_lr_mult=2,
bias_decay_mult=0.5,
norm_decay_mult=0,
dwconv_decay_mult=0.1)
optim_constructor = DefaultOptimizerConstructor(optimizer_cfg,
paramwise_cfg)
optimizer = optim_constructor(model)
check_optimizer(optimizer, model, **paramwise_cfg)
# paramwise_cfg with ExampleModel, weight decay is None
model = ExampleModel()
optimizer_cfg = dict(type='Rprop', lr=base_lr)
paramwise_cfg = dict(bias_lr_mult=2)
optim_constructor = DefaultOptimizerConstructor(optimizer_cfg,
paramwise_cfg)
optimizer = optim_constructor(model)
param_groups = optimizer.param_groups
assert isinstance(optimizer, torch.optim.Rprop)
assert optimizer.defaults['lr'] == base_lr
model_parameters = list(model.parameters())
assert len(param_groups) == len(model_parameters)
for i, param in enumerate(model_parameters):
param_group = param_groups[i]
assert torch.equal(param_group['params'][0], param)
# param1
assert param_groups[0]['lr'] == base_lr
# conv1.weight
assert param_groups[1]['lr'] == base_lr
# conv2.weight
assert param_groups[2]['lr'] == base_lr
# conv2.bias
assert param_groups[3]['lr'] == base_lr * paramwise_cfg['bias_lr_mult']
# bn.weight
assert param_groups[4]['lr'] == base_lr
# bn.bias
assert param_groups[5]['lr'] == base_lr
# sub.param1
assert param_groups[6]['lr'] == base_lr
# sub.conv1.weight
assert param_groups[7]['lr'] == base_lr
# sub.conv1.bias
assert param_groups[8]['lr'] == base_lr * paramwise_cfg['bias_lr_mult']
# sub.gn.weight
assert param_groups[9]['lr'] == base_lr
# sub.gn.bias
assert param_groups[10]['lr'] == base_lr
# paramwise_cfg with pseudo data parallel
model = PseudoDataParallel()
optimizer_cfg = dict(
type='SGD', lr=base_lr, weight_decay=base_wd, momentum=momentum)
paramwise_cfg = dict(
bias_lr_mult=2,
bias_decay_mult=0.5,
norm_decay_mult=0,
dwconv_decay_mult=0.1)
optim_constructor = DefaultOptimizerConstructor(optimizer_cfg,
paramwise_cfg)
optimizer = optim_constructor(model)
check_optimizer(optimizer, model, prefix='module.', **paramwise_cfg)
# paramwise_cfg with DataParallel
if torch.cuda.is_available():
model = torch.nn.DataParallel(ExampleModel())
optimizer_cfg = dict(
type='SGD', lr=base_lr, weight_decay=base_wd, momentum=momentum)
paramwise_cfg = dict(
bias_lr_mult=2,
bias_decay_mult=0.5,
norm_decay_mult=0,
dwconv_decay_mult=0.1)
optim_constructor = DefaultOptimizerConstructor(
optimizer_cfg, paramwise_cfg)
optimizer = optim_constructor(model)
check_optimizer(optimizer, model, prefix='module.', **paramwise_cfg)
# paramwise_cfg with ExampleModel and no grad
for param in model.parameters():
param.requires_grad = False
optim_constructor = DefaultOptimizerConstructor(optimizer_cfg,
paramwise_cfg)
optimizer = optim_constructor(model)
param_groups = optimizer.param_groups
assert isinstance(optimizer, torch.optim.SGD)
assert optimizer.defaults['lr'] == base_lr
assert optimizer.defaults['momentum'] == momentum
assert optimizer.defaults['weight_decay'] == base_wd
for i, (name, param) in enumerate(model.named_parameters()):
param_group = param_groups[i]
assert torch.equal(param_group['params'][0], param)
assert param_group['momentum'] == momentum
assert param_group['lr'] == base_lr
assert param_group['weight_decay'] == base_wd
# paramwise_cfg with bypass_duplicate option
model = ExampleDuplicateModel()
optimizer_cfg = dict(
type='SGD', lr=base_lr, weight_decay=base_wd, momentum=momentum)
paramwise_cfg = dict(
bias_lr_mult=2,
bias_decay_mult=0.5,
norm_decay_mult=0,
dwconv_decay_mult=0.1)
with pytest.raises(ValueError) as excinfo:
optim_constructor = DefaultOptimizerConstructor(
optimizer_cfg, paramwise_cfg)
optim_constructor(model)
assert 'some parameters appear in more than one parameter ' \
'group' == excinfo.value
paramwise_cfg = dict(
bias_lr_mult=2,
bias_decay_mult=0.5,
norm_decay_mult=0,
dwconv_decay_mult=0.1,
bypass_duplicate=True)
optim_constructor = DefaultOptimizerConstructor(optimizer_cfg,
paramwise_cfg)
with warnings.catch_warnings(record=True) as w:
optimizer = optim_constructor(model)
warnings.simplefilter('always')
assert len(w) == 1
assert str(w[0].message) == 'conv3.0 is duplicate. It is skipped ' \
'since bypass_duplicate=True'
model_parameters = list(model.parameters())
assert len(optimizer.param_groups) == len(model_parameters) == 11
check_optimizer(optimizer, model, **paramwise_cfg)
def test_torch_optimizers():
torch_optimizers = [
'ASGD', 'Adadelta', 'Adagrad', 'Adam', 'AdamW', 'Adamax', 'LBFGS',
'Optimizer', 'RMSprop', 'Rprop', 'SGD', 'SparseAdam'
]
assert set(torch_optimizers).issubset(set(TORCH_OPTIMIZERS))
def test_build_optimizer_constructor():
model = ExampleModel()
optimizer_cfg = dict(
type='SGD', lr=base_lr, weight_decay=base_wd, momentum=momentum)
paramwise_cfg = dict(
bias_lr_mult=2,
bias_decay_mult=0.5,
norm_decay_mult=0,
dwconv_decay_mult=0.1)
optim_constructor_cfg = dict(
type='DefaultOptimizerConstructor',
optimizer_cfg=optimizer_cfg,
paramwise_cfg=paramwise_cfg)
optim_constructor = build_optimizer_constructor(optim_constructor_cfg)
optimizer = optim_constructor(model)
check_optimizer(optimizer, model, **paramwise_cfg)
from mmcv.runner import OPTIMIZERS
from mmcv.utils import build_from_cfg
@OPTIMIZER_BUILDERS.register_module()
class MyOptimizerConstructor(DefaultOptimizerConstructor):
def __call__(self, model):
if hasattr(model, 'module'):
model = model.module
conv1_lr_mult = self.paramwise_cfg.get('conv1_lr_mult', 1.)
params = []
for name, param in model.named_parameters():
param_group = {'params': [param]}
if name.startswith('conv1') and param.requires_grad:
param_group['lr'] = self.base_lr * conv1_lr_mult
params.append(param_group)
optimizer_cfg['params'] = params
return build_from_cfg(optimizer_cfg, OPTIMIZERS)
paramwise_cfg = dict(conv1_lr_mult=5)
optim_constructor_cfg = dict(
type='MyOptimizerConstructor',
optimizer_cfg=optimizer_cfg,
paramwise_cfg=paramwise_cfg)
optim_constructor = build_optimizer_constructor(optim_constructor_cfg)
optimizer = optim_constructor(model)
param_groups = optimizer.param_groups
assert isinstance(optimizer, torch.optim.SGD)
assert optimizer.defaults['lr'] == base_lr
assert optimizer.defaults['momentum'] == momentum
assert optimizer.defaults['weight_decay'] == base_wd
for i, param in enumerate(model.parameters()):
param_group = param_groups[i]
assert torch.equal(param_group['params'][0], param)
assert param_group['momentum'] == momentum
# conv1.weight
assert param_groups[1]['lr'] == base_lr * paramwise_cfg['conv1_lr_mult']
assert param_groups[1]['weight_decay'] == base_wd
def test_build_optimizer():
model = ExampleModel()
optimizer_cfg = dict(
type='SGD', lr=base_lr, weight_decay=base_wd, momentum=momentum)
optimizer = build_optimizer(model, optimizer_cfg)
check_default_optimizer(optimizer, model)
model = ExampleModel()
optimizer_cfg = dict(
type='SGD',
lr=base_lr,
weight_decay=base_wd,
momentum=momentum,
paramwise_cfg=dict(
bias_lr_mult=2,
bias_decay_mult=0.5,
norm_decay_mult=0,
dwconv_decay_mult=0.1))
optimizer = build_optimizer(model, optimizer_cfg)
check_optimizer(optimizer, model, **optimizer_cfg['paramwise_cfg'])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment