Unverified Commit 540a2209 authored by Yan Yan's avatar Yan Yan Committed by GitHub
Browse files

Merge pull request #177 from xmyqsh/GPU_Voxelization

GPU Voxelization
parents cfaa1a3a 6ac406fa
#pragma once
#include <tensorview/kernel_utils.h>
#include <tensorview/tensorview.h>
#include <torch/script.h>
namespace spconv {
template <typename Index, unsigned NDim>
__global__ void scatterPointToGridKernel(
tv::TensorView<const float> points,
tv::TensorView<const Index> indexes,
tv::TensorView<float> grids,
tv::TensorView<Index> numPointsPerGrid,
tv::TensorView<Index> pointIndex,
const tv::SimpleVector<Index, NDim> gridShape) {
Index index;
int numPoints = points.dim(0);
int numFeatures = points.dim(1);
for (int ix : tv::KernelLoopX<int>(numPoints)) {
index = tv::ArrayIndexRowMajor<NDim, NDim>::runPtrs(
indexes.data() + ix * NDim, gridShape.data(), 0);
pointIndex(ix) = index;
atomicAdd(numPointsPerGrid.data() + index, Index(1));
#pragma unroll
for (int k = 0; k != numFeatures; ++k) {
atomicAdd(grids.data() + index * numFeatures + k, *(points.data() + ix * numFeatures + k));
}
}
}
template <typename Index, unsigned NDim>
__global__ void gatherPointFromGridKernel(
tv::TensorView<const float> grids,
tv::TensorView<const Index> numPointsPerGrid,
tv::TensorView<const Index> pointIndexUnique,
tv::TensorView<float> voxels,
tv::TensorView<Index> coors,
const tv::SimpleVector<Index, NDim> gridShape) {
Index index;
int numVoxels = voxels.dim(0);
int numFeatures = grids.dim(1);
for (int ix : tv::KernelLoopX<int>(numVoxels)) {
index = pointIndexUnique(ix);
#pragma unroll
for (int k = 0; k != numFeatures; ++k) {
voxels(ix, k) = grids(index, k) / numPointsPerGrid(index);
}
index = tv::rowArrayIdxInv<Index, NDim>(
index, coors.data() + ix * NDim, gridShape.data());
}
}
template <typename Index>
__global__ void resetGridKernel(
tv::TensorView<float> grids,
tv::TensorView<Index> numPointsPerGrid,
tv::TensorView<Index> pointIndexUnique) {
Index index;
int numVoxels = pointIndexUnique.dim(0) - 1;
int numFeatures = grids.dim(1);
for (int ix : tv::KernelLoopX<int>(numVoxels)) {
index = pointIndexUnique(ix);
#pragma unroll
for (int k = 0; k != numFeatures; ++k) {
grids(index, k) = 0;
numPointsPerGrid(index) = 0;
}
}
}
template <typename Index>
__global__ void resetPointIndexKernel(
tv::TensorView<Index> pointIndex, const Index gridVolume) {
int num_max_points = pointIndex.dim(0) - 1;
for (int ix : tv::KernelLoopX<int>(num_max_points)) {
pointIndex(ix) = gridVolume;
}
}
} // namespace spconv
// Copyright 2020 xmyqsh
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <spconv/points2voxels.h>
#include <tensorview/torch_utils.h>
#include <torch/script.h>
#include <utility/timer.h>
namespace spconv {
int64_t
pointsToVoxel(torch::Tensor points,
torch::Tensor indexes,
torch::Tensor pointIndex,
torch::Tensor grids,
torch::Tensor numPointsPerGrid,
torch::Tensor voxels,
torch::Tensor coors,
std::vector<int64_t> gridShape,
const int64_t ndim);
} // namespace spconv
#pragma once
#include <tensorview/tensorview.h>
#include <torch/script.h>
namespace spconv {
void scatter_point_to_grid_cuda(
torch::Tensor points,
torch::Tensor indexes,
torch::Tensor grids,
torch::Tensor numPointsPerGrid,
torch::Tensor pointIndex,
std::vector<int64_t> gridShape,
const int ndim);
void gather_point_from_grid_cuda(
torch::Tensor grids, torch::Tensor numPointsPerGrid,
torch::Tensor pointIndex,
torch::Tensor pointIndexUnique,
torch::Tensor voxels, torch::Tensor coors,
std::vector<int64_t> gridShape,
const int ndim);
} // namespace spconv
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import numpy as np import numpy as np
import torch
from spconv import spconv_utils from spconv import spconv_utils
from spconv.spconv_utils import (non_max_suppression_cpu, from spconv.spconv_utils import (non_max_suppression_cpu,
...@@ -292,3 +293,81 @@ class VoxelGeneratorV2: ...@@ -292,3 +293,81 @@ class VoxelGeneratorV2:
@property @property
def grid_size(self): def grid_size(self):
return self._grid_size return self._grid_size
class VoxelGeneratorV3:
def __init__(self,
voxel_size,
point_cloud_range,
max_points,
num_features,
dtype,
device):
self._max_points = max_points
self._point_cloud_range = point_cloud_range
self._voxel_size = voxel_size
self._grid_size = torch.round((self._point_cloud_range[3:] - self._point_cloud_range[:3]) / self._voxel_size).to(torch.int32)
grid_volume = self._grid_size.prod()
self._grid_size = self._grid_size.cpu().numpy().tolist()
self._ndim = len(self._grid_size)
self._dtype = dtype
self._device = device
self._point_index = torch.full([max_points + 1], grid_volume, dtype=torch.int32, device=self._device)
self._grids = torch.zeros([grid_volume, num_features], dtype=self._dtype, device=self._device)
self._num_points_per_grid = torch.zeros([grid_volume], dtype=torch.int32, device=self._device)
self._voxels = torch.zeros([max_points, num_features], dtype=self._dtype, device=self._device)
self._coors = torch.zeros([max_points, self._ndim], dtype=torch.int32, device=self._device)
def generate(self, points):
assert points.shape[0] <= self._max_points, 'please enlarge max_points to not smaller than ' + str(points.shape[0])
points.to(self._dtype).to(self._device)
return self.points_to_voxel(points)
def generate_multi_gpu(self, points):
assert points.shape[0] <= self._max_points, 'please enlarge max_points to not smaller than ' + str(points.shape[0])
points.to(self._dtype).to(self._device)
return self.points_to_voxel(points)
@property
def voxel_size(self):
return self._voxel_size
@property
def point_cloud_range(self):
return self._point_cloud_range
@property
def grid_size(self):
return self._grid_size
def points_to_voxel(self, points):
"""
points: [N, ndim] float tensor. points[:, :3] contain xyz points and
points[:, 3:] contain other information such as reflectivity.
voxel_size: [3] list/tuple or array or tensor, float. xyz, indicate voxel size
coors_range: [6] list/tuple or array or tensor, float. indicate voxel range.
format: xyzxyz, minmax
"""
indexes = torch.floor((points[:, :3] - self._point_cloud_range[:3]) / self._voxel_size).to(torch.int32)
num_voxel = torch.ops.spconv.points_to_voxel(points, indexes,
self._point_index,
self._grids,
self._num_points_per_grid,
self._voxels,
self._coors,
self._grid_size,
self._ndim)
voxels = self._voxels[:num_voxel, :]
coors = self._coors[:num_voxel, :]
# xyz --> zyx
#coors = coors[::-1]
x, y, z = coors[:, 0].reshape([-1, 1]), coors[:, 1].reshape([-1, 1]), coors[:, 2].reshape([-1, 1])
coors = torch.cat([z, y, x], dim=1)
# can be skipped
# x, y, z, f = voxels[:, 0].reshape([-1, 1]), voxels[:, 1].reshape([-1, 1]), voxels[:, 2].reshape([-1, 1]), voxels[:, 3:]
# voxels = torch.cat([z, y, x, f], dim=1)
return voxels, coors
set(ALL_FILES all.cc indice.cc reordering.cc maxpool.cc nms.cc spconv_ops.cc pool_ops.cc) set(ALL_FILES all.cc indice.cc reordering.cc maxpool.cc nms.cc spconv_ops.cc pool_ops.cc point2voxel_ops.cc)
if (SPCONV_BuildCUDA) if (SPCONV_BuildCUDA)
set(ALL_FILES ${ALL_FILES} indice.cu reordering.cu maxpool.cu pillar_scatter.cu cublas_gemm.cc fused_conv.cu) set(ALL_FILES ${ALL_FILES} indice.cu reordering.cu maxpool.cu pillar_scatter.cu cublas_gemm.cc point2voxel.cu fused_conv.cu)
endif() endif()
add_library(spconv SHARED ${ALL_FILES}) add_library(spconv SHARED ${ALL_FILES})
......
...@@ -15,12 +15,14 @@ ...@@ -15,12 +15,14 @@
#include <spconv/fused_spconv_ops.h> #include <spconv/fused_spconv_ops.h>
#include <spconv/nms_ops.h> #include <spconv/nms_ops.h>
#include <spconv/pillar_scatter_ops.h> #include <spconv/pillar_scatter_ops.h>
#include <spconv/point2voxel_ops.h>
#include <spconv/pool_ops.h> #include <spconv/pool_ops.h>
#include <spconv/spconv_ops.h> #include <spconv/spconv_ops.h>
#include <torch/script.h> #include <torch/script.h>
static auto registry = static auto registry =
torch::RegisterOperators() torch::RegisterOperators()
.op("spconv::points_to_voxel", &spconv::pointsToVoxel)
.op("spconv::get_indice_pairs", &spconv::getIndicePairs) .op("spconv::get_indice_pairs", &spconv::getIndicePairs)
.op("spconv::indice_conv", &spconv::indiceConv) .op("spconv::indice_conv", &spconv::indiceConv)
.op("spconv::indice_conv_backward", &spconv::indiceConvBackward) .op("spconv::indice_conv_backward", &spconv::indiceConvBackward)
......
#include <ATen/ATen.h>
#include <spconv/point2voxel.cu.h>
//#include <spconv/point2voxel.h>
#include <tensorview/cuda_utils.h>
#include <tensorview/mp_helper.h>
#include <tensorview/tensor.h>
#include <tensorview/tensorview.h>
#include <tensorview/torch_utils.h>
namespace spconv {
void scatter_point_to_grid_cuda(
torch::Tensor points,
torch::Tensor indexes,
torch::Tensor grids,
torch::Tensor numPointsPerGrid,
torch::Tensor pointIndex,
std::vector<int64_t> gridShape,
const int ndim) {
auto stream = at::cuda::getCurrentCUDAStream();
auto num_points = points.size(0);
auto num_features = points.size(1);
tv::dispatch_torch<int32_t>(pointIndex.scalar_type(), [&](auto IndexValue) {
using Index = decltype(IndexValue);
tv::dispatch_int<2, 3, 4>(ndim, [&](auto I) {
constexpr int NDim = decltype(I)::value;
tv::SimpleVector<Index, NDim> gs(gridShape.begin(), gridShape.end());
scatterPointToGridKernel<Index, NDim>
<<<tv::cuda::getBlocks(num_points), tv::cuda::CUDA_NUM_THREADS,
0, stream>>>(tv::torch2tv<float>(points),
tv::torch2tv<Index>(indexes),
tv::torch2tv<float>(grids),
tv::torch2tv<Index>(numPointsPerGrid),
tv::torch2tv<Index>(pointIndex),
gs);
TV_CHECK_CUDA_ERR_V2("scatterPointToGridKernel failed");
#ifdef TV_LOG_KERNEL_INFO
cudaFuncAttributes attr;
checkCudaErrors(cudaFuncGetAttributes(
&attr, scatterPointToGridKernel<Index, NDim>));
tv::ssprint("scatterPointToGridKernel<", tv::type_s<Index>, NDim,
">", attr.numRegs);
#endif
});
});
}
void gather_point_from_grid_cuda(
torch::Tensor grids, torch::Tensor numPointsPerGrid,
torch::Tensor pointIndex,
torch::Tensor pointIndexUnique,
torch::Tensor voxels, torch::Tensor coors,
std::vector<int64_t> gridShape,
const int ndim) {
auto stream = at::cuda::getCurrentCUDAStream();
auto num_voxel = voxels.size(0);
auto num_max_points = pointIndex.size(0) - 1;
auto grid_volume = grids.size(0);
tv::dispatch_torch<int32_t>(pointIndexUnique.scalar_type(), [&](auto IndexValue) {
using Index = decltype(IndexValue);
tv::dispatch_int<2, 3, 4>(ndim, [&](auto I) {
constexpr int NDim = decltype(I)::value;
tv::SimpleVector<Index, NDim> gs(gridShape.begin(), gridShape.end());
resetPointIndexKernel<Index>
<<<tv::cuda::getBlocks(num_max_points), tv::cuda::CUDA_NUM_THREADS,
0, stream>>>(tv::torch2tv<Index>(pointIndex), grid_volume);
TV_CHECK_CUDA_ERR_V2("resetPointIndexKernel failed");
#ifdef TV_LOG_KERNEL_INFO
cudaFuncAttributes attr0;
checkCudaErrors(cudaFuncGetAttributes(
&attr0, resetPointIndexKernel<Index, NDim>));
tv::ssprint("resetPointIndexKernel<", tv::type_s<Index>, NDim, ">",
attr0.numRegs);
#endif
gatherPointFromGridKernel<Index, NDim>
<<<tv::cuda::getBlocks(num_voxel), tv::cuda::CUDA_NUM_THREADS,
0, stream>>>(tv::torch2tv<float>(grids),
tv::torch2tv<Index>(numPointsPerGrid),
tv::torch2tv<Index>(pointIndexUnique),
tv::torch2tv<float>(voxels),
tv::torch2tv<Index>(coors),
gs);
TV_CHECK_CUDA_ERR_V2("gatherPointFromGridKernel failed");
#ifdef TV_LOG_KERNEL_INFO
cudaFuncAttributes attr1;
checkCudaErrors(cudaFuncGetAttributes(
&attr1, gatherPointFromGridKernel<Index, NDim>));
tv::ssprint("gatherPointFromGridKernel<", tv::type_s<Index>, NDim, ">",
attr1.numRegs);
#endif
resetGridKernel<Index>
<<<tv::cuda::getBlocks(num_voxel), tv::cuda::CUDA_NUM_THREADS,
0, stream>>>(tv::torch2tv<float>(grids),
tv::torch2tv<Index>(numPointsPerGrid),
tv::torch2tv<Index>(pointIndexUnique));
TV_CHECK_CUDA_ERR_V2("resetGridKernel failed");
#ifdef TV_LOG_KERNEL_INFO
cudaFuncAttributes attr2;
checkCudaErrors(cudaFuncGetAttributes(
&attr2, resetGridKernel<Index, NDim>));
tv::ssprint("resetGridKernel<", tv::type_s<Index>, NDim, ">",
attr2.numRegs);
#endif
});
});
}
} // namespace spconv
#include <spconv/point2voxel_ops.h>
//#include <spconv/point2voxel.cu.h>
namespace spconv {
int64_t
pointsToVoxel(torch::Tensor points,
torch::Tensor indexes,
torch::Tensor pointIndex,
torch::Tensor grids,
torch::Tensor numPointsPerGrid,
torch::Tensor voxels,
torch::Tensor coors,
std::vector<int64_t> gridShape,
const int64_t ndim) {
if (points.device().type() == torch::kCPU) {
TV_THROW_INVALID_ARG("not support cpu currently");
}
#ifdef TV_CUDA
else if (points.device().type() == torch::kCUDA) {
scatter_point_to_grid_cuda(points, indexes, grids,
numPointsPerGrid, pointIndex, gridShape, ndim);
}
#endif
else {
TV_THROW_INVALID_ARG("unknown device type");
}
auto res = torch::_unique(pointIndex);
auto pointIndexUnique = std::get<0>(res);
auto num_voxel = pointIndexUnique.size(0) - 1;
if (points.device().type() == torch::kCPU) {
TV_THROW_INVALID_ARG("not support cpu currently");
}
#ifdef TV_CUDA
else if (points.device().type() == torch::kCUDA) {
gather_point_from_grid_cuda(grids, numPointsPerGrid,
pointIndex, pointIndexUnique, voxels, coors, gridShape, ndim);
}
#endif
else {
TV_THROW_INVALID_ARG("unknown device type");
}
return num_voxel;
}
} // namespace spconv
import time
from pathlib import Path
import numpy as np
import torch
from torch import nn
import spconv
from spconv.utils import VoxelGeneratorV2, VoxelGeneratorV3
def waymo_data_gpu(batch_size=1):
print('gpu with total points available per voxel')
data = np.load(Path(__file__).parent / "data" / "benchmark-pc.npz")
points = torch.from_numpy(data['pc']).cuda().float()
voxel_size = torch.Tensor([0.1, 0.1, 0.1]).to(points.dtype).to(points.device)
coors_range = torch.Tensor([-80, -80, -2, 80, 80, 6]).to(points.dtype).to(points.device)
gen = VoxelGeneratorV3(voxel_size, coors_range, max_points=200000,
num_features=points.shape[1],
dtype=points.dtype,
device=points.device)
voxels, coors = gen.generate(points)
times = []
with torch.no_grad():
for i in range(200):
torch.cuda.synchronize()
t = time.time()
voxels, coors = gen.generate(points)
torch.cuda.synchronize()
times.append(time.time() - t)
print("voxelization time", np.mean(times[100:]))
N = coors.shape[0]
batch_id = torch.zeros([N, 1], dtype=coors.dtype, device=coors.device)
coors = torch.cat([batch_id, coors], dim=1)
return voxels, coors, gen.grid_size
def waymo_data_cpu(max_points_per_voxel=1, batch_size=1):
print('cpu with %d max points per voxel' % max_points_per_voxel)
gen = VoxelGeneratorV2([0.1, 0.1, 0.1], [-80, -80, -2, 80, 80, 6], max_points_per_voxel,
150000)
data = np.load(Path(__file__).parent / "data" / "benchmark-pc.npz")
pc = data["pc"]
data = gen.generate(pc)
times = []
with torch.no_grad():
for i in range(200):
torch.cuda.synchronize()
t = time.time()
data = gen.generate(pc)
torch.cuda.synchronize()
times.append(time.time() - t)
print("voxelization time", np.mean(times[100:]))
voxels = data["voxels"].reshape(-1, 3)
coors = data["coordinates"]
N = coors.shape[0]
coors = np.concatenate([np.full([N, 1], 0, coors.dtype), coors], axis=1)
return voxels, coors, gen.grid_size
def get_index(coor, grid_size):
index = coor[0]
for c, g in zip(coor[1:], grid_size):
index = index * g + c
return index
def main():
voxels_gpu, coors_gpu, grid_size_gpu = waymo_data_gpu()
voxels_cpu, coors_cpu, grid_size_cpu = waymo_data_cpu(1)
waymo_data_cpu(10)
waymo_data_cpu(40)
print('...')
grid_size_gpu = grid_size_gpu[::-1]
grid_size_cpu = grid_size_cpu[::-1]
assert len(grid_size_gpu) == len(grid_size_cpu), "mismatch grid size"
assert grid_size_gpu[0] == grid_size_cpu[0], "mismatch grid size"
assert grid_size_gpu[1] == grid_size_cpu[1], "mismatch grid size"
assert grid_size_gpu[2] == grid_size_cpu[2], "mismatch grid size"
assert coors_gpu.shape[0] == coors_cpu.shape[0], "mismatch coors shape"
index2voxel = dict()
for coor, voxel in zip(coors_gpu, voxels_gpu):
index = get_index(coor, grid_size_gpu).item()
index2voxel[index] = voxel[:3].cpu()
for coor, voxel in zip(coors_cpu, voxels_cpu):
index = get_index(coor, grid_size_cpu).item()
assert index in index2voxel, "mismatch index: " + str(index)
assert (index2voxel.pop(index) - voxel[:3]).abs().max() < 0.1, \
"voxel diff should be smaller than voxel_size 0.1"
print('Perfect GPU Voxelization!!!')
if __name__ == "__main__":
main()
import time
from pathlib import Path
import numpy as np
import torch
from torch import nn
import spconv
from spconv.utils import VoxelGeneratorV3
def waymo_data(batch_size=1):
data = np.load(Path(__file__).parent / "data" / "benchmark-pc.npz")
points = torch.from_numpy(data['pc']).cuda().float()
voxel_size = torch.Tensor([0.1, 0.1, 0.1]).to(points.dtype).to(points.device)
coors_range = torch.Tensor([-80, -80, -2, 80, 80, 6]).to(points.dtype).to(points.device)
gen = VoxelGeneratorV3(voxel_size, coors_range)
voxels, coors = gen.generate(points)
N = coors.shape[0]
batch_id = torch.zeros([N, 1], dtype=coors.dtype, device=coors.device)
coors = torch.cat([batch_id, coors], dim=1)
return voxels, coors, gen.grid_size
class Net(nn.Module):
def __init__(self, shape, algo, device):
super().__init__()
self.device = device
self.net = spconv.SparseSequential(
spconv.SubMConv3d(3, 64, 3, bias=False, indice_key="c0", algo=algo),
spconv.SubMConv3d(64, 64, 3, bias=False, indice_key="c0", algo=algo),
# nn.BatchNorm1d(32),
# nn.ReLU(),
spconv.SparseMaxPool3d(2, 2),
spconv.SubMConv3d(64, 96, 3, bias=False, indice_key="c1", algo=algo),
spconv.SubMConv3d(96, 96, 3, bias=False, indice_key="c1", algo=algo),
# nn.BatchNorm1d(64),
# nn.ReLU(),
spconv.SparseMaxPool3d(2, 2),
spconv.SubMConv3d(96, 128, 3, bias=False, indice_key="c2", algo=algo),
spconv.SubMConv3d(128, 128, 3, bias=False, indice_key="c2", algo=algo),
# nn.BatchNorm1d(128),
# nn.ReLU(),
spconv.SparseMaxPool3d(2, 2),
spconv.SubMConv3d(128, 160, 3, bias=False, indice_key="c3", algo=algo),
spconv.SubMConv3d(160, 160, 3, bias=False, indice_key="c3", algo=algo),
# nn.BatchNorm1d(128),
# nn.ReLU(),
spconv.SparseMaxPool3d(2, 2),
spconv.SubMConv3d(160, 192, 3, bias=False, indice_key="c4", algo=algo),
spconv.SubMConv3d(192, 192, 3, bias=False, indice_key="c4", algo=algo),
# nn.BatchNorm1d(128),
# nn.ReLU(),
spconv.SparseMaxPool3d(2, 2),
spconv.SubMConv3d(192, 224, 3, bias=False, indice_key="c5", algo=algo),
spconv.SubMConv3d(224, 224, 3, bias=False, indice_key="c5", algo=algo),
# nn.BatchNorm1d(128),
# nn.ReLU(),
spconv.SparseMaxPool3d(2, 2),
spconv.SubMConv3d(224, 256, 3, bias=False, indice_key="c6", algo=algo),
spconv.SubMConv3d(256, 256, 3, bias=False, indice_key="c6", algo=algo),
)
max_batch_size = 1
# grid (dense map) is used for indice generation. use pre-allocated grid can run faster.
self.grid = torch.full([max_batch_size, *shape], -1,
dtype=torch.int32, device=self.device)
# self.grid = None
self.shape = shape
def forward(self, features, coors, batch_size):
x = spconv.SparseConvTensor(features, coors, self.shape, batch_size,
self.grid)
return self.net(x)
def main():
voxels, coors, spatial_shape = waymo_data()
voxels_th, coors_th = voxels, coors
algo = spconv.ConvAlgo.Native
net = Net(spatial_shape[::-1], algo, voxels_th.device).cuda(device=voxels_th.device).eval().float()
print(coors_th.shape)
out = net(voxels_th, coors_th, 1)
print(out.spatial_shape)
times = []
with torch.no_grad():
for i in range(20):
torch.cuda.synchronize()
t = time.time()
out = net(voxels_th, coors_th, 1)
torch.cuda.synchronize()
times.append(time.time() - t)
# print((net.grid == -1).float().sum(), net.grid.numel())
# print("spconv time", time.time() - t)
print("spconv time", np.mean(times[10:]))
if __name__ == "__main__":
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment