deform_pool.py 5.91 KB
Newer Older
yhcao6's avatar
yhcao6 committed
1
2
from torch import nn

3
from ..functions.deform_pool import deform_roi_pooling
yhcao6's avatar
yhcao6 committed
4
5
6
7
8
9


class DeformRoIPooling(nn.Module):

    def __init__(self,
                 spatial_scale,
10
                 out_size,
yhcao6's avatar
yhcao6 committed
11
12
13
14
15
16
17
18
                 output_dim,
                 no_trans,
                 group_size=1,
                 part_size=None,
                 sample_per_part=4,
                 trans_std=.0):
        super(DeformRoIPooling, self).__init__()
        self.spatial_scale = spatial_scale
19
        self.out_size = out_size
yhcao6's avatar
yhcao6 committed
20
21
22
        self.output_dim = output_dim
        self.no_trans = no_trans
        self.group_size = group_size
23
        self.part_size = out_size if part_size is None else part_size
yhcao6's avatar
yhcao6 committed
24
25
26
27
28
        self.sample_per_part = sample_per_part
        self.trans_std = trans_std

    def forward(self, data, rois, offset):
        if self.no_trans:
29
            offset = data.new_empty(0)
yhcao6's avatar
yhcao6 committed
30
        return deform_roi_pooling(
31
            data, rois, offset, self.spatial_scale, self.out_size,
yhcao6's avatar
yhcao6 committed
32
33
            self.output_dim, self.no_trans, self.group_size, self.part_size,
            self.sample_per_part, self.trans_std)
yhcao6's avatar
yhcao6 committed
34
35


36
class DeformRoIPoolingPack(DeformRoIPooling):
yhcao6's avatar
yhcao6 committed
37
38
39

    def __init__(self,
                 spatial_scale,
40
                 out_size,
yhcao6's avatar
yhcao6 committed
41
42
43
44
45
46
47
                 output_dim,
                 no_trans,
                 group_size=1,
                 part_size=None,
                 sample_per_part=4,
                 trans_std=.0,
                 deform_fc_dim=1024):
48
49
50
        super(DeformRoIPoolingPack,
              self).__init__(spatial_scale, out_size, output_dim, no_trans,
                             group_size, part_size, sample_per_part, trans_std)
yhcao6's avatar
yhcao6 committed
51
52
53
54
55

        self.deform_fc_dim = deform_fc_dim

        if not no_trans:
            self.offset_fc = nn.Sequential(
56
57
                nn.Linear(self.out_size * self.out_size * self.output_dim,
                          self.deform_fc_dim), nn.ReLU(inplace=True),
yhcao6's avatar
yhcao6 committed
58
59
60
                nn.Linear(self.deform_fc_dim, self.deform_fc_dim),
                nn.ReLU(inplace=True),
                nn.Linear(self.deform_fc_dim,
61
                          self.out_size * self.out_size * 2))
yhcao6's avatar
yhcao6 committed
62
63
64
65
66
            self.offset_fc[4].weight.data.zero_()
            self.offset_fc[4].bias.data.zero_()

    def forward(self, data, rois):
        if self.no_trans:
67
            offset = data.new_empty(0)
yhcao6's avatar
yhcao6 committed
68
69
        else:
            n = rois.shape[0]
70
            offset = data.new_empty(0)
yhcao6's avatar
yhcao6 committed
71
            x = deform_roi_pooling(data, rois, offset, self.spatial_scale,
72
                                   self.out_size, self.output_dim, True,
yhcao6's avatar
yhcao6 committed
73
74
                                   self.group_size, self.part_size,
                                   self.sample_per_part, self.trans_std)
yhcao6's avatar
yhcao6 committed
75
            offset = self.offset_fc(x.view(n, -1))
76
            offset = offset.view(n, 2, self.out_size, self.out_size)
yhcao6's avatar
yhcao6 committed
77
            feat = deform_roi_pooling(
78
                data, rois, offset, self.spatial_scale, self.out_size,
yhcao6's avatar
yhcao6 committed
79
                self.output_dim, self.no_trans, self.group_size,
80
                self.part_size, self.sample_per_part, self.trans_std)
yhcao6's avatar
yhcao6 committed
81
            return feat
yhcao6's avatar
yhcao6 committed
82
        return deform_roi_pooling(
83
84
85
86
87
            data, rois, offset, self.spatial_scale, self.out_size,
            self.output_dim, self.no_trans, self.group_size, self.part_size,
            self.sample_per_part, self.trans_std)


88
class ModulatedDeformRoIPoolingPack(DeformRoIPooling):
89
90
91
92
93
94
95
96
97
98
99

    def __init__(self,
                 spatial_scale,
                 out_size,
                 output_dim,
                 no_trans,
                 group_size=1,
                 part_size=None,
                 sample_per_part=4,
                 trans_std=.0,
                 deform_fc_dim=1024):
100
        super(ModulatedDeformRoIPoolingPack, self).__init__(
101
102
103
104
105
106
107
            spatial_scale, out_size, output_dim, no_trans, group_size,
            part_size, sample_per_part, trans_std)

        self.deform_fc_dim = deform_fc_dim

        if not no_trans:
            self.offset_fc = nn.Sequential(
108
109
                nn.Linear(self.out_size * self.out_size * self.output_dim,
                          self.deform_fc_dim), nn.ReLU(inplace=True),
110
111
112
113
114
115
                nn.Linear(self.deform_fc_dim, self.deform_fc_dim),
                nn.ReLU(inplace=True),
                nn.Linear(self.deform_fc_dim,
                          self.out_size * self.out_size * 2))
            self.offset_fc[4].weight.data.zero_()
            self.offset_fc[4].bias.data.zero_()
116
117
118
119
120
121
122
            self.mask_fc = nn.Sequential(
                nn.Linear(self.out_size * self.out_size * self.output_dim,
                          self.deform_fc_dim), nn.ReLU(inplace=True),
                nn.Linear(self.deform_fc_dim,
                          self.out_size * self.out_size * 1), nn.Sigmoid())
            self.mask_fc[2].weight.data.zero_()
            self.mask_fc[2].bias.data.zero_()
123
124
125

    def forward(self, data, rois):
        if self.no_trans:
126
            offset = data.new_empty(0)
127
128
        else:
            n = rois.shape[0]
129
            offset = data.new_empty(0)
130
131
132
133
134
135
            x = deform_roi_pooling(data, rois, offset, self.spatial_scale,
                                   self.out_size, self.output_dim, True,
                                   self.group_size, self.part_size,
                                   self.sample_per_part, self.trans_std)
            offset = self.offset_fc(x.view(n, -1))
            offset = offset.view(n, 2, self.out_size, self.out_size)
136
137
            mask = self.mask_fc(x.view(n, -1))
            mask = mask.view(n, 1, self.out_size, self.out_size)
138
139
140
            feat = deform_roi_pooling(
                data, rois, offset, self.spatial_scale, self.out_size,
                self.output_dim, self.no_trans, self.group_size,
141
                self.part_size, self.sample_per_part, self.trans_std) * mask
142
143
144
            return feat
        return deform_roi_pooling(
            data, rois, offset, self.spatial_scale, self.out_size,
yhcao6's avatar
yhcao6 committed
145
146
            self.output_dim, self.no_trans, self.group_size, self.part_size,
            self.sample_per_part, self.trans_std)