deform_pool.py 9.97 KB
Newer Older
1
2
3
4
import torch
import torch.nn as nn
from torch.autograd import Function
from torch.autograd.function import once_differentiable
Cao Yuhang's avatar
Cao Yuhang committed
5
from torch.nn.modules.utils import _pair
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24

from . import deform_pool_cuda


class DeformRoIPoolingFunction(Function):

    @staticmethod
    def forward(ctx,
                data,
                rois,
                offset,
                spatial_scale,
                out_size,
                out_channels,
                no_trans,
                group_size=1,
                part_size=None,
                sample_per_part=4,
                trans_std=.0):
Cao Yuhang's avatar
Cao Yuhang committed
25
26
27
28
29
30
        # TODO: support unsquare RoIs
        out_h, out_w = _pair(out_size)
        assert isinstance(out_h, int) and isinstance(out_w, int)
        assert out_h == out_w
        out_size = out_h  # out_h and out_w must be equal

31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
        ctx.spatial_scale = spatial_scale
        ctx.out_size = out_size
        ctx.out_channels = out_channels
        ctx.no_trans = no_trans
        ctx.group_size = group_size
        ctx.part_size = out_size if part_size is None else part_size
        ctx.sample_per_part = sample_per_part
        ctx.trans_std = trans_std

        assert 0.0 <= ctx.trans_std <= 1.0
        if not data.is_cuda:
            raise NotImplementedError

        n = rois.shape[0]
        output = data.new_empty(n, out_channels, out_size, out_size)
        output_count = data.new_empty(n, out_channels, out_size, out_size)
        deform_pool_cuda.deform_psroi_pooling_cuda_forward(
            data, rois, offset, output, output_count, ctx.no_trans,
            ctx.spatial_scale, ctx.out_channels, ctx.group_size, ctx.out_size,
            ctx.part_size, ctx.sample_per_part, ctx.trans_std)

        if data.requires_grad or rois.requires_grad or offset.requires_grad:
            ctx.save_for_backward(data, rois, offset)
        ctx.output_count = output_count

        return output

    @staticmethod
    @once_differentiable
    def backward(ctx, grad_output):
        if not grad_output.is_cuda:
            raise NotImplementedError

        data, rois, offset = ctx.saved_tensors
        output_count = ctx.output_count
        grad_input = torch.zeros_like(data)
        grad_rois = None
        grad_offset = torch.zeros_like(offset)

        deform_pool_cuda.deform_psroi_pooling_cuda_backward(
            grad_output, data, rois, offset, output_count, grad_input,
            grad_offset, ctx.no_trans, ctx.spatial_scale, ctx.out_channels,
            ctx.group_size, ctx.out_size, ctx.part_size, ctx.sample_per_part,
            ctx.trans_std)
        return (grad_input, grad_rois, grad_offset, None, None, None, None,
                None, None, None, None)


deform_roi_pooling = DeformRoIPoolingFunction.apply
yhcao6's avatar
yhcao6 committed
80
81
82
83
84
85


class DeformRoIPooling(nn.Module):

    def __init__(self,
                 spatial_scale,
86
                 out_size,
Kai Chen's avatar
Kai Chen committed
87
                 out_channels,
yhcao6's avatar
yhcao6 committed
88
89
90
91
92
93
94
                 no_trans,
                 group_size=1,
                 part_size=None,
                 sample_per_part=4,
                 trans_std=.0):
        super(DeformRoIPooling, self).__init__()
        self.spatial_scale = spatial_scale
Cao Yuhang's avatar
Cao Yuhang committed
95
        self.out_size = _pair(out_size)
Kai Chen's avatar
Kai Chen committed
96
        self.out_channels = out_channels
yhcao6's avatar
yhcao6 committed
97
98
        self.no_trans = no_trans
        self.group_size = group_size
99
        self.part_size = out_size if part_size is None else part_size
yhcao6's avatar
yhcao6 committed
100
101
102
103
104
        self.sample_per_part = sample_per_part
        self.trans_std = trans_std

    def forward(self, data, rois, offset):
        if self.no_trans:
105
            offset = data.new_empty(0)
106
107
108
109
110
        return deform_roi_pooling(data, rois, offset, self.spatial_scale,
                                  self.out_size, self.out_channels,
                                  self.no_trans, self.group_size,
                                  self.part_size, self.sample_per_part,
                                  self.trans_std)
yhcao6's avatar
yhcao6 committed
111
112


113
class DeformRoIPoolingPack(DeformRoIPooling):
yhcao6's avatar
yhcao6 committed
114
115
116

    def __init__(self,
                 spatial_scale,
117
                 out_size,
Kai Chen's avatar
Kai Chen committed
118
                 out_channels,
yhcao6's avatar
yhcao6 committed
119
120
121
122
123
                 no_trans,
                 group_size=1,
                 part_size=None,
                 sample_per_part=4,
                 trans_std=.0,
124
                 num_offset_fcs=3,
Kai Chen's avatar
Kai Chen committed
125
                 deform_fc_channels=1024):
126
        super(DeformRoIPoolingPack,
Kai Chen's avatar
Kai Chen committed
127
              self).__init__(spatial_scale, out_size, out_channels, no_trans,
128
                             group_size, part_size, sample_per_part, trans_std)
yhcao6's avatar
yhcao6 committed
129

130
        self.num_offset_fcs = num_offset_fcs
Kai Chen's avatar
Kai Chen committed
131
        self.deform_fc_channels = deform_fc_channels
yhcao6's avatar
yhcao6 committed
132
133

        if not no_trans:
134
            seq = []
Cao Yuhang's avatar
Cao Yuhang committed
135
            ic = self.out_size[0] * self.out_size[1] * self.out_channels
136
137
138
139
            for i in range(self.num_offset_fcs):
                if i < self.num_offset_fcs - 1:
                    oc = self.deform_fc_channels
                else:
Cao Yuhang's avatar
Cao Yuhang committed
140
                    oc = self.out_size[0] * self.out_size[1] * 2
141
142
143
144
145
                seq.append(nn.Linear(ic, oc))
                ic = oc
                if i < self.num_offset_fcs - 1:
                    seq.append(nn.ReLU(inplace=True))
            self.offset_fc = nn.Sequential(*seq)
Kai Chen's avatar
Kai Chen committed
146
147
            self.offset_fc[-1].weight.data.zero_()
            self.offset_fc[-1].bias.data.zero_()
yhcao6's avatar
yhcao6 committed
148
149

    def forward(self, data, rois):
Kai Chen's avatar
Kai Chen committed
150
        assert data.size(1) == self.out_channels
yhcao6's avatar
yhcao6 committed
151
        if self.no_trans:
152
            offset = data.new_empty(0)
153
154
155
156
157
            return deform_roi_pooling(data, rois, offset, self.spatial_scale,
                                      self.out_size, self.out_channels,
                                      self.no_trans, self.group_size,
                                      self.part_size, self.sample_per_part,
                                      self.trans_std)
yhcao6's avatar
yhcao6 committed
158
159
        else:
            n = rois.shape[0]
160
            offset = data.new_empty(0)
yhcao6's avatar
yhcao6 committed
161
            x = deform_roi_pooling(data, rois, offset, self.spatial_scale,
Kai Chen's avatar
Kai Chen committed
162
                                   self.out_size, self.out_channels, True,
yhcao6's avatar
yhcao6 committed
163
164
                                   self.group_size, self.part_size,
                                   self.sample_per_part, self.trans_std)
yhcao6's avatar
yhcao6 committed
165
            offset = self.offset_fc(x.view(n, -1))
Cao Yuhang's avatar
Cao Yuhang committed
166
            offset = offset.view(n, 2, self.out_size[0], self.out_size[1])
167
168
169
170
171
            return deform_roi_pooling(data, rois, offset, self.spatial_scale,
                                      self.out_size, self.out_channels,
                                      self.no_trans, self.group_size,
                                      self.part_size, self.sample_per_part,
                                      self.trans_std)
172
173


174
class ModulatedDeformRoIPoolingPack(DeformRoIPooling):
175
176
177
178

    def __init__(self,
                 spatial_scale,
                 out_size,
Kai Chen's avatar
Kai Chen committed
179
                 out_channels,
180
181
182
183
184
                 no_trans,
                 group_size=1,
                 part_size=None,
                 sample_per_part=4,
                 trans_std=.0,
185
186
                 num_offset_fcs=3,
                 num_mask_fcs=2,
Kai Chen's avatar
Kai Chen committed
187
                 deform_fc_channels=1024):
188
189
190
        super(ModulatedDeformRoIPoolingPack,
              self).__init__(spatial_scale, out_size, out_channels, no_trans,
                             group_size, part_size, sample_per_part, trans_std)
191

192
193
        self.num_offset_fcs = num_offset_fcs
        self.num_mask_fcs = num_mask_fcs
Kai Chen's avatar
Kai Chen committed
194
        self.deform_fc_channels = deform_fc_channels
195
196

        if not no_trans:
197
            offset_fc_seq = []
Cao Yuhang's avatar
Cao Yuhang committed
198
            ic = self.out_size[0] * self.out_size[1] * self.out_channels
199
200
201
202
            for i in range(self.num_offset_fcs):
                if i < self.num_offset_fcs - 1:
                    oc = self.deform_fc_channels
                else:
Cao Yuhang's avatar
Cao Yuhang committed
203
                    oc = self.out_size[0] * self.out_size[1] * 2
204
205
206
207
208
                offset_fc_seq.append(nn.Linear(ic, oc))
                ic = oc
                if i < self.num_offset_fcs - 1:
                    offset_fc_seq.append(nn.ReLU(inplace=True))
            self.offset_fc = nn.Sequential(*offset_fc_seq)
Kai Chen's avatar
Kai Chen committed
209
210
            self.offset_fc[-1].weight.data.zero_()
            self.offset_fc[-1].bias.data.zero_()
211
212

            mask_fc_seq = []
Cao Yuhang's avatar
Cao Yuhang committed
213
            ic = self.out_size[0] * self.out_size[1] * self.out_channels
214
215
216
217
            for i in range(self.num_mask_fcs):
                if i < self.num_mask_fcs - 1:
                    oc = self.deform_fc_channels
                else:
Cao Yuhang's avatar
Cao Yuhang committed
218
                    oc = self.out_size[0] * self.out_size[1]
219
220
221
222
223
224
225
226
227
                mask_fc_seq.append(nn.Linear(ic, oc))
                ic = oc
                if i < self.num_mask_fcs - 1:
                    mask_fc_seq.append(nn.ReLU(inplace=True))
                else:
                    mask_fc_seq.append(nn.Sigmoid())
            self.mask_fc = nn.Sequential(*mask_fc_seq)
            self.mask_fc[-2].weight.data.zero_()
            self.mask_fc[-2].bias.data.zero_()
228
229

    def forward(self, data, rois):
Kai Chen's avatar
Kai Chen committed
230
        assert data.size(1) == self.out_channels
231
        if self.no_trans:
232
            offset = data.new_empty(0)
233
234
235
236
237
            return deform_roi_pooling(data, rois, offset, self.spatial_scale,
                                      self.out_size, self.out_channels,
                                      self.no_trans, self.group_size,
                                      self.part_size, self.sample_per_part,
                                      self.trans_std)
238
239
        else:
            n = rois.shape[0]
240
            offset = data.new_empty(0)
241
            x = deform_roi_pooling(data, rois, offset, self.spatial_scale,
Kai Chen's avatar
Kai Chen committed
242
                                   self.out_size, self.out_channels, True,
243
244
245
                                   self.group_size, self.part_size,
                                   self.sample_per_part, self.trans_std)
            offset = self.offset_fc(x.view(n, -1))
Cao Yuhang's avatar
Cao Yuhang committed
246
            offset = offset.view(n, 2, self.out_size[0], self.out_size[1])
247
            mask = self.mask_fc(x.view(n, -1))
Cao Yuhang's avatar
Cao Yuhang committed
248
            mask = mask.view(n, 1, self.out_size[0], self.out_size[1])
Kai Chen's avatar
Kai Chen committed
249
            return deform_roi_pooling(
250
                data, rois, offset, self.spatial_scale, self.out_size,
Kai Chen's avatar
Kai Chen committed
251
                self.out_channels, self.no_trans, self.group_size,
252
                self.part_size, self.sample_per_part, self.trans_std) * mask