deform_conv.py 4.13 KB
Newer Older
yhcao6's avatar
yhcao6 committed
1
2
3
4
5
6
import math

import torch
import torch.nn as nn
from torch.nn.modules.utils import _pair

7
from ..functions.deform_conv import deform_conv, modulated_deform_conv
yhcao6's avatar
yhcao6 committed
8
9


10
class DeformConv(nn.Module):
yhcao6's avatar
yhcao6 committed
11
12
13
14
15
16
17
18

    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 stride=1,
                 padding=0,
                 dilation=1,
19
20
21
                 deformable_groups=1,
                 bias=None):
        assert bias is None
yhcao6's avatar
yhcao6 committed
22
23
24
25
26
27
28
        super(DeformConv, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = _pair(kernel_size)
        self.stride = _pair(stride)
        self.padding = _pair(padding)
        self.dilation = _pair(dilation)
29
        self.deformable_groups = deformable_groups
yhcao6's avatar
yhcao6 committed
30
31
32
33
34
35
36
37
38
39
40

        self.weight = nn.Parameter(
            torch.Tensor(out_channels, in_channels, *self.kernel_size))

        self.reset_parameters()

    def reset_parameters(self):
        n = self.in_channels
        for k in self.kernel_size:
            n *= k
        stdv = 1. / math.sqrt(n)
yhcao6's avatar
yhcao6 committed
41
        self.weight.data.uniform_(-stdv, stdv)
yhcao6's avatar
yhcao6 committed
42
43

    def forward(self, input, offset):
yhcao6's avatar
yhcao6 committed
44
        return deform_conv(input, offset, self.weight, self.stride,
45
                           self.padding, self.dilation, self.deformable_groups)
46
47
48
49
50
51
52
53


class ModulatedDeformConv(nn.Module):

    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
54
55
                 stride=1,
                 padding=0,
56
57
                 dilation=1,
                 deformable_groups=1,
58
                 bias=True):
59
60
61
62
63
64
65
66
        super(ModulatedDeformConv, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = _pair(kernel_size)
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.deformable_groups = deformable_groups
67
        self.with_bias = bias
68
69
70

        self.weight = nn.Parameter(
            torch.Tensor(out_channels, in_channels, *self.kernel_size))
71
72
73
        if bias:
            self.bias = nn.Parameter(torch.Tensor(out_channels))
        else:
74
            self.register_parameter('bias', None)
75
76
77
78
79
80
81
82
        self.reset_parameters()

    def reset_parameters(self):
        n = self.in_channels
        for k in self.kernel_size:
            n *= k
        stdv = 1. / math.sqrt(n)
        self.weight.data.uniform_(-stdv, stdv)
83
        if self.bias is not None:
84
            self.bias.data.zero_()
85
86
87
88

    def forward(self, input, offset, mask):
        return modulated_deform_conv(input, offset, mask, self.weight,
                                     self.bias, self.stride, self.padding,
89
                                     self.dilation, self.deformable_groups)
90
91
92
93
94
95
96
97


class ModulatedDeformConvPack(ModulatedDeformConv):

    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
98
99
                 stride=1,
                 padding=0,
100
101
                 dilation=1,
                 deformable_groups=1,
102
                 bias=True):
103
104
        super(ModulatedDeformConvPack,
              self).__init__(in_channels, out_channels, kernel_size, stride,
105
                             padding, dilation, deformable_groups, bias)
106
107
108
109
110
111

        self.conv_offset_mask = nn.Conv2d(
            self.in_channels,
            self.deformable_groups * 3 * self.kernel_size[0] *
            self.kernel_size[1],
            kernel_size=self.kernel_size,
112
113
            stride=_pair(self.stride),
            padding=_pair(self.padding),
114
115
116
117
118
119
120
121
122
123
124
125
126
127
            bias=True)
        self.init_offset()

    def init_offset(self):
        self.conv_offset_mask.weight.data.zero_()
        self.conv_offset_mask.bias.data.zero_()

    def forward(self, input):
        out = self.conv_offset_mask(input)
        o1, o2, mask = torch.chunk(out, 3, dim=1)
        offset = torch.cat((o1, o2), dim=1)
        mask = torch.sigmoid(mask)
        return modulated_deform_conv(input, offset, mask, self.weight,
                                     self.bias, self.stride, self.padding,
128
                                     self.dilation, self.deformable_groups)