Unverified Commit ed2887bb authored by Wenwei Zhang's avatar Wenwei Zhang Committed by GitHub
Browse files

Support to specify LR of DCN's conv_offset (#344)

* Support to specify LR of DCN's conv_offset

* Resolve comments & add unit test

* Resolve formats

* Fix CI for DCN

* Mock DCN when cpu only

* Use mock for cpu testing

* Fix docstring and support ModulatedDCN

* set offset_lr_mult as dcn's arguments, link CU-49u01p

* fix lr bug

* fall back to set LR in constructor

* resolve comments
parent 467b4883
......@@ -186,6 +186,8 @@ class DeformConv2d(nn.Module):
bias=False):
super(DeformConv2d, self).__init__()
assert not bias, \
f'bias={bias} is not supported in DeformConv2d.'
assert in_channels % groups == 0, \
f'in_channels {in_channels} cannot be divisible by groups {groups}'
assert out_channels % groups == 0, \
......@@ -267,7 +269,6 @@ class DeformConv2dPack(DeformConv2d):
def __init__(self, *args, **kwargs):
super(DeformConv2dPack, self).__init__(*args, **kwargs)
self.conv_offset = nn.Conv2d(
self.in_channels,
self.deform_groups * 2 * self.kernel_size[0] * self.kernel_size[1],
......
......@@ -4,6 +4,7 @@ import torch
from torch.nn import GroupNorm, LayerNorm
from mmcv.utils import _BatchNorm, _InstanceNorm, build_from_cfg, is_list_of
from mmcv.utils.ext_loader import check_ops_exist
from .builder import OPTIMIZER_BUILDERS, OPTIMIZERS
......@@ -27,19 +28,34 @@ class DefaultOptimizerConstructor:
and ``decay_mult``. See Example 2 below.
- ``bias_lr_mult`` (float): It will be multiplied to the learning
rate for all bias parameters (except for those in normalization
layers).
layers and offset layers of DCN).
- ``bias_decay_mult`` (float): It will be multiplied to the weight
decay for all bias parameters (except for those in
normalization layers and depthwise conv layers).
normalization layers, depthwise conv layers, offset layers of DCN).
- ``norm_decay_mult`` (float): It will be multiplied to the weight
decay for all weight and bias parameters of normalization
layers.
- ``dwconv_decay_mult`` (float): It will be multiplied to the weight
decay for all weight and bias parameters of depthwise conv
layers.
- ``dcn_offset_lr_mult`` (float): It will be multiplied to the learning
rate for parameters of offset layer in the deformable convs
of a model.
- ``bypass_duplicate`` (bool): If true, the duplicate parameters
would not be added into optimizer. Default: False.
Note:
1. If the option ``dcn_offset_lr_mult`` is used, the constructor will
override the effect of ``bias_lr_mult`` in the bias of offset
layer. So be careful when using both ``bias_lr_mult`` and
``dcn_offset_lr_mult``. If you wish to apply both of them to the
offset layer in deformable convs, set ``dcn_offset_lr_mult``
to the original ``dcn_offset_lr_mult`` * ``bias_lr_mult``.
2. If the option ``dcn_offset_lr_mult`` is used, the construtor will
apply it to all the DCN layers in the model. So be carefull when
the model contains multiple DCN layers in places other than
backbone.
Args:
model (:obj:`nn.Module`): The model with parameters to be optimized.
optimizer_cfg (dict): The config dict of the optimizer.
......@@ -117,7 +133,7 @@ class DefaultOptimizerConstructor:
return not param.isdisjoint(param_set)
def add_params(self, params, module, prefix=''):
def add_params(self, params, module, prefix='', is_dcn_module=None):
"""Add all parameters of module to the params list.
The parameters of the given module will be added to the list of param
......@@ -128,6 +144,9 @@ class DefaultOptimizerConstructor:
in place.
module (nn.Module): The module to be added.
prefix (str): The prefix of the module
is_dcn_module (int|float|None): If the current module is a
submodule of DCN, `is_dcn_module` will be passed to
control conv_offset layer's learning rate. Defaults to None.
"""
# get param-wise options
custom_keys = self.paramwise_cfg.get('custom_keys', {})
......@@ -139,6 +158,7 @@ class DefaultOptimizerConstructor:
norm_decay_mult = self.paramwise_cfg.get('norm_decay_mult', 1.)
dwconv_decay_mult = self.paramwise_cfg.get('dwconv_decay_mult', 1.)
bypass_duplicate = self.paramwise_cfg.get('bypass_duplicate', False)
dcn_offset_lr_mult = self.paramwise_cfg.get('dcn_offset_lr_mult', 1.)
# special rules for norm layers and depth-wise conv layers
is_norm = isinstance(module,
......@@ -167,10 +187,18 @@ class DefaultOptimizerConstructor:
decay_mult = custom_keys[key].get('decay_mult', 1.)
param_group['weight_decay'] = self.base_wd * decay_mult
break
if not is_custom:
# bias_lr_mult affects all bias parameters except for norm.bias
if name == 'bias' and not is_norm:
# bias_lr_mult affects all bias parameters
# except for norm.bias dcn.conv_offset.bias
if name == 'bias' and not (is_norm or is_dcn_module):
param_group['lr'] = self.base_lr * bias_lr_mult
if (prefix.find('conv_offset') != -1 and is_dcn_module
and isinstance(module, torch.nn.Conv2d)):
# deal with both dcn_offset's bias & weight
param_group['lr'] = self.base_lr * dcn_offset_lr_mult
# apply weight decay policies
if self.base_wd is not None:
# norm decay
......@@ -182,14 +210,25 @@ class DefaultOptimizerConstructor:
param_group[
'weight_decay'] = self.base_wd * dwconv_decay_mult
# bias lr and decay
elif name == 'bias':
elif name == 'bias' and not is_dcn_module:
# TODO: current bias_decay_mult will have affect on DCN
param_group[
'weight_decay'] = self.base_wd * bias_decay_mult
params.append(param_group)
if check_ops_exist():
from mmcv.ops import DeformConv2d, ModulatedDeformConv2d
is_dcn_module = isinstance(module,
(DeformConv2d, ModulatedDeformConv2d))
else:
is_dcn_module = False
for child_name, child_mod in module.named_children():
child_prefix = f'{prefix}.{child_name}' if prefix else child_name
self.add_params(params, child_mod, prefix=child_prefix)
self.add_params(
params,
child_mod,
prefix=child_prefix,
is_dcn_module=is_dcn_module)
def __call__(self, model):
if hasattr(model, 'module'):
......
import importlib
import os
import pkgutil
from collections import namedtuple
import torch
......@@ -25,3 +26,8 @@ else:
ext_list.append(
extension.load(fun, name, lib_dir=lib_root).op_)
return ExtModule(*ext_list)
def check_ops_exist():
ext_loader = pkgutil.find_loader('mmcv._ext')
return ext_loader is not None
import sys
import warnings
from unittest.mock import MagicMock
import pytest
import torch
......@@ -7,6 +9,12 @@ import torch.nn as nn
from mmcv.runner import OPTIMIZER_BUILDERS, DefaultOptimizerConstructor
from mmcv.runner.optimizer import build_optimizer, build_optimizer_constructor
from mmcv.runner.optimizer.builder import TORCH_OPTIMIZERS
from mmcv.utils.ext_loader import check_ops_exist
OPS_AVAILABLE = check_ops_exist()
if not OPS_AVAILABLE:
sys.modules['mmcv.ops'] = MagicMock(
DeformConv2d=dict, ModulatedDeformConv2d=dict)
class SubModel(nn.Module):
......@@ -30,6 +38,10 @@ class ExampleModel(nn.Module):
self.conv2 = nn.Conv2d(4, 2, kernel_size=1)
self.bn = nn.BatchNorm2d(2)
self.sub = SubModel()
if OPS_AVAILABLE:
from mmcv.ops import DeformConv2dPack
self.dcn = DeformConv2dPack(
3, 4, kernel_size=3, deformable_groups=1)
def forward(self, x):
return x
......@@ -46,6 +58,10 @@ class ExampleDuplicateModel(nn.Module):
self.sub = SubModel()
self.conv3 = nn.Sequential(nn.Conv2d(3, 4, kernel_size=1, bias=False))
self.conv3[0] = self.conv1[0]
if OPS_AVAILABLE:
from mmcv.ops import DeformConv2dPack
self.dcn = DeformConv2dPack(
3, 4, kernel_size=3, deformable_groups=1)
def forward(self, x):
return x
......@@ -72,11 +88,19 @@ def check_default_optimizer(optimizer, model, prefix=''):
assert optimizer.defaults['momentum'] == momentum
assert optimizer.defaults['weight_decay'] == base_wd
param_groups = optimizer.param_groups[0]
param_names = [
'param1', 'conv1.weight', 'conv2.weight', 'conv2.bias', 'bn.weight',
'bn.bias', 'sub.param1', 'sub.conv1.weight', 'sub.conv1.bias',
'sub.gn.weight', 'sub.gn.bias'
]
if OPS_AVAILABLE:
param_names = [
'param1', 'conv1.weight', 'conv2.weight', 'conv2.bias',
'bn.weight', 'bn.bias', 'sub.param1', 'sub.conv1.weight',
'sub.conv1.bias', 'sub.gn.weight', 'sub.gn.bias', 'dcn.weight',
'dcn.conv_offset.weight', 'dcn.conv_offset.bias'
]
else:
param_names = [
'param1', 'conv1.weight', 'conv2.weight', 'conv2.bias',
'bn.weight', 'bn.bias', 'sub.param1', 'sub.conv1.weight',
'sub.conv1.bias', 'sub.gn.weight', 'sub.gn.bias'
]
param_dict = dict(model.named_parameters())
assert len(param_groups['params']) == len(param_names)
for i in range(len(param_groups['params'])):
......@@ -84,14 +108,15 @@ def check_default_optimizer(optimizer, model, prefix=''):
param_dict[prefix + param_names[i]])
def check_optimizer(optimizer,
model,
prefix='',
bias_lr_mult=1,
bias_decay_mult=1,
norm_decay_mult=1,
dwconv_decay_mult=1,
bypass_duplicate=False):
def check_sgd_optimizer(optimizer,
model,
prefix='',
bias_lr_mult=1,
bias_decay_mult=1,
norm_decay_mult=1,
dwconv_decay_mult=1,
dcn_offset_lr_mult=1,
bypass_duplicate=False):
param_groups = optimizer.param_groups
assert isinstance(optimizer, torch.optim.SGD)
assert optimizer.defaults['lr'] == base_lr
......@@ -103,6 +128,7 @@ def check_optimizer(optimizer,
param_group = param_groups[i]
assert torch.equal(param_group['params'][0], param)
assert param_group['momentum'] == momentum
# param1
param1 = param_groups[0]
assert param1['lr'] == base_lr
......@@ -148,6 +174,19 @@ def check_optimizer(optimizer,
assert sub_gn_bias['lr'] == base_lr
assert sub_gn_bias['weight_decay'] == base_wd * norm_decay_mult
if torch.cuda.is_available():
dcn_conv_weight = param_groups[11]
assert dcn_conv_weight['lr'] == base_lr
assert dcn_conv_weight['weight_decay'] == base_wd
dcn_offset_weight = param_groups[12]
assert dcn_offset_weight['lr'] == base_lr * dcn_offset_lr_mult
assert dcn_offset_weight['weight_decay'] == base_wd
dcn_offset_bias = param_groups[13]
assert dcn_offset_bias['lr'] == base_lr * dcn_offset_lr_mult
assert dcn_offset_bias['weight_decay'] == base_wd
def test_default_optimizer_constructor():
model = ExampleModel()
......@@ -229,11 +268,12 @@ def test_default_optimizer_constructor():
bias_lr_mult=2,
bias_decay_mult=0.5,
norm_decay_mult=0,
dwconv_decay_mult=0.1)
dwconv_decay_mult=0.1,
dcn_offset_lr_mult=0.1)
optim_constructor = DefaultOptimizerConstructor(optimizer_cfg,
paramwise_cfg)
optimizer = optim_constructor(model)
check_optimizer(optimizer, model, **paramwise_cfg)
check_sgd_optimizer(optimizer, model, **paramwise_cfg)
# paramwise_cfg with ExampleModel, weight decay is None
model = ExampleModel()
......@@ -274,6 +314,14 @@ def test_default_optimizer_constructor():
# sub.gn.bias
assert param_groups[10]['lr'] == base_lr
if OPS_AVAILABLE:
# dcn.weight
assert param_groups[11]['lr'] == base_lr
# dcn.conv_offset.weight
assert param_groups[12]['lr'] == base_lr
# dcn.conv_offset.bias
assert param_groups[13]['lr'] == base_lr
# paramwise_cfg with pseudo data parallel
model = PseudoDataParallel()
optimizer_cfg = dict(
......@@ -282,11 +330,12 @@ def test_default_optimizer_constructor():
bias_lr_mult=2,
bias_decay_mult=0.5,
norm_decay_mult=0,
dwconv_decay_mult=0.1)
dwconv_decay_mult=0.1,
dcn_offset_lr_mult=0.1)
optim_constructor = DefaultOptimizerConstructor(optimizer_cfg,
paramwise_cfg)
optimizer = optim_constructor(model)
check_optimizer(optimizer, model, prefix='module.', **paramwise_cfg)
check_sgd_optimizer(optimizer, model, prefix='module.', **paramwise_cfg)
# paramwise_cfg with DataParallel
if torch.cuda.is_available():
......@@ -297,11 +346,13 @@ def test_default_optimizer_constructor():
bias_lr_mult=2,
bias_decay_mult=0.5,
norm_decay_mult=0,
dwconv_decay_mult=0.1)
dwconv_decay_mult=0.1,
dcn_offset_lr_mult=0.1)
optim_constructor = DefaultOptimizerConstructor(
optimizer_cfg, paramwise_cfg)
optimizer = optim_constructor(model)
check_optimizer(optimizer, model, prefix='module.', **paramwise_cfg)
check_sgd_optimizer(
optimizer, model, prefix='module.', **paramwise_cfg)
# paramwise_cfg with ExampleModel and no grad
for param in model.parameters():
......@@ -342,6 +393,7 @@ def test_default_optimizer_constructor():
bias_decay_mult=0.5,
norm_decay_mult=0,
dwconv_decay_mult=0.1,
dcn_offset_lr_mult=0.1,
bypass_duplicate=True)
optim_constructor = DefaultOptimizerConstructor(optimizer_cfg,
paramwise_cfg)
......@@ -352,8 +404,9 @@ def test_default_optimizer_constructor():
assert str(w[0].message) == 'conv3.0 is duplicate. It is skipped ' \
'since bypass_duplicate=True'
model_parameters = list(model.parameters())
assert len(optimizer.param_groups) == len(model_parameters) == 11
check_optimizer(optimizer, model, **paramwise_cfg)
num_params = 14 if OPS_AVAILABLE else 11
assert len(optimizer.param_groups) == len(model_parameters) == num_params
check_sgd_optimizer(optimizer, model, **paramwise_cfg)
# test DefaultOptimizerConstructor with custom_keys and ExampleModel
model = ExampleModel()
......@@ -435,7 +488,8 @@ def test_default_optimizer_constructor():
'weight_decay': base_wd
})
assert len(param_groups) == 11
num_params = 14 if OPS_AVAILABLE else 11
assert len(param_groups) == num_params
for i, (name, param) in enumerate(model.named_parameters()):
assert torch.equal(param_groups[i]['params'][0], param)
for group, settings in zip(groups, group_settings):
......@@ -481,7 +535,8 @@ def test_default_optimizer_constructor():
'weight_decay': 0
})
assert len(param_groups) == 11
num_params = 14 if OPS_AVAILABLE else 11
assert len(param_groups) == num_params
for i, (name, param) in enumerate(model.named_parameters()):
assert torch.equal(param_groups[i]['params'][0], param)
for group, settings in zip(groups, group_settings):
......@@ -507,14 +562,15 @@ def test_build_optimizer_constructor():
bias_lr_mult=2,
bias_decay_mult=0.5,
norm_decay_mult=0,
dwconv_decay_mult=0.1)
dwconv_decay_mult=0.1,
dcn_offset_lr_mult=0.1)
optim_constructor_cfg = dict(
type='DefaultOptimizerConstructor',
optimizer_cfg=optimizer_cfg,
paramwise_cfg=paramwise_cfg)
optim_constructor = build_optimizer_constructor(optim_constructor_cfg)
optimizer = optim_constructor(model)
check_optimizer(optimizer, model, **paramwise_cfg)
check_sgd_optimizer(optimizer, model, **paramwise_cfg)
from mmcv.runner import OPTIMIZERS
from mmcv.utils import build_from_cfg
......@@ -577,6 +633,7 @@ def test_build_optimizer():
bias_lr_mult=2,
bias_decay_mult=0.5,
norm_decay_mult=0,
dwconv_decay_mult=0.1))
dwconv_decay_mult=0.1,
dcn_offset_lr_mult=0.1))
optimizer = build_optimizer(model, optimizer_cfg)
check_optimizer(optimizer, model, **optimizer_cfg['paramwise_cfg'])
check_sgd_optimizer(optimizer, model, **optimizer_cfg['paramwise_cfg'])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment