test_group_points.py 13.5 KB
Newer Older
1
# Copyright (c) OpenMMLab. All rights reserved.
2
3
4
5
6
7
8
9
10
11
12
13
14
import pytest
import torch

from mmcv.ops import grouping_operation


@pytest.mark.skipif(
    not torch.cuda.is_available(), reason='requires CUDA support')
def test_grouping_points():
    idx = torch.tensor([[[0, 0, 0], [3, 3, 3], [8, 8, 8], [0, 0, 0], [0, 0, 0],
                         [0, 0, 0]],
                        [[0, 0, 0], [6, 6, 6], [9, 9, 9], [0, 0, 0], [0, 0, 0],
                         [0, 0, 0]]]).int().cuda()
15
    features = torch.tensor([[[
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
        0.5798, -0.7981, -0.9280, -1.3311, 1.3687, 0.9277, -0.4164, -1.8274,
        0.9268, 0.8414
    ],
                              [
                                  5.4247, 1.5113, 2.3944, 1.4740, 5.0300,
                                  5.1030, 1.9360, 2.1939, 2.1581, 3.4666
                              ],
                              [
                                  -1.6266, -1.0281, -1.0393, -1.6931, -1.3982,
                                  -0.5732, -1.0830, -1.7561, -1.6786, -1.6967
                              ]],
                             [[
                                 -0.0380, -0.1880, -1.5724, 0.6905, -0.3190,
                                 0.7798, -0.3693, -0.9457, -0.2942, -1.8527
                             ],
                              [
                                  1.1773, 1.5009, 2.6399, 5.9242, 1.0962,
                                  2.7346, 6.0865, 1.5555, 4.3303, 2.8229
                              ],
                              [
                                  -0.6646, -0.6870, -0.1125, -0.2224, -0.3445,
                                  -1.4049, 0.4990, -0.7037, -0.9924, 0.0386
                              ]]]).cuda()

40
    output = grouping_operation(features, idx)
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
    expected_output = torch.tensor([[[[0.5798, 0.5798, 0.5798],
                                      [-1.3311, -1.3311, -1.3311],
                                      [0.9268, 0.9268, 0.9268],
                                      [0.5798, 0.5798, 0.5798],
                                      [0.5798, 0.5798, 0.5798],
                                      [0.5798, 0.5798, 0.5798]],
                                     [[5.4247, 5.4247, 5.4247],
                                      [1.4740, 1.4740, 1.4740],
                                      [2.1581, 2.1581, 2.1581],
                                      [5.4247, 5.4247, 5.4247],
                                      [5.4247, 5.4247, 5.4247],
                                      [5.4247, 5.4247, 5.4247]],
                                     [[-1.6266, -1.6266, -1.6266],
                                      [-1.6931, -1.6931, -1.6931],
                                      [-1.6786, -1.6786, -1.6786],
                                      [-1.6266, -1.6266, -1.6266],
                                      [-1.6266, -1.6266, -1.6266],
                                      [-1.6266, -1.6266, -1.6266]]],
                                    [[[-0.0380, -0.0380, -0.0380],
                                      [-0.3693, -0.3693, -0.3693],
                                      [-1.8527, -1.8527, -1.8527],
                                      [-0.0380, -0.0380, -0.0380],
                                      [-0.0380, -0.0380, -0.0380],
                                      [-0.0380, -0.0380, -0.0380]],
                                     [[1.1773, 1.1773, 1.1773],
                                      [6.0865, 6.0865, 6.0865],
                                      [2.8229, 2.8229, 2.8229],
                                      [1.1773, 1.1773, 1.1773],
                                      [1.1773, 1.1773, 1.1773],
                                      [1.1773, 1.1773, 1.1773]],
                                     [[-0.6646, -0.6646, -0.6646],
                                      [0.4990, 0.4990, 0.4990],
                                      [0.0386, 0.0386, 0.0386],
                                      [-0.6646, -0.6646, -0.6646],
                                      [-0.6646, -0.6646, -0.6646],
                                      [-0.6646, -0.6646, -0.6646]]]]).cuda()
    assert torch.allclose(output, expected_output)
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235


@pytest.mark.skipif(
    not torch.cuda.is_available(), reason='requires CUDA support')
def test_stack_grouping_points():
    idx = torch.tensor([[0, 0, 0], [3, 3, 3], [8, 8, 8], [1, 1, 1], [0, 0, 0],
                        [2, 2, 2], [0, 0, 0], [6, 6, 6], [9, 9, 9], [0, 0, 0],
                        [1, 1, 1], [0, 0, 0]]).int().cuda()
    features = torch.tensor([[
        0.5798, -0.7981, -0.9280, -1.3311, 1.3687, 0.9277, -0.4164, -1.8274,
        0.9268, 0.8414
    ],
                             [
                                 5.4247, 1.5113, 2.3944, 1.4740, 5.0300,
                                 5.1030, 1.9360, 2.1939, 2.1581, 3.4666
                             ],
                             [
                                 -1.6266, -1.0281, -1.0393, -1.6931, -1.3982,
                                 -0.5732, -1.0830, -1.7561, -1.6786, -1.6967
                             ],
                             [
                                 -0.0380, -0.1880, -1.5724, 0.6905, -0.3190,
                                 0.7798, -0.3693, -0.9457, -0.2942, -1.8527
                             ],
                             [
                                 1.1773, 1.5009, 2.6399, 5.9242, 1.0962,
                                 2.7346, 6.0865, 1.5555, 4.3303, 2.8229
                             ],
                             [
                                 -0.6646, -0.6870, -0.1125, -0.2224, -0.3445,
                                 -1.4049, 0.4990, -0.7037, -0.9924, 0.0386
                             ]]).float().cuda()
    features_batch_cnt = torch.tensor([3, 3]).int().cuda()
    indices_batch_cnt = torch.tensor([6, 6]).int().cuda()
    output = grouping_operation(features, idx, features_batch_cnt,
                                indices_batch_cnt)
    expected_output = torch.Tensor([[[0.5798, 0.5798, 0.5798],
                                     [-0.7981, -0.7981, -0.7981],
                                     [-0.9280, -0.9280, -0.9280],
                                     [-1.3311, -1.3311, -1.3311],
                                     [1.3687, 1.3687, 1.3687],
                                     [0.9277, 0.9277, 0.9277],
                                     [-0.4164, -0.4164, -0.4164],
                                     [-1.8274, -1.8274, -1.8274],
                                     [0.9268, 0.9268, 0.9268],
                                     [0.8414, 0.8414, 0.8414]],
                                    [[0.0000, 0.0000, 0.0000],
                                     [0.0000, 0.0000, 0.0000],
                                     [0.0000, 0.0000, 0.0000],
                                     [0.0000, 0.0000, 0.0000],
                                     [0.0000, 0.0000, 0.0000],
                                     [0.0000, 0.0000, 0.0000],
                                     [0.0000, 0.0000, 0.0000],
                                     [0.0000, 0.0000, 0.0000],
                                     [0.0000, 0.0000, 0.0000],
                                     [0.0000, 0.0000, 0.0000]],
                                    [[0.0000, 0.0000, 0.0000],
                                     [0.0000, 0.0000, 0.0000],
                                     [0.0000, 0.0000, 0.0000],
                                     [0.0000, 0.0000, 0.0000],
                                     [0.0000, 0.0000, 0.0000],
                                     [0.0000, 0.0000, 0.0000],
                                     [0.0000, 0.0000, 0.0000],
                                     [0.0000, 0.0000, 0.0000],
                                     [0.0000, 0.0000, 0.0000],
                                     [0.0000, 0.0000, 0.0000]],
                                    [[5.4247, 5.4247, 5.4247],
                                     [1.5113, 1.5113, 1.5113],
                                     [2.3944, 2.3944, 2.3944],
                                     [1.4740, 1.4740, 1.4740],
                                     [5.0300, 5.0300, 5.0300],
                                     [5.1030, 5.1030, 5.1030],
                                     [1.9360, 1.9360, 1.9360],
                                     [2.1939, 2.1939, 2.1939],
                                     [2.1581, 2.1581, 2.1581],
                                     [3.4666, 3.4666, 3.4666]],
                                    [[0.5798, 0.5798, 0.5798],
                                     [-0.7981, -0.7981, -0.7981],
                                     [-0.9280, -0.9280, -0.9280],
                                     [-1.3311, -1.3311, -1.3311],
                                     [1.3687, 1.3687, 1.3687],
                                     [0.9277, 0.9277, 0.9277],
                                     [-0.4164, -0.4164, -0.4164],
                                     [-1.8274, -1.8274, -1.8274],
                                     [0.9268, 0.9268, 0.9268],
                                     [0.8414, 0.8414, 0.8414]],
                                    [[-1.6266, -1.6266, -1.6266],
                                     [-1.0281, -1.0281, -1.0281],
                                     [-1.0393, -1.0393, -1.0393],
                                     [-1.6931, -1.6931, -1.6931],
                                     [-1.3982, -1.3982, -1.3982],
                                     [-0.5732, -0.5732, -0.5732],
                                     [-1.0830, -1.0830, -1.0830],
                                     [-1.7561, -1.7561, -1.7561],
                                     [-1.6786, -1.6786, -1.6786],
                                     [-1.6967, -1.6967, -1.6967]],
                                    [[-0.0380, -0.0380, -0.0380],
                                     [-0.1880, -0.1880, -0.1880],
                                     [-1.5724, -1.5724, -1.5724],
                                     [0.6905, 0.6905, 0.6905],
                                     [-0.3190, -0.3190, -0.3190],
                                     [0.7798, 0.7798, 0.7798],
                                     [-0.3693, -0.3693, -0.3693],
                                     [-0.9457, -0.9457, -0.9457],
                                     [-0.2942, -0.2942, -0.2942],
                                     [-1.8527, -1.8527, -1.8527]],
                                    [[0.0000, 0.0000, 0.0000],
                                     [0.0000, 0.0000, 0.0000],
                                     [0.0000, 0.0000, 0.0000],
                                     [0.0000, 0.0000, 0.0000],
                                     [0.0000, 0.0000, 0.0000],
                                     [0.0000, 0.0000, 0.0000],
                                     [0.0000, 0.0000, 0.0000],
                                     [0.0000, 0.0000, 0.0000],
                                     [0.0000, 0.0000, 0.0000],
                                     [0.0000, 0.0000, 0.0000]],
                                    [[0.0000, 0.0000, 0.0000],
                                     [0.0000, 0.0000, 0.0000],
                                     [0.0000, 0.0000, 0.0000],
                                     [0.0000, 0.0000, 0.0000],
                                     [0.0000, 0.0000, 0.0000],
                                     [0.0000, 0.0000, 0.0000],
                                     [0.0000, 0.0000, 0.0000],
                                     [0.0000, 0.0000, 0.0000],
                                     [0.0000, 0.0000, 0.0000],
                                     [0.0000, 0.0000, 0.0000]],
                                    [[-0.0380, -0.0380, -0.0380],
                                     [-0.1880, -0.1880, -0.1880],
                                     [-1.5724, -1.5724, -1.5724],
                                     [0.6905, 0.6905, 0.6905],
                                     [-0.3190, -0.3190, -0.3190],
                                     [0.7798, 0.7798, 0.7798],
                                     [-0.3693, -0.3693, -0.3693],
                                     [-0.9457, -0.9457, -0.9457],
                                     [-0.2942, -0.2942, -0.2942],
                                     [-1.8527, -1.8527, -1.8527]],
                                    [[1.1773, 1.1773, 1.1773],
                                     [1.5009, 1.5009, 1.5009],
                                     [2.6399, 2.6399, 2.6399],
                                     [5.9242, 5.9242, 5.9242],
                                     [1.0962, 1.0962, 1.0962],
                                     [2.7346, 2.7346, 2.7346],
                                     [6.0865, 6.0865, 6.0865],
                                     [1.5555, 1.5555, 1.5555],
                                     [4.3303, 4.3303, 4.3303],
                                     [2.8229, 2.8229, 2.8229]],
                                    [[-0.0380, -0.0380, -0.0380],
                                     [-0.1880, -0.1880, -0.1880],
                                     [-1.5724, -1.5724, -1.5724],
                                     [0.6905, 0.6905, 0.6905],
                                     [-0.3190, -0.3190, -0.3190],
                                     [0.7798, 0.7798, 0.7798],
                                     [-0.3693, -0.3693, -0.3693],
                                     [-0.9457, -0.9457, -0.9457],
                                     [-0.2942, -0.2942, -0.2942],
                                     [-1.8527, -1.8527,
                                      -1.8527]]]).cuda().float()
    assert torch.allclose(output, expected_output)