test_vote_fusion.py 14.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
"""Tests the core function of vote fusion.

CommandLine:
    pytest tests/test_models/test_fusion/test_vote_fusion.py
"""

import torch

from mmdet3d.models.fusion_layers import VoteFusion


def test_vote_fusion():
    img_meta = {
        'ori_shape': (530, 730, 3),
        'img_shape': (600, 826, 3),
        'pad_shape': (608, 832, 3),
        'scale_factor':
        torch.tensor([1.1315, 1.1321, 1.1315, 1.1321]),
        'flip':
        False,
        'pcd_horizontal_flip':
        False,
        'pcd_vertical_flip':
        False,
        'pcd_trans':
        torch.tensor([0., 0., 0.]),
        'pcd_scale_factor':
        1.0308290128214932,
        'pcd_rotation':
        torch.tensor([[0.9747, 0.2234, 0.0000], [-0.2234, 0.9747, 0.0000],
                      [0.0000, 0.0000, 1.0000]]),
        'transformation_3d_flow': ['HF', 'R', 'S', 'T']
    }

    calibs = {
        'Rt':
        torch.tensor([[[0.979570, 0.047954, -0.195330],
                       [0.047954, 0.887470, 0.458370],
                       [0.195330, -0.458370, 0.867030]]]),
        'K':
        torch.tensor([[[529.5000, 0.0000, 365.0000],
                       [0.0000, 529.5000, 265.0000], [0.0000, 0.0000,
                                                      1.0000]]])
    }

    bboxes = torch.tensor([[[
        5.4286e+02, 9.8283e+01, 6.1700e+02, 1.6742e+02, 9.7922e-01, 3.0000e+00
    ], [
        4.2613e+02, 8.4646e+01, 4.9091e+02, 1.6237e+02, 9.7848e-01, 3.0000e+00
    ], [
        2.5606e+02, 7.3244e+01, 3.7883e+02, 1.8471e+02, 9.7317e-01, 3.0000e+00
    ], [
        6.0104e+02, 1.0648e+02, 6.6757e+02, 1.9216e+02, 8.4607e-01, 3.0000e+00
    ], [
        2.2923e+02, 1.4984e+02, 7.0163e+02, 4.6537e+02, 3.5719e-01, 0.0000e+00
    ], [
        2.5614e+02, 7.4965e+01, 3.3275e+02, 1.5908e+02, 2.8688e-01, 3.0000e+00
    ], [
        9.8718e+00, 1.4142e+02, 2.0213e+02, 3.3878e+02, 1.0935e-01, 3.0000e+00
    ], [
        6.1930e+02, 1.1768e+02, 6.8505e+02, 2.0318e+02, 1.0720e-01, 3.0000e+00
    ]]])

    seeds_3d = torch.tensor([[[0.044544, 1.675476, -1.531831],
                              [2.500625, 7.238662, -0.737675],
                              [-0.600003, 4.827733, -0.084022],
                              [1.396212, 3.994484, -1.551180],
                              [-2.054746, 2.012759, -0.357472],
                              [-0.582477, 6.580470, -1.466052],
                              [1.313331, 5.722039, 0.123904],
                              [-1.107057, 3.450359, -1.043422],
                              [1.759746, 5.655951, -1.519564],
                              [-0.203003, 6.453243, 0.137703],
                              [-0.910429, 0.904407, -0.512307],
                              [0.434049, 3.032374, -0.763842],
                              [1.438146, 2.289263, -1.546332],
                              [0.575622, 5.041906, -0.891143],
                              [-1.675931, 1.417597, -1.588347]]])

    imgs = torch.linspace(
        -1, 1, steps=608 * 832).reshape(1, 608, 832).repeat(3, 1, 1)[None]

    expected_tensor1 = torch.tensor(
        [[[
            0.000000e+00, -0.000000e+00, 0.000000e+00, -0.000000e+00,
            0.000000e+00, 1.193706e-01, -0.000000e+00, -2.879214e-01,
            -0.000000e+00, 0.000000e+00, 1.422463e-01, -6.474612e-01,
            -0.000000e+00, 1.490057e-02, 0.000000e+00
        ],
          [
              0.000000e+00, -0.000000e+00, -0.000000e+00, 0.000000e+00,
              0.000000e+00, -1.873745e+00, -0.000000e+00, 1.576240e-01,
              0.000000e+00, -0.000000e+00, -3.646177e-02, -7.751858e-01,
              0.000000e+00, 9.593642e-02, 0.000000e+00
          ],
          [
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, -6.263277e-02, 0.000000e+00, -3.646387e-01,
              0.000000e+00, 0.000000e+00, -5.875812e-01, -6.263450e-02,
              0.000000e+00, 1.149264e-01, 0.000000e+00
          ],
          [
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 8.899736e-01, 0.000000e+00, 9.019017e-01,
              0.000000e+00, 0.000000e+00, 6.917775e-01, 8.899733e-01,
              0.000000e+00, 9.812444e-01, 0.000000e+00
          ],
          [
              -0.000000e+00, -0.000000e+00, -0.000000e+00, -0.000000e+00,
              -0.000000e+00, -4.516903e-01, -0.000000e+00, -2.315422e-01,
              -0.000000e+00, -0.000000e+00, -4.197519e-01, -4.516906e-01,
              -0.000000e+00, -1.547615e-01, -0.000000e+00
          ],
          [
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 3.571937e-01, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 0.000000e+00, 0.000000e+00, 3.571937e-01,
              0.000000e+00, 0.000000e+00, 0.000000e+00
          ],
          [
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 0.000000e+00, 0.000000e+00
          ],
          [
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 0.000000e+00, 0.000000e+00
          ],
          [
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 0.000000e+00, 0.000000e+00, 9.731653e-01,
              0.000000e+00, 0.000000e+00, 1.093455e-01, 0.000000e+00,
              0.000000e+00, 8.460656e-01, 0.000000e+00
          ],
          [
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 0.000000e+00, 0.000000e+00
          ],
          [
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 0.000000e+00, 0.000000e+00
          ],
          [
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 0.000000e+00, 0.000000e+00
          ],
          [
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 0.000000e+00, 0.000000e+00
          ],
          [
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 0.000000e+00, 0.000000e+00
          ],
          [
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 0.000000e+00, 0.000000e+00
          ],
          [
              2.316288e-03, -1.948284e-03, -3.694394e-03, 2.176163e-04,
              -3.882605e-03, -1.901490e-03, -3.355042e-03, -1.774631e-03,
              -6.981542e-04, -3.886823e-03, -1.302233e-03, -1.189933e-03,
              2.540967e-03, -1.834944e-03, 1.032048e-03
          ],
          [
              2.316288e-03, -1.948284e-03, -3.694394e-03, 2.176163e-04,
              -3.882605e-03, -1.901490e-03, -3.355042e-03, -1.774631e-03,
              -6.981542e-04, -3.886823e-03, -1.302233e-03, -1.189933e-03,
              2.540967e-03, -1.834944e-03, 1.032048e-03
          ],
          [
              2.316288e-03, -1.948284e-03, -3.694394e-03, 2.176163e-04,
              -3.882605e-03, -1.901490e-03, -3.355042e-03, -1.774631e-03,
              -6.981542e-04, -3.886823e-03, -1.302233e-03, -1.189933e-03,
              2.540967e-03, -1.834944e-03, 1.032048e-03
          ]]])

    expected_tensor2 = torch.tensor([[
        False, False, False, False, False, True, False, True, False, False,
        True, True, False, True, False, False, False, False, False, False,
        False, False, True, False, False, False, False, False, True, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, True, False
    ]])

    expected_tensor3 = torch.tensor(
        [[[
            -0.000000e+00, -0.000000e+00, -0.000000e+00, -0.000000e+00,
            0.000000e+00, -0.000000e+00, -0.000000e+00, 0.000000e+00,
            -0.000000e+00, -0.000000e+00, 0.000000e+00, -0.000000e+00,
            -0.000000e+00, 1.720988e-01, 0.000000e+00
        ],
          [
              0.000000e+00, -0.000000e+00, -0.000000e+00, 0.000000e+00,
              -0.000000e+00, 0.000000e+00, -0.000000e+00, 0.000000e+00,
              0.000000e+00, -0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 4.824460e-02, 0.000000e+00
          ],
          [
              -0.000000e+00, -0.000000e+00, -0.000000e+00, -0.000000e+00,
              -0.000000e+00, -0.000000e+00, -0.000000e+00, 0.000000e+00,
              -0.000000e+00, -0.000000e+00, -0.000000e+00, -0.000000e+00,
              -0.000000e+00, 1.447314e-01, -0.000000e+00
          ],
          [
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 9.759269e-01, 0.000000e+00
          ],
          [
              -0.000000e+00, -0.000000e+00, -0.000000e+00, -0.000000e+00,
              -0.000000e+00, -0.000000e+00, -0.000000e+00, -0.000000e+00,
              -0.000000e+00, -0.000000e+00, -0.000000e+00, -0.000000e+00,
              -0.000000e+00, -1.631542e-01, -0.000000e+00
          ],
          [
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 0.000000e+00, 0.000000e+00
          ],
          [
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 0.000000e+00, 0.000000e+00
          ],
          [
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 0.000000e+00, 0.000000e+00
          ],
          [
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 1.072001e-01, 0.000000e+00
          ],
          [
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 0.000000e+00, 0.000000e+00
          ],
          [
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 0.000000e+00, 0.000000e+00
          ],
          [
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 0.000000e+00, 0.000000e+00
          ],
          [
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 0.000000e+00, 0.000000e+00
          ],
          [
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 0.000000e+00, 0.000000e+00
          ],
          [
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 0.000000e+00, 0.000000e+00
          ],
          [
              2.316288e-03, -1.948284e-03, -3.694394e-03, 2.176163e-04,
              -3.882605e-03, -1.901490e-03, -3.355042e-03, -1.774631e-03,
              -6.981542e-04, -3.886823e-03, -1.302233e-03, -1.189933e-03,
              2.540967e-03, -1.834944e-03, 1.032048e-03
          ],
          [
              2.316288e-03, -1.948284e-03, -3.694394e-03, 2.176163e-04,
              -3.882605e-03, -1.901490e-03, -3.355042e-03, -1.774631e-03,
              -6.981542e-04, -3.886823e-03, -1.302233e-03, -1.189933e-03,
              2.540967e-03, -1.834944e-03, 1.032048e-03
          ],
          [
              2.316288e-03, -1.948284e-03, -3.694394e-03, 2.176163e-04,
              -3.882605e-03, -1.901490e-03, -3.355042e-03, -1.774631e-03,
              -6.981542e-04, -3.886823e-03, -1.302233e-03, -1.189933e-03,
              2.540967e-03, -1.834944e-03, 1.032048e-03
          ]]])

    fusion = VoteFusion()
    out1, out2 = fusion(imgs, bboxes, seeds_3d, [img_meta], calibs)
    assert torch.allclose(expected_tensor1, out1[:, :, :15], 1e-3)
    assert torch.allclose(expected_tensor2.float(), out2.float(), 1e-3)
    assert torch.allclose(expected_tensor3, out1[:, :, 30:45], 1e-3)

    out1, out2 = fusion(imgs, bboxes[:, :2], seeds_3d, [img_meta], calibs)
    out1 = out1[:, :15, 30:45]
    out2 = out2[:, 30:45].float()
    assert torch.allclose(torch.zeros_like(out1), out1, 1e-3)
    assert torch.allclose(torch.zeros_like(out2), out2, 1e-3)