Commit 643b46d3 authored by zww's avatar zww
Browse files

Replace #include <torch/extension.h> by #include <torch/types.h> to avoid cuda bug

parent 397a9280
...@@ -7,8 +7,8 @@ ...@@ -7,8 +7,8 @@
#include <assert.h> #include <assert.h>
#include <math.h> #include <math.h>
#include <stdio.h> #include <stdio.h>
#include <torch/extension.h>
#include <torch/serialize/tensor.h> #include <torch/serialize/tensor.h>
#include <torch/types.h>
#define THREADS_PER_BLOCK 256 #define THREADS_PER_BLOCK 256
#define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0)) #define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0))
......
//Modified from // Modified from
//https://github.com/sshaoshuai/PCDet/blob/master/pcdet/ops/roiaware_pool3d/src/roiaware_pool3d_kernel.cu // https://github.com/sshaoshuai/PCDet/blob/master/pcdet/ops/roiaware_pool3d/src/roiaware_pool3d_kernel.cu
//RoI-aware point cloud feature pooling // RoI-aware point cloud feature pooling
//Written by Shaoshuai Shi // Written by Shaoshuai Shi
//All Rights Reserved 2019. // All Rights Reserved 2019.
#include <torch/serialize/tensor.h>
#include <torch/extension.h>
#include <assert.h> #include <assert.h>
#include <math.h> #include <math.h>
#include <stdio.h> #include <stdio.h>
#include <torch/serialize/tensor.h>
#include <torch/types.h>
#define THREADS_PER_BLOCK 256 #define THREADS_PER_BLOCK 256
#define DIVUP(m,n) ((m) / (n) + ((m) % (n) > 0)) #define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0))
// #define DEBUG // #define DEBUG
__device__ inline void lidar_to_local_coords(float shift_x, float shift_y,
__device__ inline void lidar_to_local_coords(float shift_x, float shift_y, float rz, float &local_x, float &local_y){ float rz, float &local_x,
float &local_y) {
// should rotate pi/2 + alpha to translate LiDAR to local // should rotate pi/2 + alpha to translate LiDAR to local
float rot_angle = rz + M_PI / 2; float rot_angle = rz + M_PI / 2;
float cosa = cos(rot_angle), sina = sin(rot_angle); float cosa = cos(rot_angle), sina = sin(rot_angle);
...@@ -26,10 +25,11 @@ __device__ inline void lidar_to_local_coords(float shift_x, float shift_y, float ...@@ -26,10 +25,11 @@ __device__ inline void lidar_to_local_coords(float shift_x, float shift_y, float
local_y = shift_x * sina + shift_y * cosa; local_y = shift_x * sina + shift_y * cosa;
} }
__device__ inline int check_pt_in_box3d(const float *pt, const float *box3d,
__device__ inline int check_pt_in_box3d(const float *pt, const float *box3d, float &local_x, float &local_y){ float &local_x, float &local_y) {
// param pt: (x, y, z) // param pt: (x, y, z)
// param box3d: (cx, cy, cz, w, l, h, rz) in LiDAR coordinate, cz in the bottom center // param box3d: (cx, cy, cz, w, l, h, rz) in LiDAR coordinate, cz in the
// bottom center
float x = pt[0], y = pt[1], z = pt[2]; float x = pt[0], y = pt[1], z = pt[2];
float cx = box3d[0], cy = box3d[1], cz = box3d[2]; float cx = box3d[0], cy = box3d[1], cz = box3d[2];
float w = box3d[3], l = box3d[4], h = box3d[5], rz = box3d[6]; float w = box3d[3], l = box3d[4], h = box3d[5], rz = box3d[6];
...@@ -37,16 +37,19 @@ __device__ inline int check_pt_in_box3d(const float *pt, const float *box3d, flo ...@@ -37,16 +37,19 @@ __device__ inline int check_pt_in_box3d(const float *pt, const float *box3d, flo
if (fabsf(z - cz) > h / 2.0) return 0; if (fabsf(z - cz) > h / 2.0) return 0;
lidar_to_local_coords(x - cx, y - cy, rz, local_x, local_y); lidar_to_local_coords(x - cx, y - cy, rz, local_x, local_y);
float in_flag = (local_x > -l / 2.0) & (local_x < l / 2.0) & (local_y > -w / 2.0) & (local_y < w / 2.0); float in_flag = (local_x > -l / 2.0) & (local_x < l / 2.0) &
(local_y > -w / 2.0) & (local_y < w / 2.0);
return in_flag; return in_flag;
} }
__global__ void generate_pts_mask_for_box3d(int boxes_num, int pts_num,
__global__ void generate_pts_mask_for_box3d(int boxes_num, int pts_num, int out_x, int out_y, int out_z, int out_x, int out_y, int out_z,
const float *rois, const float *pts, int *pts_mask){ const float *rois, const float *pts,
int *pts_mask) {
// params rois: (N, 7) [x, y, z, w, l, h, rz] in LiDAR coordinate // params rois: (N, 7) [x, y, z, w, l, h, rz] in LiDAR coordinate
// params pts: (npoints, 3) [x, y, z] // params pts: (npoints, 3) [x, y, z]
// params pts_mask: (N, npoints): -1 means point doesnot in this box, otherwise: encode (x_idxs, y_idxs, z_idxs) by binary bit // params pts_mask: (N, npoints): -1 means point doesnot in this box,
// otherwise: encode (x_idxs, y_idxs, z_idxs) by binary bit
int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
int box_idx = blockIdx.y; int box_idx = blockIdx.y;
if (pt_idx >= pts_num || box_idx >= boxes_num) return; if (pt_idx >= pts_num || box_idx >= boxes_num) return;
...@@ -59,7 +62,7 @@ __global__ void generate_pts_mask_for_box3d(int boxes_num, int pts_num, int out_ ...@@ -59,7 +62,7 @@ __global__ void generate_pts_mask_for_box3d(int boxes_num, int pts_num, int out_
int cur_in_flag = check_pt_in_box3d(pts, rois, local_x, local_y); int cur_in_flag = check_pt_in_box3d(pts, rois, local_x, local_y);
pts_mask[0] = -1; pts_mask[0] = -1;
if (cur_in_flag > 0){ if (cur_in_flag > 0) {
float local_z = pts[2] - rois[2]; float local_z = pts[2] - rois[2];
float w = rois[3], l = rois[4], h = rois[5]; float w = rois[3], l = rois[4], h = rois[5];
...@@ -77,17 +80,22 @@ __global__ void generate_pts_mask_for_box3d(int boxes_num, int pts_num, int out_ ...@@ -77,17 +80,22 @@ __global__ void generate_pts_mask_for_box3d(int boxes_num, int pts_num, int out_
unsigned int idx_encoding = (x_idx << 16) + (y_idx << 8) + z_idx; unsigned int idx_encoding = (x_idx << 16) + (y_idx << 8) + z_idx;
#ifdef DEBUG #ifdef DEBUG
printf("mask: pts_%d(%.3f, %.3f, %.3f), local(%.3f, %.3f, %.3f), idx(%d, %d, %d), res(%.3f, %.3f, %.3f), idx_encoding=%x\n", printf(
pt_idx, pts[0], pts[1], pts[2], local_x, local_y, local_z, x_idx, y_idx, z_idx, x_res, y_res, z_res, idx_encoding); "mask: pts_%d(%.3f, %.3f, %.3f), local(%.3f, %.3f, %.3f), idx(%d, %d, "
"%d), res(%.3f, %.3f, %.3f), idx_encoding=%x\n",
pt_idx, pts[0], pts[1], pts[2], local_x, local_y, local_z, x_idx, y_idx,
z_idx, x_res, y_res, z_res, idx_encoding);
#endif #endif
pts_mask[0] = idx_encoding; pts_mask[0] = idx_encoding;
} }
} }
__global__ void collect_inside_pts_for_box3d(int boxes_num, int pts_num,
__global__ void collect_inside_pts_for_box3d(int boxes_num, int pts_num, int max_pts_each_voxel, int max_pts_each_voxel, int out_x,
int out_x, int out_y, int out_z, const int *pts_mask, int *pts_idx_of_voxels){ int out_y, int out_z,
const int *pts_mask,
int *pts_idx_of_voxels) {
// params pts_mask: (N, npoints) 0 or 1 // params pts_mask: (N, npoints) 0 or 1
// params pts_idx_of_voxels: (N, out_x, out_y, out_z, max_pts_each_voxel) // params pts_idx_of_voxels: (N, out_x, out_y, out_z, max_pts_each_voxel)
...@@ -97,33 +105,36 @@ __global__ void collect_inside_pts_for_box3d(int boxes_num, int pts_num, int max ...@@ -97,33 +105,36 @@ __global__ void collect_inside_pts_for_box3d(int boxes_num, int pts_num, int max
int max_num_pts = max_pts_each_voxel - 1; // index 0 is the counter int max_num_pts = max_pts_each_voxel - 1; // index 0 is the counter
pts_idx_of_voxels += box_idx * out_x * out_y * out_z * max_pts_each_voxel; pts_idx_of_voxels += box_idx * out_x * out_y * out_z * max_pts_each_voxel;
for (int k = 0; k < pts_num; k++){ for (int k = 0; k < pts_num; k++) {
if (pts_mask[box_idx * pts_num + k] != -1){ if (pts_mask[box_idx * pts_num + k] != -1) {
unsigned int idx_encoding = pts_mask[box_idx * pts_num + k]; unsigned int idx_encoding = pts_mask[box_idx * pts_num + k];
unsigned int x_idx = (idx_encoding >> 16) & 0xFF; unsigned int x_idx = (idx_encoding >> 16) & 0xFF;
unsigned int y_idx = (idx_encoding >> 8) & 0xFF; unsigned int y_idx = (idx_encoding >> 8) & 0xFF;
unsigned int z_idx = idx_encoding & 0xFF; unsigned int z_idx = idx_encoding & 0xFF;
unsigned int base_offset = x_idx * out_y * out_z * max_pts_each_voxel + y_idx * out_z * max_pts_each_voxel + z_idx * max_pts_each_voxel; unsigned int base_offset = x_idx * out_y * out_z * max_pts_each_voxel +
y_idx * out_z * max_pts_each_voxel +
z_idx * max_pts_each_voxel;
unsigned int cnt = pts_idx_of_voxels[base_offset]; unsigned int cnt = pts_idx_of_voxels[base_offset];
if (cnt < max_num_pts){ if (cnt < max_num_pts) {
pts_idx_of_voxels[base_offset + cnt + 1] = k; pts_idx_of_voxels[base_offset + cnt + 1] = k;
pts_idx_of_voxels[base_offset]++; pts_idx_of_voxels[base_offset]++;
} }
#ifdef DEBUG #ifdef DEBUG
printf("collect: pts_%d, idx(%d, %d, %d), idx_encoding=%x\n", printf("collect: pts_%d, idx(%d, %d, %d), idx_encoding=%x\n", k, x_idx,
k, x_idx, y_idx, z_idx, idx_encoding); y_idx, z_idx, idx_encoding);
#endif #endif
} }
} }
} }
__global__ void roiaware_maxpool3d(int boxes_num, int pts_num, int channels,
__global__ void roiaware_maxpool3d(int boxes_num, int pts_num, int channels, int max_pts_each_voxel, int out_x, int max_pts_each_voxel, int out_x, int out_y,
int out_y, int out_z, const float *pts_feature, const int *pts_idx_of_voxels, float *pooled_features, int *argmax){ int out_z, const float *pts_feature,
const int *pts_idx_of_voxels,
float *pooled_features, int *argmax) {
// params pts_feature: (npoints, C) // params pts_feature: (npoints, C)
// params pts_idx_of_voxels: (N, out_x, out_y, out_z, max_pts_each_voxel), index 0 is the counter // params pts_idx_of_voxels: (N, out_x, out_y, out_z, max_pts_each_voxel),
// params pooled_features: (N, out_x, out_y, out_z, C) // index 0 is the counter params pooled_features: (N, out_x, out_y, out_z, C)
// params argmax: (N, out_x, out_y, out_z, C) // params argmax: (N, out_x, out_y, out_z, C)
int box_idx = blockIdx.z; int box_idx = blockIdx.z;
...@@ -133,46 +144,57 @@ __global__ void roiaware_maxpool3d(int boxes_num, int pts_num, int channels, int ...@@ -133,46 +144,57 @@ __global__ void roiaware_maxpool3d(int boxes_num, int pts_num, int channels, int
int x_idx = voxel_idx_flat / (out_y * out_z); int x_idx = voxel_idx_flat / (out_y * out_z);
int y_idx = (voxel_idx_flat - x_idx * (out_y * out_z)) / out_z; int y_idx = (voxel_idx_flat - x_idx * (out_y * out_z)) / out_z;
int z_idx = voxel_idx_flat % out_z; int z_idx = voxel_idx_flat % out_z;
if (box_idx >= boxes_num || channel_idx >= channels|| x_idx >= out_x || y_idx >= out_y || z_idx >= out_z) return; if (box_idx >= boxes_num || channel_idx >= channels || x_idx >= out_x ||
y_idx >= out_y || z_idx >= out_z)
return;
#ifdef DEBUG #ifdef DEBUG
printf("src pts_idx_of_voxels: (%p, ), argmax: %p\n", pts_idx_of_voxels, argmax); printf("src pts_idx_of_voxels: (%p, ), argmax: %p\n", pts_idx_of_voxels,
argmax);
#endif #endif
int offset_base = x_idx * out_y * out_z + y_idx * out_z + z_idx; int offset_base = x_idx * out_y * out_z + y_idx * out_z + z_idx;
pts_idx_of_voxels += box_idx * out_x * out_y * out_z * max_pts_each_voxel + offset_base * max_pts_each_voxel; pts_idx_of_voxels += box_idx * out_x * out_y * out_z * max_pts_each_voxel +
pooled_features += box_idx * out_x * out_y * out_z * channels + offset_base * channels + channel_idx; offset_base * max_pts_each_voxel;
argmax += box_idx * out_x * out_y * out_z * channels + offset_base * channels + channel_idx; pooled_features += box_idx * out_x * out_y * out_z * channels +
offset_base * channels + channel_idx;
argmax += box_idx * out_x * out_y * out_z * channels +
offset_base * channels + channel_idx;
int argmax_idx = -1; int argmax_idx = -1;
float max_val = -1e50; float max_val = -1e50;
int total_pts = pts_idx_of_voxels[0]; int total_pts = pts_idx_of_voxels[0];
for (int k = 1; k <= total_pts; k++){ for (int k = 1; k <= total_pts; k++) {
if (pts_feature[pts_idx_of_voxels[k] * channels + channel_idx] > max_val){ if (pts_feature[pts_idx_of_voxels[k] * channels + channel_idx] > max_val) {
max_val = pts_feature[pts_idx_of_voxels[k] * channels + channel_idx]; max_val = pts_feature[pts_idx_of_voxels[k] * channels + channel_idx];
argmax_idx = pts_idx_of_voxels[k]; argmax_idx = pts_idx_of_voxels[k];
} }
} }
if (argmax_idx != -1){ if (argmax_idx != -1) {
pooled_features[0] = max_val; pooled_features[0] = max_val;
} }
argmax[0] = argmax_idx; argmax[0] = argmax_idx;
#ifdef DEBUG #ifdef DEBUG
printf("channel_%d idx(%d, %d, %d), argmax_idx=(%d, %.3f), total=%d, after pts_idx: %p, argmax: (%p, %d)\n", printf(
channel_idx, x_idx, y_idx, z_idx, argmax_idx, max_val, total_pts, pts_idx_of_voxels, argmax, argmax_idx); "channel_%d idx(%d, %d, %d), argmax_idx=(%d, %.3f), total=%d, after "
"pts_idx: %p, argmax: (%p, %d)\n",
channel_idx, x_idx, y_idx, z_idx, argmax_idx, max_val, total_pts,
pts_idx_of_voxels, argmax, argmax_idx);
#endif #endif
} }
__global__ void roiaware_avgpool3d(int boxes_num, int pts_num, int channels,
__global__ void roiaware_avgpool3d(int boxes_num, int pts_num, int channels, int max_pts_each_voxel, int out_x, int max_pts_each_voxel, int out_x, int out_y,
int out_y, int out_z, const float *pts_feature, const int *pts_idx_of_voxels, float *pooled_features){ int out_z, const float *pts_feature,
const int *pts_idx_of_voxels,
float *pooled_features) {
// params pts_feature: (npoints, C) // params pts_feature: (npoints, C)
// params pts_idx_of_voxels: (N, out_x, out_y, out_z, max_pts_each_voxel), index 0 is the counter // params pts_idx_of_voxels: (N, out_x, out_y, out_z, max_pts_each_voxel),
// params pooled_features: (N, out_x, out_y, out_z, C) // index 0 is the counter params pooled_features: (N, out_x, out_y, out_z, C)
// params argmax: (N, out_x, out_y, out_z, C) // params argmax: (N, out_x, out_y, out_z, C)
int box_idx = blockIdx.z; int box_idx = blockIdx.z;
...@@ -182,28 +204,34 @@ __global__ void roiaware_avgpool3d(int boxes_num, int pts_num, int channels, int ...@@ -182,28 +204,34 @@ __global__ void roiaware_avgpool3d(int boxes_num, int pts_num, int channels, int
int x_idx = voxel_idx_flat / (out_y * out_z); int x_idx = voxel_idx_flat / (out_y * out_z);
int y_idx = (voxel_idx_flat - x_idx * (out_y * out_z)) / out_z; int y_idx = (voxel_idx_flat - x_idx * (out_y * out_z)) / out_z;
int z_idx = voxel_idx_flat % out_z; int z_idx = voxel_idx_flat % out_z;
if (box_idx >= boxes_num || channel_idx >= channels|| x_idx >= out_x || y_idx >= out_y || z_idx >= out_z) return; if (box_idx >= boxes_num || channel_idx >= channels || x_idx >= out_x ||
y_idx >= out_y || z_idx >= out_z)
return;
int offset_base = x_idx * out_y * out_z + y_idx * out_z + z_idx; int offset_base = x_idx * out_y * out_z + y_idx * out_z + z_idx;
pts_idx_of_voxels += box_idx * out_x * out_y * out_z * max_pts_each_voxel + offset_base * max_pts_each_voxel; pts_idx_of_voxels += box_idx * out_x * out_y * out_z * max_pts_each_voxel +
pooled_features += box_idx * out_x * out_y * out_z * channels + offset_base * channels + channel_idx; offset_base * max_pts_each_voxel;
pooled_features += box_idx * out_x * out_y * out_z * channels +
offset_base * channels + channel_idx;
float sum_val = 0; float sum_val = 0;
int total_pts = pts_idx_of_voxels[0]; int total_pts = pts_idx_of_voxels[0];
for (int k = 1; k <= total_pts; k++){ for (int k = 1; k <= total_pts; k++) {
sum_val += pts_feature[pts_idx_of_voxels[k] * channels + channel_idx]; sum_val += pts_feature[pts_idx_of_voxels[k] * channels + channel_idx];
} }
if (total_pts > 0){ if (total_pts > 0) {
pooled_features[0] = sum_val / total_pts; pooled_features[0] = sum_val / total_pts;
} }
} }
void roiaware_pool3d_launcher(int boxes_num, int pts_num, int channels,
int max_pts_each_voxel, int out_x, int out_y,
void roiaware_pool3d_launcher(int boxes_num, int pts_num, int channels, int max_pts_each_voxel, int out_x, int out_y, int out_z, int out_z, const float *rois, const float *pts,
const float *rois, const float *pts, const float *pts_feature, int *argmax, int *pts_idx_of_voxels, float *pooled_features, int pool_method){ const float *pts_feature, int *argmax,
int *pts_idx_of_voxels, float *pooled_features,
int pool_method) {
// params rois: (N, 7) [x, y, z, w, l, h, rz] in LiDAR coordinate // params rois: (N, 7) [x, y, z, w, l, h, rz] in LiDAR coordinate
// params pts: (npoints, 3) [x, y, z] in LiDAR coordinate // params pts: (npoints, 3) [x, y, z] in LiDAR coordinate
// params pts_feature: (npoints, C) // params pts_feature: (npoints, C)
...@@ -218,25 +246,28 @@ void roiaware_pool3d_launcher(int boxes_num, int pts_num, int channels, int max_ ...@@ -218,25 +246,28 @@ void roiaware_pool3d_launcher(int boxes_num, int pts_num, int channels, int max_
dim3 blocks_mask(DIVUP(pts_num, THREADS_PER_BLOCK), boxes_num); dim3 blocks_mask(DIVUP(pts_num, THREADS_PER_BLOCK), boxes_num);
dim3 threads(THREADS_PER_BLOCK); dim3 threads(THREADS_PER_BLOCK);
generate_pts_mask_for_box3d<<<blocks_mask, threads>>>(boxes_num, pts_num, out_x, out_y, out_z, rois, pts, pts_mask); generate_pts_mask_for_box3d<<<blocks_mask, threads>>>(
boxes_num, pts_num, out_x, out_y, out_z, rois, pts, pts_mask);
// TODO: Merge the collect and pool functions, SS // TODO: Merge the collect and pool functions, SS
dim3 blocks_collect(DIVUP(boxes_num, THREADS_PER_BLOCK)); dim3 blocks_collect(DIVUP(boxes_num, THREADS_PER_BLOCK));
collect_inside_pts_for_box3d<<<blocks_collect, threads>>>(boxes_num, pts_num, max_pts_each_voxel, collect_inside_pts_for_box3d<<<blocks_collect, threads>>>(
out_x, out_y, out_z, pts_mask, pts_idx_of_voxels); boxes_num, pts_num, max_pts_each_voxel, out_x, out_y, out_z, pts_mask,
pts_idx_of_voxels);
dim3 blocks_pool(DIVUP(out_x * out_y * out_z, THREADS_PER_BLOCK), channels, boxes_num);
if (pool_method == 0){ dim3 blocks_pool(DIVUP(out_x * out_y * out_z, THREADS_PER_BLOCK), channels,
roiaware_maxpool3d<<<blocks_pool, threads>>>(boxes_num, pts_num, channels, max_pts_each_voxel, out_x, out_y, out_z, boxes_num);
if (pool_method == 0) {
roiaware_maxpool3d<<<blocks_pool, threads>>>(
boxes_num, pts_num, channels, max_pts_each_voxel, out_x, out_y, out_z,
pts_feature, pts_idx_of_voxels, pooled_features, argmax); pts_feature, pts_idx_of_voxels, pooled_features, argmax);
} } else if (pool_method == 1) {
else if (pool_method == 1){ roiaware_avgpool3d<<<blocks_pool, threads>>>(
roiaware_avgpool3d<<<blocks_pool, threads>>>(boxes_num, pts_num, channels, max_pts_each_voxel, out_x, out_y, out_z, boxes_num, pts_num, channels, max_pts_each_voxel, out_x, out_y, out_z,
pts_feature, pts_idx_of_voxels, pooled_features); pts_feature, pts_idx_of_voxels, pooled_features);
} }
cudaFree(pts_mask); cudaFree(pts_mask);
#ifdef DEBUG #ifdef DEBUG
...@@ -244,9 +275,11 @@ void roiaware_pool3d_launcher(int boxes_num, int pts_num, int channels, int max_ ...@@ -244,9 +275,11 @@ void roiaware_pool3d_launcher(int boxes_num, int pts_num, int channels, int max_
#endif #endif
} }
__global__ void roiaware_maxpool3d_backward(int boxes_num, int channels,
__global__ void roiaware_maxpool3d_backward(int boxes_num, int channels, int out_x, int out_y, int out_z, int out_x, int out_y, int out_z,
const int *argmax, const float *grad_out, float *grad_in){ const int *argmax,
const float *grad_out,
float *grad_in) {
// params argmax: (N, out_x, out_y, out_z, C) // params argmax: (N, out_x, out_y, out_z, C)
// params grad_out: (N, out_x, out_y, out_z, C) // params grad_out: (N, out_x, out_y, out_z, C)
// params grad_in: (npoints, C), return value // params grad_in: (npoints, C), return value
...@@ -258,20 +291,27 @@ __global__ void roiaware_maxpool3d_backward(int boxes_num, int channels, int out ...@@ -258,20 +291,27 @@ __global__ void roiaware_maxpool3d_backward(int boxes_num, int channels, int out
int x_idx = voxel_idx_flat / (out_y * out_z); int x_idx = voxel_idx_flat / (out_y * out_z);
int y_idx = (voxel_idx_flat - x_idx * (out_y * out_z)) / out_z; int y_idx = (voxel_idx_flat - x_idx * (out_y * out_z)) / out_z;
int z_idx = voxel_idx_flat % out_z; int z_idx = voxel_idx_flat % out_z;
if (box_idx >= boxes_num || channel_idx >= channels|| x_idx >= out_x || y_idx >= out_y || z_idx >= out_z) return; if (box_idx >= boxes_num || channel_idx >= channels || x_idx >= out_x ||
y_idx >= out_y || z_idx >= out_z)
return;
int offset_base = x_idx * out_y * out_z + y_idx * out_z + z_idx; int offset_base = x_idx * out_y * out_z + y_idx * out_z + z_idx;
argmax += box_idx * out_x * out_y * out_z * channels + offset_base * channels + channel_idx; argmax += box_idx * out_x * out_y * out_z * channels +
grad_out += box_idx * out_x * out_y * out_z * channels + offset_base * channels + channel_idx; offset_base * channels + channel_idx;
grad_out += box_idx * out_x * out_y * out_z * channels +
offset_base * channels + channel_idx;
if (argmax[0] == -1) return; if (argmax[0] == -1) return;
atomicAdd(grad_in + argmax[0] * channels + channel_idx, grad_out[0] * 1); atomicAdd(grad_in + argmax[0] * channels + channel_idx, grad_out[0] * 1);
} }
__global__ void roiaware_avgpool3d_backward(int boxes_num, int channels,
__global__ void roiaware_avgpool3d_backward(int boxes_num, int channels, int out_x, int out_y, int out_z, int out_x, int out_y, int out_z,
int max_pts_each_voxel, const int *pts_idx_of_voxels, const float *grad_out, float *grad_in){ int max_pts_each_voxel,
const int *pts_idx_of_voxels,
const float *grad_out,
float *grad_in) {
// params pts_idx_of_voxels: (N, out_x, out_y, out_z, max_pts_each_voxel) // params pts_idx_of_voxels: (N, out_x, out_y, out_z, max_pts_each_voxel)
// params grad_out: (N, out_x, out_y, out_z, C) // params grad_out: (N, out_x, out_y, out_z, C)
// params grad_in: (npoints, C), return value // params grad_in: (npoints, C), return value
...@@ -283,41 +323,45 @@ __global__ void roiaware_avgpool3d_backward(int boxes_num, int channels, int out ...@@ -283,41 +323,45 @@ __global__ void roiaware_avgpool3d_backward(int boxes_num, int channels, int out
int x_idx = voxel_idx_flat / (out_y * out_z); int x_idx = voxel_idx_flat / (out_y * out_z);
int y_idx = (voxel_idx_flat - x_idx * (out_y * out_z)) / out_z; int y_idx = (voxel_idx_flat - x_idx * (out_y * out_z)) / out_z;
int z_idx = voxel_idx_flat % out_z; int z_idx = voxel_idx_flat % out_z;
if (box_idx >= boxes_num || channel_idx >= channels|| x_idx >= out_x || y_idx >= out_y || z_idx >= out_z) return; if (box_idx >= boxes_num || channel_idx >= channels || x_idx >= out_x ||
y_idx >= out_y || z_idx >= out_z)
return;
int offset_base = x_idx * out_y * out_z + y_idx * out_z + z_idx; int offset_base = x_idx * out_y * out_z + y_idx * out_z + z_idx;
pts_idx_of_voxels += box_idx * out_x * out_y * out_z * max_pts_each_voxel + offset_base * max_pts_each_voxel; pts_idx_of_voxels += box_idx * out_x * out_y * out_z * max_pts_each_voxel +
grad_out += box_idx * out_x * out_y * out_z * channels + offset_base * channels + channel_idx; offset_base * max_pts_each_voxel;
grad_out += box_idx * out_x * out_y * out_z * channels +
offset_base * channels + channel_idx;
int total_pts = pts_idx_of_voxels[0]; int total_pts = pts_idx_of_voxels[0];
float cur_grad = 1 / fmaxf(float(total_pts), 1.0); float cur_grad = 1 / fmaxf(float(total_pts), 1.0);
for (int k = 1; k <= total_pts; k++){ for (int k = 1; k <= total_pts; k++) {
atomicAdd(grad_in + pts_idx_of_voxels[k] * channels + channel_idx, grad_out[0] * cur_grad); atomicAdd(grad_in + pts_idx_of_voxels[k] * channels + channel_idx,
grad_out[0] * cur_grad);
} }
} }
void roiaware_pool3d_backward_launcher(int boxes_num, int out_x, int out_y,
int out_z, int channels,
void roiaware_pool3d_backward_launcher(int boxes_num, int out_x, int out_y, int out_z, int channels, int max_pts_each_voxel, int max_pts_each_voxel,
const int *pts_idx_of_voxels, const int *argmax, const float *grad_out, float *grad_in, int pool_method){ const int *pts_idx_of_voxels,
const int *argmax, const float *grad_out,
float *grad_in, int pool_method) {
// params pts_idx_of_voxels: (N, out_x, out_y, out_z, max_pts_each_voxel) // params pts_idx_of_voxels: (N, out_x, out_y, out_z, max_pts_each_voxel)
// params argmax: (N, out_x, out_y, out_z, C) // params argmax: (N, out_x, out_y, out_z, C)
// params grad_out: (N, out_x, out_y, out_z, C) // params grad_out: (N, out_x, out_y, out_z, C)
// params grad_in: (npoints, C), return value // params grad_in: (npoints, C), return value
// params pool_method: 0: max_pool, 1: avg_pool // params pool_method: 0: max_pool, 1: avg_pool
dim3 blocks(DIVUP(out_x * out_y * out_z, THREADS_PER_BLOCK), channels, boxes_num); dim3 blocks(DIVUP(out_x * out_y * out_z, THREADS_PER_BLOCK), channels,
boxes_num);
dim3 threads(THREADS_PER_BLOCK); dim3 threads(THREADS_PER_BLOCK);
if (pool_method == 0){ if (pool_method == 0) {
roiaware_maxpool3d_backward<<<blocks, threads>>>( roiaware_maxpool3d_backward<<<blocks, threads>>>(
boxes_num, channels, out_x, out_y, out_z, argmax, grad_out, grad_in boxes_num, channels, out_x, out_y, out_z, argmax, grad_out, grad_in);
); } else if (pool_method == 1) {
}
else if (pool_method == 1){
roiaware_avgpool3d_backward<<<blocks, threads>>>( roiaware_avgpool3d_backward<<<blocks, threads>>>(
boxes_num, channels, out_x, out_y, out_z, max_pts_each_voxel, pts_idx_of_voxels, grad_out, grad_in boxes_num, channels, out_x, out_y, out_z, max_pts_each_voxel,
); pts_idx_of_voxels, grad_out, grad_in);
} }
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment