Unverified Commit 2f88c124 authored by Wenhao Wu's avatar Wenhao Wu Committed by GitHub
Browse files

[Enhance] Replace mmdet3d ops with mmcv ops (#1240)

* import some ops from mmcv instead of mmdet3d

* use mmcv ops in primitive_head.py

* use mmcv ops in PAConv

* remove ops in mmdet3d & fix some bugs

* remove spconv & fix some bugs

* fix voxelization unittest

* remove spconv in ops/__init__.py

* refine ops/__init__.py

* recover sparse_block in ops/__init__

* fix parta2_bbox_head unittest

* remove remaining ops

* recover ops/__init__.py for bc breaking

* add source of ops from mmcv

* recover the unittest for voxelization

* remove unittest
parent 41d77dad
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import torch import torch
from mmcv.cnn import ConvModule from mmcv.cnn import ConvModule
from mmcv.ops import GroupAll
from mmcv.ops import PointsSampler as Points_Sampler
from mmcv.ops import QueryAndGroup, gather_points
from torch import nn as nn from torch import nn as nn
from torch.nn import functional as F from torch.nn import functional as F
from mmdet3d.ops import (GroupAll, PAConv, Points_Sampler, QueryAndGroup, from mmdet3d.ops import PAConv
gather_points)
from .builder import SA_MODULES from .builder import SA_MODULES
......
# Copyright (c) OpenMMLab. All rights reserved.
from .points_in_boxes import (points_in_boxes_all, points_in_boxes_cpu,
points_in_boxes_part)
from .roiaware_pool3d import RoIAwarePool3d
__all__ = [
'RoIAwarePool3d', 'points_in_boxes_part', 'points_in_boxes_cpu',
'points_in_boxes_all'
]
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from . import roiaware_pool3d_ext
def points_in_boxes_part(points, boxes):
"""Find the box in which each point is (CUDA).
Args:
points (torch.Tensor): [B, M, 3], [x, y, z] in LiDAR/DEPTH coordinate
boxes (torch.Tensor): [B, T, 7],
num_valid_boxes <= T, [x, y, z, x_size, y_size, z_size, rz] in
LiDAR/DEPTH coordinate, (x, y, z) is the bottom center
Returns:
box_idxs_of_pts (torch.Tensor): (B, M), default background = -1
"""
assert points.shape[0] == boxes.shape[0], \
f'Points and boxes should have the same batch size, ' \
f'got {points.shape[0]} and {boxes.shape[0]}'
assert boxes.shape[2] == 7, \
f'boxes dimension should be 7, ' \
f'got unexpected shape {boxes.shape[2]}'
assert points.shape[2] == 3, \
f'points dimension should be 3, ' \
f'got unexpected shape {points.shape[2]}'
batch_size, num_points, _ = points.shape
box_idxs_of_pts = points.new_zeros((batch_size, num_points),
dtype=torch.int).fill_(-1)
# If manually put the tensor 'points' or 'boxes' on a device
# which is not the current device, some temporary variables
# will be created on the current device in the cuda op,
# and the output will be incorrect.
# Therefore, we force the current device to be the same
# as the device of the tensors if it was not.
# Please refer to https://github.com/open-mmlab/mmdetection3d/issues/305
# for the incorrect output before the fix.
points_device = points.get_device()
assert points_device == boxes.get_device(), \
'Points and boxes should be put on the same device'
if torch.cuda.current_device() != points_device:
torch.cuda.set_device(points_device)
roiaware_pool3d_ext.points_in_boxes_part(boxes.contiguous(),
points.contiguous(),
box_idxs_of_pts)
return box_idxs_of_pts
def points_in_boxes_cpu(points, boxes):
"""Find all boxes in which each point is (CPU). The CPU version of
:meth:`points_in_boxes_all`.
Args:
points (torch.Tensor): [B, M, 3], [x, y, z] in
LiDAR/DEPTH coordinate
boxes (torch.Tensor): [B, T, 7],
num_valid_boxes <= T, [x, y, z, x_size, y_size, z_size, rz],
(x, y, z) is the bottom center.
Returns:
box_idxs_of_pts (torch.Tensor): (B, M, T), default background = 0.
"""
assert points.shape[0] == boxes.shape[0], \
f'Points and boxes should have the same batch size, ' \
f'got {points.shape[0]} and {boxes.shape[0]}'
assert boxes.shape[2] == 7, \
f'boxes dimension should be 7, ' \
f'got unexpected shape {boxes.shape[2]}'
assert points.shape[2] == 3, \
f'points dimension should be 3, ' \
f'got unexpected shape {points.shape[2]}'
batch_size, num_points, _ = points.shape
num_boxes = boxes.shape[1]
point_indices = points.new_zeros((batch_size, num_boxes, num_points),
dtype=torch.int)
for b in range(batch_size):
roiaware_pool3d_ext.points_in_boxes_cpu(boxes[b].float().contiguous(),
points[b].float().contiguous(),
point_indices[b])
point_indices = point_indices.transpose(1, 2)
return point_indices
def points_in_boxes_all(points, boxes):
"""Find all boxes in which each point is (CUDA).
Args:
points (torch.Tensor): [B, M, 3], [x, y, z] in LiDAR/DEPTH coordinate
boxes (torch.Tensor): [B, T, 7],
num_valid_boxes <= T, [x, y, z, x_size, y_size, z_size, rz],
(x, y, z) is the bottom center.
Returns:
box_idxs_of_pts (torch.Tensor): (B, M, T), default background = 0.
"""
assert boxes.shape[0] == points.shape[0], \
f'Points and boxes should have the same batch size, ' \
f'got {boxes.shape[0]} and {boxes.shape[0]}'
assert boxes.shape[2] == 7, \
f'boxes dimension should be 7, ' \
f'got unexpected shape {boxes.shape[2]}'
assert points.shape[2] == 3, \
f'points dimension should be 3, ' \
f'got unexpected shape {points.shape[2]}'
batch_size, num_points, _ = points.shape
num_boxes = boxes.shape[1]
box_idxs_of_pts = points.new_zeros((batch_size, num_points, num_boxes),
dtype=torch.int).fill_(0)
# Same reason as line 25-32
points_device = points.get_device()
assert points_device == boxes.get_device(), \
'Points and boxes should be put on the same device'
if torch.cuda.current_device() != points_device:
torch.cuda.set_device(points_device)
roiaware_pool3d_ext.points_in_boxes_all(boxes.contiguous(),
points.contiguous(),
box_idxs_of_pts)
return box_idxs_of_pts
# Copyright (c) OpenMMLab. All rights reserved.
import mmcv
import torch
from torch import nn as nn
from torch.autograd import Function
from . import roiaware_pool3d_ext
class RoIAwarePool3d(nn.Module):
def __init__(self, out_size, max_pts_per_voxel=128, mode='max'):
super().__init__()
"""RoIAwarePool3d module
Args:
out_size (int or tuple): n or [n1, n2, n3]
max_pts_per_voxel (int): m
mode (str): 'max' or 'avg'
"""
self.out_size = out_size
self.max_pts_per_voxel = max_pts_per_voxel
assert mode in ['max', 'avg']
pool_method_map = {'max': 0, 'avg': 1}
self.mode = pool_method_map[mode]
def forward(self, rois, pts, pts_feature):
"""RoIAwarePool3d module forward.
Args:
rois (torch.Tensor): [N, 7],in LiDAR coordinate,
(x, y, z) is the bottom center of rois
pts (torch.Tensor): [npoints, 3]
pts_feature (torch.Tensor): [npoints, C]
Returns:
pooled_features (torch.Tensor): [N, out_x, out_y, out_z, C]
"""
return RoIAwarePool3dFunction.apply(rois, pts, pts_feature,
self.out_size,
self.max_pts_per_voxel, self.mode)
class RoIAwarePool3dFunction(Function):
@staticmethod
def forward(ctx, rois, pts, pts_feature, out_size, max_pts_per_voxel,
mode):
"""RoIAwarePool3d function forward.
Args:
rois (torch.Tensor): [N, 7], in LiDAR coordinate,
(x, y, z) is the bottom center of rois
pts (torch.Tensor): [npoints, 3]
pts_feature (torch.Tensor): [npoints, C]
out_size (int or tuple): n or [n1, n2, n3]
max_pts_per_voxel (int): m
mode (int): 0 (max pool) or 1 (average pool)
Returns:
pooled_features (torch.Tensor): [N, out_x, out_y, out_z, C]
"""
if isinstance(out_size, int):
out_x = out_y = out_z = out_size
else:
assert len(out_size) == 3
assert mmcv.is_tuple_of(out_size, int)
out_x, out_y, out_z = out_size
num_rois = rois.shape[0]
num_channels = pts_feature.shape[-1]
num_pts = pts.shape[0]
pooled_features = pts_feature.new_zeros(
(num_rois, out_x, out_y, out_z, num_channels))
argmax = pts_feature.new_zeros(
(num_rois, out_x, out_y, out_z, num_channels), dtype=torch.int)
pts_idx_of_voxels = pts_feature.new_zeros(
(num_rois, out_x, out_y, out_z, max_pts_per_voxel),
dtype=torch.int)
roiaware_pool3d_ext.forward(rois, pts, pts_feature, argmax,
pts_idx_of_voxels, pooled_features, mode)
ctx.roiaware_pool3d_for_backward = (pts_idx_of_voxels, argmax, mode,
num_pts, num_channels)
return pooled_features
@staticmethod
def backward(ctx, grad_out):
"""RoIAwarePool3d function forward.
Args:
grad_out (torch.Tensor): [N, out_x, out_y, out_z, C]
Returns:
grad_in (torch.Tensor): [npoints, C]
"""
ret = ctx.roiaware_pool3d_for_backward
pts_idx_of_voxels, argmax, mode, num_pts, num_channels = ret
grad_in = grad_out.new_zeros((num_pts, num_channels))
roiaware_pool3d_ext.backward(pts_idx_of_voxels, argmax,
grad_out.contiguous(), grad_in, mode)
return None, None, grad_in, None, None, None
if __name__ == '__main__':
pass
// Modified from
// https://github.com/sshaoshuai/PCDet/blob/master/pcdet/ops/roiaware_pool3d/src/roiaware_pool3d_kernel.cu
// Written by Shaoshuai Shi
// All Rights Reserved 2019.
#include <assert.h>
#include <math.h>
#include <stdio.h>
#include <torch/extension.h>
#include <torch/serialize/tensor.h>
#define CHECK_CONTIGUOUS(x) \
TORCH_CHECK(x.is_contiguous(), #x, " must be contiguous ")
// #define DEBUG
inline void lidar_to_local_coords_cpu(float shift_x, float shift_y, float rz,
float &local_x, float &local_y) {
float cosa = cos(-rz), sina = sin(-rz);
local_x = shift_x * cosa + shift_y * (-sina);
local_y = shift_x * sina + shift_y * cosa;
}
inline int check_pt_in_box3d_cpu(const float *pt, const float *box3d,
float &local_x, float &local_y) {
// param pt: (x, y, z)
// param box3d: (cx, cy, cz, x_size, y_size, z_size, rz) in LiDAR coordinate, cz in the
// bottom center
float x = pt[0], y = pt[1], z = pt[2];
float cx = box3d[0], cy = box3d[1], cz = box3d[2];
float x_size = box3d[3], y_size = box3d[4], z_size = box3d[5], rz = box3d[6];
cz += z_size / 2.0; // shift to the center since cz in box3d is the bottom center
if (fabsf(z - cz) > z_size / 2.0) return 0;
lidar_to_local_coords_cpu(x - cx, y - cy, rz, local_x, local_y);
float in_flag = (local_x > -x_size / 2.0) & (local_x < x_size / 2.0) &
(local_y > -y_size / 2.0) & (local_y < y_size / 2.0);
return in_flag;
}
int points_in_boxes_cpu(at::Tensor boxes_tensor, at::Tensor pts_tensor,
at::Tensor pts_indices_tensor) {
// params boxes: (N, 7) [x, y, z, x_size, y_size, z_size, rz] in LiDAR coordinate, z is the
// bottom center, each box DO NOT overlaps params pts: (npoints, 3) [x, y, z]
// in LiDAR coordinate params pts_indices: (N, npoints)
CHECK_CONTIGUOUS(boxes_tensor);
CHECK_CONTIGUOUS(pts_tensor);
CHECK_CONTIGUOUS(pts_indices_tensor);
int boxes_num = boxes_tensor.size(0);
int pts_num = pts_tensor.size(0);
const float *boxes = boxes_tensor.data_ptr<float>();
const float *pts = pts_tensor.data_ptr<float>();
int *pts_indices = pts_indices_tensor.data_ptr<int>();
float local_x = 0, local_y = 0;
for (int i = 0; i < boxes_num; i++) {
for (int j = 0; j < pts_num; j++) {
int cur_in_flag =
check_pt_in_box3d_cpu(pts + j * 3, boxes + i * 7, local_x, local_y);
pts_indices[i * pts_num + j] = cur_in_flag;
}
}
return 1;
}
// Modified from
// https://github.com/sshaoshuai/PCDet/blob/master/pcdet/ops/roiaware_pool3d/src/roiaware_pool3d_kernel.cu
// Written by Shaoshuai Shi
// All Rights Reserved 2019.
#include <assert.h>
#include <math.h>
#include <stdio.h>
#include <torch/serialize/tensor.h>
#include <torch/types.h>
#define THREADS_PER_BLOCK 256
#define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0))
#define CHECK_CUDA(x) \
TORCH_CHECK(x.device().is_cuda(), #x, " must be a CUDAtensor ")
#define CHECK_CONTIGUOUS(x) \
TORCH_CHECK(x.is_contiguous(), #x, " must be contiguous ")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
// #define DEBUG
__device__ inline void lidar_to_local_coords(float shift_x, float shift_y,
float rz, float &local_x,
float &local_y) {
float cosa = cos(-rz), sina = sin(-rz);
local_x = shift_x * cosa + shift_y * (-sina);
local_y = shift_x * sina + shift_y * cosa;
}
__device__ inline int check_pt_in_box3d(const float *pt, const float *box3d,
float &local_x, float &local_y) {
// param pt: (x, y, z)
// param box3d: (cx, cy, cz, x_size, y_size, z_size, rz) in LiDAR coordinate, cz in the
// bottom center
float x = pt[0], y = pt[1], z = pt[2];
float cx = box3d[0], cy = box3d[1], cz = box3d[2];
float x_size = box3d[3], y_size = box3d[4], z_size = box3d[5], rz = box3d[6];
cz += z_size / 2.0; // shift to the center since cz in box3d is the bottom center
if (fabsf(z - cz) > z_size / 2.0) return 0;
lidar_to_local_coords(x - cx, y - cy, rz, local_x, local_y);
float in_flag = (local_x > -x_size / 2.0) & (local_x < x_size / 2.0) &
(local_y > -y_size / 2.0) & (local_y < y_size / 2.0);
return in_flag;
}
__global__ void points_in_boxes_part_kernel(int batch_size, int boxes_num,
int pts_num, const float *boxes,
const float *pts,
int *box_idx_of_points) {
// params boxes: (B, N, 7) [x, y, z, x_size, y_size, z_size, rz] in LiDAR coordinate, z is
// the bottom center, each box DO NOT overlaps params pts: (B, npoints, 3) [x,
// y, z] in LiDAR coordinate params boxes_idx_of_points: (B, npoints), default
// -1
int bs_idx = blockIdx.y;
int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (bs_idx >= batch_size || pt_idx >= pts_num) return;
boxes += bs_idx * boxes_num * 7;
pts += bs_idx * pts_num * 3 + pt_idx * 3;
box_idx_of_points += bs_idx * pts_num + pt_idx;
float local_x = 0, local_y = 0;
int cur_in_flag = 0;
for (int k = 0; k < boxes_num; k++) {
cur_in_flag = check_pt_in_box3d(pts, boxes + k * 7, local_x, local_y);
if (cur_in_flag) {
box_idx_of_points[0] = k;
break;
}
}
}
__global__ void points_in_boxes_all_kernel(int batch_size, int boxes_num,
int pts_num, const float *boxes,
const float *pts,
int *box_idx_of_points) {
// params boxes: (B, N, 7) [x, y, z, x_size, y_size, z_size, rz] in LiDAR coordinate, z is
// the bottom center, each box DO NOT overlaps params pts: (B, npoints, 3) [x,
// y, z] in LiDAR coordinate params boxes_idx_of_points: (B, npoints), default
// -1
int bs_idx = blockIdx.y;
int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (bs_idx >= batch_size || pt_idx >= pts_num) return;
boxes += bs_idx * boxes_num * 7;
pts += bs_idx * pts_num * 3 + pt_idx * 3;
box_idx_of_points += bs_idx * pts_num * boxes_num + pt_idx * boxes_num;
float local_x = 0, local_y = 0;
int cur_in_flag = 0;
for (int k = 0; k < boxes_num; k++) {
cur_in_flag = check_pt_in_box3d(pts, boxes + k * 7, local_x, local_y);
if (cur_in_flag) {
box_idx_of_points[k] = 1;
}
cur_in_flag = 0;
}
}
void points_in_boxes_part_launcher(int batch_size, int boxes_num, int pts_num,
const float *boxes, const float *pts,
int *box_idx_of_points) {
// params boxes: (B, N, 7) [x, y, z, x_size, y_size, z_size, rz] in LiDAR coordinate, z is
// the bottom center, each box DO NOT overlaps params pts: (B, npoints, 3) [x,
// y, z] in LiDAR coordinate params boxes_idx_of_points: (B, npoints), default
// -1
cudaError_t err;
dim3 blocks(DIVUP(pts_num, THREADS_PER_BLOCK), batch_size);
dim3 threads(THREADS_PER_BLOCK);
points_in_boxes_part_kernel<<<blocks, threads>>>(batch_size, boxes_num, pts_num,
boxes, pts, box_idx_of_points);
err = cudaGetLastError();
if (cudaSuccess != err) {
fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
exit(-1);
}
#ifdef DEBUG
cudaDeviceSynchronize(); // for using printf in kernel function
#endif
}
void points_in_boxes_all_launcher(int batch_size, int boxes_num, int pts_num,
const float *boxes, const float *pts,
int *box_idx_of_points) {
// params boxes: (B, N, 7) [x, y, z, x_size, y_size, z_size, rz] in LiDAR coordinate, z is
// the bottom center, each box params pts: (B, npoints, 3) [x, y, z] in
// LiDAR coordinate params boxes_idx_of_points: (B, npoints), default -1
cudaError_t err;
dim3 blocks(DIVUP(pts_num, THREADS_PER_BLOCK), batch_size);
dim3 threads(THREADS_PER_BLOCK);
points_in_boxes_all_kernel<<<blocks, threads>>>(
batch_size, boxes_num, pts_num, boxes, pts, box_idx_of_points);
err = cudaGetLastError();
if (cudaSuccess != err) {
fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
exit(-1);
}
#ifdef DEBUG
cudaDeviceSynchronize(); // for using printf in kernel function
#endif
}
int points_in_boxes_part(at::Tensor boxes_tensor, at::Tensor pts_tensor,
at::Tensor box_idx_of_points_tensor) {
// params boxes: (B, N, 7) [x, y, z, x_size, y_size, z_size, rz] in LiDAR coordinate, z is
// the bottom center, each box DO NOT overlaps params pts: (B, npoints, 3) [x,
// y, z] in LiDAR coordinate params boxes_idx_of_points: (B, npoints), default
// -1
CHECK_INPUT(boxes_tensor);
CHECK_INPUT(pts_tensor);
CHECK_INPUT(box_idx_of_points_tensor);
int batch_size = boxes_tensor.size(0);
int boxes_num = boxes_tensor.size(1);
int pts_num = pts_tensor.size(1);
const float *boxes = boxes_tensor.data_ptr<float>();
const float *pts = pts_tensor.data_ptr<float>();
int *box_idx_of_points = box_idx_of_points_tensor.data_ptr<int>();
points_in_boxes_part_launcher(batch_size, boxes_num, pts_num, boxes, pts,
box_idx_of_points);
return 1;
}
int points_in_boxes_all(at::Tensor boxes_tensor, at::Tensor pts_tensor,
at::Tensor box_idx_of_points_tensor) {
// params boxes: (B, N, 7) [x, y, z, x_size, y_size, z_size, rz] in LiDAR coordinate, z is
// the bottom center. params pts: (B, npoints, 3) [x, y, z] in LiDAR
// coordinate params boxes_idx_of_points: (B, npoints), default -1
CHECK_INPUT(boxes_tensor);
CHECK_INPUT(pts_tensor);
CHECK_INPUT(box_idx_of_points_tensor);
int batch_size = boxes_tensor.size(0);
int boxes_num = boxes_tensor.size(1);
int pts_num = pts_tensor.size(1);
const float *boxes = boxes_tensor.data_ptr<float>();
const float *pts = pts_tensor.data_ptr<float>();
int *box_idx_of_points = box_idx_of_points_tensor.data_ptr<int>();
points_in_boxes_all_launcher(batch_size, boxes_num, pts_num, boxes, pts,
box_idx_of_points);
return 1;
}
// Modified from
// https://github.com/sshaoshuai/PCDet/blob/master/pcdet/ops/roiaware_pool3d/src/roiaware_pool3d_kernel.cu
// Written by Shaoshuai Shi
// All Rights Reserved 2019.
#include <assert.h>
#include <torch/extension.h>
#include <torch/serialize/tensor.h>
#define CHECK_CUDA(x) \
TORCH_CHECK(x.device().is_cuda(), #x, " must be a CUDAtensor ")
#define CHECK_CONTIGUOUS(x) \
TORCH_CHECK(x.is_contiguous(), #x, " must be contiguous ")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
void roiaware_pool3d_launcher(int boxes_num, int pts_num, int channels,
int max_pts_each_voxel, int out_x, int out_y,
int out_z, const float *rois, const float *pts,
const float *pts_feature, int *argmax,
int *pts_idx_of_voxels, float *pooled_features,
int pool_method);
void roiaware_pool3d_backward_launcher(int boxes_num, int out_x, int out_y,
int out_z, int channels,
int max_pts_each_voxel,
const int *pts_idx_of_voxels,
const int *argmax, const float *grad_out,
float *grad_in, int pool_method);
int roiaware_pool3d_gpu(at::Tensor rois, at::Tensor pts, at::Tensor pts_feature,
at::Tensor argmax, at::Tensor pts_idx_of_voxels,
at::Tensor pooled_features, int pool_method);
int roiaware_pool3d_gpu_backward(at::Tensor pts_idx_of_voxels,
at::Tensor argmax, at::Tensor grad_out,
at::Tensor grad_in, int pool_method);
int points_in_boxes_cpu(at::Tensor boxes_tensor, at::Tensor pts_tensor,
at::Tensor pts_indices_tensor);
int points_in_boxes_part(at::Tensor boxes_tensor, at::Tensor pts_tensor,
at::Tensor box_idx_of_points_tensor);
int points_in_boxes_all(at::Tensor boxes_tensor, at::Tensor pts_tensor,
at::Tensor box_idx_of_points_tensor);
int roiaware_pool3d_gpu(at::Tensor rois, at::Tensor pts, at::Tensor pts_feature,
at::Tensor argmax, at::Tensor pts_idx_of_voxels,
at::Tensor pooled_features, int pool_method) {
// params rois: (N, 7) [x, y, z, x_size, y_size, z_size, ry] in LiDAR coordinate
// params pts: (npoints, 3) [x, y, z] in LiDAR coordinate
// params pts_feature: (npoints, C)
// params argmax: (N, out_x, out_y, out_z, C)
// params pts_idx_of_voxels: (N, out_x, out_y, out_z, max_pts_each_voxel)
// params pooled_features: (N, out_x, out_y, out_z, C)
// params pool_method: 0: max_pool 1: avg_pool
CHECK_INPUT(rois);
CHECK_INPUT(pts);
CHECK_INPUT(pts_feature);
CHECK_INPUT(argmax);
CHECK_INPUT(pts_idx_of_voxels);
CHECK_INPUT(pooled_features);
int boxes_num = rois.size(0);
int pts_num = pts.size(0);
int channels = pts_feature.size(1);
int max_pts_each_voxel = pts_idx_of_voxels.size(4); // index 0 is the counter
int out_x = pts_idx_of_voxels.size(1);
int out_y = pts_idx_of_voxels.size(2);
int out_z = pts_idx_of_voxels.size(3);
assert((out_x < 256) && (out_y < 256) &&
(out_z < 256)); // we encode index with 8bit
const float *rois_data = rois.data_ptr<float>();
const float *pts_data = pts.data_ptr<float>();
const float *pts_feature_data = pts_feature.data_ptr<float>();
int *argmax_data = argmax.data_ptr<int>();
int *pts_idx_of_voxels_data = pts_idx_of_voxels.data_ptr<int>();
float *pooled_features_data = pooled_features.data_ptr<float>();
roiaware_pool3d_launcher(
boxes_num, pts_num, channels, max_pts_each_voxel, out_x, out_y, out_z,
rois_data, pts_data, pts_feature_data, argmax_data,
pts_idx_of_voxels_data, pooled_features_data, pool_method);
return 1;
}
int roiaware_pool3d_gpu_backward(at::Tensor pts_idx_of_voxels,
at::Tensor argmax, at::Tensor grad_out,
at::Tensor grad_in, int pool_method) {
// params pts_idx_of_voxels: (N, out_x, out_y, out_z, max_pts_each_voxel)
// params argmax: (N, out_x, out_y, out_z, C)
// params grad_out: (N, out_x, out_y, out_z, C)
// params grad_in: (npoints, C), return value
// params pool_method: 0: max_pool 1: avg_pool
CHECK_INPUT(pts_idx_of_voxels);
CHECK_INPUT(argmax);
CHECK_INPUT(grad_out);
CHECK_INPUT(grad_in);
int boxes_num = pts_idx_of_voxels.size(0);
int out_x = pts_idx_of_voxels.size(1);
int out_y = pts_idx_of_voxels.size(2);
int out_z = pts_idx_of_voxels.size(3);
int max_pts_each_voxel = pts_idx_of_voxels.size(4); // index 0 is the counter
int channels = grad_out.size(4);
const int *pts_idx_of_voxels_data = pts_idx_of_voxels.data_ptr<int>();
const int *argmax_data = argmax.data_ptr<int>();
const float *grad_out_data = grad_out.data_ptr<float>();
float *grad_in_data = grad_in.data_ptr<float>();
roiaware_pool3d_backward_launcher(boxes_num, out_x, out_y, out_z, channels,
max_pts_each_voxel, pts_idx_of_voxels_data,
argmax_data, grad_out_data, grad_in_data,
pool_method);
return 1;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &roiaware_pool3d_gpu, "roiaware pool3d forward (CUDA)");
m.def("backward", &roiaware_pool3d_gpu_backward,
"roiaware pool3d backward (CUDA)");
m.def("points_in_boxes_part", &points_in_boxes_part,
"points_in_boxes_part forward (CUDA)");
m.def("points_in_boxes_all", &points_in_boxes_all,
"points_in_boxes_all forward (CUDA)");
m.def("points_in_boxes_cpu", &points_in_boxes_cpu,
"points_in_boxes_cpu forward (CPU)");
}
// Modified from
// https://github.com/sshaoshuai/PCDet/blob/master/pcdet/ops/roiaware_pool3d/src/roiaware_pool3d_kernel.cu
// Written by Shaoshuai Shi
// All Rights Reserved 2019.
#include <assert.h>
#include <math.h>
#include <stdio.h>
#include <torch/serialize/tensor.h>
#include <torch/types.h>
#define THREADS_PER_BLOCK 256
#define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0))
// #define DEBUG
__device__ inline void lidar_to_local_coords(float shift_x, float shift_y,
float rz, float &local_x,
float &local_y) {
float cosa = cos(-rz), sina = sin(-rz);
local_x = shift_x * cosa + shift_y * (-sina);
local_y = shift_x * sina + shift_y * cosa;
}
__device__ inline int check_pt_in_box3d(const float *pt, const float *box3d,
float &local_x, float &local_y) {
// param pt: (x, y, z)
// param box3d: (cx, cy, cz, x_size, y_size, z_size, rz) in LiDAR coordinate, cz in the
// bottom center
float x = pt[0], y = pt[1], z = pt[2];
float cx = box3d[0], cy = box3d[1], cz = box3d[2];
float x_size = box3d[3], y_size = box3d[4], z_size = box3d[5], rz = box3d[6];
cz += z_size / 2.0; // shift to the center since cz in box3d is the bottom center
if (fabsf(z - cz) > z_size / 2.0) return 0;
lidar_to_local_coords(x - cx, y - cy, rz, local_x, local_y);
float in_flag = (local_x > -x_size / 2.0) & (local_x < x_size / 2.0) &
(local_y > -y_size / 2.0) & (local_y < y_size / 2.0);
return in_flag;
}
__global__ void generate_pts_mask_for_box3d(int boxes_num, int pts_num,
int out_x, int out_y, int out_z,
const float *rois, const float *pts,
int *pts_mask) {
// params rois: (N, 7) [x, y, z, x_size, y_size, z_size, rz] in LiDAR coordinate
// params pts: (npoints, 3) [x, y, z]
// params pts_mask: (N, npoints): -1 means point does not in this box,
// otherwise: encode (x_idxs, y_idxs, z_idxs) by binary bit
int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
int box_idx = blockIdx.y;
if (pt_idx >= pts_num || box_idx >= boxes_num) return;
pts += pt_idx * 3;
rois += box_idx * 7;
pts_mask += box_idx * pts_num + pt_idx;
float local_x = 0, local_y = 0;
int cur_in_flag = check_pt_in_box3d(pts, rois, local_x, local_y);
pts_mask[0] = -1;
if (cur_in_flag > 0) {
float local_z = pts[2] - rois[2];
float x_size = rois[3], y_size = rois[4], z_size = rois[5];
float x_res = x_size / out_x;
float y_res = y_size / out_y;
float z_res = z_size / out_z;
unsigned int x_idx = int((local_x + x_size / 2) / x_res);
unsigned int y_idx = int((local_y + y_size / 2) / y_res);
unsigned int z_idx = int(local_z / z_res);
x_idx = min(max(x_idx, 0), out_x - 1);
y_idx = min(max(y_idx, 0), out_y - 1);
z_idx = min(max(z_idx, 0), out_z - 1);
unsigned int idx_encoding = (x_idx << 16) + (y_idx << 8) + z_idx;
#ifdef DEBUG
printf(
"mask: pts_%d(%.3f, %.3f, %.3f), local(%.3f, %.3f, %.3f), idx(%d, %d, "
"%d), res(%.3f, %.3f, %.3f), idx_encoding=%x\n",
pt_idx, pts[0], pts[1], pts[2], local_x, local_y, local_z, x_idx, y_idx,
z_idx, x_res, y_res, z_res, idx_encoding);
#endif
pts_mask[0] = idx_encoding;
}
}
__global__ void collect_inside_pts_for_box3d(int boxes_num, int pts_num,
int max_pts_each_voxel, int out_x,
int out_y, int out_z,
const int *pts_mask,
int *pts_idx_of_voxels) {
// params pts_mask: (N, npoints) 0 or 1
// params pts_idx_of_voxels: (N, out_x, out_y, out_z, max_pts_each_voxel)
int box_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (box_idx >= boxes_num) return;
int max_num_pts = max_pts_each_voxel - 1; // index 0 is the counter
pts_idx_of_voxels += box_idx * out_x * out_y * out_z * max_pts_each_voxel;
for (int k = 0; k < pts_num; k++) {
if (pts_mask[box_idx * pts_num + k] != -1) {
unsigned int idx_encoding = pts_mask[box_idx * pts_num + k];
unsigned int x_idx = (idx_encoding >> 16) & 0xFF;
unsigned int y_idx = (idx_encoding >> 8) & 0xFF;
unsigned int z_idx = idx_encoding & 0xFF;
unsigned int base_offset = x_idx * out_y * out_z * max_pts_each_voxel +
y_idx * out_z * max_pts_each_voxel +
z_idx * max_pts_each_voxel;
unsigned int cnt = pts_idx_of_voxels[base_offset];
if (cnt < max_num_pts) {
pts_idx_of_voxels[base_offset + cnt + 1] = k;
pts_idx_of_voxels[base_offset]++;
}
#ifdef DEBUG
printf("collect: pts_%d, idx(%d, %d, %d), idx_encoding=%x\n", k, x_idx,
y_idx, z_idx, idx_encoding);
#endif
}
}
}
__global__ void roiaware_maxpool3d(int boxes_num, int pts_num, int channels,
int max_pts_each_voxel, int out_x, int out_y,
int out_z, const float *pts_feature,
const int *pts_idx_of_voxels,
float *pooled_features, int *argmax) {
// params pts_feature: (npoints, C)
// params pts_idx_of_voxels: (N, out_x, out_y, out_z, max_pts_each_voxel),
// index 0 is the counter params pooled_features: (N, out_x, out_y, out_z, C)
// params argmax: (N, out_x, out_y, out_z, C)
int box_idx = blockIdx.z;
int channel_idx = blockIdx.y;
int voxel_idx_flat = blockIdx.x * blockDim.x + threadIdx.x;
int x_idx = voxel_idx_flat / (out_y * out_z);
int y_idx = (voxel_idx_flat - x_idx * (out_y * out_z)) / out_z;
int z_idx = voxel_idx_flat % out_z;
if (box_idx >= boxes_num || channel_idx >= channels || x_idx >= out_x ||
y_idx >= out_y || z_idx >= out_z)
return;
#ifdef DEBUG
printf("src pts_idx_of_voxels: (%p, ), argmax: %p\n", pts_idx_of_voxels,
argmax);
#endif
int offset_base = x_idx * out_y * out_z + y_idx * out_z + z_idx;
pts_idx_of_voxels += box_idx * out_x * out_y * out_z * max_pts_each_voxel +
offset_base * max_pts_each_voxel;
pooled_features += box_idx * out_x * out_y * out_z * channels +
offset_base * channels + channel_idx;
argmax += box_idx * out_x * out_y * out_z * channels +
offset_base * channels + channel_idx;
int argmax_idx = -1;
float max_val = -1e50;
int total_pts = pts_idx_of_voxels[0];
for (int k = 1; k <= total_pts; k++) {
if (pts_feature[pts_idx_of_voxels[k] * channels + channel_idx] > max_val) {
max_val = pts_feature[pts_idx_of_voxels[k] * channels + channel_idx];
argmax_idx = pts_idx_of_voxels[k];
}
}
if (argmax_idx != -1) {
pooled_features[0] = max_val;
}
argmax[0] = argmax_idx;
#ifdef DEBUG
printf(
"channel_%d idx(%d, %d, %d), argmax_idx=(%d, %.3f), total=%d, after "
"pts_idx: %p, argmax: (%p, %d)\n",
channel_idx, x_idx, y_idx, z_idx, argmax_idx, max_val, total_pts,
pts_idx_of_voxels, argmax, argmax_idx);
#endif
}
__global__ void roiaware_avgpool3d(int boxes_num, int pts_num, int channels,
int max_pts_each_voxel, int out_x, int out_y,
int out_z, const float *pts_feature,
const int *pts_idx_of_voxels,
float *pooled_features) {
// params pts_feature: (npoints, C)
// params pts_idx_of_voxels: (N, out_x, out_y, out_z, max_pts_each_voxel),
// index 0 is the counter params pooled_features: (N, out_x, out_y, out_z, C)
// params argmax: (N, out_x, out_y, out_z, C)
int box_idx = blockIdx.z;
int channel_idx = blockIdx.y;
int voxel_idx_flat = blockIdx.x * blockDim.x + threadIdx.x;
int x_idx = voxel_idx_flat / (out_y * out_z);
int y_idx = (voxel_idx_flat - x_idx * (out_y * out_z)) / out_z;
int z_idx = voxel_idx_flat % out_z;
if (box_idx >= boxes_num || channel_idx >= channels || x_idx >= out_x ||
y_idx >= out_y || z_idx >= out_z)
return;
int offset_base = x_idx * out_y * out_z + y_idx * out_z + z_idx;
pts_idx_of_voxels += box_idx * out_x * out_y * out_z * max_pts_each_voxel +
offset_base * max_pts_each_voxel;
pooled_features += box_idx * out_x * out_y * out_z * channels +
offset_base * channels + channel_idx;
float sum_val = 0;
int total_pts = pts_idx_of_voxels[0];
for (int k = 1; k <= total_pts; k++) {
sum_val += pts_feature[pts_idx_of_voxels[k] * channels + channel_idx];
}
if (total_pts > 0) {
pooled_features[0] = sum_val / total_pts;
}
}
void roiaware_pool3d_launcher(int boxes_num, int pts_num, int channels,
int max_pts_each_voxel, int out_x, int out_y,
int out_z, const float *rois, const float *pts,
const float *pts_feature, int *argmax,
int *pts_idx_of_voxels, float *pooled_features,
int pool_method) {
// params rois: (N, 7) [x, y, z, x_size, y_size, z_size, rz] in LiDAR coordinate
// params pts: (npoints, 3) [x, y, z] in LiDAR coordinate
// params pts_feature: (npoints, C)
// params argmax: (N, out_x, out_y, out_z, C)
// params pts_idx_of_voxels: (N, out_x, out_y, out_z, max_pts_each_voxel)
// params pooled_features: (N, out_x, out_y, out_z, C)
// params pool_method: 0: max_pool 1: avg_pool
int *pts_mask = NULL;
cudaMalloc(&pts_mask, boxes_num * pts_num * sizeof(int)); // (N, M)
cudaMemset(pts_mask, -1, boxes_num * pts_num * sizeof(int));
dim3 blocks_mask(DIVUP(pts_num, THREADS_PER_BLOCK), boxes_num);
dim3 threads(THREADS_PER_BLOCK);
generate_pts_mask_for_box3d<<<blocks_mask, threads>>>(
boxes_num, pts_num, out_x, out_y, out_z, rois, pts, pts_mask);
// TODO: Merge the collect and pool functions, SS
dim3 blocks_collect(DIVUP(boxes_num, THREADS_PER_BLOCK));
collect_inside_pts_for_box3d<<<blocks_collect, threads>>>(
boxes_num, pts_num, max_pts_each_voxel, out_x, out_y, out_z, pts_mask,
pts_idx_of_voxels);
dim3 blocks_pool(DIVUP(out_x * out_y * out_z, THREADS_PER_BLOCK), channels,
boxes_num);
if (pool_method == 0) {
roiaware_maxpool3d<<<blocks_pool, threads>>>(
boxes_num, pts_num, channels, max_pts_each_voxel, out_x, out_y, out_z,
pts_feature, pts_idx_of_voxels, pooled_features, argmax);
} else if (pool_method == 1) {
roiaware_avgpool3d<<<blocks_pool, threads>>>(
boxes_num, pts_num, channels, max_pts_each_voxel, out_x, out_y, out_z,
pts_feature, pts_idx_of_voxels, pooled_features);
}
cudaFree(pts_mask);
#ifdef DEBUG
cudaDeviceSynchronize(); // for using printf in kernel function
#endif
}
__global__ void roiaware_maxpool3d_backward(int boxes_num, int channels,
int out_x, int out_y, int out_z,
const int *argmax,
const float *grad_out,
float *grad_in) {
// params argmax: (N, out_x, out_y, out_z, C)
// params grad_out: (N, out_x, out_y, out_z, C)
// params grad_in: (npoints, C), return value
int box_idx = blockIdx.z;
int channel_idx = blockIdx.y;
int voxel_idx_flat = blockIdx.x * blockDim.x + threadIdx.x;
int x_idx = voxel_idx_flat / (out_y * out_z);
int y_idx = (voxel_idx_flat - x_idx * (out_y * out_z)) / out_z;
int z_idx = voxel_idx_flat % out_z;
if (box_idx >= boxes_num || channel_idx >= channels || x_idx >= out_x ||
y_idx >= out_y || z_idx >= out_z)
return;
int offset_base = x_idx * out_y * out_z + y_idx * out_z + z_idx;
argmax += box_idx * out_x * out_y * out_z * channels +
offset_base * channels + channel_idx;
grad_out += box_idx * out_x * out_y * out_z * channels +
offset_base * channels + channel_idx;
if (argmax[0] == -1) return;
atomicAdd(grad_in + argmax[0] * channels + channel_idx, grad_out[0] * 1);
}
__global__ void roiaware_avgpool3d_backward(int boxes_num, int channels,
int out_x, int out_y, int out_z,
int max_pts_each_voxel,
const int *pts_idx_of_voxels,
const float *grad_out,
float *grad_in) {
// params pts_idx_of_voxels: (N, out_x, out_y, out_z, max_pts_each_voxel)
// params grad_out: (N, out_x, out_y, out_z, C)
// params grad_in: (npoints, C), return value
int box_idx = blockIdx.z;
int channel_idx = blockIdx.y;
int voxel_idx_flat = blockIdx.x * blockDim.x + threadIdx.x;
int x_idx = voxel_idx_flat / (out_y * out_z);
int y_idx = (voxel_idx_flat - x_idx * (out_y * out_z)) / out_z;
int z_idx = voxel_idx_flat % out_z;
if (box_idx >= boxes_num || channel_idx >= channels || x_idx >= out_x ||
y_idx >= out_y || z_idx >= out_z)
return;
int offset_base = x_idx * out_y * out_z + y_idx * out_z + z_idx;
pts_idx_of_voxels += box_idx * out_x * out_y * out_z * max_pts_each_voxel +
offset_base * max_pts_each_voxel;
grad_out += box_idx * out_x * out_y * out_z * channels +
offset_base * channels + channel_idx;
int total_pts = pts_idx_of_voxels[0];
float cur_grad = 1 / fmaxf(float(total_pts), 1.0);
for (int k = 1; k <= total_pts; k++) {
atomicAdd(grad_in + pts_idx_of_voxels[k] * channels + channel_idx,
grad_out[0] * cur_grad);
}
}
void roiaware_pool3d_backward_launcher(int boxes_num, int out_x, int out_y,
int out_z, int channels,
int max_pts_each_voxel,
const int *pts_idx_of_voxels,
const int *argmax, const float *grad_out,
float *grad_in, int pool_method) {
// params pts_idx_of_voxels: (N, out_x, out_y, out_z, max_pts_each_voxel)
// params argmax: (N, out_x, out_y, out_z, C)
// params grad_out: (N, out_x, out_y, out_z, C)
// params grad_in: (npoints, C), return value
// params pool_method: 0: max_pool, 1: avg_pool
dim3 blocks(DIVUP(out_x * out_y * out_z, THREADS_PER_BLOCK), channels,
boxes_num);
dim3 threads(THREADS_PER_BLOCK);
if (pool_method == 0) {
roiaware_maxpool3d_backward<<<blocks, threads>>>(
boxes_num, channels, out_x, out_y, out_z, argmax, grad_out, grad_in);
} else if (pool_method == 1) {
roiaware_avgpool3d_backward<<<blocks, threads>>>(
boxes_num, channels, out_x, out_y, out_z, max_pts_each_voxel,
pts_idx_of_voxels, grad_out, grad_in);
}
}
# Copyright (c) OpenMMLab. All rights reserved.
from .roipoint_pool3d import RoIPointPool3d
__all__ = ['RoIPointPool3d']
# Copyright (c) OpenMMLab. All rights reserved.
from torch import nn as nn
from torch.autograd import Function
from . import roipoint_pool3d_ext
class RoIPointPool3d(nn.Module):
def __init__(self, num_sampled_points=512):
super().__init__()
"""
Args:
num_sampled_points (int): Number of samples in each roi
"""
self.num_sampled_points = num_sampled_points
def forward(self, points, point_features, boxes3d):
"""
Args:
points (torch.Tensor): Input points whose shape is BxNx3
point_features: (B, N, C)
boxes3d: (B, M, 7), [x, y, z, dx, dy, dz, heading]
Returns:
torch.Tensor: (B, M, 512, 3 + C) pooled_features
torch.Tensor: (B, M) pooled_empty_flag
"""
return RoIPointPool3dFunction.apply(points, point_features, boxes3d,
self.num_sampled_points)
class RoIPointPool3dFunction(Function):
@staticmethod
def forward(ctx, points, point_features, boxes3d, num_sampled_points=512):
"""
Args:
points (torch.Tensor): Input points whose shape is (B, N, 3)
point_features (torch.Tensor): Input points features shape is \
(B, N, C)
boxes3d (torch.Tensor): Input bounding boxes whose shape is \
(B, M, 7)
num_sampled_points (int): the num of sampled points
Returns:
torch.Tensor: (B, M, 512, 3 + C) pooled_features
torch.Tensor: (B, M) pooled_empty_flag
"""
assert points.shape.__len__() == 3 and points.shape[2] == 3
batch_size, boxes_num, feature_len = points.shape[0], boxes3d.shape[
1], point_features.shape[2]
pooled_boxes3d = boxes3d.view(batch_size, -1, 7)
pooled_features = point_features.new_zeros(
(batch_size, boxes_num, num_sampled_points, 3 + feature_len))
pooled_empty_flag = point_features.new_zeros(
(batch_size, boxes_num)).int()
roipoint_pool3d_ext.forward(points.contiguous(),
pooled_boxes3d.contiguous(),
point_features.contiguous(),
pooled_features, pooled_empty_flag)
return pooled_features, pooled_empty_flag
@staticmethod
def backward(ctx, grad_out):
raise NotImplementedError
if __name__ == '__main__':
pass
/*
Modified for
https://github.com/open-mmlab/OpenPCDet/blob/master/pcdet/ops/roipoint_pool3d/src/roipoint_pool3d_kernel.cu
Point cloud feature pooling
Written by Shaoshuai Shi
All Rights Reserved 2018.
*/
#include <torch/serialize/tensor.h>
#include <torch/extension.h>
#define CHECK_CUDA(x) do { \
if (!x.type().is_cuda()) { \
fprintf(stderr, "%s must be CUDA tensor at %s:%d\n", #x, __FILE__, __LINE__); \
exit(-1); \
} \
} while (0)
#define CHECK_CONTIGUOUS(x) do { \
if (!x.is_contiguous()) { \
fprintf(stderr, "%s must be contiguous tensor at %s:%d\n", #x, __FILE__, __LINE__); \
exit(-1); \
} \
} while (0)
#define CHECK_INPUT(x) CHECK_CUDA(x);CHECK_CONTIGUOUS(x)
void roipool3dLauncher(int batch_size, int pts_num, int boxes_num, int feature_in_len, int sampled_pts_num,
const float *xyz, const float *boxes3d, const float *pts_feature, float *pooled_features, int *pooled_empty_flag);
int roipool3d_gpu(at::Tensor xyz, at::Tensor boxes3d, at::Tensor pts_feature, at::Tensor pooled_features, at::Tensor pooled_empty_flag){
// params xyz: (B, N, 3)
// params boxes3d: (B, M, 7)
// params pts_feature: (B, N, C)
// params pooled_features: (B, M, 512, 3+C)
// params pooled_empty_flag: (B, M)
CHECK_INPUT(xyz);
CHECK_INPUT(boxes3d);
CHECK_INPUT(pts_feature);
CHECK_INPUT(pooled_features);
CHECK_INPUT(pooled_empty_flag);
int batch_size = xyz.size(0);
int pts_num = xyz.size(1);
int boxes_num = boxes3d.size(1);
int feature_in_len = pts_feature.size(2);
int sampled_pts_num = pooled_features.size(2);
const float * xyz_data = xyz.data<float>();
const float * boxes3d_data = boxes3d.data<float>();
const float * pts_feature_data = pts_feature.data<float>();
float * pooled_features_data = pooled_features.data<float>();
int * pooled_empty_flag_data = pooled_empty_flag.data<int>();
roipool3dLauncher(batch_size, pts_num, boxes_num, feature_in_len, sampled_pts_num,
xyz_data, boxes3d_data, pts_feature_data, pooled_features_data, pooled_empty_flag_data);
return 1;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &roipool3d_gpu, "roipool3d forward (CUDA)");
}
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from mmcv.cnn import build_conv_layer, build_norm_layer from mmcv.cnn import build_conv_layer, build_norm_layer
from mmcv.ops import SparseModule, SparseSequential
from torch import nn from torch import nn
from mmdet3d.ops import spconv
from mmdet.models.backbones.resnet import BasicBlock, Bottleneck from mmdet.models.backbones.resnet import BasicBlock, Bottleneck
class SparseBottleneck(Bottleneck, spconv.SparseModule): class SparseBottleneck(Bottleneck, SparseModule):
"""Sparse bottleneck block for PartA^2. """Sparse bottleneck block for PartA^2.
Bottleneck block implemented with submanifold sparse convolution. Bottleneck block implemented with submanifold sparse convolution.
...@@ -32,7 +32,7 @@ class SparseBottleneck(Bottleneck, spconv.SparseModule): ...@@ -32,7 +32,7 @@ class SparseBottleneck(Bottleneck, spconv.SparseModule):
conv_cfg=None, conv_cfg=None,
norm_cfg=None): norm_cfg=None):
spconv.SparseModule.__init__(self) SparseModule.__init__(self)
Bottleneck.__init__( Bottleneck.__init__(
self, self,
inplanes, inplanes,
...@@ -65,7 +65,7 @@ class SparseBottleneck(Bottleneck, spconv.SparseModule): ...@@ -65,7 +65,7 @@ class SparseBottleneck(Bottleneck, spconv.SparseModule):
return out return out
class SparseBasicBlock(BasicBlock, spconv.SparseModule): class SparseBasicBlock(BasicBlock, SparseModule):
"""Sparse basic block for PartA^2. """Sparse basic block for PartA^2.
Sparse basic block implemented with submanifold sparse convolution. Sparse basic block implemented with submanifold sparse convolution.
...@@ -90,7 +90,7 @@ class SparseBasicBlock(BasicBlock, spconv.SparseModule): ...@@ -90,7 +90,7 @@ class SparseBasicBlock(BasicBlock, spconv.SparseModule):
downsample=None, downsample=None,
conv_cfg=None, conv_cfg=None,
norm_cfg=None): norm_cfg=None):
spconv.SparseModule.__init__(self) SparseModule.__init__(self)
BasicBlock.__init__( BasicBlock.__init__(
self, self,
inplanes, inplanes,
...@@ -182,5 +182,5 @@ def make_sparse_convmodule(in_channels, ...@@ -182,5 +182,5 @@ def make_sparse_convmodule(in_channels,
elif layer == 'act': elif layer == 'act':
layers.append(nn.ReLU(inplace=True)) layers.append(nn.ReLU(inplace=True))
layers = spconv.SparseSequential(*layers) layers = SparseSequential(*layers)
return layers return layers
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment