Commit 8bd1a5f7 authored by Wenwei Zhang's avatar Wenwei Zhang Committed by Kai Chen
Browse files

Encapsulate DCN into a ConvModule & Conv_layers (#1894)

* update code

* reformat

* fix load pretrain bug

* reformat

* fix key warnings

* use build_conv_layer & ConvModule

* rm unrelated file

* fix resnext group bug

* change cfg to pass test

* use build_conv_layer to claim dcn in resnets

* change _version = 2

* MDCN to DCNv2

* change dependency of mmcv from 0.2.15 to 0.2.16

* resolve comments

* rm comments

* get_root_logger()
parent 1ebb6cb8
...@@ -11,7 +11,7 @@ model = dict( ...@@ -11,7 +11,7 @@ model = dict(
frozen_stages=1, frozen_stages=1,
style='pytorch', style='pytorch',
dcn=dict( dcn=dict(
modulated=False, deformable_groups=1, fallback_on_stride=False), type='DCN', deformable_groups=1, fallback_on_stride=False),
stage_with_dcn=(False, True, True, True)), stage_with_dcn=(False, True, True, True)),
neck=dict( neck=dict(
type='FPN', type='FPN',
......
...@@ -11,7 +11,7 @@ model = dict( ...@@ -11,7 +11,7 @@ model = dict(
frozen_stages=1, frozen_stages=1,
style='pytorch', style='pytorch',
dcn=dict( dcn=dict(
modulated=False, deformable_groups=1, fallback_on_stride=False), type='DCN', deformable_groups=1, fallback_on_stride=False),
stage_with_dcn=(False, True, True, True)), stage_with_dcn=(False, True, True, True)),
neck=dict( neck=dict(
type='FPN', type='FPN',
......
...@@ -10,7 +10,7 @@ model = dict( ...@@ -10,7 +10,7 @@ model = dict(
frozen_stages=1, frozen_stages=1,
style='pytorch', style='pytorch',
dcn=dict( dcn=dict(
modulated=False, deformable_groups=1, fallback_on_stride=False), type='DCN', deformable_groups=1, fallback_on_stride=False),
stage_with_dcn=(False, True, True, True)), stage_with_dcn=(False, True, True, True)),
neck=dict( neck=dict(
type='FPN', type='FPN',
......
...@@ -12,8 +12,7 @@ model = dict( ...@@ -12,8 +12,7 @@ model = dict(
frozen_stages=1, frozen_stages=1,
style='pytorch', style='pytorch',
dcn=dict( dcn=dict(
modulated=False, type='DCN',
groups=32,
deformable_groups=1, deformable_groups=1,
fallback_on_stride=False), fallback_on_stride=False),
stage_with_dcn=(False, True, True, True)), stage_with_dcn=(False, True, True, True)),
......
...@@ -10,7 +10,7 @@ model = dict( ...@@ -10,7 +10,7 @@ model = dict(
frozen_stages=1, frozen_stages=1,
style='pytorch', style='pytorch',
dcn=dict( dcn=dict(
modulated=True, deformable_groups=4, fallback_on_stride=False), type='DCNv2', deformable_groups=4, fallback_on_stride=False),
stage_with_dcn=(False, True, True, True)), stage_with_dcn=(False, True, True, True)),
neck=dict( neck=dict(
type='FPN', type='FPN',
......
...@@ -10,7 +10,7 @@ model = dict( ...@@ -10,7 +10,7 @@ model = dict(
frozen_stages=1, frozen_stages=1,
style='pytorch', style='pytorch',
dcn=dict( dcn=dict(
modulated=True, deformable_groups=1, fallback_on_stride=False), type='DCNv2', deformable_groups=1, fallback_on_stride=False),
stage_with_dcn=(False, True, True, True)), stage_with_dcn=(False, True, True, True)),
neck=dict( neck=dict(
type='FPN', type='FPN',
......
...@@ -10,7 +10,7 @@ model = dict( ...@@ -10,7 +10,7 @@ model = dict(
frozen_stages=1, frozen_stages=1,
style='pytorch', style='pytorch',
dcn=dict( dcn=dict(
modulated=False, deformable_groups=1, fallback_on_stride=False), type='DCN', deformable_groups=1, fallback_on_stride=False),
stage_with_dcn=(False, True, True, True)), stage_with_dcn=(False, True, True, True)),
neck=dict( neck=dict(
type='FPN', type='FPN',
......
# model settings
model = dict(
type='MaskRCNN',
pretrained='torchvision://resnet50',
backbone=dict(
type='ResNet',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
style='pytorch',
dcn=dict(type='DCNv2', deformable_groups=1, fallback_on_stride=False),
stage_with_dcn=(False, True, True, True)),
neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
num_outs=5),
rpn_head=dict(
type='RPNHead',
in_channels=256,
feat_channels=256,
anchor_scales=[8],
anchor_ratios=[0.5, 1.0, 2.0],
anchor_strides=[4, 8, 16, 32, 64],
target_means=[.0, .0, .0, .0],
target_stds=[1.0, 1.0, 1.0, 1.0],
loss_cls=dict(
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0)),
bbox_roi_extractor=dict(
type='SingleRoIExtractor',
roi_layer=dict(type='RoIAlign', out_size=7, sample_num=2),
out_channels=256,
featmap_strides=[4, 8, 16, 32]),
bbox_head=dict(
type='SharedFCBBoxHead',
num_fcs=2,
in_channels=256,
fc_out_channels=1024,
roi_feat_size=7,
num_classes=81,
target_means=[0., 0., 0., 0.],
target_stds=[0.1, 0.1, 0.2, 0.2],
reg_class_agnostic=False,
loss_cls=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)),
mask_roi_extractor=dict(
type='SingleRoIExtractor',
roi_layer=dict(type='RoIAlign', out_size=14, sample_num=2),
out_channels=256,
featmap_strides=[4, 8, 16, 32]),
mask_head=dict(
type='FCNMaskHead',
num_convs=4,
in_channels=256,
conv_out_channels=256,
num_classes=81,
loss_mask=dict(
type='CrossEntropyLoss', use_mask=True, loss_weight=1.0)))
# model training and testing settings
train_cfg = dict(
rpn=dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.7,
neg_iou_thr=0.3,
min_pos_iou=0.3,
ignore_iof_thr=-1),
sampler=dict(
type='RandomSampler',
num=256,
pos_fraction=0.5,
neg_pos_ub=-1,
add_gt_as_proposals=False),
allowed_border=0,
pos_weight=-1,
debug=False),
rpn_proposal=dict(
nms_across_levels=False,
nms_pre=2000,
nms_post=2000,
max_num=2000,
nms_thr=0.7,
min_bbox_size=0),
rcnn=dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.5,
neg_iou_thr=0.5,
min_pos_iou=0.5,
ignore_iof_thr=-1),
sampler=dict(
type='RandomSampler',
num=512,
pos_fraction=0.25,
neg_pos_ub=-1,
add_gt_as_proposals=True),
mask_size=28,
pos_weight=-1,
debug=False))
test_cfg = dict(
rpn=dict(
nms_across_levels=False,
nms_pre=1000,
nms_post=1000,
max_num=1000,
nms_thr=0.7,
min_bbox_size=0),
rcnn=dict(
score_thr=0.05,
nms=dict(type='nms', iou_thr=0.5),
max_per_img=100,
mask_thr_binary=0.5))
# dataset settings
dataset_type = 'CocoDataset'
data_root = 'data/coco/'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(1333, 800),
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data = dict(
imgs_per_gpu=2,
workers_per_gpu=2,
train=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_train2017.json',
img_prefix=data_root + 'train2017/',
pipeline=train_pipeline),
val=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
pipeline=test_pipeline),
test=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
pipeline=test_pipeline))
# optimizer
optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)
optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
# learning policy
lr_config = dict(
policy='step',
warmup='linear',
warmup_iters=500,
warmup_ratio=1.0 / 3,
step=[8, 11])
checkpoint_config = dict(interval=1)
# yapf:disable
log_config = dict(
interval=50,
hooks=[
dict(type='TextLoggerHook'),
# dict(type='TensorboardLoggerHook')
])
# yapf:enable
# runtime settings
total_epochs = 12
dist_params = dict(backend='nccl')
log_level = 'INFO'
work_dir = './work_dirs/mask_rcnn_dconv_c3-c5_r50_fpn_1x'
load_from = None
resume_from = None
workflow = [('train', 1)]
...@@ -13,7 +13,7 @@ model = dict( ...@@ -13,7 +13,7 @@ model = dict(
spatial_range=-1, num_heads=8, attention_type='0010', kv_stride=2), spatial_range=-1, num_heads=8, attention_type='0010', kv_stride=2),
stage_with_gen_attention=[[], [], [0, 1, 2, 3, 4, 5], [0, 1, 2]], stage_with_gen_attention=[[], [], [0, 1, 2, 3, 4, 5], [0, 1, 2]],
dcn=dict( dcn=dict(
modulated=False, deformable_groups=1, fallback_on_stride=False), type='DCN', deformable_groups=1, fallback_on_stride=False),
stage_with_dcn=(False, True, True, True), stage_with_dcn=(False, True, True, True),
), ),
neck=dict( neck=dict(
......
...@@ -13,7 +13,7 @@ model = dict( ...@@ -13,7 +13,7 @@ model = dict(
spatial_range=-1, num_heads=8, attention_type='1111', kv_stride=2), spatial_range=-1, num_heads=8, attention_type='1111', kv_stride=2),
stage_with_gen_attention=[[], [], [0, 1, 2, 3, 4, 5], [0, 1, 2]], stage_with_gen_attention=[[], [], [0, 1, 2, 3, 4, 5], [0, 1, 2]],
dcn=dict( dcn=dict(
modulated=False, deformable_groups=1, fallback_on_stride=False), type='DCN', deformable_groups=1, fallback_on_stride=False),
stage_with_dcn=(False, True, True, True), stage_with_dcn=(False, True, True, True),
), ),
neck=dict( neck=dict(
......
...@@ -15,7 +15,7 @@ model = dict( ...@@ -15,7 +15,7 @@ model = dict(
frozen_stages=1, frozen_stages=1,
style='pytorch', style='pytorch',
dcn=dict( dcn=dict(
modulated=False, type='DCN',
groups=64, groups=64,
deformable_groups=1, deformable_groups=1,
fallback_on_stride=False), fallback_on_stride=False),
......
...@@ -5,7 +5,7 @@ from mmcv.runner import load_checkpoint ...@@ -5,7 +5,7 @@ from mmcv.runner import load_checkpoint
from torch.nn.modules.batchnorm import _BatchNorm from torch.nn.modules.batchnorm import _BatchNorm
from mmdet.models.plugins import GeneralizedAttention from mmdet.models.plugins import GeneralizedAttention
from mmdet.ops import ContextBlock, DeformConv, ModulatedDeformConv from mmdet.ops import ContextBlock
from ..registry import BACKBONES from ..registry import BACKBONES
from ..utils import build_conv_layer, build_norm_layer from ..utils import build_conv_layer, build_norm_layer
...@@ -143,10 +143,8 @@ class Bottleneck(nn.Module): ...@@ -143,10 +143,8 @@ class Bottleneck(nn.Module):
bias=False) bias=False)
self.add_module(self.norm1_name, norm1) self.add_module(self.norm1_name, norm1)
fallback_on_stride = False fallback_on_stride = False
self.with_modulated_dcn = False
if self.with_dcn: if self.with_dcn:
fallback_on_stride = dcn.get('fallback_on_stride', False) fallback_on_stride = dcn.pop('fallback_on_stride', False)
self.with_modulated_dcn = dcn.get('modulated', False)
if not self.with_dcn or fallback_on_stride: if not self.with_dcn or fallback_on_stride:
self.conv2 = build_conv_layer( self.conv2 = build_conv_layer(
conv_cfg, conv_cfg,
...@@ -158,30 +156,17 @@ class Bottleneck(nn.Module): ...@@ -158,30 +156,17 @@ class Bottleneck(nn.Module):
dilation=dilation, dilation=dilation,
bias=False) bias=False)
else: else:
assert conv_cfg is None, 'conv_cfg must be None for DCN' assert self.conv_cfg is None, 'conv_cfg cannot be None for DCN'
self.deformable_groups = dcn.get('deformable_groups', 1) self.conv2 = build_conv_layer(
if not self.with_modulated_dcn: dcn,
conv_op = DeformConv
offset_channels = 18
else:
conv_op = ModulatedDeformConv
offset_channels = 27
self.conv2_offset = nn.Conv2d(
planes,
self.deformable_groups * offset_channels,
kernel_size=3,
stride=self.conv2_stride,
padding=dilation,
dilation=dilation)
self.conv2 = conv_op(
planes, planes,
planes, planes,
kernel_size=3, kernel_size=3,
stride=self.conv2_stride, stride=self.conv2_stride,
padding=dilation, padding=dilation,
dilation=dilation, dilation=dilation,
deformable_groups=self.deformable_groups,
bias=False) bias=False)
self.add_module(self.norm2_name, norm2) self.add_module(self.norm2_name, norm2)
self.conv3 = build_conv_layer( self.conv3 = build_conv_layer(
conv_cfg, conv_cfg,
...@@ -224,17 +209,7 @@ class Bottleneck(nn.Module): ...@@ -224,17 +209,7 @@ class Bottleneck(nn.Module):
out = self.norm1(out) out = self.norm1(out)
out = self.relu(out) out = self.relu(out)
if not self.with_dcn: out = self.conv2(out)
out = self.conv2(out)
elif self.with_modulated_dcn:
offset_mask = self.conv2_offset(out)
offset = offset_mask[:, :18 * self.deformable_groups, :, :]
mask = offset_mask[:, -9 * self.deformable_groups:, :, :]
mask = mask.sigmoid()
out = self.conv2(out, offset, mask)
else:
offset = self.conv2_offset(out)
out = self.conv2(out, offset)
out = self.norm2(out) out = self.norm2(out)
out = self.relu(out) out = self.relu(out)
......
...@@ -2,7 +2,6 @@ import math ...@@ -2,7 +2,6 @@ import math
import torch.nn as nn import torch.nn as nn
from mmdet.ops import DeformConv, ModulatedDeformConv
from ..registry import BACKBONES from ..registry import BACKBONES
from ..utils import build_conv_layer, build_norm_layer from ..utils import build_conv_layer, build_norm_layer
from .resnet import Bottleneck as _Bottleneck from .resnet import Bottleneck as _Bottleneck
...@@ -41,8 +40,7 @@ class Bottleneck(_Bottleneck): ...@@ -41,8 +40,7 @@ class Bottleneck(_Bottleneck):
fallback_on_stride = False fallback_on_stride = False
self.with_modulated_dcn = False self.with_modulated_dcn = False
if self.with_dcn: if self.with_dcn:
fallback_on_stride = self.dcn.get('fallback_on_stride', False) fallback_on_stride = self.dcn.pop('fallback_on_stride', False)
self.with_modulated_dcn = self.dcn.get('modulated', False)
if not self.with_dcn or fallback_on_stride: if not self.with_dcn or fallback_on_stride:
self.conv2 = build_conv_layer( self.conv2 = build_conv_layer(
self.conv_cfg, self.conv_cfg,
...@@ -56,22 +54,8 @@ class Bottleneck(_Bottleneck): ...@@ -56,22 +54,8 @@ class Bottleneck(_Bottleneck):
bias=False) bias=False)
else: else:
assert self.conv_cfg is None, 'conv_cfg must be None for DCN' assert self.conv_cfg is None, 'conv_cfg must be None for DCN'
groups = self.dcn.get('groups', 1) self.conv2 = build_conv_layer(
deformable_groups = self.dcn.get('deformable_groups', 1) self.dcn,
if not self.with_modulated_dcn:
conv_op = DeformConv
offset_channels = 18
else:
conv_op = ModulatedDeformConv
offset_channels = 27
self.conv2_offset = nn.Conv2d(
width,
deformable_groups * offset_channels,
kernel_size=3,
stride=self.conv2_stride,
padding=self.dilation,
dilation=self.dilation)
self.conv2 = conv_op(
width, width,
width, width,
kernel_size=3, kernel_size=3,
...@@ -79,8 +63,8 @@ class Bottleneck(_Bottleneck): ...@@ -79,8 +63,8 @@ class Bottleneck(_Bottleneck):
padding=self.dilation, padding=self.dilation,
dilation=self.dilation, dilation=self.dilation,
groups=groups, groups=groups,
deformable_groups=deformable_groups,
bias=False) bias=False)
self.add_module(self.norm2_name, norm2) self.add_module(self.norm2_name, norm2)
self.conv3 = build_conv_layer( self.conv3 = build_conv_layer(
self.conv_cfg, self.conv_cfg,
......
...@@ -3,12 +3,15 @@ import warnings ...@@ -3,12 +3,15 @@ import warnings
import torch.nn as nn import torch.nn as nn
from mmcv.cnn import constant_init, kaiming_init from mmcv.cnn import constant_init, kaiming_init
from mmdet.ops import DeformConvPack, ModulatedDeformConvPack
from .conv_ws import ConvWS2d from .conv_ws import ConvWS2d
from .norm import build_norm_layer from .norm import build_norm_layer
conv_cfg = { conv_cfg = {
'Conv': nn.Conv2d, 'Conv': nn.Conv2d,
'ConvWS': ConvWS2d, 'ConvWS': ConvWS2d,
'DCN': DeformConvPack,
'DCNv2': ModulatedDeformConvPack,
# TODO: octave conv # TODO: octave conv
} }
......
...@@ -4,7 +4,7 @@ import torch ...@@ -4,7 +4,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from torch.autograd import Function from torch.autograd import Function
from torch.autograd.function import once_differentiable from torch.autograd.function import once_differentiable
from torch.nn.modules.utils import _pair from torch.nn.modules.utils import _pair, _single
from . import deform_conv_cuda from . import deform_conv_cuda
...@@ -24,7 +24,7 @@ class DeformConvFunction(Function): ...@@ -24,7 +24,7 @@ class DeformConvFunction(Function):
im2col_step=64): im2col_step=64):
if input is not None and input.dim() != 4: if input is not None and input.dim() != 4:
raise ValueError( raise ValueError(
"Expected 4D tensor as input, got {}D tensor instead.".format( 'Expected 4D tensor as input, got {}D tensor instead.'.format(
input.dim())) input.dim()))
ctx.stride = _pair(stride) ctx.stride = _pair(stride)
ctx.padding = _pair(padding) ctx.padding = _pair(padding)
...@@ -105,7 +105,7 @@ class DeformConvFunction(Function): ...@@ -105,7 +105,7 @@ class DeformConvFunction(Function):
output_size += ((in_size + (2 * pad) - kernel) // stride_ + 1, ) output_size += ((in_size + (2 * pad) - kernel) // stride_ + 1, )
if not all(map(lambda s: s > 0, output_size)): if not all(map(lambda s: s > 0, output_size)):
raise ValueError( raise ValueError(
"convolution input is too small (output would be {})".format( 'convolution input is too small (output would be {})'.format(
'x'.join(map(str, output_size)))) 'x'.join(map(str, output_size))))
return output_size return output_size
...@@ -217,6 +217,9 @@ class DeformConv(nn.Module): ...@@ -217,6 +217,9 @@ class DeformConv(nn.Module):
self.dilation = _pair(dilation) self.dilation = _pair(dilation)
self.groups = groups self.groups = groups
self.deformable_groups = deformable_groups self.deformable_groups = deformable_groups
# enable compatibility with nn.Conv2d
self.transposed = False
self.output_padding = _single(0)
self.weight = nn.Parameter( self.weight = nn.Parameter(
torch.Tensor(out_channels, in_channels // self.groups, torch.Tensor(out_channels, in_channels // self.groups,
...@@ -237,6 +240,22 @@ class DeformConv(nn.Module): ...@@ -237,6 +240,22 @@ class DeformConv(nn.Module):
class DeformConvPack(DeformConv): class DeformConvPack(DeformConv):
"""A Deformable Conv Encapsulation that acts as normal Conv layers.
Args:
in_channels (int): Same as nn.Conv2d.
out_channels (int): Same as nn.Conv2d.
kernel_size (int or tuple[int]): Same as nn.Conv2d.
stride (int or tuple[int]): Same as nn.Conv2d.
padding (int or tuple[int]): Same as nn.Conv2d.
dilation (int or tuple[int]): Same as nn.Conv2d.
groups (int): Same as nn.Conv2d.
bias (bool or str): If specified as `auto`, it will be decided by the
norm_cfg. Bias will be set as True if norm_cfg is None, otherwise
False.
"""
_version = 2
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(DeformConvPack, self).__init__(*args, **kwargs) super(DeformConvPack, self).__init__(*args, **kwargs)
...@@ -260,6 +279,33 @@ class DeformConvPack(DeformConv): ...@@ -260,6 +279,33 @@ class DeformConvPack(DeformConv):
return deform_conv(x, offset, self.weight, self.stride, self.padding, return deform_conv(x, offset, self.weight, self.stride, self.padding,
self.dilation, self.groups, self.deformable_groups) self.dilation, self.groups, self.deformable_groups)
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
version = local_metadata.get('version', None)
if version is None or version < 2:
# the key is different in early versions
# In version < 2, DeformConvPack loads previous benchmark models.
if (prefix + 'conv_offset.weight' not in state_dict
and prefix[:-1] + '_offset.weight' in state_dict):
state_dict[prefix + 'conv_offset.weight'] = state_dict.pop(
prefix[:-1] + '_offset.weight')
if (prefix + 'conv_offset.bias' not in state_dict
and prefix[:-1] + '_offset.bias' in state_dict):
state_dict[prefix +
'conv_offset.bias'] = state_dict.pop(prefix[:-1] +
'_offset.bias')
if version is not None and version > 1:
from mmdet.apis import get_root_logger
logger = get_root_logger()
logger.info('DeformConvPack {} is upgraded to version 2.'.format(
prefix.rstrip('.')))
super()._load_from_state_dict(state_dict, prefix, local_metadata,
strict, missing_keys, unexpected_keys,
error_msgs)
class ModulatedDeformConv(nn.Module): class ModulatedDeformConv(nn.Module):
...@@ -283,6 +329,9 @@ class ModulatedDeformConv(nn.Module): ...@@ -283,6 +329,9 @@ class ModulatedDeformConv(nn.Module):
self.groups = groups self.groups = groups
self.deformable_groups = deformable_groups self.deformable_groups = deformable_groups
self.with_bias = bias self.with_bias = bias
# enable compatibility with nn.Conv2d
self.transposed = False
self.output_padding = _single(0)
self.weight = nn.Parameter( self.weight = nn.Parameter(
torch.Tensor(out_channels, in_channels // groups, torch.Tensor(out_channels, in_channels // groups,
...@@ -309,11 +358,27 @@ class ModulatedDeformConv(nn.Module): ...@@ -309,11 +358,27 @@ class ModulatedDeformConv(nn.Module):
class ModulatedDeformConvPack(ModulatedDeformConv): class ModulatedDeformConvPack(ModulatedDeformConv):
"""A ModulatedDeformable Conv Encapsulation that acts as normal Conv layers.
Args:
in_channels (int): Same as nn.Conv2d.
out_channels (int): Same as nn.Conv2d.
kernel_size (int or tuple[int]): Same as nn.Conv2d.
stride (int or tuple[int]): Same as nn.Conv2d.
padding (int or tuple[int]): Same as nn.Conv2d.
dilation (int or tuple[int]): Same as nn.Conv2d.
groups (int): Same as nn.Conv2d.
bias (bool or str): If specified as `auto`, it will be decided by the
norm_cfg. Bias will be set as True if norm_cfg is None, otherwise
False.
"""
_version = 2
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(ModulatedDeformConvPack, self).__init__(*args, **kwargs) super(ModulatedDeformConvPack, self).__init__(*args, **kwargs)
self.conv_offset_mask = nn.Conv2d( self.conv_offset = nn.Conv2d(
self.in_channels, self.in_channels,
self.deformable_groups * 3 * self.kernel_size[0] * self.deformable_groups * 3 * self.kernel_size[0] *
self.kernel_size[1], self.kernel_size[1],
...@@ -324,14 +389,43 @@ class ModulatedDeformConvPack(ModulatedDeformConv): ...@@ -324,14 +389,43 @@ class ModulatedDeformConvPack(ModulatedDeformConv):
self.init_offset() self.init_offset()
def init_offset(self): def init_offset(self):
self.conv_offset_mask.weight.data.zero_() self.conv_offset.weight.data.zero_()
self.conv_offset_mask.bias.data.zero_() self.conv_offset.bias.data.zero_()
def forward(self, x): def forward(self, x):
out = self.conv_offset_mask(x) out = self.conv_offset(x)
o1, o2, mask = torch.chunk(out, 3, dim=1) o1, o2, mask = torch.chunk(out, 3, dim=1)
offset = torch.cat((o1, o2), dim=1) offset = torch.cat((o1, o2), dim=1)
mask = torch.sigmoid(mask) mask = torch.sigmoid(mask)
return modulated_deform_conv(x, offset, mask, self.weight, self.bias, return modulated_deform_conv(x, offset, mask, self.weight, self.bias,
self.stride, self.padding, self.dilation, self.stride, self.padding, self.dilation,
self.groups, self.deformable_groups) self.groups, self.deformable_groups)
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
version = local_metadata.get('version', None)
if version is None or version < 2:
# the key is different in early versions
# In version < 2, ModulatedDeformConvPack
# loads previous benchmark models.
if (prefix + 'conv_offset.weight' not in state_dict
and prefix[:-1] + '_offset.weight' in state_dict):
state_dict[prefix + 'conv_offset.weight'] = state_dict.pop(
prefix[:-1] + '_offset.weight')
if (prefix + 'conv_offset.bias' not in state_dict
and prefix[:-1] + '_offset.bias' in state_dict):
state_dict[prefix +
'conv_offset.bias'] = state_dict.pop(prefix[:-1] +
'_offset.bias')
if version is not None and version > 1:
from mmdet.apis import get_root_logger
logger = get_root_logger()
logger.info(
'ModulatedDeformConvPack {} is upgraded to version 2.'.format(
prefix.rstrip('.')))
super()._load_from_state_dict(state_dict, prefix, local_metadata,
strict, missing_keys, unexpected_keys,
error_msgs)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment