Unverified Commit 3b6ae96d authored by Kai Chen's avatar Kai Chen Committed by GitHub
Browse files

Merge pull request #299 from yhcao6/dcn_cpp_extension

Add dcn group param support
parents e2227ddb 2b6104d3
...@@ -193,6 +193,7 @@ We released RPN, Faster R-CNN and Mask R-CNN models in the first version. More m ...@@ -193,6 +193,7 @@ We released RPN, Faster R-CNN and Mask R-CNN models in the first version. More m
| R-50-FPN | Faster | pytorch | - | dpool | 1x | 4.6 | 0.714 | 8.7 | 37.9 | - | [model](https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/dcn/faster_rcnn_dpool_r50_fpn_1x_20190125-f4fc1d70.pth) | | R-50-FPN | Faster | pytorch | - | dpool | 1x | 4.6 | 0.714 | 8.7 | 37.9 | - | [model](https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/dcn/faster_rcnn_dpool_r50_fpn_1x_20190125-f4fc1d70.pth) |
| R-50-FPN | Faster | pytorch | - | mdpool | 1x | 5.2 | 0.769 | 8.2 | 38.1 | - | [model](https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/dcn/faster_rcnn_mdpool_r50_fpn_1x_20190125-473d0f3d.pth) | | R-50-FPN | Faster | pytorch | - | mdpool | 1x | 5.2 | 0.769 | 8.2 | 38.1 | - | [model](https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/dcn/faster_rcnn_mdpool_r50_fpn_1x_20190125-473d0f3d.pth) |
| R-101-FPN | Faster | pytorch | dconv(c3-c5) | - | 1x | 5.8 | 0.811 | 8.0 | 42.1 | - | [model](https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/dcn/faster_rcnn_dconv_c3-c5_r101_fpn_1x_20190125-a7e31b65.pth) | | R-101-FPN | Faster | pytorch | dconv(c3-c5) | - | 1x | 5.8 | 0.811 | 8.0 | 42.1 | - | [model](https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/dcn/faster_rcnn_dconv_c3-c5_r101_fpn_1x_20190125-a7e31b65.pth) |
| X-101-32x4d-FPN | Faster | pytorch | dconv(c3-c5) | - | 1x | 7.1 | 1.126 | 6.6 | 43.5 | - | [model](https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/dcn/faster_rcnn_dconv_c3-c5_x101_32x4d_fpn_1x_20190201-6d46376f.pth) |
| R-50-FPN | Mask | pytorch | dconv(c3-c5) | - | 1x | 4.5 | 0.712 | 7.7 | 41.1 | 37.2 | [model](https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/dcn/mask_rcnn_dconv_c3-c5_r50_fpn_1x_20190125-4f94ff79.pth) | | R-50-FPN | Mask | pytorch | dconv(c3-c5) | - | 1x | 4.5 | 0.712 | 7.7 | 41.1 | 37.2 | [model](https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/dcn/mask_rcnn_dconv_c3-c5_r50_fpn_1x_20190125-4f94ff79.pth) |
| R-50-FPN | Mask | pytorch | mdconv(c3-c5) | - | 1x | 4.5 | 0.712 | 7.7 | 41.4 | 37.4 | [model](https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/dcn/mask_rcnn_mdconv_c3-c5_r50_fpn_1x_20190125-c5601dc3.pth) | | R-50-FPN | Mask | pytorch | mdconv(c3-c5) | - | 1x | 4.5 | 0.712 | 7.7 | 41.4 | 37.4 | [model](https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/dcn/mask_rcnn_mdconv_c3-c5_r50_fpn_1x_20190125-c5601dc3.pth) |
| R-101-FPN | Mask | pytorch | dconv(c3-c5) | - | 1x | 6.4 | 0.939 | 6.5 | 43.2 | 38.7 | [model](https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/dcn/mask_rcnn_dconv_c3-c5_r101_fpn_1x_20190125-decb6db5.pth) | | R-101-FPN | Mask | pytorch | dconv(c3-c5) | - | 1x | 6.4 | 0.939 | 6.5 | 43.2 | 38.7 | [model](https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/dcn/mask_rcnn_dconv_c3-c5_r101_fpn_1x_20190125-decb6db5.pth) |
......
# model settings
model = dict(
type='FasterRCNN',
pretrained='open-mmlab://resnext101_32x4d',
backbone=dict(
type='ResNeXt',
depth=101,
groups=32,
base_width=4,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
style='pytorch',
dcn=dict(
modulated=False,
groups=32,
deformable_groups=1,
fallback_on_stride=False),
stage_with_dcn=(False, True, True, True)),
neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
num_outs=5),
rpn_head=dict(
type='RPNHead',
in_channels=256,
feat_channels=256,
anchor_scales=[8],
anchor_ratios=[0.5, 1.0, 2.0],
anchor_strides=[4, 8, 16, 32, 64],
target_means=[.0, .0, .0, .0],
target_stds=[1.0, 1.0, 1.0, 1.0],
use_sigmoid_cls=True),
bbox_roi_extractor=dict(
type='SingleRoIExtractor',
roi_layer=dict(type='RoIAlign', out_size=7, sample_num=2),
out_channels=256,
featmap_strides=[4, 8, 16, 32]),
bbox_head=dict(
type='SharedFCBBoxHead',
num_fcs=2,
in_channels=256,
fc_out_channels=1024,
roi_feat_size=7,
num_classes=81,
target_means=[0., 0., 0., 0.],
target_stds=[0.1, 0.1, 0.2, 0.2],
reg_class_agnostic=False))
# model training and testing settings
train_cfg = dict(
rpn=dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.7,
neg_iou_thr=0.3,
min_pos_iou=0.3,
ignore_iof_thr=-1),
sampler=dict(
type='RandomSampler',
num=256,
pos_fraction=0.5,
neg_pos_ub=-1,
add_gt_as_proposals=False),
allowed_border=0,
pos_weight=-1,
smoothl1_beta=1 / 9.0,
debug=False),
rcnn=dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.5,
neg_iou_thr=0.5,
min_pos_iou=0.5,
ignore_iof_thr=-1),
sampler=dict(
type='RandomSampler',
num=512,
pos_fraction=0.25,
neg_pos_ub=-1,
add_gt_as_proposals=True),
pos_weight=-1,
debug=False))
test_cfg = dict(
rpn=dict(
nms_across_levels=False,
nms_pre=2000,
nms_post=2000,
max_num=2000,
nms_thr=0.7,
min_bbox_size=0),
rcnn=dict(
score_thr=0.05, nms=dict(type='nms', iou_thr=0.5), max_per_img=100)
# soft-nms is also supported for rcnn testing
# e.g., nms=dict(type='soft_nms', iou_thr=0.5, min_score=0.05)
)
# dataset settings
dataset_type = 'CocoDataset'
data_root = 'data/coco/'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
data = dict(
imgs_per_gpu=2,
workers_per_gpu=2,
train=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_train2017.json',
img_prefix=data_root + 'train2017/',
img_scale=(1333, 800),
img_norm_cfg=img_norm_cfg,
size_divisor=32,
flip_ratio=0.5,
with_mask=False,
with_crowd=True,
with_label=True),
val=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
img_scale=(1333, 800),
img_norm_cfg=img_norm_cfg,
size_divisor=32,
flip_ratio=0,
with_mask=False,
with_crowd=True,
with_label=True),
test=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
img_scale=(1333, 800),
img_norm_cfg=img_norm_cfg,
size_divisor=32,
flip_ratio=0,
with_mask=False,
with_label=False,
test_mode=True))
# optimizer
optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)
optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
# learning policy
lr_config = dict(
policy='step',
warmup='linear',
warmup_iters=500,
warmup_ratio=1.0 / 3,
step=[8, 11])
checkpoint_config = dict(interval=1)
# yapf:disable
log_config = dict(
interval=50,
hooks=[
dict(type='TextLoggerHook'),
# dict(type='TensorboardLoggerHook')
])
# yapf:enable
# runtime settings
total_epochs = 12
dist_params = dict(backend='nccl')
log_level = 'INFO'
work_dir = './work_dirs/faster_rcnn_dconv_c3-c5_x101_32x4d_fpn_1x'
load_from = None
resume_from = None
workflow = [('train', 1)]
...@@ -2,8 +2,9 @@ import math ...@@ -2,8 +2,9 @@ import math
import torch.nn as nn import torch.nn as nn
from .resnet import ResNet from mmdet.ops import DeformConv, ModulatedDeformConv
from .resnet import Bottleneck as _Bottleneck from .resnet import Bottleneck as _Bottleneck
from .resnet import ResNet
from ..registry import BACKBONES from ..registry import BACKBONES
from ..utils import build_norm_layer from ..utils import build_norm_layer
...@@ -22,15 +23,12 @@ class Bottleneck(_Bottleneck): ...@@ -22,15 +23,12 @@ class Bottleneck(_Bottleneck):
else: else:
width = math.floor(self.planes * (base_width / 64)) * groups width = math.floor(self.planes * (base_width / 64)) * groups
self.norm1_name, norm1 = build_norm_layer(self.normalize, self.norm1_name, norm1 = build_norm_layer(
width, self.normalize, width, postfix=1)
postfix=1) self.norm2_name, norm2 = build_norm_layer(
self.norm2_name, norm2 = build_norm_layer(self.normalize, self.normalize, width, postfix=2)
width, self.norm3_name, norm3 = build_norm_layer(
postfix=2) self.normalize, self.planes * self.expansion, postfix=3)
self.norm3_name, norm3 = build_norm_layer(self.normalize,
self.planes * self.expansion,
postfix=3)
self.conv1 = nn.Conv2d( self.conv1 = nn.Conv2d(
self.inplanes, self.inplanes,
...@@ -39,6 +37,12 @@ class Bottleneck(_Bottleneck): ...@@ -39,6 +37,12 @@ class Bottleneck(_Bottleneck):
stride=self.conv1_stride, stride=self.conv1_stride,
bias=False) bias=False)
self.add_module(self.norm1_name, norm1) self.add_module(self.norm1_name, norm1)
fallback_on_stride = False
self.with_modulated_dcn = False
if self.with_dcn:
fallback_on_stride = self.dcn.get('fallback_on_stride', False)
self.with_modulated_dcn = self.dcn.get('modulated', False)
if not self.with_dcn or fallback_on_stride:
self.conv2 = nn.Conv2d( self.conv2 = nn.Conv2d(
width, width,
width, width,
...@@ -48,6 +52,32 @@ class Bottleneck(_Bottleneck): ...@@ -48,6 +52,32 @@ class Bottleneck(_Bottleneck):
dilation=self.dilation, dilation=self.dilation,
groups=groups, groups=groups,
bias=False) bias=False)
else:
groups = self.dcn.get('groups', 1)
deformable_groups = self.dcn.get('deformable_groups', 1)
if not self.with_modulated_dcn:
conv_op = DeformConv
offset_channels = 18
else:
conv_op = ModulatedDeformConv
offset_channels = 27
self.conv2_offset = nn.Conv2d(
width,
deformable_groups * offset_channels,
kernel_size=3,
stride=self.conv2_stride,
padding=self.dilation,
dilation=self.dilation)
self.conv2 = conv_op(
width,
width,
kernel_size=3,
stride=self.conv2_stride,
padding=self.dilation,
dilation=self.dilation,
groups=groups,
deformable_groups=deformable_groups,
bias=False)
self.add_module(self.norm2_name, norm2) self.add_module(self.norm2_name, norm2)
self.conv3 = nn.Conv2d( self.conv3 = nn.Conv2d(
width, self.planes * self.expansion, kernel_size=1, bias=False) width, self.planes * self.expansion, kernel_size=1, bias=False)
...@@ -64,7 +94,8 @@ def make_res_layer(block, ...@@ -64,7 +94,8 @@ def make_res_layer(block,
base_width=4, base_width=4,
style='pytorch', style='pytorch',
with_cp=False, with_cp=False,
normalize=dict(type='BN')): normalize=dict(type='BN'),
dcn=None):
downsample = None downsample = None
if stride != 1 or inplanes != planes * block.expansion: if stride != 1 or inplanes != planes * block.expansion:
downsample = nn.Sequential( downsample = nn.Sequential(
...@@ -89,7 +120,8 @@ def make_res_layer(block, ...@@ -89,7 +120,8 @@ def make_res_layer(block,
base_width=base_width, base_width=base_width,
style=style, style=style,
with_cp=with_cp, with_cp=with_cp,
normalize=normalize)) normalize=normalize,
dcn=dcn))
inplanes = planes * block.expansion inplanes = planes * block.expansion
for i in range(1, blocks): for i in range(1, blocks):
layers.append( layers.append(
...@@ -102,7 +134,8 @@ def make_res_layer(block, ...@@ -102,7 +134,8 @@ def make_res_layer(block,
base_width=base_width, base_width=base_width,
style=style, style=style,
with_cp=with_cp, with_cp=with_cp,
normalize=normalize)) normalize=normalize,
dcn=dcn))
return nn.Sequential(*layers) return nn.Sequential(*layers)
...@@ -150,6 +183,7 @@ class ResNeXt(ResNet): ...@@ -150,6 +183,7 @@ class ResNeXt(ResNet):
for i, num_blocks in enumerate(self.stage_blocks): for i, num_blocks in enumerate(self.stage_blocks):
stride = self.strides[i] stride = self.strides[i]
dilation = self.dilations[i] dilation = self.dilations[i]
dcn = self.dcn if self.stage_with_dcn[i] else None
planes = 64 * 2**i planes = 64 * 2**i
res_layer = make_res_layer( res_layer = make_res_layer(
self.block, self.block,
...@@ -162,7 +196,8 @@ class ResNeXt(ResNet): ...@@ -162,7 +196,8 @@ class ResNeXt(ResNet):
base_width=self.base_width, base_width=self.base_width,
style=self.style, style=self.style,
with_cp=self.with_cp, with_cp=self.with_cp,
normalize=self.normalize) normalize=self.normalize,
dcn=dcn)
self.inplanes = planes * self.block.expansion self.inplanes = planes * self.block.expansion
layer_name = 'layer{}'.format(i + 1) layer_name = 'layer{}'.format(i + 1)
self.add_module(layer_name, res_layer) self.add_module(layer_name, res_layer)
......
...@@ -15,6 +15,7 @@ class DeformConvFunction(Function): ...@@ -15,6 +15,7 @@ class DeformConvFunction(Function):
stride=1, stride=1,
padding=0, padding=0,
dilation=1, dilation=1,
groups=1,
deformable_groups=1, deformable_groups=1,
im2col_step=64): im2col_step=64):
if input is not None and input.dim() != 4: if input is not None and input.dim() != 4:
...@@ -24,6 +25,7 @@ class DeformConvFunction(Function): ...@@ -24,6 +25,7 @@ class DeformConvFunction(Function):
ctx.stride = _pair(stride) ctx.stride = _pair(stride)
ctx.padding = _pair(padding) ctx.padding = _pair(padding)
ctx.dilation = _pair(dilation) ctx.dilation = _pair(dilation)
ctx.groups = groups
ctx.deformable_groups = deformable_groups ctx.deformable_groups = deformable_groups
ctx.im2col_step = im2col_step ctx.im2col_step = im2col_step
...@@ -45,7 +47,8 @@ class DeformConvFunction(Function): ...@@ -45,7 +47,8 @@ class DeformConvFunction(Function):
input, weight, offset, output, ctx.bufs_[0], ctx.bufs_[1], input, weight, offset, output, ctx.bufs_[0], ctx.bufs_[1],
weight.size(3), weight.size(2), ctx.stride[1], ctx.stride[0], weight.size(3), weight.size(2), ctx.stride[1], ctx.stride[0],
ctx.padding[1], ctx.padding[0], ctx.dilation[1], ctx.padding[1], ctx.padding[0], ctx.dilation[1],
ctx.dilation[0], ctx.deformable_groups, cur_im2col_step) ctx.dilation[0], ctx.groups, ctx.deformable_groups,
cur_im2col_step)
return output return output
@staticmethod @staticmethod
...@@ -69,7 +72,8 @@ class DeformConvFunction(Function): ...@@ -69,7 +72,8 @@ class DeformConvFunction(Function):
grad_offset, weight, ctx.bufs_[0], weight.size(3), grad_offset, weight, ctx.bufs_[0], weight.size(3),
weight.size(2), ctx.stride[1], ctx.stride[0], weight.size(2), ctx.stride[1], ctx.stride[0],
ctx.padding[1], ctx.padding[0], ctx.dilation[1], ctx.padding[1], ctx.padding[0], ctx.dilation[1],
ctx.dilation[0], ctx.deformable_groups, cur_im2col_step) ctx.dilation[0], ctx.groups, ctx.deformable_groups,
cur_im2col_step)
if ctx.needs_input_grad[2]: if ctx.needs_input_grad[2]:
grad_weight = torch.zeros_like(weight) grad_weight = torch.zeros_like(weight)
...@@ -78,9 +82,11 @@ class DeformConvFunction(Function): ...@@ -78,9 +82,11 @@ class DeformConvFunction(Function):
grad_weight, ctx.bufs_[0], ctx.bufs_[1], weight.size(3), grad_weight, ctx.bufs_[0], ctx.bufs_[1], weight.size(3),
weight.size(2), ctx.stride[1], ctx.stride[0], weight.size(2), ctx.stride[1], ctx.stride[0],
ctx.padding[1], ctx.padding[0], ctx.dilation[1], ctx.padding[1], ctx.padding[0], ctx.dilation[1],
ctx.dilation[0], ctx.deformable_groups, 1, cur_im2col_step) ctx.dilation[0], ctx.groups, ctx.deformable_groups, 1,
cur_im2col_step)
return grad_input, grad_offset, grad_weight, None, None, None, None return (grad_input, grad_offset, grad_weight, None, None, None, None,
None)
@staticmethod @staticmethod
def _output_size(input, weight, padding, dilation, stride): def _output_size(input, weight, padding, dilation, stride):
...@@ -111,10 +117,12 @@ class ModulatedDeformConvFunction(Function): ...@@ -111,10 +117,12 @@ class ModulatedDeformConvFunction(Function):
stride=1, stride=1,
padding=0, padding=0,
dilation=1, dilation=1,
groups=1,
deformable_groups=1): deformable_groups=1):
ctx.stride = stride ctx.stride = stride
ctx.padding = padding ctx.padding = padding
ctx.dilation = dilation ctx.dilation = dilation
ctx.groups = groups
ctx.deformable_groups = deformable_groups ctx.deformable_groups = deformable_groups
ctx.with_bias = bias is not None ctx.with_bias = bias is not None
if not ctx.with_bias: if not ctx.with_bias:
...@@ -131,7 +139,7 @@ class ModulatedDeformConvFunction(Function): ...@@ -131,7 +139,7 @@ class ModulatedDeformConvFunction(Function):
input, weight, bias, ctx._bufs[0], offset, mask, output, input, weight, bias, ctx._bufs[0], offset, mask, output,
ctx._bufs[1], weight.shape[2], weight.shape[3], ctx.stride, ctx._bufs[1], weight.shape[2], weight.shape[3], ctx.stride,
ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation, ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation,
ctx.deformable_groups, ctx.with_bias) ctx.groups, ctx.deformable_groups, ctx.with_bias)
return output return output
@staticmethod @staticmethod
...@@ -149,12 +157,12 @@ class ModulatedDeformConvFunction(Function): ...@@ -149,12 +157,12 @@ class ModulatedDeformConvFunction(Function):
grad_input, grad_weight, grad_bias, grad_offset, grad_mask, grad_input, grad_weight, grad_bias, grad_offset, grad_mask,
grad_output, weight.shape[2], weight.shape[3], ctx.stride, grad_output, weight.shape[2], weight.shape[3], ctx.stride,
ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation, ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation,
ctx.deformable_groups, ctx.with_bias) ctx.groups, ctx.deformable_groups, ctx.with_bias)
if not ctx.with_bias: if not ctx.with_bias:
grad_bias = None grad_bias = None
return (grad_input, grad_offset, grad_mask, grad_weight, grad_bias, return (grad_input, grad_offset, grad_mask, grad_weight, grad_bias,
None, None, None, None) None, None, None, None, None)
@staticmethod @staticmethod
def _infer_shape(ctx, input, weight): def _infer_shape(ctx, input, weight):
......
...@@ -16,20 +16,30 @@ class DeformConv(nn.Module): ...@@ -16,20 +16,30 @@ class DeformConv(nn.Module):
stride=1, stride=1,
padding=0, padding=0,
dilation=1, dilation=1,
groups=1,
deformable_groups=1, deformable_groups=1,
bias=False): bias=False):
assert not bias assert not bias
super(DeformConv, self).__init__() super(DeformConv, self).__init__()
assert in_channels % groups == 0, \
'in_channels {} cannot be divisible by groups {}'.format(
in_channels, groups)
assert out_channels % groups == 0, \
'out_channels {} cannot be divisible by groups {}'.format(
out_channels, groups)
self.in_channels = in_channels self.in_channels = in_channels
self.out_channels = out_channels self.out_channels = out_channels
self.kernel_size = _pair(kernel_size) self.kernel_size = _pair(kernel_size)
self.stride = _pair(stride) self.stride = _pair(stride)
self.padding = _pair(padding) self.padding = _pair(padding)
self.dilation = _pair(dilation) self.dilation = _pair(dilation)
self.groups = groups
self.deformable_groups = deformable_groups self.deformable_groups = deformable_groups
self.weight = nn.Parameter( self.weight = nn.Parameter(
torch.Tensor(out_channels, in_channels, *self.kernel_size)) torch.Tensor(out_channels, in_channels // self.groups,
*self.kernel_size))
self.reset_parameters() self.reset_parameters()
...@@ -42,7 +52,8 @@ class DeformConv(nn.Module): ...@@ -42,7 +52,8 @@ class DeformConv(nn.Module):
def forward(self, input, offset): def forward(self, input, offset):
return deform_conv(input, offset, self.weight, self.stride, return deform_conv(input, offset, self.weight, self.stride,
self.padding, self.dilation, self.deformable_groups) self.padding, self.dilation, self.groups,
self.deformable_groups)
class ModulatedDeformConv(nn.Module): class ModulatedDeformConv(nn.Module):
...@@ -54,6 +65,7 @@ class ModulatedDeformConv(nn.Module): ...@@ -54,6 +65,7 @@ class ModulatedDeformConv(nn.Module):
stride=1, stride=1,
padding=0, padding=0,
dilation=1, dilation=1,
groups=1,
deformable_groups=1, deformable_groups=1,
bias=True): bias=True):
super(ModulatedDeformConv, self).__init__() super(ModulatedDeformConv, self).__init__()
...@@ -63,11 +75,13 @@ class ModulatedDeformConv(nn.Module): ...@@ -63,11 +75,13 @@ class ModulatedDeformConv(nn.Module):
self.stride = stride self.stride = stride
self.padding = padding self.padding = padding
self.dilation = dilation self.dilation = dilation
self.groups = groups
self.deformable_groups = deformable_groups self.deformable_groups = deformable_groups
self.with_bias = bias self.with_bias = bias
self.weight = nn.Parameter( self.weight = nn.Parameter(
torch.Tensor(out_channels, in_channels, *self.kernel_size)) torch.Tensor(out_channels, in_channels // groups,
*self.kernel_size))
if bias: if bias:
self.bias = nn.Parameter(torch.Tensor(out_channels)) self.bias = nn.Parameter(torch.Tensor(out_channels))
else: else:
...@@ -84,9 +98,9 @@ class ModulatedDeformConv(nn.Module): ...@@ -84,9 +98,9 @@ class ModulatedDeformConv(nn.Module):
self.bias.data.zero_() self.bias.data.zero_()
def forward(self, input, offset, mask): def forward(self, input, offset, mask):
return modulated_deform_conv(input, offset, mask, self.weight, return modulated_deform_conv(
self.bias, self.stride, self.padding, input, offset, mask, self.weight, self.bias, self.stride,
self.dilation, self.deformable_groups) self.padding, self.dilation, self.groups, self.deformable_groups)
class ModulatedDeformConvPack(ModulatedDeformConv): class ModulatedDeformConvPack(ModulatedDeformConv):
...@@ -98,14 +112,15 @@ class ModulatedDeformConvPack(ModulatedDeformConv): ...@@ -98,14 +112,15 @@ class ModulatedDeformConvPack(ModulatedDeformConv):
stride=1, stride=1,
padding=0, padding=0,
dilation=1, dilation=1,
groups=1,
deformable_groups=1, deformable_groups=1,
bias=True): bias=True):
super(ModulatedDeformConvPack, super(ModulatedDeformConvPack, self).__init__(
self).__init__(in_channels, out_channels, kernel_size, stride, in_channels, out_channels, kernel_size, stride, padding, dilation,
padding, dilation, deformable_groups, bias) groups, deformable_groups, bias)
self.conv_offset_mask = nn.Conv2d( self.conv_offset_mask = nn.Conv2d(
self.in_channels, self.in_channels // self.groups,
self.deformable_groups * 3 * self.kernel_size[0] * self.deformable_groups * 3 * self.kernel_size[0] *
self.kernel_size[1], self.kernel_size[1],
kernel_size=self.kernel_size, kernel_size=self.kernel_size,
...@@ -123,6 +138,6 @@ class ModulatedDeformConvPack(ModulatedDeformConv): ...@@ -123,6 +138,6 @@ class ModulatedDeformConvPack(ModulatedDeformConv):
o1, o2, mask = torch.chunk(out, 3, dim=1) o1, o2, mask = torch.chunk(out, 3, dim=1)
offset = torch.cat((o1, o2), dim=1) offset = torch.cat((o1, o2), dim=1)
mask = torch.sigmoid(mask) mask = torch.sigmoid(mask)
return modulated_deform_conv(input, offset, mask, self.weight, return modulated_deform_conv(
self.bias, self.stride, self.padding, input, offset, mask, self.weight, self.bias, self.stride,
self.dilation, self.deformable_groups) self.padding, self.dilation, self.groups, self.deformable_groups)
...@@ -62,7 +62,7 @@ void modulated_deformable_col2im_coord_cuda(const at::Tensor data_col, const at: ...@@ -62,7 +62,7 @@ void modulated_deformable_col2im_coord_cuda(const at::Tensor data_col, const at:
void shape_check(at::Tensor input, at::Tensor offset, void shape_check(at::Tensor input, at::Tensor offset,
at::Tensor *gradOutput, at::Tensor weight, int kH, int kW, at::Tensor *gradOutput, at::Tensor weight, int kH, int kW,
int dH, int dW, int padH, int padW, int dilationH, int dH, int dW, int padH, int padW, int dilationH,
int dilationW, int deformable_group) int dilationW, int group, int deformable_group)
{ {
AT_CHECK(weight.ndimension() == 4, AT_CHECK(weight.ndimension() == 4,
...@@ -105,7 +105,7 @@ void shape_check(at::Tensor input, at::Tensor offset, ...@@ -105,7 +105,7 @@ void shape_check(at::Tensor input, at::Tensor offset,
AT_CHECK(ndim == 3 || ndim == 4, AT_CHECK(ndim == 3 || ndim == 4,
"3D or 4D input tensor expected but got: %s", ndim); "3D or 4D input tensor expected but got: %s", ndim);
long nInputPlane = weight.size(1); long nInputPlane = weight.size(1) * group;
long inputHeight = input.size(dimh); long inputHeight = input.size(dimh);
long inputWidth = input.size(dimw); long inputWidth = input.size(dimw);
long nOutputPlane = weight.size(0); long nOutputPlane = weight.size(0);
...@@ -154,7 +154,7 @@ int deform_conv_forward_cuda(at::Tensor input, at::Tensor weight, ...@@ -154,7 +154,7 @@ int deform_conv_forward_cuda(at::Tensor input, at::Tensor weight,
at::Tensor offset, at::Tensor output, at::Tensor offset, at::Tensor output,
at::Tensor columns, at::Tensor ones, int kW, at::Tensor columns, at::Tensor ones, int kW,
int kH, int dW, int dH, int padW, int padH, int kH, int dW, int dH, int padW, int padH,
int dilationW, int dilationH, int dilationW, int dilationH, int group,
int deformable_group, int im2col_step) int deformable_group, int im2col_step)
{ {
...@@ -164,7 +164,7 @@ int deform_conv_forward_cuda(at::Tensor input, at::Tensor weight, ...@@ -164,7 +164,7 @@ int deform_conv_forward_cuda(at::Tensor input, at::Tensor weight,
// todo: possibly change data indexing because of parallel_imgs // todo: possibly change data indexing because of parallel_imgs
shape_check(input, offset, NULL, weight, kH, kW, dH, dW, padH, padW, shape_check(input, offset, NULL, weight, kH, kW, dH, dW, padH, padW,
dilationH, dilationW, deformable_group); dilationH, dilationW, group, deformable_group);
input = input.contiguous(); input = input.contiguous();
offset = offset.contiguous(); offset = offset.contiguous();
...@@ -207,6 +207,8 @@ int deform_conv_forward_cuda(at::Tensor input, at::Tensor weight, ...@@ -207,6 +207,8 @@ int deform_conv_forward_cuda(at::Tensor input, at::Tensor weight,
at::Tensor output_buffer = at::zeros({batchSize / im2col_step, nOutputPlane, im2col_step * outputHeight, outputWidth}, output.type()); at::Tensor output_buffer = at::zeros({batchSize / im2col_step, nOutputPlane, im2col_step * outputHeight, outputWidth}, output.type());
output_buffer = output_buffer.view({output_buffer.size(0), group, output_buffer.size(1) / group, output_buffer.size(2), output_buffer.size(3)});
for (int elt = 0; elt < batchSize / im2col_step; elt++) for (int elt = 0; elt < batchSize / im2col_step; elt++)
{ {
deformable_im2col( deformable_im2col(
...@@ -214,10 +216,17 @@ int deform_conv_forward_cuda(at::Tensor input, at::Tensor weight, ...@@ -214,10 +216,17 @@ int deform_conv_forward_cuda(at::Tensor input, at::Tensor weight,
inputWidth, kH, kW, padH, padW, dH, dW, dilationH, dilationW, inputWidth, kH, kW, padH, padW, dH, dW, dilationH, dilationW,
im2col_step, deformable_group, columns); im2col_step, deformable_group, columns);
output_buffer[elt] = columns = columns.view({group, columns.size(0) / group, columns.size(1)});
output_buffer[elt].flatten(1).addmm_(weight.flatten(1), columns).view_as(output_buffer[elt]); weight = weight.view({group, weight.size(0) / group, weight.size(1), weight.size(2), weight.size(3)});
for (int g = 0; g < group; g++){
output_buffer[elt][g] =
output_buffer[elt][g].flatten(1).addmm_(weight[g].flatten(1), columns[g]).view_as(output_buffer[elt][g]);
}
} }
output_buffer = output_buffer.view({output_buffer.size(0), output_buffer.size(1) * output_buffer.size(2), output_buffer.size(3), output_buffer.size(4)});
output_buffer = output_buffer.view( output_buffer = output_buffer.view(
{batchSize / im2col_step, nOutputPlane, im2col_step, outputHeight, outputWidth}); {batchSize / im2col_step, nOutputPlane, im2col_step, outputHeight, outputWidth});
output_buffer.transpose_(1, 2); output_buffer.transpose_(1, 2);
...@@ -241,11 +250,11 @@ int deform_conv_backward_input_cuda( ...@@ -241,11 +250,11 @@ int deform_conv_backward_input_cuda(
at::Tensor input, at::Tensor offset, at::Tensor gradOutput, at::Tensor input, at::Tensor offset, at::Tensor gradOutput,
at::Tensor gradInput, at::Tensor gradOffset, at::Tensor weight, at::Tensor gradInput, at::Tensor gradOffset, at::Tensor weight,
at::Tensor columns, int kW, int kH, int dW, int dH, int padW, int padH, at::Tensor columns, int kW, int kH, int dW, int dH, int padW, int padH,
int dilationW, int dilationH, int deformable_group, int im2col_step) int dilationW, int dilationH, int group, int deformable_group, int im2col_step)
{ {
shape_check(input, offset, &gradOutput, weight, kH, kW, dH, dW, padH, shape_check(input, offset, &gradOutput, weight, kH, kW, dH, dW, padH,
padW, dilationH, dilationW, deformable_group); padW, dilationH, dilationW, group, deformable_group);
input = input.contiguous(); input = input.contiguous();
offset = offset.contiguous(); offset = offset.contiguous();
...@@ -292,7 +301,17 @@ int deform_conv_backward_input_cuda( ...@@ -292,7 +301,17 @@ int deform_conv_backward_input_cuda(
for (int elt = 0; elt < batchSize / im2col_step; elt++) for (int elt = 0; elt < batchSize / im2col_step; elt++)
{ {
columns = columns.addmm_(weight.flatten(1).transpose(0, 1), gradOutput[elt].flatten(1), 0.0f, 1.0f); // divide into groups
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
weight = weight.view({group, weight.size(0) / group, weight.size(1), weight.size(2), weight.size(3)});
gradOutput = gradOutput.view({gradOutput.size(0), group, gradOutput.size(1) / group, gradOutput.size(2), gradOutput.size(3), gradOutput.size(4)});
for (int g = 0; g < group; g++){
columns[g] = columns[g].addmm_(weight[g].flatten(1).transpose(0, 1), gradOutput[elt][g].flatten(1), 0.0f, 1.0f);
}
columns = columns.view({columns.size(0) * columns.size(1), columns.size(2)});
gradOutput = gradOutput.view({gradOutput.size(0), gradOutput.size(1) * gradOutput.size(2), gradOutput.size(3), gradOutput.size(4), gradOutput.size(5)});
deformable_col2im_coord( deformable_col2im_coord(
columns, input[elt], offset[elt], columns, input[elt], offset[elt],
...@@ -329,7 +348,7 @@ int deform_conv_backward_parameters_cuda( ...@@ -329,7 +348,7 @@ int deform_conv_backward_parameters_cuda(
at::Tensor input, at::Tensor offset, at::Tensor gradOutput, at::Tensor input, at::Tensor offset, at::Tensor gradOutput,
at::Tensor gradWeight, // at::Tensor gradBias, at::Tensor gradWeight, // at::Tensor gradBias,
at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH, at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH,
int padW, int padH, int dilationW, int dilationH, int deformable_group, int padW, int padH, int dilationW, int dilationH, int group, int deformable_group,
float scale, int im2col_step) float scale, int im2col_step)
{ {
...@@ -338,7 +357,7 @@ int deform_conv_backward_parameters_cuda( ...@@ -338,7 +357,7 @@ int deform_conv_backward_parameters_cuda(
// todo: add im2col_step as input // todo: add im2col_step as input
shape_check(input, offset, &gradOutput, gradWeight, kH, kW, dH, dW, shape_check(input, offset, &gradOutput, gradWeight, kH, kW, dH, dW,
padH, padW, dilationH, dilationW, deformable_group); padH, padW, dilationH, dilationW, group, deformable_group);
input = input.contiguous(); input = input.contiguous();
offset = offset.contiguous(); offset = offset.contiguous();
...@@ -395,9 +414,19 @@ int deform_conv_backward_parameters_cuda( ...@@ -395,9 +414,19 @@ int deform_conv_backward_parameters_cuda(
inputWidth, kH, kW, padH, padW, dH, dW, dilationH, dilationW, inputWidth, kH, kW, padH, padW, dH, dW, dilationH, dilationW,
im2col_step, deformable_group, columns); im2col_step, deformable_group, columns);
gradWeight = gradWeight.flatten(1).addmm_( // divide into group
gradOutputBuffer[elt].flatten(1), columns.transpose(1, 0), 1.0, scale) gradOutputBuffer = gradOutputBuffer.view({gradOutputBuffer.size(0), group, gradOutputBuffer.size(1) / group, gradOutputBuffer.size(2), gradOutputBuffer.size(3)});
.view_as(gradWeight); columns = columns.view({group, columns.size(0) / group, columns.size(1)});
gradWeight = gradWeight.view({group, gradWeight.size(0) / group, gradWeight.size(1), gradWeight.size(2), gradWeight.size(3)});
for (int g = 0; g < group; g++){
gradWeight[g] = gradWeight[g].flatten(1).addmm_(
gradOutputBuffer[elt][g].flatten(1), columns[g].transpose(1, 0), 1.0, scale)
.view_as(gradWeight[g]);
}
gradOutputBuffer = gradOutputBuffer.view({gradOutputBuffer.size(0), gradOutputBuffer.size(1) * gradOutputBuffer.size(2), gradOutputBuffer.size(3), gradOutputBuffer.size(4)});
columns = columns.view({columns.size(0) * columns.size(1), columns.size(2)});
gradWeight = gradWeight.view({gradWeight.size(0) * gradWeight.size(1), gradWeight.size(2), gradWeight.size(3), gradWeight.size(4)});
} }
input = input.view({batchSize, nInputPlane, inputHeight, inputWidth}); input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
...@@ -413,6 +442,7 @@ int deform_conv_backward_parameters_cuda( ...@@ -413,6 +442,7 @@ int deform_conv_backward_parameters_cuda(
return 1; return 1;
} }
void modulated_deform_conv_cuda_forward(at::Tensor input, at::Tensor weight, void modulated_deform_conv_cuda_forward(at::Tensor input, at::Tensor weight,
at::Tensor bias, at::Tensor ones, at::Tensor bias, at::Tensor ones,
at::Tensor offset, at::Tensor mask, at::Tensor offset, at::Tensor mask,
...@@ -420,7 +450,7 @@ void modulated_deform_conv_cuda_forward(at::Tensor input, at::Tensor weight, ...@@ -420,7 +450,7 @@ void modulated_deform_conv_cuda_forward(at::Tensor input, at::Tensor weight,
int kernel_h, int kernel_w, int kernel_h, int kernel_w,
const int stride_h, const int stride_w, const int stride_h, const int stride_w,
const int pad_h, const int pad_w, const int pad_h, const int pad_w,
const int dilation_h, const int dilation_w, const int dilation_h, const int dilation_w, const int group,
const int deformable_group, const bool with_bias) const int deformable_group, const bool with_bias)
{ {
AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous"); AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
...@@ -439,9 +469,9 @@ void modulated_deform_conv_cuda_forward(at::Tensor input, at::Tensor weight, ...@@ -439,9 +469,9 @@ void modulated_deform_conv_cuda_forward(at::Tensor input, at::Tensor weight,
if (kernel_h_ != kernel_h || kernel_w_ != kernel_w) if (kernel_h_ != kernel_h || kernel_w_ != kernel_w)
AT_ERROR("Input shape and kernel shape wont match: (%d x %d vs %d x %d).", AT_ERROR("Input shape and kernel shape wont match: (%d x %d vs %d x %d).",
kernel_h_, kernel_w, kernel_h_, kernel_w_); kernel_h_, kernel_w, kernel_h_, kernel_w_);
if (channels != channels_kernel) if (channels != channels_kernel * group)
AT_ERROR("Input shape and kernel channels wont match: (%d vs %d).", AT_ERROR("Input shape and kernel channels wont match: (%d vs %d).",
channels, channels_kernel); channels, channels_kernel * group);
const int height_out = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; const int height_out = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
const int width_out = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; const int width_out = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
...@@ -458,6 +488,8 @@ void modulated_deform_conv_cuda_forward(at::Tensor input, at::Tensor weight, ...@@ -458,6 +488,8 @@ void modulated_deform_conv_cuda_forward(at::Tensor input, at::Tensor weight,
// resize temporary columns // resize temporary columns
columns = at::zeros({channels * kernel_h * kernel_w, 1 * height_out * width_out}, input.type()); columns = at::zeros({channels * kernel_h * kernel_w, 1 * height_out * width_out}, input.type());
output = output.view({output.size(0), group, output.size(1) / group, output.size(2), output.size(3)});
for (int b = 0; b < batch; b++) for (int b = 0; b < batch; b++)
{ {
modulated_deformable_im2col_cuda(input[b], offset[b], mask[b], modulated_deformable_im2col_cuda(input[b], offset[b], mask[b],
...@@ -466,9 +498,20 @@ void modulated_deform_conv_cuda_forward(at::Tensor input, at::Tensor weight, ...@@ -466,9 +498,20 @@ void modulated_deform_conv_cuda_forward(at::Tensor input, at::Tensor weight,
pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
deformable_group, columns); deformable_group, columns);
output[b] = output[b].flatten(1).addmm_(weight.flatten(1), columns).view_as(output[b]); // divide into group
weight = weight.view({group, weight.size(0) / group, weight.size(1), weight.size(2), weight.size(3)});
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
for (int g = 0; g < group; g++){
output[b][g] = output[b][g].flatten(1).addmm_(weight[g].flatten(1), columns[g]).view_as(output[b][g]);
}
weight = weight.view({weight.size(0) * weight.size(1), weight.size(2), weight.size(3), weight.size(4)});
columns = columns.view({columns.size(0) * columns.size(1), columns.size(2)});
} }
output = output.view({output.size(0), output.size(1) * output.size(2), output.size(3), output.size(4)});
if (with_bias){ if (with_bias){
output += bias.view({1, bias.size(0), 1, 1}); output += bias.view({1, bias.size(0), 1, 1});
} }
...@@ -484,7 +527,7 @@ void modulated_deform_conv_cuda_backward(at::Tensor input, at::Tensor weight, ...@@ -484,7 +527,7 @@ void modulated_deform_conv_cuda_backward(at::Tensor input, at::Tensor weight,
int kernel_h, int kernel_w, int kernel_h, int kernel_w,
int stride_h, int stride_w, int stride_h, int stride_w,
int pad_h, int pad_w, int pad_h, int pad_w,
int dilation_h, int dilation_w, int dilation_h, int dilation_w, int group,
int deformable_group, const bool with_bias) int deformable_group, const bool with_bias)
{ {
AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous"); AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
...@@ -501,9 +544,9 @@ void modulated_deform_conv_cuda_backward(at::Tensor input, at::Tensor weight, ...@@ -501,9 +544,9 @@ void modulated_deform_conv_cuda_backward(at::Tensor input, at::Tensor weight,
if (kernel_h_ != kernel_h || kernel_w_ != kernel_w) if (kernel_h_ != kernel_h || kernel_w_ != kernel_w)
AT_ERROR("Input shape and kernel shape wont match: (%d x %d vs %d x %d).", AT_ERROR("Input shape and kernel shape wont match: (%d x %d vs %d x %d).",
kernel_h_, kernel_w, kernel_h_, kernel_w_); kernel_h_, kernel_w, kernel_h_, kernel_w_);
if (channels != channels_kernel) if (channels != channels_kernel * group)
AT_ERROR("Input shape and kernel channels wont match: (%d vs %d).", AT_ERROR("Input shape and kernel channels wont match: (%d vs %d).",
channels, channels_kernel); channels, channels_kernel * group);
const int height_out = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; const int height_out = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
const int width_out = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; const int width_out = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
...@@ -518,9 +561,20 @@ void modulated_deform_conv_cuda_backward(at::Tensor input, at::Tensor weight, ...@@ -518,9 +561,20 @@ void modulated_deform_conv_cuda_backward(at::Tensor input, at::Tensor weight,
grad_input = grad_input.view({batch, channels, height, width}); grad_input = grad_input.view({batch, channels, height, width});
columns = at::zeros({channels * kernel_h * kernel_w, height_out * width_out}, input.type()); columns = at::zeros({channels * kernel_h * kernel_w, height_out * width_out}, input.type());
grad_output = grad_output.view({grad_output.size(0), group, grad_output.size(1) / group, grad_output.size(2), grad_output.size(3)});
for (int b = 0; b < batch; b++) for (int b = 0; b < batch; b++)
{ {
columns.addmm_(weight.flatten(1).transpose(0, 1), grad_output[b].flatten(1), 0.0f, 1.0f); // divide int group
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
weight = weight.view({group, weight.size(0) / group, weight.size(1), weight.size(2), weight.size(3)});
for (int g = 0; g < group; g++){
columns[g].addmm_(weight[g].flatten(1).transpose(0, 1), grad_output[b][g].flatten(1), 0.0f, 1.0f);
}
columns = columns.view({columns.size(0) * columns.size(1), columns.size(2)});
weight = weight.view({weight.size(0) * weight.size(1), weight.size(2), weight.size(3), weight.size(4)});
// gradient w.r.t. input coordinate data // gradient w.r.t. input coordinate data
modulated_deformable_col2im_coord_cuda(columns, input[b], offset[b], mask[b], modulated_deformable_col2im_coord_cuda(columns, input[b], offset[b], mask[b],
...@@ -545,14 +599,27 @@ void modulated_deform_conv_cuda_backward(at::Tensor input, at::Tensor weight, ...@@ -545,14 +599,27 @@ void modulated_deform_conv_cuda_backward(at::Tensor input, at::Tensor weight,
dilation_h, dilation_w, deformable_group, dilation_h, dilation_w, deformable_group,
columns); columns);
grad_weight = grad_weight.flatten(1).addmm_(grad_output[b].flatten(1), columns.transpose(0, 1)).view_as(grad_weight); columns = columns.view({group, columns.size(0) / group, columns.size(1)});
grad_weight = grad_weight.view({group, grad_weight.size(0) / group, grad_weight.size(1), grad_weight.size(2), grad_weight.size(3)});
if (with_bias)
grad_bias = grad_bias.view({group, grad_bias.size(0) / group});
for (int g = 0; g < group; g++){
grad_weight[g] = grad_weight[g].flatten(1).addmm_(grad_output[b][g].flatten(1), columns[g].transpose(0, 1)).view_as(grad_weight[g]);
if (with_bias){ if (with_bias){
grad_bias = grad_bias.view({-1, 1}).addmm_(grad_output[b].flatten(1), ones.view({-1, 1})).view(-1); grad_bias[g] = grad_bias[g].view({-1, 1}).addmm_(grad_output[b][g].flatten(1), ones.view({-1, 1})).view(-1);
} }
} }
columns = columns.view({columns.size(0) * columns.size(1), columns.size(2)});
grad_weight = grad_weight.view({grad_weight.size(0) * grad_weight.size(1), grad_weight.size(2), grad_weight.size(3), grad_weight.size(4)});
if (with_bias)
grad_bias = grad_bias.view({grad_bias.size(0) * grad_bias.size(1)});
}
grad_output = grad_output.view({grad_output.size(0) * grad_output.size(1), grad_output.size(2), grad_output.size(3), grad_output.size(4)});
} }
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{ {
m.def("deform_conv_forward_cuda", &deform_conv_forward_cuda, "deform forward (CUDA)"); m.def("deform_conv_forward_cuda", &deform_conv_forward_cuda, "deform forward (CUDA)");
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment