deform_conv.py 12.2 KB
Newer Older
1
2
import math

yhcao6's avatar
yhcao6 committed
3
import torch
4
import torch.nn as nn
yhcao6's avatar
yhcao6 committed
5
from torch.autograd import Function
6
from torch.autograd.function import once_differentiable
yhcao6's avatar
yhcao6 committed
7
8
from torch.nn.modules.utils import _pair

9
from . import deform_conv_cuda
yhcao6's avatar
yhcao6 committed
10
11
12
13


class DeformConvFunction(Function):

yhcao6's avatar
yhcao6 committed
14
15
16
17
18
19
20
21
    @staticmethod
    def forward(ctx,
                input,
                offset,
                weight,
                stride=1,
                padding=0,
                dilation=1,
yhcao6's avatar
yhcao6 committed
22
                groups=1,
yhcao6's avatar
yhcao6 committed
23
24
25
26
27
28
29
30
31
                deformable_groups=1,
                im2col_step=64):
        if input is not None and input.dim() != 4:
            raise ValueError(
                "Expected 4D tensor as input, got {}D tensor instead.".format(
                    input.dim()))
        ctx.stride = _pair(stride)
        ctx.padding = _pair(padding)
        ctx.dilation = _pair(dilation)
yhcao6's avatar
yhcao6 committed
32
        ctx.groups = groups
yhcao6's avatar
yhcao6 committed
33
34
        ctx.deformable_groups = deformable_groups
        ctx.im2col_step = im2col_step
yhcao6's avatar
yhcao6 committed
35

yhcao6's avatar
yhcao6 committed
36
        ctx.save_for_backward(input, offset, weight)
yhcao6's avatar
yhcao6 committed
37

38
39
40
        output = input.new_empty(
            DeformConvFunction._output_size(input, weight, ctx.padding,
                                            ctx.dilation, ctx.stride))
yhcao6's avatar
yhcao6 committed
41

42
        ctx.bufs_ = [input.new_empty(0), input.new_empty(0)]  # columns, ones
yhcao6's avatar
yhcao6 committed
43
44
45
46

        if not input.is_cuda:
            raise NotImplementedError
        else:
yhcao6's avatar
yhcao6 committed
47
            cur_im2col_step = min(ctx.im2col_step, input.shape[0])
yhcao6's avatar
yhcao6 committed
48
49
            assert (input.shape[0] %
                    cur_im2col_step) == 0, 'im2col step must divide batchsize'
yhcao6's avatar
yhcao6 committed
50
51
52
53
            deform_conv_cuda.deform_conv_forward_cuda(
                input, weight, offset, output, ctx.bufs_[0], ctx.bufs_[1],
                weight.size(3), weight.size(2), ctx.stride[1], ctx.stride[0],
                ctx.padding[1], ctx.padding[0], ctx.dilation[1],
yhcao6's avatar
yhcao6 committed
54
55
                ctx.dilation[0], ctx.groups, ctx.deformable_groups,
                cur_im2col_step)
yhcao6's avatar
yhcao6 committed
56
57
        return output

yhcao6's avatar
yhcao6 committed
58
    @staticmethod
59
    @once_differentiable
yhcao6's avatar
yhcao6 committed
60
61
    def backward(ctx, grad_output):
        input, offset, weight = ctx.saved_tensors
yhcao6's avatar
yhcao6 committed
62
63
64
65
66
67

        grad_input = grad_offset = grad_weight = None

        if not grad_output.is_cuda:
            raise NotImplementedError
        else:
yhcao6's avatar
yhcao6 committed
68
            cur_im2col_step = min(ctx.im2col_step, input.shape[0])
yhcao6's avatar
yhcao6 committed
69
70
71
            assert (input.shape[0] %
                    cur_im2col_step) == 0, 'im2col step must divide batchsize'

yhcao6's avatar
yhcao6 committed
72
73
74
75
            if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
                grad_input = torch.zeros_like(input)
                grad_offset = torch.zeros_like(offset)
                deform_conv_cuda.deform_conv_backward_input_cuda(
yhcao6's avatar
yhcao6 committed
76
                    input, offset, grad_output, grad_input,
yhcao6's avatar
yhcao6 committed
77
78
79
                    grad_offset, weight, ctx.bufs_[0], weight.size(3),
                    weight.size(2), ctx.stride[1], ctx.stride[0],
                    ctx.padding[1], ctx.padding[0], ctx.dilation[1],
yhcao6's avatar
yhcao6 committed
80
81
                    ctx.dilation[0], ctx.groups, ctx.deformable_groups,
                    cur_im2col_step)
yhcao6's avatar
yhcao6 committed
82
83
84
85

            if ctx.needs_input_grad[2]:
                grad_weight = torch.zeros_like(weight)
                deform_conv_cuda.deform_conv_backward_parameters_cuda(
yhcao6's avatar
yhcao6 committed
86
                    input, offset, grad_output,
yhcao6's avatar
yhcao6 committed
87
88
89
                    grad_weight, ctx.bufs_[0], ctx.bufs_[1], weight.size(3),
                    weight.size(2), ctx.stride[1], ctx.stride[0],
                    ctx.padding[1], ctx.padding[0], ctx.dilation[1],
yhcao6's avatar
yhcao6 committed
90
91
                    ctx.dilation[0], ctx.groups, ctx.deformable_groups, 1,
                    cur_im2col_step)
yhcao6's avatar
yhcao6 committed
92

yhcao6's avatar
yhcao6 committed
93
94
        return (grad_input, grad_offset, grad_weight, None, None, None, None,
                None)
yhcao6's avatar
yhcao6 committed
95

yhcao6's avatar
yhcao6 committed
96
97
    @staticmethod
    def _output_size(input, weight, padding, dilation, stride):
yhcao6's avatar
yhcao6 committed
98
99
100
101
        channels = weight.size(0)
        output_size = (input.size(0), channels)
        for d in range(input.dim() - 2):
            in_size = input.size(d + 2)
yhcao6's avatar
yhcao6 committed
102
103
104
105
            pad = padding[d]
            kernel = dilation[d] * (weight.size(d + 2) - 1) + 1
            stride_ = stride[d]
            output_size += ((in_size + (2 * pad) - kernel) // stride_ + 1, )
yhcao6's avatar
yhcao6 committed
106
107
108
109
110
        if not all(map(lambda s: s > 0, output_size)):
            raise ValueError(
                "convolution input is too small (output would be {})".format(
                    'x'.join(map(str, output_size))))
        return output_size
yhcao6's avatar
yhcao6 committed
111
112


113
114
115
116
117
118
119
120
class ModulatedDeformConvFunction(Function):

    @staticmethod
    def forward(ctx,
                input,
                offset,
                mask,
                weight,
121
122
123
                bias=None,
                stride=1,
                padding=0,
124
                dilation=1,
yhcao6's avatar
yhcao6 committed
125
                groups=1,
126
                deformable_groups=1):
127
128
129
        ctx.stride = stride
        ctx.padding = padding
        ctx.dilation = dilation
yhcao6's avatar
yhcao6 committed
130
        ctx.groups = groups
131
        ctx.deformable_groups = deformable_groups
132
133
134
        ctx.with_bias = bias is not None
        if not ctx.with_bias:
            bias = input.new_empty(1)  # fake tensor
135
136
137
138
139
        if not input.is_cuda:
            raise NotImplementedError
        if weight.requires_grad or mask.requires_grad or offset.requires_grad \
                or input.requires_grad:
            ctx.save_for_backward(input, offset, mask, weight, bias)
140
141
142
        output = input.new_empty(
            ModulatedDeformConvFunction._infer_shape(ctx, input, weight))
        ctx._bufs = [input.new_empty(0), input.new_empty(0)]
143
        deform_conv_cuda.modulated_deform_conv_cuda_forward(
144
145
146
            input, weight, bias, ctx._bufs[0], offset, mask, output,
            ctx._bufs[1], weight.shape[2], weight.shape[3], ctx.stride,
            ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation,
yhcao6's avatar
yhcao6 committed
147
            ctx.groups, ctx.deformable_groups, ctx.with_bias)
148
149
150
        return output

    @staticmethod
151
    @once_differentiable
152
153
154
155
156
157
158
159
160
    def backward(ctx, grad_output):
        if not grad_output.is_cuda:
            raise NotImplementedError
        input, offset, mask, weight, bias = ctx.saved_tensors
        grad_input = torch.zeros_like(input)
        grad_offset = torch.zeros_like(offset)
        grad_mask = torch.zeros_like(mask)
        grad_weight = torch.zeros_like(weight)
        grad_bias = torch.zeros_like(bias)
161
        deform_conv_cuda.modulated_deform_conv_cuda_backward(
162
163
164
165
            input, weight, bias, ctx._bufs[0], offset, mask, ctx._bufs[1],
            grad_input, grad_weight, grad_bias, grad_offset, grad_mask,
            grad_output, weight.shape[2], weight.shape[3], ctx.stride,
            ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation,
yhcao6's avatar
yhcao6 committed
166
            ctx.groups, ctx.deformable_groups, ctx.with_bias)
167
168
        if not ctx.with_bias:
            grad_bias = None
169
170

        return (grad_input, grad_offset, grad_mask, grad_weight, grad_bias,
yhcao6's avatar
yhcao6 committed
171
                None, None, None, None, None)
172
173
174
175
176
177
178
179
180
181
182
183
184
185

    @staticmethod
    def _infer_shape(ctx, input, weight):
        n = input.size(0)
        channels_out = weight.size(0)
        height, width = input.shape[2:4]
        kernel_h, kernel_w = weight.shape[2:4]
        height_out = (height + 2 * ctx.padding -
                      (ctx.dilation * (kernel_h - 1) + 1)) // ctx.stride + 1
        width_out = (width + 2 * ctx.padding -
                     (ctx.dilation * (kernel_w - 1) + 1)) // ctx.stride + 1
        return n, channels_out, height_out, width_out


yhcao6's avatar
yhcao6 committed
186
deform_conv = DeformConvFunction.apply
187
modulated_deform_conv = ModulatedDeformConvFunction.apply
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337


class DeformConv(nn.Module):

    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 stride=1,
                 padding=0,
                 dilation=1,
                 groups=1,
                 deformable_groups=1,
                 bias=False):
        super(DeformConv, self).__init__()

        assert not bias
        assert in_channels % groups == 0, \
            'in_channels {} cannot be divisible by groups {}'.format(
                in_channels, groups)
        assert out_channels % groups == 0, \
            'out_channels {} cannot be divisible by groups {}'.format(
                out_channels, groups)

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = _pair(kernel_size)
        self.stride = _pair(stride)
        self.padding = _pair(padding)
        self.dilation = _pair(dilation)
        self.groups = groups
        self.deformable_groups = deformable_groups

        self.weight = nn.Parameter(
            torch.Tensor(out_channels, in_channels // self.groups,
                         *self.kernel_size))

        self.reset_parameters()

    def reset_parameters(self):
        n = self.in_channels
        for k in self.kernel_size:
            n *= k
        stdv = 1. / math.sqrt(n)
        self.weight.data.uniform_(-stdv, stdv)

    def forward(self, x, offset):
        return deform_conv(x, offset, self.weight, self.stride, self.padding,
                           self.dilation, self.groups, self.deformable_groups)


class DeformConvPack(DeformConv):

    def __init__(self, *args, **kwargs):
        super(DeformConvPack, self).__init__(*args, **kwargs)

        self.conv_offset = nn.Conv2d(
            self.in_channels,
            self.deformable_groups * 2 * self.kernel_size[0] *
            self.kernel_size[1],
            kernel_size=self.kernel_size,
            stride=_pair(self.stride),
            padding=_pair(self.padding),
            bias=True)
        self.init_offset()

    def init_offset(self):
        self.conv_offset.weight.data.zero_()
        self.conv_offset.bias.data.zero_()

    def forward(self, x):
        offset = self.conv_offset(x)
        return deform_conv(x, offset, self.weight, self.stride, self.padding,
                           self.dilation, self.groups, self.deformable_groups)


class ModulatedDeformConv(nn.Module):

    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 stride=1,
                 padding=0,
                 dilation=1,
                 groups=1,
                 deformable_groups=1,
                 bias=True):
        super(ModulatedDeformConv, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = _pair(kernel_size)
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.groups = groups
        self.deformable_groups = deformable_groups
        self.with_bias = bias

        self.weight = nn.Parameter(
            torch.Tensor(out_channels, in_channels // groups,
                         *self.kernel_size))
        if bias:
            self.bias = nn.Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        n = self.in_channels
        for k in self.kernel_size:
            n *= k
        stdv = 1. / math.sqrt(n)
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.zero_()

    def forward(self, x, offset, mask):
        return modulated_deform_conv(x, offset, mask, self.weight, self.bias,
                                     self.stride, self.padding, self.dilation,
                                     self.groups, self.deformable_groups)


class ModulatedDeformConvPack(ModulatedDeformConv):

    def __init__(self, *args, **kwargs):
        super(ModulatedDeformConvPack, self).__init__(*args, **kwargs)

        self.conv_offset_mask = nn.Conv2d(
            self.in_channels,
            self.deformable_groups * 3 * self.kernel_size[0] *
            self.kernel_size[1],
            kernel_size=self.kernel_size,
            stride=_pair(self.stride),
            padding=_pair(self.padding),
            bias=True)
        self.init_offset()

    def init_offset(self):
        self.conv_offset_mask.weight.data.zero_()
        self.conv_offset_mask.bias.data.zero_()

    def forward(self, x):
        out = self.conv_offset_mask(x)
        o1, o2, mask = torch.chunk(out, 3, dim=1)
        offset = torch.cat((o1, o2), dim=1)
        mask = torch.sigmoid(mask)
        return modulated_deform_conv(x, offset, mask, self.weight, self.bias,
                                     self.stride, self.padding, self.dilation,
                                     self.groups, self.deformable_groups)