deform_pool.py 9.65 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import torch
import torch.nn as nn
from torch.autograd import Function
from torch.autograd.function import once_differentiable

from . import deform_pool_cuda


class DeformRoIPoolingFunction(Function):

    @staticmethod
    def forward(ctx,
                data,
                rois,
                offset,
                spatial_scale,
                out_size,
                out_channels,
                no_trans,
                group_size=1,
                part_size=None,
                sample_per_part=4,
                trans_std=.0):
        ctx.spatial_scale = spatial_scale
        ctx.out_size = out_size
        ctx.out_channels = out_channels
        ctx.no_trans = no_trans
        ctx.group_size = group_size
        ctx.part_size = out_size if part_size is None else part_size
        ctx.sample_per_part = sample_per_part
        ctx.trans_std = trans_std

        assert 0.0 <= ctx.trans_std <= 1.0
        if not data.is_cuda:
            raise NotImplementedError

        n = rois.shape[0]
        output = data.new_empty(n, out_channels, out_size, out_size)
        output_count = data.new_empty(n, out_channels, out_size, out_size)
        deform_pool_cuda.deform_psroi_pooling_cuda_forward(
            data, rois, offset, output, output_count, ctx.no_trans,
            ctx.spatial_scale, ctx.out_channels, ctx.group_size, ctx.out_size,
            ctx.part_size, ctx.sample_per_part, ctx.trans_std)

        if data.requires_grad or rois.requires_grad or offset.requires_grad:
            ctx.save_for_backward(data, rois, offset)
        ctx.output_count = output_count

        return output

    @staticmethod
    @once_differentiable
    def backward(ctx, grad_output):
        if not grad_output.is_cuda:
            raise NotImplementedError

        data, rois, offset = ctx.saved_tensors
        output_count = ctx.output_count
        grad_input = torch.zeros_like(data)
        grad_rois = None
        grad_offset = torch.zeros_like(offset)

        deform_pool_cuda.deform_psroi_pooling_cuda_backward(
            grad_output, data, rois, offset, output_count, grad_input,
            grad_offset, ctx.no_trans, ctx.spatial_scale, ctx.out_channels,
            ctx.group_size, ctx.out_size, ctx.part_size, ctx.sample_per_part,
            ctx.trans_std)
        return (grad_input, grad_rois, grad_offset, None, None, None, None,
                None, None, None, None)


deform_roi_pooling = DeformRoIPoolingFunction.apply
yhcao6's avatar
yhcao6 committed
73
74
75
76
77
78


class DeformRoIPooling(nn.Module):

    def __init__(self,
                 spatial_scale,
79
                 out_size,
Kai Chen's avatar
Kai Chen committed
80
                 out_channels,
yhcao6's avatar
yhcao6 committed
81
82
83
84
85
86
87
                 no_trans,
                 group_size=1,
                 part_size=None,
                 sample_per_part=4,
                 trans_std=.0):
        super(DeformRoIPooling, self).__init__()
        self.spatial_scale = spatial_scale
88
        self.out_size = out_size
Kai Chen's avatar
Kai Chen committed
89
        self.out_channels = out_channels
yhcao6's avatar
yhcao6 committed
90
91
        self.no_trans = no_trans
        self.group_size = group_size
92
        self.part_size = out_size if part_size is None else part_size
yhcao6's avatar
yhcao6 committed
93
94
95
96
97
        self.sample_per_part = sample_per_part
        self.trans_std = trans_std

    def forward(self, data, rois, offset):
        if self.no_trans:
98
            offset = data.new_empty(0)
99
100
101
102
103
        return deform_roi_pooling(data, rois, offset, self.spatial_scale,
                                  self.out_size, self.out_channels,
                                  self.no_trans, self.group_size,
                                  self.part_size, self.sample_per_part,
                                  self.trans_std)
yhcao6's avatar
yhcao6 committed
104
105


106
class DeformRoIPoolingPack(DeformRoIPooling):
yhcao6's avatar
yhcao6 committed
107
108
109

    def __init__(self,
                 spatial_scale,
110
                 out_size,
Kai Chen's avatar
Kai Chen committed
111
                 out_channels,
yhcao6's avatar
yhcao6 committed
112
113
114
115
116
                 no_trans,
                 group_size=1,
                 part_size=None,
                 sample_per_part=4,
                 trans_std=.0,
117
                 num_offset_fcs=3,
Kai Chen's avatar
Kai Chen committed
118
                 deform_fc_channels=1024):
119
        super(DeformRoIPoolingPack,
Kai Chen's avatar
Kai Chen committed
120
              self).__init__(spatial_scale, out_size, out_channels, no_trans,
121
                             group_size, part_size, sample_per_part, trans_std)
yhcao6's avatar
yhcao6 committed
122

123
        self.num_offset_fcs = num_offset_fcs
Kai Chen's avatar
Kai Chen committed
124
        self.deform_fc_channels = deform_fc_channels
yhcao6's avatar
yhcao6 committed
125
126

        if not no_trans:
127
128
129
130
131
132
133
134
135
136
137
138
            seq = []
            ic = self.out_size * self.out_size * self.out_channels
            for i in range(self.num_offset_fcs):
                if i < self.num_offset_fcs - 1:
                    oc = self.deform_fc_channels
                else:
                    oc = self.out_size * self.out_size * 2
                seq.append(nn.Linear(ic, oc))
                ic = oc
                if i < self.num_offset_fcs - 1:
                    seq.append(nn.ReLU(inplace=True))
            self.offset_fc = nn.Sequential(*seq)
Kai Chen's avatar
Kai Chen committed
139
140
            self.offset_fc[-1].weight.data.zero_()
            self.offset_fc[-1].bias.data.zero_()
yhcao6's avatar
yhcao6 committed
141
142

    def forward(self, data, rois):
Kai Chen's avatar
Kai Chen committed
143
        assert data.size(1) == self.out_channels
yhcao6's avatar
yhcao6 committed
144
        if self.no_trans:
145
            offset = data.new_empty(0)
146
147
148
149
150
            return deform_roi_pooling(data, rois, offset, self.spatial_scale,
                                      self.out_size, self.out_channels,
                                      self.no_trans, self.group_size,
                                      self.part_size, self.sample_per_part,
                                      self.trans_std)
yhcao6's avatar
yhcao6 committed
151
152
        else:
            n = rois.shape[0]
153
            offset = data.new_empty(0)
yhcao6's avatar
yhcao6 committed
154
            x = deform_roi_pooling(data, rois, offset, self.spatial_scale,
Kai Chen's avatar
Kai Chen committed
155
                                   self.out_size, self.out_channels, True,
yhcao6's avatar
yhcao6 committed
156
157
                                   self.group_size, self.part_size,
                                   self.sample_per_part, self.trans_std)
yhcao6's avatar
yhcao6 committed
158
            offset = self.offset_fc(x.view(n, -1))
159
            offset = offset.view(n, 2, self.out_size, self.out_size)
160
161
162
163
164
            return deform_roi_pooling(data, rois, offset, self.spatial_scale,
                                      self.out_size, self.out_channels,
                                      self.no_trans, self.group_size,
                                      self.part_size, self.sample_per_part,
                                      self.trans_std)
165
166


167
class ModulatedDeformRoIPoolingPack(DeformRoIPooling):
168
169
170
171

    def __init__(self,
                 spatial_scale,
                 out_size,
Kai Chen's avatar
Kai Chen committed
172
                 out_channels,
173
174
175
176
177
                 no_trans,
                 group_size=1,
                 part_size=None,
                 sample_per_part=4,
                 trans_std=.0,
178
179
                 num_offset_fcs=3,
                 num_mask_fcs=2,
Kai Chen's avatar
Kai Chen committed
180
                 deform_fc_channels=1024):
181
182
183
        super(ModulatedDeformRoIPoolingPack,
              self).__init__(spatial_scale, out_size, out_channels, no_trans,
                             group_size, part_size, sample_per_part, trans_std)
184

185
186
        self.num_offset_fcs = num_offset_fcs
        self.num_mask_fcs = num_mask_fcs
Kai Chen's avatar
Kai Chen committed
187
        self.deform_fc_channels = deform_fc_channels
188
189

        if not no_trans:
190
191
192
193
194
195
196
197
198
199
200
201
            offset_fc_seq = []
            ic = self.out_size * self.out_size * self.out_channels
            for i in range(self.num_offset_fcs):
                if i < self.num_offset_fcs - 1:
                    oc = self.deform_fc_channels
                else:
                    oc = self.out_size * self.out_size * 2
                offset_fc_seq.append(nn.Linear(ic, oc))
                ic = oc
                if i < self.num_offset_fcs - 1:
                    offset_fc_seq.append(nn.ReLU(inplace=True))
            self.offset_fc = nn.Sequential(*offset_fc_seq)
Kai Chen's avatar
Kai Chen committed
202
203
            self.offset_fc[-1].weight.data.zero_()
            self.offset_fc[-1].bias.data.zero_()
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220

            mask_fc_seq = []
            ic = self.out_size * self.out_size * self.out_channels
            for i in range(self.num_mask_fcs):
                if i < self.num_mask_fcs - 1:
                    oc = self.deform_fc_channels
                else:
                    oc = self.out_size * self.out_size
                mask_fc_seq.append(nn.Linear(ic, oc))
                ic = oc
                if i < self.num_mask_fcs - 1:
                    mask_fc_seq.append(nn.ReLU(inplace=True))
                else:
                    mask_fc_seq.append(nn.Sigmoid())
            self.mask_fc = nn.Sequential(*mask_fc_seq)
            self.mask_fc[-2].weight.data.zero_()
            self.mask_fc[-2].bias.data.zero_()
221
222

    def forward(self, data, rois):
Kai Chen's avatar
Kai Chen committed
223
        assert data.size(1) == self.out_channels
224
        if self.no_trans:
225
            offset = data.new_empty(0)
226
227
228
229
230
            return deform_roi_pooling(data, rois, offset, self.spatial_scale,
                                      self.out_size, self.out_channels,
                                      self.no_trans, self.group_size,
                                      self.part_size, self.sample_per_part,
                                      self.trans_std)
231
232
        else:
            n = rois.shape[0]
233
            offset = data.new_empty(0)
234
            x = deform_roi_pooling(data, rois, offset, self.spatial_scale,
Kai Chen's avatar
Kai Chen committed
235
                                   self.out_size, self.out_channels, True,
236
237
238
239
                                   self.group_size, self.part_size,
                                   self.sample_per_part, self.trans_std)
            offset = self.offset_fc(x.view(n, -1))
            offset = offset.view(n, 2, self.out_size, self.out_size)
240
241
            mask = self.mask_fc(x.view(n, -1))
            mask = mask.view(n, 1, self.out_size, self.out_size)
Kai Chen's avatar
Kai Chen committed
242
            return deform_roi_pooling(
243
                data, rois, offset, self.spatial_scale, self.out_size,
Kai Chen's avatar
Kai Chen committed
244
                self.out_channels, self.no_trans, self.group_size,
245
                self.part_size, self.sample_per_part, self.trans_std) * mask