Commit 8bd1a5f7 authored by Wenwei Zhang's avatar Wenwei Zhang Committed by Kai Chen
Browse files

Encapsulate DCN into a ConvModule & Conv_layers (#1894)

* update code

* reformat

* fix load pretrain bug

* reformat

* fix key warnings

* use build_conv_layer & ConvModule

* rm unrelated file

* fix resnext group bug

* change cfg to pass test

* use build_conv_layer to claim dcn in resnets

* change _version = 2

* MDCN to DCNv2

* change dependency of mmcv from 0.2.15 to 0.2.16

* resolve comments

* rm comments

* get_root_logger()
parent 1ebb6cb8
......@@ -11,7 +11,7 @@ model = dict(
frozen_stages=1,
style='pytorch',
dcn=dict(
modulated=False, deformable_groups=1, fallback_on_stride=False),
type='DCN', deformable_groups=1, fallback_on_stride=False),
stage_with_dcn=(False, True, True, True)),
neck=dict(
type='FPN',
......
......@@ -11,7 +11,7 @@ model = dict(
frozen_stages=1,
style='pytorch',
dcn=dict(
modulated=False, deformable_groups=1, fallback_on_stride=False),
type='DCN', deformable_groups=1, fallback_on_stride=False),
stage_with_dcn=(False, True, True, True)),
neck=dict(
type='FPN',
......
......@@ -10,7 +10,7 @@ model = dict(
frozen_stages=1,
style='pytorch',
dcn=dict(
modulated=False, deformable_groups=1, fallback_on_stride=False),
type='DCN', deformable_groups=1, fallback_on_stride=False),
stage_with_dcn=(False, True, True, True)),
neck=dict(
type='FPN',
......
......@@ -12,8 +12,7 @@ model = dict(
frozen_stages=1,
style='pytorch',
dcn=dict(
modulated=False,
groups=32,
type='DCN',
deformable_groups=1,
fallback_on_stride=False),
stage_with_dcn=(False, True, True, True)),
......
......@@ -10,7 +10,7 @@ model = dict(
frozen_stages=1,
style='pytorch',
dcn=dict(
modulated=True, deformable_groups=4, fallback_on_stride=False),
type='DCNv2', deformable_groups=4, fallback_on_stride=False),
stage_with_dcn=(False, True, True, True)),
neck=dict(
type='FPN',
......
......@@ -10,7 +10,7 @@ model = dict(
frozen_stages=1,
style='pytorch',
dcn=dict(
modulated=True, deformable_groups=1, fallback_on_stride=False),
type='DCNv2', deformable_groups=1, fallback_on_stride=False),
stage_with_dcn=(False, True, True, True)),
neck=dict(
type='FPN',
......
......@@ -10,7 +10,7 @@ model = dict(
frozen_stages=1,
style='pytorch',
dcn=dict(
modulated=False, deformable_groups=1, fallback_on_stride=False),
type='DCN', deformable_groups=1, fallback_on_stride=False),
stage_with_dcn=(False, True, True, True)),
neck=dict(
type='FPN',
......
# model settings
model = dict(
type='MaskRCNN',
pretrained='torchvision://resnet50',
backbone=dict(
type='ResNet',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
style='pytorch',
dcn=dict(type='DCNv2', deformable_groups=1, fallback_on_stride=False),
stage_with_dcn=(False, True, True, True)),
neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
num_outs=5),
rpn_head=dict(
type='RPNHead',
in_channels=256,
feat_channels=256,
anchor_scales=[8],
anchor_ratios=[0.5, 1.0, 2.0],
anchor_strides=[4, 8, 16, 32, 64],
target_means=[.0, .0, .0, .0],
target_stds=[1.0, 1.0, 1.0, 1.0],
loss_cls=dict(
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0)),
bbox_roi_extractor=dict(
type='SingleRoIExtractor',
roi_layer=dict(type='RoIAlign', out_size=7, sample_num=2),
out_channels=256,
featmap_strides=[4, 8, 16, 32]),
bbox_head=dict(
type='SharedFCBBoxHead',
num_fcs=2,
in_channels=256,
fc_out_channels=1024,
roi_feat_size=7,
num_classes=81,
target_means=[0., 0., 0., 0.],
target_stds=[0.1, 0.1, 0.2, 0.2],
reg_class_agnostic=False,
loss_cls=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)),
mask_roi_extractor=dict(
type='SingleRoIExtractor',
roi_layer=dict(type='RoIAlign', out_size=14, sample_num=2),
out_channels=256,
featmap_strides=[4, 8, 16, 32]),
mask_head=dict(
type='FCNMaskHead',
num_convs=4,
in_channels=256,
conv_out_channels=256,
num_classes=81,
loss_mask=dict(
type='CrossEntropyLoss', use_mask=True, loss_weight=1.0)))
# model training and testing settings
train_cfg = dict(
rpn=dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.7,
neg_iou_thr=0.3,
min_pos_iou=0.3,
ignore_iof_thr=-1),
sampler=dict(
type='RandomSampler',
num=256,
pos_fraction=0.5,
neg_pos_ub=-1,
add_gt_as_proposals=False),
allowed_border=0,
pos_weight=-1,
debug=False),
rpn_proposal=dict(
nms_across_levels=False,
nms_pre=2000,
nms_post=2000,
max_num=2000,
nms_thr=0.7,
min_bbox_size=0),
rcnn=dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.5,
neg_iou_thr=0.5,
min_pos_iou=0.5,
ignore_iof_thr=-1),
sampler=dict(
type='RandomSampler',
num=512,
pos_fraction=0.25,
neg_pos_ub=-1,
add_gt_as_proposals=True),
mask_size=28,
pos_weight=-1,
debug=False))
test_cfg = dict(
rpn=dict(
nms_across_levels=False,
nms_pre=1000,
nms_post=1000,
max_num=1000,
nms_thr=0.7,
min_bbox_size=0),
rcnn=dict(
score_thr=0.05,
nms=dict(type='nms', iou_thr=0.5),
max_per_img=100,
mask_thr_binary=0.5))
# dataset settings
dataset_type = 'CocoDataset'
data_root = 'data/coco/'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(1333, 800),
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data = dict(
imgs_per_gpu=2,
workers_per_gpu=2,
train=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_train2017.json',
img_prefix=data_root + 'train2017/',
pipeline=train_pipeline),
val=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
pipeline=test_pipeline),
test=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
pipeline=test_pipeline))
# optimizer
optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)
optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
# learning policy
lr_config = dict(
policy='step',
warmup='linear',
warmup_iters=500,
warmup_ratio=1.0 / 3,
step=[8, 11])
checkpoint_config = dict(interval=1)
# yapf:disable
log_config = dict(
interval=50,
hooks=[
dict(type='TextLoggerHook'),
# dict(type='TensorboardLoggerHook')
])
# yapf:enable
# runtime settings
total_epochs = 12
dist_params = dict(backend='nccl')
log_level = 'INFO'
work_dir = './work_dirs/mask_rcnn_dconv_c3-c5_r50_fpn_1x'
load_from = None
resume_from = None
workflow = [('train', 1)]
......@@ -13,7 +13,7 @@ model = dict(
spatial_range=-1, num_heads=8, attention_type='0010', kv_stride=2),
stage_with_gen_attention=[[], [], [0, 1, 2, 3, 4, 5], [0, 1, 2]],
dcn=dict(
modulated=False, deformable_groups=1, fallback_on_stride=False),
type='DCN', deformable_groups=1, fallback_on_stride=False),
stage_with_dcn=(False, True, True, True),
),
neck=dict(
......
......@@ -13,7 +13,7 @@ model = dict(
spatial_range=-1, num_heads=8, attention_type='1111', kv_stride=2),
stage_with_gen_attention=[[], [], [0, 1, 2, 3, 4, 5], [0, 1, 2]],
dcn=dict(
modulated=False, deformable_groups=1, fallback_on_stride=False),
type='DCN', deformable_groups=1, fallback_on_stride=False),
stage_with_dcn=(False, True, True, True),
),
neck=dict(
......
......@@ -15,7 +15,7 @@ model = dict(
frozen_stages=1,
style='pytorch',
dcn=dict(
modulated=False,
type='DCN',
groups=64,
deformable_groups=1,
fallback_on_stride=False),
......
......@@ -5,7 +5,7 @@ from mmcv.runner import load_checkpoint
from torch.nn.modules.batchnorm import _BatchNorm
from mmdet.models.plugins import GeneralizedAttention
from mmdet.ops import ContextBlock, DeformConv, ModulatedDeformConv
from mmdet.ops import ContextBlock
from ..registry import BACKBONES
from ..utils import build_conv_layer, build_norm_layer
......@@ -143,10 +143,8 @@ class Bottleneck(nn.Module):
bias=False)
self.add_module(self.norm1_name, norm1)
fallback_on_stride = False
self.with_modulated_dcn = False
if self.with_dcn:
fallback_on_stride = dcn.get('fallback_on_stride', False)
self.with_modulated_dcn = dcn.get('modulated', False)
fallback_on_stride = dcn.pop('fallback_on_stride', False)
if not self.with_dcn or fallback_on_stride:
self.conv2 = build_conv_layer(
conv_cfg,
......@@ -158,30 +156,17 @@ class Bottleneck(nn.Module):
dilation=dilation,
bias=False)
else:
assert conv_cfg is None, 'conv_cfg must be None for DCN'
self.deformable_groups = dcn.get('deformable_groups', 1)
if not self.with_modulated_dcn:
conv_op = DeformConv
offset_channels = 18
else:
conv_op = ModulatedDeformConv
offset_channels = 27
self.conv2_offset = nn.Conv2d(
planes,
self.deformable_groups * offset_channels,
kernel_size=3,
stride=self.conv2_stride,
padding=dilation,
dilation=dilation)
self.conv2 = conv_op(
assert self.conv_cfg is None, 'conv_cfg cannot be None for DCN'
self.conv2 = build_conv_layer(
dcn,
planes,
planes,
kernel_size=3,
stride=self.conv2_stride,
padding=dilation,
dilation=dilation,
deformable_groups=self.deformable_groups,
bias=False)
self.add_module(self.norm2_name, norm2)
self.conv3 = build_conv_layer(
conv_cfg,
......@@ -224,17 +209,7 @@ class Bottleneck(nn.Module):
out = self.norm1(out)
out = self.relu(out)
if not self.with_dcn:
out = self.conv2(out)
elif self.with_modulated_dcn:
offset_mask = self.conv2_offset(out)
offset = offset_mask[:, :18 * self.deformable_groups, :, :]
mask = offset_mask[:, -9 * self.deformable_groups:, :, :]
mask = mask.sigmoid()
out = self.conv2(out, offset, mask)
else:
offset = self.conv2_offset(out)
out = self.conv2(out, offset)
out = self.conv2(out)
out = self.norm2(out)
out = self.relu(out)
......
......@@ -2,7 +2,6 @@ import math
import torch.nn as nn
from mmdet.ops import DeformConv, ModulatedDeformConv
from ..registry import BACKBONES
from ..utils import build_conv_layer, build_norm_layer
from .resnet import Bottleneck as _Bottleneck
......@@ -41,8 +40,7 @@ class Bottleneck(_Bottleneck):
fallback_on_stride = False
self.with_modulated_dcn = False
if self.with_dcn:
fallback_on_stride = self.dcn.get('fallback_on_stride', False)
self.with_modulated_dcn = self.dcn.get('modulated', False)
fallback_on_stride = self.dcn.pop('fallback_on_stride', False)
if not self.with_dcn or fallback_on_stride:
self.conv2 = build_conv_layer(
self.conv_cfg,
......@@ -56,22 +54,8 @@ class Bottleneck(_Bottleneck):
bias=False)
else:
assert self.conv_cfg is None, 'conv_cfg must be None for DCN'
groups = self.dcn.get('groups', 1)
deformable_groups = self.dcn.get('deformable_groups', 1)
if not self.with_modulated_dcn:
conv_op = DeformConv
offset_channels = 18
else:
conv_op = ModulatedDeformConv
offset_channels = 27
self.conv2_offset = nn.Conv2d(
width,
deformable_groups * offset_channels,
kernel_size=3,
stride=self.conv2_stride,
padding=self.dilation,
dilation=self.dilation)
self.conv2 = conv_op(
self.conv2 = build_conv_layer(
self.dcn,
width,
width,
kernel_size=3,
......@@ -79,8 +63,8 @@ class Bottleneck(_Bottleneck):
padding=self.dilation,
dilation=self.dilation,
groups=groups,
deformable_groups=deformable_groups,
bias=False)
self.add_module(self.norm2_name, norm2)
self.conv3 = build_conv_layer(
self.conv_cfg,
......
......@@ -3,12 +3,15 @@ import warnings
import torch.nn as nn
from mmcv.cnn import constant_init, kaiming_init
from mmdet.ops import DeformConvPack, ModulatedDeformConvPack
from .conv_ws import ConvWS2d
from .norm import build_norm_layer
conv_cfg = {
'Conv': nn.Conv2d,
'ConvWS': ConvWS2d,
'DCN': DeformConvPack,
'DCNv2': ModulatedDeformConvPack,
# TODO: octave conv
}
......
......@@ -4,7 +4,7 @@ import torch
import torch.nn as nn
from torch.autograd import Function
from torch.autograd.function import once_differentiable
from torch.nn.modules.utils import _pair
from torch.nn.modules.utils import _pair, _single
from . import deform_conv_cuda
......@@ -24,7 +24,7 @@ class DeformConvFunction(Function):
im2col_step=64):
if input is not None and input.dim() != 4:
raise ValueError(
"Expected 4D tensor as input, got {}D tensor instead.".format(
'Expected 4D tensor as input, got {}D tensor instead.'.format(
input.dim()))
ctx.stride = _pair(stride)
ctx.padding = _pair(padding)
......@@ -105,7 +105,7 @@ class DeformConvFunction(Function):
output_size += ((in_size + (2 * pad) - kernel) // stride_ + 1, )
if not all(map(lambda s: s > 0, output_size)):
raise ValueError(
"convolution input is too small (output would be {})".format(
'convolution input is too small (output would be {})'.format(
'x'.join(map(str, output_size))))
return output_size
......@@ -217,6 +217,9 @@ class DeformConv(nn.Module):
self.dilation = _pair(dilation)
self.groups = groups
self.deformable_groups = deformable_groups
# enable compatibility with nn.Conv2d
self.transposed = False
self.output_padding = _single(0)
self.weight = nn.Parameter(
torch.Tensor(out_channels, in_channels // self.groups,
......@@ -237,6 +240,22 @@ class DeformConv(nn.Module):
class DeformConvPack(DeformConv):
"""A Deformable Conv Encapsulation that acts as normal Conv layers.
Args:
in_channels (int): Same as nn.Conv2d.
out_channels (int): Same as nn.Conv2d.
kernel_size (int or tuple[int]): Same as nn.Conv2d.
stride (int or tuple[int]): Same as nn.Conv2d.
padding (int or tuple[int]): Same as nn.Conv2d.
dilation (int or tuple[int]): Same as nn.Conv2d.
groups (int): Same as nn.Conv2d.
bias (bool or str): If specified as `auto`, it will be decided by the
norm_cfg. Bias will be set as True if norm_cfg is None, otherwise
False.
"""
_version = 2
def __init__(self, *args, **kwargs):
super(DeformConvPack, self).__init__(*args, **kwargs)
......@@ -260,6 +279,33 @@ class DeformConvPack(DeformConv):
return deform_conv(x, offset, self.weight, self.stride, self.padding,
self.dilation, self.groups, self.deformable_groups)
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
version = local_metadata.get('version', None)
if version is None or version < 2:
# the key is different in early versions
# In version < 2, DeformConvPack loads previous benchmark models.
if (prefix + 'conv_offset.weight' not in state_dict
and prefix[:-1] + '_offset.weight' in state_dict):
state_dict[prefix + 'conv_offset.weight'] = state_dict.pop(
prefix[:-1] + '_offset.weight')
if (prefix + 'conv_offset.bias' not in state_dict
and prefix[:-1] + '_offset.bias' in state_dict):
state_dict[prefix +
'conv_offset.bias'] = state_dict.pop(prefix[:-1] +
'_offset.bias')
if version is not None and version > 1:
from mmdet.apis import get_root_logger
logger = get_root_logger()
logger.info('DeformConvPack {} is upgraded to version 2.'.format(
prefix.rstrip('.')))
super()._load_from_state_dict(state_dict, prefix, local_metadata,
strict, missing_keys, unexpected_keys,
error_msgs)
class ModulatedDeformConv(nn.Module):
......@@ -283,6 +329,9 @@ class ModulatedDeformConv(nn.Module):
self.groups = groups
self.deformable_groups = deformable_groups
self.with_bias = bias
# enable compatibility with nn.Conv2d
self.transposed = False
self.output_padding = _single(0)
self.weight = nn.Parameter(
torch.Tensor(out_channels, in_channels // groups,
......@@ -309,11 +358,27 @@ class ModulatedDeformConv(nn.Module):
class ModulatedDeformConvPack(ModulatedDeformConv):
"""A ModulatedDeformable Conv Encapsulation that acts as normal Conv layers.
Args:
in_channels (int): Same as nn.Conv2d.
out_channels (int): Same as nn.Conv2d.
kernel_size (int or tuple[int]): Same as nn.Conv2d.
stride (int or tuple[int]): Same as nn.Conv2d.
padding (int or tuple[int]): Same as nn.Conv2d.
dilation (int or tuple[int]): Same as nn.Conv2d.
groups (int): Same as nn.Conv2d.
bias (bool or str): If specified as `auto`, it will be decided by the
norm_cfg. Bias will be set as True if norm_cfg is None, otherwise
False.
"""
_version = 2
def __init__(self, *args, **kwargs):
super(ModulatedDeformConvPack, self).__init__(*args, **kwargs)
self.conv_offset_mask = nn.Conv2d(
self.conv_offset = nn.Conv2d(
self.in_channels,
self.deformable_groups * 3 * self.kernel_size[0] *
self.kernel_size[1],
......@@ -324,14 +389,43 @@ class ModulatedDeformConvPack(ModulatedDeformConv):
self.init_offset()
def init_offset(self):
self.conv_offset_mask.weight.data.zero_()
self.conv_offset_mask.bias.data.zero_()
self.conv_offset.weight.data.zero_()
self.conv_offset.bias.data.zero_()
def forward(self, x):
out = self.conv_offset_mask(x)
out = self.conv_offset(x)
o1, o2, mask = torch.chunk(out, 3, dim=1)
offset = torch.cat((o1, o2), dim=1)
mask = torch.sigmoid(mask)
return modulated_deform_conv(x, offset, mask, self.weight, self.bias,
self.stride, self.padding, self.dilation,
self.groups, self.deformable_groups)
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
version = local_metadata.get('version', None)
if version is None or version < 2:
# the key is different in early versions
# In version < 2, ModulatedDeformConvPack
# loads previous benchmark models.
if (prefix + 'conv_offset.weight' not in state_dict
and prefix[:-1] + '_offset.weight' in state_dict):
state_dict[prefix + 'conv_offset.weight'] = state_dict.pop(
prefix[:-1] + '_offset.weight')
if (prefix + 'conv_offset.bias' not in state_dict
and prefix[:-1] + '_offset.bias' in state_dict):
state_dict[prefix +
'conv_offset.bias'] = state_dict.pop(prefix[:-1] +
'_offset.bias')
if version is not None and version > 1:
from mmdet.apis import get_root_logger
logger = get_root_logger()
logger.info(
'ModulatedDeformConvPack {} is upgraded to version 2.'.format(
prefix.rstrip('.')))
super()._load_from_state_dict(state_dict, prefix, local_metadata,
strict, missing_keys, unexpected_keys,
error_msgs)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment