deform_pool.py 6.89 KB
Newer Older
yhcao6's avatar
yhcao6 committed
1
2
from torch import nn

3
from ..functions.deform_pool import deform_roi_pooling
yhcao6's avatar
yhcao6 committed
4
5
6
7
8
9


class DeformRoIPooling(nn.Module):

    def __init__(self,
                 spatial_scale,
10
                 out_size,
Kai Chen's avatar
Kai Chen committed
11
                 out_channels,
yhcao6's avatar
yhcao6 committed
12
13
14
15
16
17
18
                 no_trans,
                 group_size=1,
                 part_size=None,
                 sample_per_part=4,
                 trans_std=.0):
        super(DeformRoIPooling, self).__init__()
        self.spatial_scale = spatial_scale
19
        self.out_size = out_size
Kai Chen's avatar
Kai Chen committed
20
        self.out_channels = out_channels
yhcao6's avatar
yhcao6 committed
21
22
        self.no_trans = no_trans
        self.group_size = group_size
23
        self.part_size = out_size if part_size is None else part_size
yhcao6's avatar
yhcao6 committed
24
25
26
27
28
        self.sample_per_part = sample_per_part
        self.trans_std = trans_std

    def forward(self, data, rois, offset):
        if self.no_trans:
29
            offset = data.new_empty(0)
yhcao6's avatar
yhcao6 committed
30
        return deform_roi_pooling(
31
            data, rois, offset, self.spatial_scale, self.out_size,
Kai Chen's avatar
Kai Chen committed
32
            self.out_channels, self.no_trans, self.group_size, self.part_size,
yhcao6's avatar
yhcao6 committed
33
            self.sample_per_part, self.trans_std)
yhcao6's avatar
yhcao6 committed
34
35


36
class DeformRoIPoolingPack(DeformRoIPooling):
yhcao6's avatar
yhcao6 committed
37
38
39

    def __init__(self,
                 spatial_scale,
40
                 out_size,
Kai Chen's avatar
Kai Chen committed
41
                 out_channels,
yhcao6's avatar
yhcao6 committed
42
43
44
45
46
                 no_trans,
                 group_size=1,
                 part_size=None,
                 sample_per_part=4,
                 trans_std=.0,
47
                 num_offset_fcs=3,
Kai Chen's avatar
Kai Chen committed
48
                 deform_fc_channels=1024):
49
        super(DeformRoIPoolingPack,
Kai Chen's avatar
Kai Chen committed
50
              self).__init__(spatial_scale, out_size, out_channels, no_trans,
51
                             group_size, part_size, sample_per_part, trans_std)
yhcao6's avatar
yhcao6 committed
52

53
        self.num_offset_fcs = num_offset_fcs
Kai Chen's avatar
Kai Chen committed
54
        self.deform_fc_channels = deform_fc_channels
yhcao6's avatar
yhcao6 committed
55
56

        if not no_trans:
57
58
59
60
61
62
63
64
65
66
67
68
            seq = []
            ic = self.out_size * self.out_size * self.out_channels
            for i in range(self.num_offset_fcs):
                if i < self.num_offset_fcs - 1:
                    oc = self.deform_fc_channels
                else:
                    oc = self.out_size * self.out_size * 2
                seq.append(nn.Linear(ic, oc))
                ic = oc
                if i < self.num_offset_fcs - 1:
                    seq.append(nn.ReLU(inplace=True))
            self.offset_fc = nn.Sequential(*seq)
Kai Chen's avatar
Kai Chen committed
69
70
            self.offset_fc[-1].weight.data.zero_()
            self.offset_fc[-1].bias.data.zero_()
yhcao6's avatar
yhcao6 committed
71
72

    def forward(self, data, rois):
Kai Chen's avatar
Kai Chen committed
73
        assert data.size(1) == self.out_channels
yhcao6's avatar
yhcao6 committed
74
        if self.no_trans:
75
            offset = data.new_empty(0)
Kai Chen's avatar
Kai Chen committed
76
77
78
79
            return deform_roi_pooling(
                data, rois, offset, self.spatial_scale, self.out_size,
                self.out_channels, self.no_trans, self.group_size,
                self.part_size, self.sample_per_part, self.trans_std)
yhcao6's avatar
yhcao6 committed
80
81
        else:
            n = rois.shape[0]
82
            offset = data.new_empty(0)
yhcao6's avatar
yhcao6 committed
83
            x = deform_roi_pooling(data, rois, offset, self.spatial_scale,
Kai Chen's avatar
Kai Chen committed
84
                                   self.out_size, self.out_channels, True,
yhcao6's avatar
yhcao6 committed
85
86
                                   self.group_size, self.part_size,
                                   self.sample_per_part, self.trans_std)
yhcao6's avatar
yhcao6 committed
87
            offset = self.offset_fc(x.view(n, -1))
88
            offset = offset.view(n, 2, self.out_size, self.out_size)
Kai Chen's avatar
Kai Chen committed
89
            return deform_roi_pooling(
90
                data, rois, offset, self.spatial_scale, self.out_size,
Kai Chen's avatar
Kai Chen committed
91
                self.out_channels, self.no_trans, self.group_size,
92
                self.part_size, self.sample_per_part, self.trans_std)
93
94


95
class ModulatedDeformRoIPoolingPack(DeformRoIPooling):
96
97
98
99

    def __init__(self,
                 spatial_scale,
                 out_size,
Kai Chen's avatar
Kai Chen committed
100
                 out_channels,
101
102
103
104
105
                 no_trans,
                 group_size=1,
                 part_size=None,
                 sample_per_part=4,
                 trans_std=.0,
106
107
                 num_offset_fcs=3,
                 num_mask_fcs=2,
Kai Chen's avatar
Kai Chen committed
108
                 deform_fc_channels=1024):
109
        super(ModulatedDeformRoIPoolingPack, self).__init__(
Kai Chen's avatar
Kai Chen committed
110
            spatial_scale, out_size, out_channels, no_trans, group_size,
111
112
            part_size, sample_per_part, trans_std)

113
114
        self.num_offset_fcs = num_offset_fcs
        self.num_mask_fcs = num_mask_fcs
Kai Chen's avatar
Kai Chen committed
115
        self.deform_fc_channels = deform_fc_channels
116
117

        if not no_trans:
118
119
120
121
122
123
124
125
126
127
128
129
            offset_fc_seq = []
            ic = self.out_size * self.out_size * self.out_channels
            for i in range(self.num_offset_fcs):
                if i < self.num_offset_fcs - 1:
                    oc = self.deform_fc_channels
                else:
                    oc = self.out_size * self.out_size * 2
                offset_fc_seq.append(nn.Linear(ic, oc))
                ic = oc
                if i < self.num_offset_fcs - 1:
                    offset_fc_seq.append(nn.ReLU(inplace=True))
            self.offset_fc = nn.Sequential(*offset_fc_seq)
Kai Chen's avatar
Kai Chen committed
130
131
            self.offset_fc[-1].weight.data.zero_()
            self.offset_fc[-1].bias.data.zero_()
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148

            mask_fc_seq = []
            ic = self.out_size * self.out_size * self.out_channels
            for i in range(self.num_mask_fcs):
                if i < self.num_mask_fcs - 1:
                    oc = self.deform_fc_channels
                else:
                    oc = self.out_size * self.out_size
                mask_fc_seq.append(nn.Linear(ic, oc))
                ic = oc
                if i < self.num_mask_fcs - 1:
                    mask_fc_seq.append(nn.ReLU(inplace=True))
                else:
                    mask_fc_seq.append(nn.Sigmoid())
            self.mask_fc = nn.Sequential(*mask_fc_seq)
            self.mask_fc[-2].weight.data.zero_()
            self.mask_fc[-2].bias.data.zero_()
149
150

    def forward(self, data, rois):
Kai Chen's avatar
Kai Chen committed
151
        assert data.size(1) == self.out_channels
152
        if self.no_trans:
153
            offset = data.new_empty(0)
Kai Chen's avatar
Kai Chen committed
154
155
156
157
            return deform_roi_pooling(
                data, rois, offset, self.spatial_scale, self.out_size,
                self.out_channels, self.no_trans, self.group_size,
                self.part_size, self.sample_per_part, self.trans_std)
158
159
        else:
            n = rois.shape[0]
160
            offset = data.new_empty(0)
161
            x = deform_roi_pooling(data, rois, offset, self.spatial_scale,
Kai Chen's avatar
Kai Chen committed
162
                                   self.out_size, self.out_channels, True,
163
164
165
166
                                   self.group_size, self.part_size,
                                   self.sample_per_part, self.trans_std)
            offset = self.offset_fc(x.view(n, -1))
            offset = offset.view(n, 2, self.out_size, self.out_size)
167
168
            mask = self.mask_fc(x.view(n, -1))
            mask = mask.view(n, 1, self.out_size, self.out_size)
Kai Chen's avatar
Kai Chen committed
169
            return deform_roi_pooling(
170
                data, rois, offset, self.spatial_scale, self.out_size,
Kai Chen's avatar
Kai Chen committed
171
                self.out_channels, self.no_trans, self.group_size,
172
                self.part_size, self.sample_per_part, self.trans_std) * mask