sample_kernel.cu 10.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
#include <ATen/ATen.h>

#include <cuda.h>
#include <cuda_runtime.h>

#include <stdint.h>
#include <vector>

#define THREADS 1024

// Original by Qi. et al (https://github.com/charlesq34/pointnet2)

template <typename scalar_t>
__global__ void query_radius_cuda_kernel(
    const int64_t* __restrict__ batch_slices,
    const int64_t* __restrict__ query_batch_slices,
    const scalar_t* __restrict__ pos,
    const scalar_t* __restrict__ query_pos,
    const scalar_t radius,
    const int64_t max_num_neighbors,
    const bool include_self,
    int64_t* idx_output,
    int64_t* cnt_output)
{
    const int64_t batch_index = blockIdx.x;
    const int64_t index = threadIdx.x;
    const int64_t stride = blockDim.x;

    const int64_t batch_start = batch_slices[2*batch_index];
    const int64_t query_batch_start = query_batch_slices[2*batch_index];
    const int64_t batch_end = batch_slices[2*batch_index+1];
    const int64_t query_batch_end = query_batch_slices[2*batch_index+1];

    const int64_t batch_size = batch_end - batch_start + 1;
    const int64_t query_batch_size = query_batch_end - query_batch_start + 1;

    pos += batch_start * 3;
    query_pos += query_batch_start * 3;
    idx_output += query_batch_start * max_num_neighbors;
    cnt_output += query_batch_start;


    for (int64_t j = index; j < query_batch_size; j+=stride){

        int64_t cnt = 0;
        scalar_t x2=query_pos[j*3+0];
        scalar_t y2=query_pos[j*3+1];
        scalar_t z2=query_pos[j*3+2];

        // dummy outputs initialisation with value -1
        if (cnt==0) {
            for (int64_t l = 0;l < max_num_neighbors; l++)
                idx_output[j*max_num_neighbors+l] = -1;
        }

        for (int64_t k = 0; k < batch_size; k++) {
            if (cnt == max_num_neighbors)
                break;

            scalar_t x1=pos[k*3+0];
            scalar_t y1=pos[k*3+1];
            scalar_t z1=pos[k*3+2];

            scalar_t d=(x2-x1)*(x2-x1)+(y2-y1)*(y2-y1)+(z2-z1)*(z2-z1);

            if (d <= radius && (d > 0 || include_self)) {
                idx_output[j * max_num_neighbors + cnt] = batch_start + k;
                cnt+=1;
            }
        }
        cnt_output[j] = cnt;
    }
}

template <typename scalar_t>
__global__ void query_knn_cuda_kernel(
    const int64_t* __restrict__ batch_slices,
    const int64_t* __restrict__ query_batch_slices,
    const scalar_t* __restrict__ pos,
    const scalar_t* __restrict__ query_pos,
    const int64_t num_neighbors,
    const bool include_self,
    scalar_t* tmp_dists,
    int64_t* idx_output){

    const int64_t batch_index = blockIdx.x;
    const int64_t index = threadIdx.x;
    const int64_t stride = blockDim.x;

    const int64_t batch_start = batch_slices[2*batch_index];
    const int64_t query_batch_start = query_batch_slices[2*batch_index];
    const int64_t batch_end = batch_slices[2*batch_index+1];
    const int64_t query_batch_end = query_batch_slices[2*batch_index+1];

    const int64_t batch_size = batch_end - batch_start + 1;
    const int64_t query_batch_size = query_batch_end - query_batch_start + 1;

    pos += batch_start * 3;
    query_pos += query_batch_start * 3;
    idx_output += query_batch_start * num_neighbors;
    tmp_dists += query_batch_start * num_neighbors;

    for (int64_t j = index; j < query_batch_size; j += stride){
        scalar_t x2=query_pos[j*3+0];
        scalar_t y2=query_pos[j*3+1];
        scalar_t z2=query_pos[j*3+2];
        // reset to dummy values

        for (int64_t l = 0; l < num_neighbors; l++){
            idx_output[j * num_neighbors + l] = -1;
            tmp_dists[j * num_neighbors + l] = 2147483647;
        }

        for (int64_t k = 0; k < batch_size; k++) {
            scalar_t x1=pos[k*3+0];
            scalar_t y1=pos[k*3+1];
            scalar_t z1=pos[k*3+2];

            scalar_t d=(x2-x1)*(x2-x1)+(y2-y1)*(y2-y1)+(z2-z1)*(z2-z1);

            if (d > 0 || include_self){
                for (int64_t i = 0; i < num_neighbors; i++){
                    if (tmp_dists[j * num_neighbors + i] > d){
                        for (int64_t i2 = num_neighbors-1; i2 > i; i2--){
                            tmp_dists[j * num_neighbors + i2] = tmp_dists[j * num_neighbors + i2 - 1];
                            idx_output[j * num_neighbors + i2] = idx_output[j * num_neighbors + i2 - 1];
                        }
                        tmp_dists[j * num_neighbors + i] = d;
                        idx_output[j * num_neighbors + i] = batch_start + k;
                        break;
                    }
                }
            }
        }
    }
}

template <typename scalar_t>
__global__ void farthest_point_sampling_kernel(
    const int64_t* __restrict__ batch_slices,
    const scalar_t* __restrict__ pos,
    const int64_t num_sample,
    const int64_t* __restrict__ start_points,
    scalar_t* tmp_dists,
    int64_t* idx_output){

    const int64_t batch_index = blockIdx.x;
    const int64_t index = threadIdx.x;
    const int64_t stride = blockDim.x;

    const int64_t batch_start = batch_slices[2*batch_index];
    const int64_t batch_end = batch_slices[2*batch_index+1];
    const int64_t batch_size = batch_end - batch_start + 1;

  __shared__ scalar_t dists[THREADS];
  __shared__ int64_t dists_i[THREADS];

    pos += batch_start * 3;
    idx_output += num_sample * batch_index;
    tmp_dists += batch_start;

    int64_t old = start_points[batch_index];

    // explicitly handle the case where less than num_sample points are available to sample from
    if (index == 0){
        idx_output[0] = batch_start + old;

        if (batch_size < num_sample){
            for (int64_t i = 0; i < batch_size; i++){
                idx_output[i] = batch_start + i;
            }
            for (int64_t i = batch_size; i < num_sample; i++){
                idx_output[i] = -1;
            }
        }
     }
    if (batch_size < num_sample){
        return;
    }

    // initialise temporary distances as big as possible
    for (int64_t j = index; j < batch_size; j+=stride){
        tmp_dists[j] = 2147483647;
    }

    __syncthreads();
    for (int64_t i = 1; i < num_sample; i++){
        int64_t besti = -1;
        scalar_t best = -1;

        // compute distance from last point to all others and update using the minimum of already computed distances
        for (int64_t j = index; j < batch_size; j+= stride){
            scalar_t td = tmp_dists[j];
            scalar_t x1 = pos[old * 3 + 0];
            scalar_t y1 = pos[old * 3 + 1];
            scalar_t z1 = pos[old * 3 + 2];

            scalar_t x2 = pos[j * 3 + 0];
            scalar_t y2 = pos[j * 3 + 1];
            scalar_t z2 = pos[j * 3 + 2];

            scalar_t d=(x2-x1)*(x2-x1)+(y2-y1)*(y2-y1)+(z2-z1)*(z2-z1);
            scalar_t d2  = min(d, tmp_dists[j]);
            if (td != d2){
                tmp_dists[j] = d2;
            }

            if (tmp_dists[j] > best){
              best = tmp_dists[j];
              besti = j;
            }
        }

        // sort best indices
        dists[index] = best;
        dists_i[index] = besti;

        __syncthreads();
        // get the maximum distances (by merging)
        for (int64_t u = 0; (1<<u) < blockDim.x ; u++){
            __syncthreads();
            if (index < (blockDim.x >> (u+1))){
                int64_t i1 = (index*2)<<u;
                int64_t i2 = (index*2+1)<<u;
                if (dists[i1] < dists[i2]){
                    dists[i1] = dists[i2];
                    dists_i[i1] = dists_i[i2];
                }
            }
        }

        __syncthreads();

        if (dists[0] == 0){
            break;
        }
        // thread 0 collects in output
        old = dists_i[0];
        if (index == 0){
            idx_output[i] = batch_start + old;
        }
    }

}


std::vector<at::Tensor> query_radius_cuda(
    int batch_size,
    at::Tensor batch_slices,
    at::Tensor query_batch_slices,
    at::Tensor pos,
    at::Tensor query_pos,
    double radius,
    int max_num_neighbors,
    bool include_self) {

  const auto num_points = query_pos.size(0);

  auto idx_output = at::empty(pos.type().toScalarType(at::kLong), {num_points, max_num_neighbors});
  auto cnt_output = at::empty(pos.type().toScalarType(at::kLong), {num_points});

  AT_DISPATCH_FLOATING_TYPES(pos.type(), "query_radius_cuda_kernel", [&] {
      query_radius_cuda_kernel<scalar_t><<<batch_size, THREADS>>>(
        batch_slices.data<int64_t>(),
        query_batch_slices.data<int64_t>(),
        pos.data<scalar_t>(),
        query_pos.data<scalar_t>(),
        (scalar_t) radius*radius,
        max_num_neighbors,
        include_self,
        idx_output.data<int64_t>(),
        cnt_output.data<int64_t>());
  });


  return {idx_output, cnt_output};
}


std::vector<at::Tensor> query_knn_cuda(
    int batch_size,
    at::Tensor batch_slices,
    at::Tensor query_batch_slices,
    at::Tensor pos,
    at::Tensor query_pos,
    int num_neighbors,
    bool include_self) {

  const auto num_points = query_pos.size(0);

  auto idx_output = at::empty(pos.type().toScalarType(at::kLong), {num_points, num_neighbors});
  auto dists = at::empty(pos.type(), {num_points, num_neighbors});

  AT_DISPATCH_FLOATING_TYPES(pos.type(), "query_knn_cuda_kernel", [&] {
    query_knn_cuda_kernel<scalar_t><<<batch_size, THREADS>>>(
      batch_slices.data<int64_t>(),
      query_batch_slices.data<int64_t>(),
      pos.data<scalar_t>(),
      query_pos.data<scalar_t>(),
      num_neighbors,
      include_self,
      dists.data<scalar_t>(),
      idx_output.data<int64_t>());
  });


  return {idx_output, dists};
}

at::Tensor farthest_point_sampling_cuda(
    int batch_size,
    at::Tensor batch_slices,
    at::Tensor pos,
    int num_sample,
    at::Tensor start_points) {

  auto idx_output = at::empty(pos.type().toScalarType(at::kLong), {batch_size * num_sample});
  auto tmp_dists = at::empty(pos.type(), {pos.size(0)});

  AT_DISPATCH_FLOATING_TYPES(pos.type(), "farthest_point_sampling_kernel", [&] {
    farthest_point_sampling_kernel<scalar_t><<<batch_size, THREADS>>>(
      batch_slices.data<int64_t>(),
      pos.data<scalar_t>(),
      num_sample,
      start_points.data<int64_t>(),
      tmp_dists.data<scalar_t>(),
      idx_output.data<int64_t>());
  });




  return idx_output;
}