deform_conv.py 16.4 KB
Newer Older
1
2
import math

yhcao6's avatar
yhcao6 committed
3
import torch
4
import torch.nn as nn
yhcao6's avatar
yhcao6 committed
5
from torch.autograd import Function
6
from torch.autograd.function import once_differentiable
7
from torch.nn.modules.utils import _pair, _single
yhcao6's avatar
yhcao6 committed
8

Kai Chen's avatar
Kai Chen committed
9
from mmdet.utils import print_log
10
from . import deform_conv_cuda
yhcao6's avatar
yhcao6 committed
11
12
13
14


class DeformConvFunction(Function):

yhcao6's avatar
yhcao6 committed
15
16
17
18
19
20
21
22
    @staticmethod
    def forward(ctx,
                input,
                offset,
                weight,
                stride=1,
                padding=0,
                dilation=1,
yhcao6's avatar
yhcao6 committed
23
                groups=1,
yhcao6's avatar
yhcao6 committed
24
25
26
27
                deformable_groups=1,
                im2col_step=64):
        if input is not None and input.dim() != 4:
            raise ValueError(
28
                'Expected 4D tensor as input, got {}D tensor instead.'.format(
yhcao6's avatar
yhcao6 committed
29
30
31
32
                    input.dim()))
        ctx.stride = _pair(stride)
        ctx.padding = _pair(padding)
        ctx.dilation = _pair(dilation)
yhcao6's avatar
yhcao6 committed
33
        ctx.groups = groups
yhcao6's avatar
yhcao6 committed
34
35
        ctx.deformable_groups = deformable_groups
        ctx.im2col_step = im2col_step
yhcao6's avatar
yhcao6 committed
36

yhcao6's avatar
yhcao6 committed
37
        ctx.save_for_backward(input, offset, weight)
yhcao6's avatar
yhcao6 committed
38

39
40
41
        output = input.new_empty(
            DeformConvFunction._output_size(input, weight, ctx.padding,
                                            ctx.dilation, ctx.stride))
yhcao6's avatar
yhcao6 committed
42

43
        ctx.bufs_ = [input.new_empty(0), input.new_empty(0)]  # columns, ones
yhcao6's avatar
yhcao6 committed
44
45
46
47

        if not input.is_cuda:
            raise NotImplementedError
        else:
yhcao6's avatar
yhcao6 committed
48
            cur_im2col_step = min(ctx.im2col_step, input.shape[0])
yhcao6's avatar
yhcao6 committed
49
50
            assert (input.shape[0] %
                    cur_im2col_step) == 0, 'im2col step must divide batchsize'
yhcao6's avatar
yhcao6 committed
51
52
53
54
            deform_conv_cuda.deform_conv_forward_cuda(
                input, weight, offset, output, ctx.bufs_[0], ctx.bufs_[1],
                weight.size(3), weight.size(2), ctx.stride[1], ctx.stride[0],
                ctx.padding[1], ctx.padding[0], ctx.dilation[1],
yhcao6's avatar
yhcao6 committed
55
56
                ctx.dilation[0], ctx.groups, ctx.deformable_groups,
                cur_im2col_step)
yhcao6's avatar
yhcao6 committed
57
58
        return output

yhcao6's avatar
yhcao6 committed
59
    @staticmethod
60
    @once_differentiable
yhcao6's avatar
yhcao6 committed
61
62
    def backward(ctx, grad_output):
        input, offset, weight = ctx.saved_tensors
yhcao6's avatar
yhcao6 committed
63
64
65
66
67
68

        grad_input = grad_offset = grad_weight = None

        if not grad_output.is_cuda:
            raise NotImplementedError
        else:
yhcao6's avatar
yhcao6 committed
69
            cur_im2col_step = min(ctx.im2col_step, input.shape[0])
yhcao6's avatar
yhcao6 committed
70
71
72
            assert (input.shape[0] %
                    cur_im2col_step) == 0, 'im2col step must divide batchsize'

yhcao6's avatar
yhcao6 committed
73
74
75
76
            if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
                grad_input = torch.zeros_like(input)
                grad_offset = torch.zeros_like(offset)
                deform_conv_cuda.deform_conv_backward_input_cuda(
yhcao6's avatar
yhcao6 committed
77
                    input, offset, grad_output, grad_input,
yhcao6's avatar
yhcao6 committed
78
79
80
                    grad_offset, weight, ctx.bufs_[0], weight.size(3),
                    weight.size(2), ctx.stride[1], ctx.stride[0],
                    ctx.padding[1], ctx.padding[0], ctx.dilation[1],
yhcao6's avatar
yhcao6 committed
81
82
                    ctx.dilation[0], ctx.groups, ctx.deformable_groups,
                    cur_im2col_step)
yhcao6's avatar
yhcao6 committed
83
84
85
86

            if ctx.needs_input_grad[2]:
                grad_weight = torch.zeros_like(weight)
                deform_conv_cuda.deform_conv_backward_parameters_cuda(
yhcao6's avatar
yhcao6 committed
87
                    input, offset, grad_output,
yhcao6's avatar
yhcao6 committed
88
89
90
                    grad_weight, ctx.bufs_[0], ctx.bufs_[1], weight.size(3),
                    weight.size(2), ctx.stride[1], ctx.stride[0],
                    ctx.padding[1], ctx.padding[0], ctx.dilation[1],
yhcao6's avatar
yhcao6 committed
91
92
                    ctx.dilation[0], ctx.groups, ctx.deformable_groups, 1,
                    cur_im2col_step)
yhcao6's avatar
yhcao6 committed
93

yhcao6's avatar
yhcao6 committed
94
95
        return (grad_input, grad_offset, grad_weight, None, None, None, None,
                None)
yhcao6's avatar
yhcao6 committed
96

yhcao6's avatar
yhcao6 committed
97
98
    @staticmethod
    def _output_size(input, weight, padding, dilation, stride):
yhcao6's avatar
yhcao6 committed
99
100
101
102
        channels = weight.size(0)
        output_size = (input.size(0), channels)
        for d in range(input.dim() - 2):
            in_size = input.size(d + 2)
yhcao6's avatar
yhcao6 committed
103
104
105
106
            pad = padding[d]
            kernel = dilation[d] * (weight.size(d + 2) - 1) + 1
            stride_ = stride[d]
            output_size += ((in_size + (2 * pad) - kernel) // stride_ + 1, )
yhcao6's avatar
yhcao6 committed
107
108
        if not all(map(lambda s: s > 0, output_size)):
            raise ValueError(
109
                'convolution input is too small (output would be {})'.format(
yhcao6's avatar
yhcao6 committed
110
111
                    'x'.join(map(str, output_size))))
        return output_size
yhcao6's avatar
yhcao6 committed
112
113


114
115
116
117
118
119
120
121
class ModulatedDeformConvFunction(Function):

    @staticmethod
    def forward(ctx,
                input,
                offset,
                mask,
                weight,
122
123
124
                bias=None,
                stride=1,
                padding=0,
125
                dilation=1,
yhcao6's avatar
yhcao6 committed
126
                groups=1,
127
                deformable_groups=1):
128
129
130
        ctx.stride = stride
        ctx.padding = padding
        ctx.dilation = dilation
yhcao6's avatar
yhcao6 committed
131
        ctx.groups = groups
132
        ctx.deformable_groups = deformable_groups
133
134
135
        ctx.with_bias = bias is not None
        if not ctx.with_bias:
            bias = input.new_empty(1)  # fake tensor
136
137
138
139
140
        if not input.is_cuda:
            raise NotImplementedError
        if weight.requires_grad or mask.requires_grad or offset.requires_grad \
                or input.requires_grad:
            ctx.save_for_backward(input, offset, mask, weight, bias)
141
142
143
        output = input.new_empty(
            ModulatedDeformConvFunction._infer_shape(ctx, input, weight))
        ctx._bufs = [input.new_empty(0), input.new_empty(0)]
144
        deform_conv_cuda.modulated_deform_conv_cuda_forward(
145
146
147
            input, weight, bias, ctx._bufs[0], offset, mask, output,
            ctx._bufs[1], weight.shape[2], weight.shape[3], ctx.stride,
            ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation,
yhcao6's avatar
yhcao6 committed
148
            ctx.groups, ctx.deformable_groups, ctx.with_bias)
149
150
151
        return output

    @staticmethod
152
    @once_differentiable
153
154
155
156
157
158
159
160
161
    def backward(ctx, grad_output):
        if not grad_output.is_cuda:
            raise NotImplementedError
        input, offset, mask, weight, bias = ctx.saved_tensors
        grad_input = torch.zeros_like(input)
        grad_offset = torch.zeros_like(offset)
        grad_mask = torch.zeros_like(mask)
        grad_weight = torch.zeros_like(weight)
        grad_bias = torch.zeros_like(bias)
162
        deform_conv_cuda.modulated_deform_conv_cuda_backward(
163
164
165
166
            input, weight, bias, ctx._bufs[0], offset, mask, ctx._bufs[1],
            grad_input, grad_weight, grad_bias, grad_offset, grad_mask,
            grad_output, weight.shape[2], weight.shape[3], ctx.stride,
            ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation,
yhcao6's avatar
yhcao6 committed
167
            ctx.groups, ctx.deformable_groups, ctx.with_bias)
168
169
        if not ctx.with_bias:
            grad_bias = None
170
171

        return (grad_input, grad_offset, grad_mask, grad_weight, grad_bias,
yhcao6's avatar
yhcao6 committed
172
                None, None, None, None, None)
173
174
175
176
177
178
179
180
181
182
183
184
185
186

    @staticmethod
    def _infer_shape(ctx, input, weight):
        n = input.size(0)
        channels_out = weight.size(0)
        height, width = input.shape[2:4]
        kernel_h, kernel_w = weight.shape[2:4]
        height_out = (height + 2 * ctx.padding -
                      (ctx.dilation * (kernel_h - 1) + 1)) // ctx.stride + 1
        width_out = (width + 2 * ctx.padding -
                     (ctx.dilation * (kernel_w - 1) + 1)) // ctx.stride + 1
        return n, channels_out, height_out, width_out


yhcao6's avatar
yhcao6 committed
187
deform_conv = DeformConvFunction.apply
188
modulated_deform_conv = ModulatedDeformConvFunction.apply
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220


class DeformConv(nn.Module):

    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 stride=1,
                 padding=0,
                 dilation=1,
                 groups=1,
                 deformable_groups=1,
                 bias=False):
        super(DeformConv, self).__init__()

        assert not bias
        assert in_channels % groups == 0, \
            'in_channels {} cannot be divisible by groups {}'.format(
                in_channels, groups)
        assert out_channels % groups == 0, \
            'out_channels {} cannot be divisible by groups {}'.format(
                out_channels, groups)

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = _pair(kernel_size)
        self.stride = _pair(stride)
        self.padding = _pair(padding)
        self.dilation = _pair(dilation)
        self.groups = groups
        self.deformable_groups = deformable_groups
221
222
223
        # enable compatibility with nn.Conv2d
        self.transposed = False
        self.output_padding = _single(0)
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243

        self.weight = nn.Parameter(
            torch.Tensor(out_channels, in_channels // self.groups,
                         *self.kernel_size))

        self.reset_parameters()

    def reset_parameters(self):
        n = self.in_channels
        for k in self.kernel_size:
            n *= k
        stdv = 1. / math.sqrt(n)
        self.weight.data.uniform_(-stdv, stdv)

    def forward(self, x, offset):
        return deform_conv(x, offset, self.weight, self.stride, self.padding,
                           self.dilation, self.groups, self.deformable_groups)


class DeformConvPack(DeformConv):
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
    """A Deformable Conv Encapsulation that acts as normal Conv layers.

    Args:
        in_channels (int): Same as nn.Conv2d.
        out_channels (int): Same as nn.Conv2d.
        kernel_size (int or tuple[int]): Same as nn.Conv2d.
        stride (int or tuple[int]): Same as nn.Conv2d.
        padding (int or tuple[int]): Same as nn.Conv2d.
        dilation (int or tuple[int]): Same as nn.Conv2d.
        groups (int): Same as nn.Conv2d.
        bias (bool or str): If specified as `auto`, it will be decided by the
            norm_cfg. Bias will be set as True if norm_cfg is None, otherwise
            False.
    """

    _version = 2
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282

    def __init__(self, *args, **kwargs):
        super(DeformConvPack, self).__init__(*args, **kwargs)

        self.conv_offset = nn.Conv2d(
            self.in_channels,
            self.deformable_groups * 2 * self.kernel_size[0] *
            self.kernel_size[1],
            kernel_size=self.kernel_size,
            stride=_pair(self.stride),
            padding=_pair(self.padding),
            bias=True)
        self.init_offset()

    def init_offset(self):
        self.conv_offset.weight.data.zero_()
        self.conv_offset.bias.data.zero_()

    def forward(self, x):
        offset = self.conv_offset(x)
        return deform_conv(x, offset, self.weight, self.stride, self.padding,
                           self.dilation, self.groups, self.deformable_groups)

283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
    def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
                              missing_keys, unexpected_keys, error_msgs):
        version = local_metadata.get('version', None)

        if version is None or version < 2:
            # the key is different in early versions
            # In version < 2, DeformConvPack loads previous benchmark models.
            if (prefix + 'conv_offset.weight' not in state_dict
                    and prefix[:-1] + '_offset.weight' in state_dict):
                state_dict[prefix + 'conv_offset.weight'] = state_dict.pop(
                    prefix[:-1] + '_offset.weight')
            if (prefix + 'conv_offset.bias' not in state_dict
                    and prefix[:-1] + '_offset.bias' in state_dict):
                state_dict[prefix +
                           'conv_offset.bias'] = state_dict.pop(prefix[:-1] +
                                                                '_offset.bias')

        if version is not None and version > 1:
Kai Chen's avatar
Kai Chen committed
301
302
303
304
            print_log(
                'DeformConvPack {} is upgraded to version 2.'.format(
                    prefix.rstrip('.')),
                logger='root')
305
306
307
308
309

        super()._load_from_state_dict(state_dict, prefix, local_metadata,
                                      strict, missing_keys, unexpected_keys,
                                      error_msgs)

310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332

class ModulatedDeformConv(nn.Module):

    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 stride=1,
                 padding=0,
                 dilation=1,
                 groups=1,
                 deformable_groups=1,
                 bias=True):
        super(ModulatedDeformConv, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = _pair(kernel_size)
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.groups = groups
        self.deformable_groups = deformable_groups
        self.with_bias = bias
333
334
335
        # enable compatibility with nn.Conv2d
        self.transposed = False
        self.output_padding = _single(0)
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361

        self.weight = nn.Parameter(
            torch.Tensor(out_channels, in_channels // groups,
                         *self.kernel_size))
        if bias:
            self.bias = nn.Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        n = self.in_channels
        for k in self.kernel_size:
            n *= k
        stdv = 1. / math.sqrt(n)
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.zero_()

    def forward(self, x, offset, mask):
        return modulated_deform_conv(x, offset, mask, self.weight, self.bias,
                                     self.stride, self.padding, self.dilation,
                                     self.groups, self.deformable_groups)


class ModulatedDeformConvPack(ModulatedDeformConv):
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
    """A ModulatedDeformable Conv Encapsulation that acts as normal Conv layers.

    Args:
        in_channels (int): Same as nn.Conv2d.
        out_channels (int): Same as nn.Conv2d.
        kernel_size (int or tuple[int]): Same as nn.Conv2d.
        stride (int or tuple[int]): Same as nn.Conv2d.
        padding (int or tuple[int]): Same as nn.Conv2d.
        dilation (int or tuple[int]): Same as nn.Conv2d.
        groups (int): Same as nn.Conv2d.
        bias (bool or str): If specified as `auto`, it will be decided by the
            norm_cfg. Bias will be set as True if norm_cfg is None, otherwise
            False.
    """

    _version = 2
378
379
380
381

    def __init__(self, *args, **kwargs):
        super(ModulatedDeformConvPack, self).__init__(*args, **kwargs)

382
        self.conv_offset = nn.Conv2d(
383
384
385
386
387
388
389
390
391
392
            self.in_channels,
            self.deformable_groups * 3 * self.kernel_size[0] *
            self.kernel_size[1],
            kernel_size=self.kernel_size,
            stride=_pair(self.stride),
            padding=_pair(self.padding),
            bias=True)
        self.init_offset()

    def init_offset(self):
393
394
        self.conv_offset.weight.data.zero_()
        self.conv_offset.bias.data.zero_()
395
396

    def forward(self, x):
397
        out = self.conv_offset(x)
398
399
400
401
402
403
        o1, o2, mask = torch.chunk(out, 3, dim=1)
        offset = torch.cat((o1, o2), dim=1)
        mask = torch.sigmoid(mask)
        return modulated_deform_conv(x, offset, mask, self.weight, self.bias,
                                     self.stride, self.padding, self.dilation,
                                     self.groups, self.deformable_groups)
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423

    def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
                              missing_keys, unexpected_keys, error_msgs):
        version = local_metadata.get('version', None)

        if version is None or version < 2:
            # the key is different in early versions
            # In version < 2, ModulatedDeformConvPack
            # loads previous benchmark models.
            if (prefix + 'conv_offset.weight' not in state_dict
                    and prefix[:-1] + '_offset.weight' in state_dict):
                state_dict[prefix + 'conv_offset.weight'] = state_dict.pop(
                    prefix[:-1] + '_offset.weight')
            if (prefix + 'conv_offset.bias' not in state_dict
                    and prefix[:-1] + '_offset.bias' in state_dict):
                state_dict[prefix +
                           'conv_offset.bias'] = state_dict.pop(prefix[:-1] +
                                                                '_offset.bias')

        if version is not None and version > 1:
Kai Chen's avatar
Kai Chen committed
424
            print_log(
425
                'ModulatedDeformConvPack {} is upgraded to version 2.'.format(
Kai Chen's avatar
Kai Chen committed
426
427
                    prefix.rstrip('.')),
                logger='root')
428
429
430
431

        super()._load_from_state_dict(state_dict, prefix, local_metadata,
                                      strict, missing_keys, unexpected_keys,
                                      error_msgs)