"docs/zh_cn/supported_tasks/lidar_det3d.md" did not exist on "5f1366cef0b8d82269f762ada3d23a67205077b5"
deform_pool.py 5.95 KB
Newer Older
yhcao6's avatar
yhcao6 committed
1
2
from torch import nn

3
from ..functions.deform_pool import deform_roi_pooling
yhcao6's avatar
yhcao6 committed
4
5
6
7
8
9


class DeformRoIPooling(nn.Module):

    def __init__(self,
                 spatial_scale,
10
                 out_size,
yhcao6's avatar
yhcao6 committed
11
12
13
14
15
16
17
18
                 output_dim,
                 no_trans,
                 group_size=1,
                 part_size=None,
                 sample_per_part=4,
                 trans_std=.0):
        super(DeformRoIPooling, self).__init__()
        self.spatial_scale = spatial_scale
19
        self.out_size = out_size
yhcao6's avatar
yhcao6 committed
20
21
22
        self.output_dim = output_dim
        self.no_trans = no_trans
        self.group_size = group_size
23
        self.part_size = out_size if part_size is None else part_size
yhcao6's avatar
yhcao6 committed
24
25
26
27
28
        self.sample_per_part = sample_per_part
        self.trans_std = trans_std

    def forward(self, data, rois, offset):
        if self.no_trans:
29
            offset = data.new_empty(0)
yhcao6's avatar
yhcao6 committed
30
        return deform_roi_pooling(
31
            data, rois, offset, self.spatial_scale, self.out_size,
yhcao6's avatar
yhcao6 committed
32
33
            self.output_dim, self.no_trans, self.group_size, self.part_size,
            self.sample_per_part, self.trans_std)
yhcao6's avatar
yhcao6 committed
34
35
36
37
38
39


class ModulatedDeformRoIPoolingPack(DeformRoIPooling):

    def __init__(self,
                 spatial_scale,
40
                 out_size,
yhcao6's avatar
yhcao6 committed
41
42
43
44
45
46
47
48
                 output_dim,
                 no_trans,
                 group_size=1,
                 part_size=None,
                 sample_per_part=4,
                 trans_std=.0,
                 deform_fc_dim=1024):
        super(ModulatedDeformRoIPoolingPack, self).__init__(
49
            spatial_scale, out_size, output_dim, no_trans, group_size,
yhcao6's avatar
yhcao6 committed
50
51
52
53
54
55
56
            part_size, sample_per_part, trans_std)

        self.deform_fc_dim = deform_fc_dim

        if not no_trans:
            self.offset_fc = nn.Sequential(
                nn.Linear(
57
                    self.out_size * self.out_size * self.output_dim,
yhcao6's avatar
yhcao6 committed
58
59
60
61
                    self.deform_fc_dim), nn.ReLU(inplace=True),
                nn.Linear(self.deform_fc_dim, self.deform_fc_dim),
                nn.ReLU(inplace=True),
                nn.Linear(self.deform_fc_dim,
62
                          self.out_size * self.out_size * 2))
yhcao6's avatar
yhcao6 committed
63
64
65
66
            self.offset_fc[4].weight.data.zero_()
            self.offset_fc[4].bias.data.zero_()
            self.mask_fc = nn.Sequential(
                nn.Linear(
67
                    self.out_size * self.out_size * self.output_dim,
yhcao6's avatar
yhcao6 committed
68
69
                    self.deform_fc_dim), nn.ReLU(inplace=True),
                nn.Linear(self.deform_fc_dim,
70
                          self.out_size * self.out_size * 1),
yhcao6's avatar
yhcao6 committed
71
72
73
74
75
76
                nn.Sigmoid())
            self.mask_fc[2].weight.data.zero_()
            self.mask_fc[2].bias.data.zero_()

    def forward(self, data, rois):
        if self.no_trans:
77
            offset = data.new_empty(0)
yhcao6's avatar
yhcao6 committed
78
79
        else:
            n = rois.shape[0]
80
            offset = data.new_empty(0)
yhcao6's avatar
yhcao6 committed
81
            x = deform_roi_pooling(data, rois, offset, self.spatial_scale,
82
                                   self.out_size, self.output_dim, True,
yhcao6's avatar
yhcao6 committed
83
84
                                   self.group_size, self.part_size,
                                   self.sample_per_part, self.trans_std)
yhcao6's avatar
yhcao6 committed
85
            offset = self.offset_fc(x.view(n, -1))
86
            offset = offset.view(n, 2, self.out_size, self.out_size)
yhcao6's avatar
yhcao6 committed
87
            mask = self.mask_fc(x.view(n, -1))
88
            mask = mask.view(n, 1, self.out_size, self.out_size)
yhcao6's avatar
yhcao6 committed
89
            feat = deform_roi_pooling(
90
                data, rois, offset, self.spatial_scale, self.out_size,
yhcao6's avatar
yhcao6 committed
91
92
                self.output_dim, self.no_trans, self.group_size,
                self.part_size, self.sample_per_part, self.trans_std) * mask
yhcao6's avatar
yhcao6 committed
93
            return feat
yhcao6's avatar
yhcao6 committed
94
        return deform_roi_pooling(
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
            data, rois, offset, self.spatial_scale, self.out_size,
            self.output_dim, self.no_trans, self.group_size, self.part_size,
            self.sample_per_part, self.trans_std)


class DeformRoIPoolingPack(DeformRoIPooling):

    def __init__(self,
                 spatial_scale,
                 out_size,
                 output_dim,
                 no_trans,
                 group_size=1,
                 part_size=None,
                 sample_per_part=4,
                 trans_std=.0,
                 deform_fc_dim=1024):
        super(DeformRoIPoolingPack, self).__init__(
            spatial_scale, out_size, output_dim, no_trans, group_size,
            part_size, sample_per_part, trans_std)

        self.deform_fc_dim = deform_fc_dim

        if not no_trans:
            self.offset_fc = nn.Sequential(
                nn.Linear(
                    self.out_size * self.out_size * self.output_dim,
                    self.deform_fc_dim), nn.ReLU(inplace=True),
                nn.Linear(self.deform_fc_dim, self.deform_fc_dim),
                nn.ReLU(inplace=True),
                nn.Linear(self.deform_fc_dim,
                          self.out_size * self.out_size * 2))
            self.offset_fc[4].weight.data.zero_()
            self.offset_fc[4].bias.data.zero_()

    def forward(self, data, rois):
        if self.no_trans:
132
            offset = data.new_empty(0)
133
134
        else:
            n = rois.shape[0]
135
            offset = data.new_empty(0)
136
137
138
139
140
141
142
143
144
145
146
147
148
            x = deform_roi_pooling(data, rois, offset, self.spatial_scale,
                                   self.out_size, self.output_dim, True,
                                   self.group_size, self.part_size,
                                   self.sample_per_part, self.trans_std)
            offset = self.offset_fc(x.view(n, -1))
            offset = offset.view(n, 2, self.out_size, self.out_size)
            feat = deform_roi_pooling(
                data, rois, offset, self.spatial_scale, self.out_size,
                self.output_dim, self.no_trans, self.group_size,
                self.part_size, self.sample_per_part, self.trans_std)
            return feat
        return deform_roi_pooling(
            data, rois, offset, self.spatial_scale, self.out_size,
yhcao6's avatar
yhcao6 committed
149
150
            self.output_dim, self.no_trans, self.group_size, self.part_size,
            self.sample_per_part, self.trans_std)