roiaware_pool3d_kernel.cu 14.6 KB
Newer Older
1
2
3
4
// Modified from
// https://github.com/sshaoshuai/PCDet/blob/master/pcdet/ops/roiaware_pool3d/src/roiaware_pool3d_kernel.cu
// Written by Shaoshuai Shi
// All Rights Reserved 2019.
wuyuefeng's avatar
wuyuefeng committed
5
6
7
8

#include <assert.h>
#include <math.h>
#include <stdio.h>
9
10
#include <torch/serialize/tensor.h>
#include <torch/types.h>
wuyuefeng's avatar
wuyuefeng committed
11
12

#define THREADS_PER_BLOCK 256
13
#define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0))
wuyuefeng's avatar
wuyuefeng committed
14
15
16

// #define DEBUG

17
18
19
__device__ inline void lidar_to_local_coords(float shift_x, float shift_y,
                                             float rz, float &local_x,
                                             float &local_y) {
20
  float cosa = cos(-rz), sina = sin(-rz);
21
22
  local_x = shift_x * cosa + shift_y * (-sina);
  local_y = shift_x * sina + shift_y * cosa;
wuyuefeng's avatar
wuyuefeng committed
23
24
}

25
26
27
__device__ inline int check_pt_in_box3d(const float *pt, const float *box3d,
                                        float &local_x, float &local_y) {
  // param pt: (x, y, z)
28
  // param box3d: (cx, cy, cz, dx, dy, dz, rz) in LiDAR coordinate, cz in the
29
30
31
  // bottom center
  float x = pt[0], y = pt[1], z = pt[2];
  float cx = box3d[0], cy = box3d[1], cz = box3d[2];
32
33
  float dx = box3d[3], dy = box3d[4], dz = box3d[5], rz = box3d[6];
  cz += dz / 2.0;  // shift to the center since cz in box3d is the bottom center
34

35
  if (fabsf(z - cz) > dz / 2.0) return 0;
36
  lidar_to_local_coords(x - cx, y - cy, rz, local_x, local_y);
37
38
  float in_flag = (local_x > -dx / 2.0) & (local_x < dx / 2.0) &
                  (local_y > -dy / 2.0) & (local_y < dy / 2.0);
39
  return in_flag;
wuyuefeng's avatar
wuyuefeng committed
40
41
}

42
43
44
45
__global__ void generate_pts_mask_for_box3d(int boxes_num, int pts_num,
                                            int out_x, int out_y, int out_z,
                                            const float *rois, const float *pts,
                                            int *pts_mask) {
46
  // params rois: (N, 7) [x, y, z, dx, dy, dz, rz] in LiDAR coordinate
47
  // params pts: (npoints, 3) [x, y, z]
48
  // params pts_mask: (N, npoints): -1 means point does not in this box,
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
  // otherwise: encode (x_idxs, y_idxs, z_idxs) by binary bit
  int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
  int box_idx = blockIdx.y;
  if (pt_idx >= pts_num || box_idx >= boxes_num) return;

  pts += pt_idx * 3;
  rois += box_idx * 7;
  pts_mask += box_idx * pts_num + pt_idx;

  float local_x = 0, local_y = 0;
  int cur_in_flag = check_pt_in_box3d(pts, rois, local_x, local_y);

  pts_mask[0] = -1;
  if (cur_in_flag > 0) {
    float local_z = pts[2] - rois[2];
64
    float dx = rois[3], dy = rois[4], dz = rois[5];
65

66
67
68
    float x_res = dx / out_x;
    float y_res = dy / out_y;
    float z_res = dz / out_z;
69

70
71
    unsigned int x_idx = int((local_x + dx / 2) / x_res);
    unsigned int y_idx = int((local_y + dy / 2) / y_res);
72
73
74
75
76
77
78
    unsigned int z_idx = int(local_z / z_res);

    x_idx = min(max(x_idx, 0), out_x - 1);
    y_idx = min(max(y_idx, 0), out_y - 1);
    z_idx = min(max(z_idx, 0), out_z - 1);

    unsigned int idx_encoding = (x_idx << 16) + (y_idx << 8) + z_idx;
wuyuefeng's avatar
wuyuefeng committed
79
#ifdef DEBUG
80
81
82
83
84
    printf(
        "mask: pts_%d(%.3f, %.3f, %.3f), local(%.3f, %.3f, %.3f), idx(%d, %d, "
        "%d), res(%.3f, %.3f, %.3f), idx_encoding=%x\n",
        pt_idx, pts[0], pts[1], pts[2], local_x, local_y, local_z, x_idx, y_idx,
        z_idx, x_res, y_res, z_res, idx_encoding);
wuyuefeng's avatar
wuyuefeng committed
85
86
#endif

87
88
    pts_mask[0] = idx_encoding;
  }
wuyuefeng's avatar
wuyuefeng committed
89
90
}

91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
__global__ void collect_inside_pts_for_box3d(int boxes_num, int pts_num,
                                             int max_pts_each_voxel, int out_x,
                                             int out_y, int out_z,
                                             const int *pts_mask,
                                             int *pts_idx_of_voxels) {
  // params pts_mask: (N, npoints)  0 or 1
  // params pts_idx_of_voxels: (N, out_x, out_y, out_z, max_pts_each_voxel)

  int box_idx = blockIdx.x * blockDim.x + threadIdx.x;
  if (box_idx >= boxes_num) return;

  int max_num_pts = max_pts_each_voxel - 1;  // index 0 is the counter
  pts_idx_of_voxels += box_idx * out_x * out_y * out_z * max_pts_each_voxel;

  for (int k = 0; k < pts_num; k++) {
    if (pts_mask[box_idx * pts_num + k] != -1) {
      unsigned int idx_encoding = pts_mask[box_idx * pts_num + k];
      unsigned int x_idx = (idx_encoding >> 16) & 0xFF;
      unsigned int y_idx = (idx_encoding >> 8) & 0xFF;
      unsigned int z_idx = idx_encoding & 0xFF;
      unsigned int base_offset = x_idx * out_y * out_z * max_pts_each_voxel +
                                 y_idx * out_z * max_pts_each_voxel +
                                 z_idx * max_pts_each_voxel;
      unsigned int cnt = pts_idx_of_voxels[base_offset];
      if (cnt < max_num_pts) {
        pts_idx_of_voxels[base_offset + cnt + 1] = k;
        pts_idx_of_voxels[base_offset]++;
      }
wuyuefeng's avatar
wuyuefeng committed
119
#ifdef DEBUG
120
121
      printf("collect: pts_%d, idx(%d, %d, %d), idx_encoding=%x\n", k, x_idx,
             y_idx, z_idx, idx_encoding);
wuyuefeng's avatar
wuyuefeng committed
122
123
#endif
    }
124
  }
wuyuefeng's avatar
wuyuefeng committed
125
126
}

127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
__global__ void roiaware_maxpool3d(int boxes_num, int pts_num, int channels,
                                   int max_pts_each_voxel, int out_x, int out_y,
                                   int out_z, const float *pts_feature,
                                   const int *pts_idx_of_voxels,
                                   float *pooled_features, int *argmax) {
  // params pts_feature: (npoints, C)
  // params pts_idx_of_voxels: (N, out_x, out_y, out_z, max_pts_each_voxel),
  // index 0 is the counter params pooled_features: (N, out_x, out_y, out_z, C)
  // params argmax: (N, out_x, out_y, out_z, C)

  int box_idx = blockIdx.z;
  int channel_idx = blockIdx.y;
  int voxel_idx_flat = blockIdx.x * blockDim.x + threadIdx.x;

  int x_idx = voxel_idx_flat / (out_y * out_z);
  int y_idx = (voxel_idx_flat - x_idx * (out_y * out_z)) / out_z;
  int z_idx = voxel_idx_flat % out_z;
  if (box_idx >= boxes_num || channel_idx >= channels || x_idx >= out_x ||
      y_idx >= out_y || z_idx >= out_z)
    return;
wuyuefeng's avatar
wuyuefeng committed
147
148

#ifdef DEBUG
149
150
  printf("src pts_idx_of_voxels: (%p, ), argmax: %p\n", pts_idx_of_voxels,
         argmax);
wuyuefeng's avatar
wuyuefeng committed
151
152
#endif

153
154
155
156
157
158
159
  int offset_base = x_idx * out_y * out_z + y_idx * out_z + z_idx;
  pts_idx_of_voxels += box_idx * out_x * out_y * out_z * max_pts_each_voxel +
                       offset_base * max_pts_each_voxel;
  pooled_features += box_idx * out_x * out_y * out_z * channels +
                     offset_base * channels + channel_idx;
  argmax += box_idx * out_x * out_y * out_z * channels +
            offset_base * channels + channel_idx;
wuyuefeng's avatar
wuyuefeng committed
160

161
162
  int argmax_idx = -1;
  float max_val = -1e50;
wuyuefeng's avatar
wuyuefeng committed
163

164
  int total_pts = pts_idx_of_voxels[0];
wuyuefeng's avatar
wuyuefeng committed
165

166
167
168
169
  for (int k = 1; k <= total_pts; k++) {
    if (pts_feature[pts_idx_of_voxels[k] * channels + channel_idx] > max_val) {
      max_val = pts_feature[pts_idx_of_voxels[k] * channels + channel_idx];
      argmax_idx = pts_idx_of_voxels[k];
wuyuefeng's avatar
wuyuefeng committed
170
    }
171
  }
wuyuefeng's avatar
wuyuefeng committed
172

173
174
175
176
  if (argmax_idx != -1) {
    pooled_features[0] = max_val;
  }
  argmax[0] = argmax_idx;
wuyuefeng's avatar
wuyuefeng committed
177
178

#ifdef DEBUG
179
180
181
182
183
  printf(
      "channel_%d idx(%d, %d, %d), argmax_idx=(%d, %.3f), total=%d, after "
      "pts_idx: %p, argmax: (%p, %d)\n",
      channel_idx, x_idx, y_idx, z_idx, argmax_idx, max_val, total_pts,
      pts_idx_of_voxels, argmax, argmax_idx);
wuyuefeng's avatar
wuyuefeng committed
184
185
186
#endif
}

187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
__global__ void roiaware_avgpool3d(int boxes_num, int pts_num, int channels,
                                   int max_pts_each_voxel, int out_x, int out_y,
                                   int out_z, const float *pts_feature,
                                   const int *pts_idx_of_voxels,
                                   float *pooled_features) {
  // params pts_feature: (npoints, C)
  // params pts_idx_of_voxels: (N, out_x, out_y, out_z, max_pts_each_voxel),
  // index 0 is the counter params pooled_features: (N, out_x, out_y, out_z, C)
  // params argmax: (N, out_x, out_y, out_z, C)

  int box_idx = blockIdx.z;
  int channel_idx = blockIdx.y;
  int voxel_idx_flat = blockIdx.x * blockDim.x + threadIdx.x;

  int x_idx = voxel_idx_flat / (out_y * out_z);
  int y_idx = (voxel_idx_flat - x_idx * (out_y * out_z)) / out_z;
  int z_idx = voxel_idx_flat % out_z;
  if (box_idx >= boxes_num || channel_idx >= channels || x_idx >= out_x ||
      y_idx >= out_y || z_idx >= out_z)
    return;

  int offset_base = x_idx * out_y * out_z + y_idx * out_z + z_idx;
  pts_idx_of_voxels += box_idx * out_x * out_y * out_z * max_pts_each_voxel +
                       offset_base * max_pts_each_voxel;
  pooled_features += box_idx * out_x * out_y * out_z * channels +
                     offset_base * channels + channel_idx;

  float sum_val = 0;
  int total_pts = pts_idx_of_voxels[0];

  for (int k = 1; k <= total_pts; k++) {
    sum_val += pts_feature[pts_idx_of_voxels[k] * channels + channel_idx];
  }

  if (total_pts > 0) {
    pooled_features[0] = sum_val / total_pts;
  }
wuyuefeng's avatar
wuyuefeng committed
224
225
}

226
227
228
229
230
231
void roiaware_pool3d_launcher(int boxes_num, int pts_num, int channels,
                              int max_pts_each_voxel, int out_x, int out_y,
                              int out_z, const float *rois, const float *pts,
                              const float *pts_feature, int *argmax,
                              int *pts_idx_of_voxels, float *pooled_features,
                              int pool_method) {
232
  // params rois: (N, 7) [x, y, z, dx, dy, dz, rz] in LiDAR coordinate
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
  // params pts: (npoints, 3) [x, y, z] in LiDAR coordinate
  // params pts_feature: (npoints, C)
  // params argmax: (N, out_x, out_y, out_z, C)
  // params pts_idx_of_voxels: (N, out_x, out_y, out_z, max_pts_each_voxel)
  // params pooled_features: (N, out_x, out_y, out_z, C)
  // params pool_method: 0: max_pool 1: avg_pool

  int *pts_mask = NULL;
  cudaMalloc(&pts_mask, boxes_num * pts_num * sizeof(int));  // (N, M)
  cudaMemset(pts_mask, -1, boxes_num * pts_num * sizeof(int));

  dim3 blocks_mask(DIVUP(pts_num, THREADS_PER_BLOCK), boxes_num);
  dim3 threads(THREADS_PER_BLOCK);
  generate_pts_mask_for_box3d<<<blocks_mask, threads>>>(
      boxes_num, pts_num, out_x, out_y, out_z, rois, pts, pts_mask);

  // TODO: Merge the collect and pool functions, SS

  dim3 blocks_collect(DIVUP(boxes_num, THREADS_PER_BLOCK));
  collect_inside_pts_for_box3d<<<blocks_collect, threads>>>(
      boxes_num, pts_num, max_pts_each_voxel, out_x, out_y, out_z, pts_mask,
      pts_idx_of_voxels);

  dim3 blocks_pool(DIVUP(out_x * out_y * out_z, THREADS_PER_BLOCK), channels,
                   boxes_num);
  if (pool_method == 0) {
    roiaware_maxpool3d<<<blocks_pool, threads>>>(
        boxes_num, pts_num, channels, max_pts_each_voxel, out_x, out_y, out_z,
        pts_feature, pts_idx_of_voxels, pooled_features, argmax);
  } else if (pool_method == 1) {
    roiaware_avgpool3d<<<blocks_pool, threads>>>(
        boxes_num, pts_num, channels, max_pts_each_voxel, out_x, out_y, out_z,
        pts_feature, pts_idx_of_voxels, pooled_features);
  }

  cudaFree(pts_mask);
wuyuefeng's avatar
wuyuefeng committed
269
270

#ifdef DEBUG
271
  cudaDeviceSynchronize();  // for using printf in kernel function
wuyuefeng's avatar
wuyuefeng committed
272
273
274
#endif
}

275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
__global__ void roiaware_maxpool3d_backward(int boxes_num, int channels,
                                            int out_x, int out_y, int out_z,
                                            const int *argmax,
                                            const float *grad_out,
                                            float *grad_in) {
  // params argmax: (N, out_x, out_y, out_z, C)
  // params grad_out: (N, out_x, out_y, out_z, C)
  // params grad_in: (npoints, C), return value

  int box_idx = blockIdx.z;
  int channel_idx = blockIdx.y;
  int voxel_idx_flat = blockIdx.x * blockDim.x + threadIdx.x;

  int x_idx = voxel_idx_flat / (out_y * out_z);
  int y_idx = (voxel_idx_flat - x_idx * (out_y * out_z)) / out_z;
  int z_idx = voxel_idx_flat % out_z;
  if (box_idx >= boxes_num || channel_idx >= channels || x_idx >= out_x ||
      y_idx >= out_y || z_idx >= out_z)
    return;

  int offset_base = x_idx * out_y * out_z + y_idx * out_z + z_idx;
  argmax += box_idx * out_x * out_y * out_z * channels +
            offset_base * channels + channel_idx;
  grad_out += box_idx * out_x * out_y * out_z * channels +
              offset_base * channels + channel_idx;

  if (argmax[0] == -1) return;

  atomicAdd(grad_in + argmax[0] * channels + channel_idx, grad_out[0] * 1);
wuyuefeng's avatar
wuyuefeng committed
304
305
}

306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
__global__ void roiaware_avgpool3d_backward(int boxes_num, int channels,
                                            int out_x, int out_y, int out_z,
                                            int max_pts_each_voxel,
                                            const int *pts_idx_of_voxels,
                                            const float *grad_out,
                                            float *grad_in) {
  // params pts_idx_of_voxels: (N, out_x, out_y, out_z, max_pts_each_voxel)
  // params grad_out: (N, out_x, out_y, out_z, C)
  // params grad_in: (npoints, C), return value

  int box_idx = blockIdx.z;
  int channel_idx = blockIdx.y;
  int voxel_idx_flat = blockIdx.x * blockDim.x + threadIdx.x;

  int x_idx = voxel_idx_flat / (out_y * out_z);
  int y_idx = (voxel_idx_flat - x_idx * (out_y * out_z)) / out_z;
  int z_idx = voxel_idx_flat % out_z;
  if (box_idx >= boxes_num || channel_idx >= channels || x_idx >= out_x ||
      y_idx >= out_y || z_idx >= out_z)
    return;

  int offset_base = x_idx * out_y * out_z + y_idx * out_z + z_idx;
  pts_idx_of_voxels += box_idx * out_x * out_y * out_z * max_pts_each_voxel +
                       offset_base * max_pts_each_voxel;
  grad_out += box_idx * out_x * out_y * out_z * channels +
              offset_base * channels + channel_idx;

  int total_pts = pts_idx_of_voxels[0];
  float cur_grad = 1 / fmaxf(float(total_pts), 1.0);
  for (int k = 1; k <= total_pts; k++) {
    atomicAdd(grad_in + pts_idx_of_voxels[k] * channels + channel_idx,
              grad_out[0] * cur_grad);
  }
wuyuefeng's avatar
wuyuefeng committed
339
340
}

341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
void roiaware_pool3d_backward_launcher(int boxes_num, int out_x, int out_y,
                                       int out_z, int channels,
                                       int max_pts_each_voxel,
                                       const int *pts_idx_of_voxels,
                                       const int *argmax, const float *grad_out,
                                       float *grad_in, int pool_method) {
  // params pts_idx_of_voxels: (N, out_x, out_y, out_z, max_pts_each_voxel)
  // params argmax: (N, out_x, out_y, out_z, C)
  // params grad_out: (N, out_x, out_y, out_z, C)
  // params grad_in: (npoints, C), return value
  // params pool_method: 0: max_pool, 1: avg_pool

  dim3 blocks(DIVUP(out_x * out_y * out_z, THREADS_PER_BLOCK), channels,
              boxes_num);
  dim3 threads(THREADS_PER_BLOCK);
  if (pool_method == 0) {
    roiaware_maxpool3d_backward<<<blocks, threads>>>(
        boxes_num, channels, out_x, out_y, out_z, argmax, grad_out, grad_in);
  } else if (pool_method == 1) {
    roiaware_avgpool3d_backward<<<blocks, threads>>>(
        boxes_num, channels, out_x, out_y, out_z, max_pts_each_voxel,
        pts_idx_of_voxels, grad_out, grad_in);
  }
wuyuefeng's avatar
wuyuefeng committed
364
}