deform_pool.py 3.87 KB
Newer Older
yhcao6's avatar
yhcao6 committed
1
2
from torch import nn

3
from ..functions.deform_pool import deform_roi_pooling
yhcao6's avatar
yhcao6 committed
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30


class DeformRoIPooling(nn.Module):

    def __init__(self,
                 spatial_scale,
                 pooled_size,
                 output_dim,
                 no_trans,
                 group_size=1,
                 part_size=None,
                 sample_per_part=4,
                 trans_std=.0):
        super(DeformRoIPooling, self).__init__()
        self.spatial_scale = spatial_scale
        self.pooled_size = pooled_size
        self.out_size = pooled_size
        self.output_dim = output_dim
        self.no_trans = no_trans
        self.group_size = group_size
        self.part_size = pooled_size if part_size is None else part_size
        self.sample_per_part = sample_per_part
        self.trans_std = trans_std

    def forward(self, data, rois, offset):
        if self.no_trans:
            offset = data.new()
yhcao6's avatar
yhcao6 committed
31
32
33
34
        return deform_roi_pooling(
            data, rois, offset, self.spatial_scale, self.pooled_size,
            self.output_dim, self.no_trans, self.group_size, self.part_size,
            self.sample_per_part, self.trans_std)
yhcao6's avatar
yhcao6 committed
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81


class ModulatedDeformRoIPoolingPack(DeformRoIPooling):

    def __init__(self,
                 spatial_scale,
                 pooled_size,
                 output_dim,
                 no_trans,
                 group_size=1,
                 part_size=None,
                 sample_per_part=4,
                 trans_std=.0,
                 deform_fc_dim=1024):
        super(ModulatedDeformRoIPoolingPack, self).__init__(
            spatial_scale, pooled_size, output_dim, no_trans, group_size,
            part_size, sample_per_part, trans_std)

        self.deform_fc_dim = deform_fc_dim

        if not no_trans:
            self.offset_fc = nn.Sequential(
                nn.Linear(
                    self.pooled_size * self.pooled_size * self.output_dim,
                    self.deform_fc_dim), nn.ReLU(inplace=True),
                nn.Linear(self.deform_fc_dim, self.deform_fc_dim),
                nn.ReLU(inplace=True),
                nn.Linear(self.deform_fc_dim,
                          self.pooled_size * self.pooled_size * 2))
            self.offset_fc[4].weight.data.zero_()
            self.offset_fc[4].bias.data.zero_()
            self.mask_fc = nn.Sequential(
                nn.Linear(
                    self.pooled_size * self.pooled_size * self.output_dim,
                    self.deform_fc_dim), nn.ReLU(inplace=True),
                nn.Linear(self.deform_fc_dim,
                          self.pooled_size * self.pooled_size * 1),
                nn.Sigmoid())
            self.mask_fc[2].weight.data.zero_()
            self.mask_fc[2].bias.data.zero_()

    def forward(self, data, rois):
        if self.no_trans:
            offset = data.new()
        else:
            n = rois.shape[0]
            offset = data.new()
yhcao6's avatar
yhcao6 committed
82
83
84
85
            x = deform_roi_pooling(data, rois, offset, self.spatial_scale,
                                   self.pooled_size, self.output_dim, True,
                                   self.group_size, self.part_size,
                                   self.sample_per_part, self.trans_std)
yhcao6's avatar
yhcao6 committed
86
87
88
89
            offset = self.offset_fc(x.view(n, -1))
            offset = offset.view(n, 2, self.pooled_size, self.pooled_size)
            mask = self.mask_fc(x.view(n, -1))
            mask = mask.view(n, 1, self.pooled_size, self.pooled_size)
yhcao6's avatar
yhcao6 committed
90
91
92
93
            feat = deform_roi_pooling(
                data, rois, offset, self.spatial_scale, self.pooled_size,
                self.output_dim, self.no_trans, self.group_size,
                self.part_size, self.sample_per_part, self.trans_std) * mask
yhcao6's avatar
yhcao6 committed
94
            return feat
yhcao6's avatar
yhcao6 committed
95
96
97
98
        return deform_roi_pooling(
            data, rois, offset, self.spatial_scale, self.pooled_size,
            self.output_dim, self.no_trans, self.group_size, self.part_size,
            self.sample_per_part, self.trans_std)