"example/tensorrt/README.md" did not exist on "899008faa3c629bd43ff6c59ec718fb28728dfb5"
Commit bf473de0 authored by Yan Yan's avatar Yan Yan
Browse files

1. change indicePairs layout to [2, KV, num]

2. working on batch indice conv
parent 6c767a51
...@@ -81,8 +81,8 @@ fusedIndiceConvBatchNorm(torch::Tensor features, torch::Tensor filters, ...@@ -81,8 +81,8 @@ fusedIndiceConvBatchNorm(torch::Tensor features, torch::Tensor filters,
// auto timer = spconv::CudaContextTimer<>(); // auto timer = spconv::CudaContextTimer<>();
auto outputBufferBlob = torch::from_blob(outputBuffer.data_ptr(), auto outputBufferBlob = torch::from_blob(outputBuffer.data_ptr(),
{nHot, numOutPlanes}, options); {nHot, numOutPlanes}, options);
auto inputBufferBlob = torch::from_blob(inputBuffer.data_ptr(), auto inputBufferBlob =
{nHot, numInPlanes}, options); torch::from_blob(inputBuffer.data_ptr(), {nHot, numInPlanes}, options);
if (device == torch::kCPU) { if (device == torch::kCPU) {
sparse_gather_cpu(inputBuffer, features, indicePairs[i][inverse], nHot); sparse_gather_cpu(inputBuffer, features, indicePairs[i][inverse], nHot);
...@@ -101,11 +101,13 @@ fusedIndiceConvBatchNorm(torch::Tensor features, torch::Tensor filters, ...@@ -101,11 +101,13 @@ fusedIndiceConvBatchNorm(torch::Tensor features, torch::Tensor filters,
// totalGEMMTime += timer.report() / 1000.0; // totalGEMMTime += timer.report() / 1000.0;
if (device == torch::kCPU) { if (device == torch::kCPU) {
sparse_scatter_add_cpu(outputBuffer, output, indicePairs[i][!inverse], nHot); sparse_scatter_add_cpu(outputBuffer, output, indicePairs[i][!inverse],
nHot);
} }
#ifdef TV_CUDA #ifdef TV_CUDA
else if (device == torch::kCUDA) { else if (device == torch::kCUDA) {
sparse_scatter_add_cuda(outputBuffer, output, indicePairs[i][!inverse], nHot); sparse_scatter_add_cuda(outputBuffer, output, indicePairs[i][!inverse],
nHot);
} }
#endif #endif
else { else {
......
...@@ -55,10 +55,10 @@ __global__ void prepareIndicePairsKernel( ...@@ -55,10 +55,10 @@ __global__ void prepareIndicePairsKernel(
pointPtr = validPoints + i * (NDim + 1); pointPtr = validPoints + i * (NDim + 1);
auto offset = pointPtr[NDim]; auto offset = pointPtr[NDim];
auto oldNum = atomicAdd(indiceNum.data() + offset, Index(1)); auto oldNum = atomicAdd(indiceNum.data() + offset, Index(1));
indicePairs(offset, 0, oldNum) = ix; indicePairs(0, offset, oldNum) = ix;
index = tv::rowArrayIdx<Index, NDim>(pointPtr, outSpatialShape.data()) + index = tv::rowArrayIdx<Index, NDim>(pointPtr, outSpatialShape.data()) +
spatialVolume * indicesIn(ix, 0); spatialVolume * indicesIn(ix, 0);
indicePairs(offset, 1, oldNum) = index; indicePairs(1, offset, oldNum) = index;
indicePairUnique[offset * indicePairsDim2 + oldNum] = index; indicePairUnique[offset * indicePairsDim2 + oldNum] = index;
} }
} }
...@@ -98,10 +98,10 @@ __global__ void prepareDeConvIndicePairsKernel( ...@@ -98,10 +98,10 @@ __global__ void prepareDeConvIndicePairsKernel(
pointPtr = validPoints + i * (NDim + 1); pointPtr = validPoints + i * (NDim + 1);
auto offset = pointPtr[NDim]; auto offset = pointPtr[NDim];
auto oldNum = atomicAdd(indiceNum.data() + offset, Index(1)); auto oldNum = atomicAdd(indiceNum.data() + offset, Index(1));
indicePairs(offset, 0, oldNum) = ix; indicePairs(0, offset, oldNum) = ix;
index = tv::rowArrayIdx<Index, NDim>(pointPtr, outSpatialShape.data()) + index = tv::rowArrayIdx<Index, NDim>(pointPtr, outSpatialShape.data()) +
spatialVolume * indicesIn(ix, 0); spatialVolume * indicesIn(ix, 0);
indicePairs(offset, 1, oldNum) = index; indicePairs(1, offset, oldNum) = index;
indicePairUnique[offset * indicePairsDim2 + oldNum] = index; indicePairUnique[offset * indicePairsDim2 + oldNum] = index;
} }
} }
...@@ -152,15 +152,16 @@ assignIndicePairsHashKernel(tv::TensorView<Index> indicesOut, int numActIn, ...@@ -152,15 +152,16 @@ assignIndicePairsHashKernel(tv::TensorView<Index> indicesOut, int numActIn,
uint2 stash_constants, unsigned stash_count) { uint2 stash_constants, unsigned stash_count) {
Index index; Index index;
int kernelVolume = indicePairs.dim(0); int kernelVolume = indicePairs.dim(1);
auto indicePairsOut = indicePairs.subview(1);
for (int ix : tv::KernelLoopX<int>(numActIn)) { for (int ix : tv::KernelLoopX<int>(numActIn)) {
for (int i = 0; i < kernelVolume; ++i) { for (int i = 0; i < kernelVolume; ++i) {
index = indicePairs(i, 1, ix); index = indicePairsOut(i, ix);
if (index > -1) { if (index > -1) {
auto val = cuhash::retrieve((unsigned)(index), table_size, table, auto val = cuhash::retrieve((unsigned)(index), table_size, table,
constants, stash_constants, stash_count); constants, stash_constants, stash_count);
assert(val != cuhash::kNotFound); assert(val != cuhash::kNotFound);
indicePairs(i, 1, ix) = (unsigned)val; indicePairsOut(i, ix) = (unsigned)val;
} }
} }
} }
...@@ -175,12 +176,14 @@ assignIndicePairsKernel(tv::TensorView<Index> indicesOut, ...@@ -175,12 +176,14 @@ assignIndicePairsKernel(tv::TensorView<Index> indicesOut,
const tv::SimpleVector<Index, NDim> outSpatialShape) { const tv::SimpleVector<Index, NDim> outSpatialShape) {
Index index; Index index;
int kernelVolume = indicePairs.dim(0); int kernelVolume = indicePairs.dim(1);
auto indicePairsOut = indicePairs.subview(1);
for (int ix : tv::KernelLoopX<int>(numActIn)) { for (int ix : tv::KernelLoopX<int>(numActIn)) {
for (int i = 0; i < kernelVolume; ++i) { for (int i = 0; i < kernelVolume; ++i) {
index = indicePairs(i, 1, ix); index = indicePairsOut(i, ix);
if (index > -1) { if (index > -1) {
indicePairs(i, 1, ix) = gridsOut[index]; indicePairsOut(i, ix) = gridsOut[index];
} }
} }
} }
...@@ -259,8 +262,8 @@ __global__ void getSubMIndicePairsKernel( ...@@ -259,8 +262,8 @@ __global__ void getSubMIndicePairsKernel(
spatialVolume * indicesIn(ix, 0); spatialVolume * indicesIn(ix, 0);
if (gridsOut[index] > -1) { if (gridsOut[index] > -1) {
auto oldNum = atomicAdd(indiceNum.data() + offset, Index(1)); auto oldNum = atomicAdd(indiceNum.data() + offset, Index(1));
indicePairs(offset, 1, oldNum) = gridsOut[index]; indicePairs(1, offset, oldNum) = gridsOut[index];
indicePairs(offset, 0, oldNum) = ix; indicePairs(0, offset, oldNum) = ix;
} }
} }
} }
...@@ -302,8 +305,8 @@ __global__ void getSubMIndicePairsHashKernel( ...@@ -302,8 +305,8 @@ __global__ void getSubMIndicePairsHashKernel(
constants, stash_constants, stash_count); constants, stash_constants, stash_count);
if (val != cuhash::kNotFound) { if (val != cuhash::kNotFound) {
auto oldNum = atomicAdd(indiceNum.data() + offset, Index(1)); auto oldNum = atomicAdd(indiceNum.data() + offset, Index(1));
indicePairs(offset, 1, oldNum) = val; indicePairs(1, offset, oldNum) = val;
indicePairs(offset, 0, oldNum) = ix; indicePairs(0, offset, oldNum) = ix;
} }
} }
} }
......
...@@ -14,26 +14,31 @@ ...@@ -14,26 +14,31 @@
#ifndef SPARSE_MAXPOOL_FUNCTOR_H_ #ifndef SPARSE_MAXPOOL_FUNCTOR_H_
#define SPARSE_MAXPOOL_FUNCTOR_H_ #define SPARSE_MAXPOOL_FUNCTOR_H_
#include <tensorview/tensorview.h> #include <tensorview/mp_helper.h>
#include <tensorview/tensor.h>
#include <tensorview/torch_utils.h>
#include <torch/script.h>
namespace spconv { namespace spconv {
namespace functor {
template <typename Device, typename T, typename Index> void maxpool_bwd_cpu(torch::Tensor outFeatures, torch::Tensor inFeatures,
struct SparseMaxPoolForwardFunctor { torch::Tensor dout, torch::Tensor din,
void operator()(const Device &d, tv::TensorView<T> outFeatures, torch::Tensor indicesIn, torch::Tensor indicesOut,
tv::TensorView<const T> inFeatures, int size);
tv::TensorView<const Index> indices, int size);
}; void maxpool_fwd_cpu(torch::Tensor outFeatures, torch::Tensor inFeatures,
torch::Tensor indicesIn, torch::Tensor indicesOut,
template <typename Device, typename T, typename Index> int size);
struct SparseMaxPoolBackwardFunctor {
void operator()(const Device &d, tv::TensorView<const T> outFeatures, void maxpool_bwd_cuda(torch::Tensor outFeatures, torch::Tensor inFeatures,
tv::TensorView<const T> inFeatures, torch::Tensor dout, torch::Tensor din,
tv::TensorView<const T> dout, tv::TensorView<T> din, torch::Tensor indicesIn, torch::Tensor indicesOut,
tv::TensorView<const Index> indices, int size); int size);
};
void maxpool_fwd_cuda(torch::Tensor outFeatures, torch::Tensor inFeatures,
} // namespace functor torch::Tensor indicesIn, torch::Tensor indicesOut,
int size);
} // namespace spconv } // namespace spconv
#endif #endif
\ No newline at end of file
...@@ -21,87 +21,14 @@ ...@@ -21,87 +21,14 @@
#include <utility/timer.h> #include <utility/timer.h>
namespace spconv { namespace spconv {
template <typename T>
torch::Tensor indiceMaxPool(torch::Tensor features, torch::Tensor indicePairs, torch::Tensor indiceMaxPool(torch::Tensor features, torch::Tensor indicePairs,
torch::Tensor indiceNum, int64_t numAct) { torch::Tensor indiceNum, int64_t numAct);
auto device = features.device().type();
auto kernelVolume = indicePairs.size(0);
auto numInPlanes = features.size(1);
auto indicePairNumCpu = indiceNum.to({torch::kCPU});
auto options =
torch::TensorOptions().dtype(features.dtype()).device(features.device());
torch::Tensor output = torch::zeros({numAct, numInPlanes}, options);
double totalTime = 0;
for (int i = 0; i < kernelVolume; ++i) {
auto nHot = indicePairNumCpu.data_ptr<int>()[i];
if (nHot <= 0) {
continue;
}
// auto timer = spconv::CudaContextTimer<>();
if (device == torch::kCPU) {
functor::SparseMaxPoolForwardFunctor<tv::CPU, T, int> forwardFtor;
forwardFtor(tv::CPU(), tv::torch2tv<T>(output),
tv::torch2tv<const T>(features),
tv::torch2tv<const int>(indicePairs).subview(i), nHot);
}
#ifdef TV_CUDA
else if (device == torch::kCUDA) {
functor::SparseMaxPoolForwardFunctor<tv::GPU, T, int> forwardFtor;
forwardFtor(tv::TorchGPU(), tv::torch2tv<T>(output),
tv::torch2tv<const T>(features),
tv::torch2tv<const int>(indicePairs).subview(i), nHot);
TV_CHECK_CUDA_ERR();
}
#endif
else {
TV_ASSERT_INVALID_ARG(false, "unknown device type");
}
// totalTime += timer.report() / 1000.0;
}
// std::cout << "maxpool forward time " << totalTime << std::endl;
return output;
}
template <typename T> torch::Tensor indiceMaxPoolBackward(torch::Tensor features,
torch::Tensor torch::Tensor outFeatures,
indiceMaxPoolBackward(torch::Tensor features, torch::Tensor outFeatures, torch::Tensor outGrad,
torch::Tensor outGrad, torch::Tensor indicePairs, torch::Tensor indicePairs,
torch::Tensor indiceNum) { torch::Tensor indiceNum);
auto device = features.device().type();
auto numInPlanes = features.size(1);
auto indicePairNumCpu = indiceNum.to({torch::kCPU});
auto options =
torch::TensorOptions().dtype(features.dtype()).device(features.device());
torch::Tensor inputGrad = torch::zeros(features.sizes(), options);
auto kernelVolume = indicePairs.size(0);
for (int i = 0; i < kernelVolume; ++i) {
auto nHot = indicePairNumCpu.data_ptr<int>()[i];
if (nHot <= 0) {
continue;
}
if (device == torch::kCPU) {
functor::SparseMaxPoolBackwardFunctor<tv::CPU, T, int> backwardFtor;
backwardFtor(tv::CPU(), tv::torch2tv<const T>(outFeatures),
tv::torch2tv<const T>(features),
tv::torch2tv<const T>(outGrad), tv::torch2tv<T>(inputGrad),
tv::torch2tv<const int>(indicePairs).subview(i), nHot);
}
#ifdef TV_CUDA
else if (device == torch::kCUDA) {
functor::SparseMaxPoolBackwardFunctor<tv::GPU, T, int> backwardFtor;
backwardFtor(tv::TorchGPU(), tv::torch2tv<const T>(outFeatures),
tv::torch2tv<const T>(features),
tv::torch2tv<const T>(outGrad), tv::torch2tv<T>(inputGrad),
tv::torch2tv<const int>(indicePairs).subview(i), nHot);
TV_CHECK_CUDA_ERR();
}
#endif
else {
TV_ASSERT_INVALID_ARG(false, "unknown device type");
}
}
return inputGrad;
}
} // namespace spconv } // namespace spconv
......
...@@ -96,6 +96,90 @@ __global__ void gatherVecBlockKernel(T *buffer, const T *features, ...@@ -96,6 +96,90 @@ __global__ void gatherVecBlockKernel(T *buffer, const T *features,
} }
} }
template <typename T, typename Index, int NumTLP, int NumILP>
__global__ void batchGatherGenericKernel(T *buffer, const T *features,
const Index *indices, int size,
int numPlanes, int batch_stride,
int feature_batch_stride) {
int ILPStrideX[NumILP];
Index inds[NumILP];
Index batchIdx[NumILP];
#pragma unroll
for (int ilp = 0; ilp < NumILP; ilp++)
ILPStrideX[ilp] = ilp * gridDim.x * blockDim.x;
for (int ix : tv::KernelLoopX<int, NumILP>(size)) {
#pragma unroll
for (int ilp = 0; ilp < NumILP; ilp++) {
if (ix + ILPStrideX[ilp] < size) {
batchIdx[ilp] = ix / feature_batch_stride;
inds[ilp] = indices[ix + ILPStrideX[ilp]] * numPlanes;
}
}
for (int iy : tv::KernelLoopY<int>(numPlanes)) {
#pragma unroll
for (int ilp = 0; ilp < NumILP; ++ilp) {
if (ix + ILPStrideX[ilp] < size)
buffer[(ix + ILPStrideX[ilp]) * numPlanes + iy] =
features[inds[ilp] + iy];
}
}
}
}
template <typename T, typename Index, int NumTLP, int NumILP, typename VecType>
__global__ void batchGatherVecKernel(T *buffer, const T *features,
const Index *indices, int size,
int numPlanes) {
int ILPStrideX[NumILP];
Index inds[NumILP];
#pragma unroll
for (int ilp = 0; ilp < NumILP; ilp++)
ILPStrideX[ilp] = ilp * gridDim.x * blockDim.x;
for (int ix : tv::KernelLoopX<int, NumILP>(size)) {
#pragma unroll
for (int ilp = 0; ilp < NumILP; ilp++) {
if (ix + ILPStrideX[ilp] < size)
inds[ilp] = indices[ix + ILPStrideX[ilp]] * numPlanes;
}
for (int iy : tv::KernelLoopY<int>(numPlanes)) {
#pragma unroll
for (int ilp = 0; ilp < NumILP; ++ilp) {
if (ix + ILPStrideX[ilp] < size)
reinterpret_cast<VecType *>(
buffer)[(ix + ILPStrideX[ilp]) * numPlanes + iy] =
reinterpret_cast<const VecType *>(features)[inds[ilp] + iy];
}
}
}
}
template <typename T, typename Index, int NumTLP, int NumILP,
typename VecType = int4>
__global__ void batchGatherVecBlockKernel(T *buffer, const T *features,
const Index *indices, int size,
int numPlanes) {
int ILPStrideY[NumILP];
#pragma unroll
for (int ilp = 0; ilp < NumILP; ilp++)
ILPStrideY[ilp] = ilp * gridDim.y * blockDim.y;
features += blockIdx.x * NumTLP;
buffer += blockIdx.x * NumTLP;
for (int iy : tv::KernelLoopY<int, NumILP>(size)) {
#pragma unroll
for (int ilp = 0; ilp < NumILP; ++ilp) {
reinterpret_cast<VecType *>(
buffer)[(iy + ILPStrideY[ilp]) * numPlanes + threadIdx.x] =
reinterpret_cast<const VecType *>(
features)[indices[iy + ILPStrideY[ilp]] * numPlanes +
threadIdx.x];
}
}
}
template <typename T, typename Index, int NumTLP, int NumILP> template <typename T, typename Index, int NumTLP, int NumILP>
__global__ void scatterAddGenericKernel(T *outFeatures, const T *buffer, __global__ void scatterAddGenericKernel(T *outFeatures, const T *buffer,
const Index *indices, int size, const Index *indices, int size,
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include <torch/script.h> #include <torch/script.h>
namespace spconv { namespace spconv {
void sparse_gather_cuda(torch::Tensor buffer, torch::Tensor features, void sparse_gather_cuda(torch::Tensor buffer, torch::Tensor features,
torch::Tensor indices, int size); torch::Tensor indices, int size);
void sparse_scatter_add_cuda(torch::Tensor buffer, torch::Tensor outFeatures, void sparse_scatter_add_cuda(torch::Tensor buffer, torch::Tensor outFeatures,
......
...@@ -61,7 +61,7 @@ getIndicePair(torch::Tensor indices, int64_t batchSize, ...@@ -61,7 +61,7 @@ getIndicePair(torch::Tensor indices, int64_t batchSize,
TV_ASSERT_RT_ERR(batchSize * outputVolume < std::numeric_limits<int>::max(), TV_ASSERT_RT_ERR(batchSize * outputVolume < std::numeric_limits<int>::max(),
msg); msg);
torch::Tensor indicePairs = torch::Tensor indicePairs =
torch::full({kernelVolume, 2, numAct}, -1, torch::full({2, kernelVolume, numAct}, -1,
torch::dtype(torch::kInt32).device(indices.device())); torch::dtype(torch::kInt32).device(indices.device()));
torch::Tensor indiceNum = torch::zeros( torch::Tensor indiceNum = torch::zeros(
{kernelVolume}, torch::dtype(torch::kInt32).device(indices.device())); {kernelVolume}, torch::dtype(torch::kInt32).device(indices.device()));
......
...@@ -27,13 +27,13 @@ ...@@ -27,13 +27,13 @@
namespace py = pybind11; namespace py = pybind11;
namespace tv { namespace tv {
template <typename Tarr> bool is_c_stype(const Tarr &arr) { template <typename Tarr> bool is_c_style(const Tarr &arr) {
return bool(arr.flags() & py::array::c_style); return bool(arr.flags() & py::array::c_style);
} }
template <typename T, int Rank = -1> template <typename T, int Rank = -1>
TensorView<T, Rank> arrayt2tv(py::array_t<T> arr) { TensorView<T, Rank> arrayt2tv(py::array_t<T> arr) {
TV_ASSERT_INVALID_ARG(is_c_stype(arr), "array must be c-contiguous array"); TV_ASSERT_INVALID_ARG(is_c_style(arr), "array must be c-contiguous array");
Shape shape; Shape shape;
for (int i = 0; i < arr.ndim(); ++i) { for (int i = 0; i < arr.ndim(); ++i) {
shape.push_back(arr.shape(i)); shape.push_back(arr.shape(i));
...@@ -46,7 +46,7 @@ TensorView<T, Rank> arrayt2tv(py::array_t<T> arr) { ...@@ -46,7 +46,7 @@ TensorView<T, Rank> arrayt2tv(py::array_t<T> arr) {
template <typename T, int Rank = -1> template <typename T, int Rank = -1>
TensorView<const T> carrayt2tv(py::array_t<T> arr) { TensorView<const T> carrayt2tv(py::array_t<T> arr) {
TV_ASSERT_INVALID_ARG(is_c_stype(arr), "array must be c-contiguous array"); TV_ASSERT_INVALID_ARG(is_c_style(arr), "array must be c-contiguous array");
Shape shape; Shape shape;
for (int i = 0; i < arr.ndim(); ++i) { for (int i = 0; i < arr.ndim(); ++i) {
shape.push_back(arr.shape(i)); shape.push_back(arr.shape(i));
...@@ -106,7 +106,7 @@ template <typename Tarr> tv::DType get_array_tv_dtype(const Tarr &arr) { ...@@ -106,7 +106,7 @@ template <typename Tarr> tv::DType get_array_tv_dtype(const Tarr &arr) {
} }
template <typename Tarr> Tensor array2tensor(Tarr &arr) { template <typename Tarr> Tensor array2tensor(Tarr &arr) {
TV_ASSERT_INVALID_ARG(is_c_stype(arr), "array must be c-contiguous array"); TV_ASSERT_INVALID_ARG(is_c_style(arr), "array must be c-contiguous array");
TensorShape shape; TensorShape shape;
for (int i = 0; i < arr.ndim(); ++i) { for (int i = 0; i < arr.ndim(); ++i) {
shape.push_back(arr.shape(i)); shape.push_back(arr.shape(i));
...@@ -115,7 +115,7 @@ template <typename Tarr> Tensor array2tensor(Tarr &arr) { ...@@ -115,7 +115,7 @@ template <typename Tarr> Tensor array2tensor(Tarr &arr) {
} }
template <typename T> Tensor arrayt2tensor(py::array_t<T> &arr) { template <typename T> Tensor arrayt2tensor(py::array_t<T> &arr) {
TV_ASSERT_INVALID_ARG(is_c_stype(arr), "array must be c-contiguous array"); TV_ASSERT_INVALID_ARG(is_c_style(arr), "array must be c-contiguous array");
TensorShape shape; TensorShape shape;
for (int i = 0; i < arr.ndim(); ++i) { for (int i = 0; i < arr.ndim(); ++i) {
shape.push_back(arr.shape(i)); shape.push_back(arr.shape(i));
......
...@@ -305,7 +305,7 @@ template <> struct TypeToDtypeTraits<const bool> { ...@@ -305,7 +305,7 @@ template <> struct TypeToDtypeTraits<const bool> {
template <class T> constexpr DType type_v = detail::TypeToDtypeTraits<T>::dtype; template <class T> constexpr DType type_v = detail::TypeToDtypeTraits<T>::dtype;
template <class... Ts, typename F> void dispatch(DType t, F &&f) { template <class... Ts, typename F> bool dispatch_noexcept(DType t, F &&f) {
static_assert(sizeof...(Ts) > 0, "you need to provide at least one type"); static_assert(sizeof...(Ts) > 0, "you need to provide at least one type");
bool notFound = true; bool notFound = true;
mp_for_each<mp_list<Ts...>>([=, &notFound, &f](auto I) { mp_for_each<mp_list<Ts...>>([=, &notFound, &f](auto I) {
...@@ -314,7 +314,11 @@ template <class... Ts, typename F> void dispatch(DType t, F &&f) { ...@@ -314,7 +314,11 @@ template <class... Ts, typename F> void dispatch(DType t, F &&f) {
notFound = false; notFound = false;
} }
}); });
if (notFound) { return !notFound;
}
template <class... Ts, typename F> void dispatch(DType t, F &&f) {
if (!dispatch_noexcept<Ts...>(t, std::forward<F>(f))) {
std::stringstream ss; std::stringstream ss;
mp_for_each<mp_list<Ts...>>([=, &ss](auto I) { mp_for_each<mp_list<Ts...>>([=, &ss](auto I) {
ss << detail::TypeToString<decltype(I)>::value << " "; ss << detail::TypeToString<decltype(I)>::value << " ";
...@@ -341,8 +345,7 @@ template <typename T, T... Is, typename F> void dispatch_scalar(T idx, F &&f) { ...@@ -341,8 +345,7 @@ template <typename T, T... Is, typename F> void dispatch_scalar(T idx, F &&f) {
} }
} }
template <int... Is, typename F> void dispatch_int(int idx, F &&f) { template <int... Is, typename F> bool dispatch_int_noexcept(int idx, F &&f) {
// used for kernel parameter selection
static_assert(sizeof...(Is) > 0, static_assert(sizeof...(Is) > 0,
"you need to provide at least one candidate"); "you need to provide at least one candidate");
bool notFound = true; bool notFound = true;
...@@ -352,17 +355,11 @@ template <int... Is, typename F> void dispatch_int(int idx, F &&f) { ...@@ -352,17 +355,11 @@ template <int... Is, typename F> void dispatch_int(int idx, F &&f) {
notFound = false; notFound = false;
} }
}); });
if (notFound) { return !notFound;
std::stringstream ss;
mp_for_each<mp_list_c<int, Is...>>(
[=, &ss](auto I) { ss << decltype(I)::value << " "; });
TV_THROW_RT_ERR("unknown value", idx, ", available:", ss.str());
}
} }
template <int... Is, typename F, class BinaryPredicate> template <int... Is, typename F, class BinaryPredicate>
void dispatch_int(int idx, BinaryPredicate p, F &&f) { bool dispatch_int_noexcept(int idx, BinaryPredicate p, F &&f) {
// BinaryPredicate: BinaryPredicate(idx, candidate)
static_assert(sizeof...(Is) > 0, static_assert(sizeof...(Is) > 0,
"you need to provide at least one candidate"); "you need to provide at least one candidate");
bool notFound = true; bool notFound = true;
...@@ -372,7 +369,22 @@ void dispatch_int(int idx, BinaryPredicate p, F &&f) { ...@@ -372,7 +369,22 @@ void dispatch_int(int idx, BinaryPredicate p, F &&f) {
notFound = false; notFound = false;
} }
}); });
if (notFound) { return !notFound;
}
template <int... Is, typename F> void dispatch_int(int idx, F &&f) {
if (!dispatch_int_noexcept<Is...>(idx, std::forward<F>(f))) {
std::stringstream ss;
mp_for_each<mp_list_c<int, Is...>>(
[=, &ss](auto I) { ss << decltype(I)::value << " "; });
TV_THROW_RT_ERR("unknown value", idx, ", available:", ss.str());
}
}
template <int... Is, typename F, class BinaryPredicate>
void dispatch_int(int idx, BinaryPredicate p, F &&f) {
// BinaryPredicate: BinaryPredicate(idx, candidate)
if (!dispatch_int_noexcept<Is...>(idx, p, std::forward<F>(f))) {
std::stringstream ss; std::stringstream ss;
mp_for_each<mp_list_c<int, Is...>>( mp_for_each<mp_list_c<int, Is...>>(
[=, &ss](auto I) { ss << decltype(I)::value << " "; }); [=, &ss](auto I) { ss << decltype(I)::value << " "; });
...@@ -397,13 +409,18 @@ struct Dispatch<T<Args...>> { ...@@ -397,13 +409,18 @@ struct Dispatch<T<Args...>> {
template <class T> struct DispatchInt; template <class T> struct DispatchInt;
template <template<class...> class Tin, template<class, int> class T, int... Ints> // Args should be std::integral_constant<int, value>
struct DispatchInt<Tin<T<int, Ints>...>> { // you need to use type_container<std::integral_constant<int, value>...>
// as template parameter of DispatchInt.
// tv::mp_list_c is ok.
template <template <class...> class T, class... Args>
struct DispatchInt<T<Args...>> {
template <typename F> inline void operator()(int t, F &&f) { template <typename F> inline void operator()(int t, F &&f) {
return dispatch_int<Ints...>(t, std::forward<F>(f)); return dispatch_int<Args::value...>(t, std::forward<F>(f));
} }
template <typename F, typename BinaryPredicate> inline void operator()(int t, BinaryPredicate p, F &&f) { template <typename F, typename BinaryPredicate>
return dispatch_int<Ints...>(t, p, std::forward<F>(f)); inline void operator()(int t, BinaryPredicate p, F &&f) {
return dispatch_int<Args::value...>(t, p, std::forward<F>(f));
} }
}; };
......
...@@ -157,7 +157,7 @@ class SparseConvolution(SparseModule): ...@@ -157,7 +157,7 @@ class SparseConvolution(SparseModule):
if self.inverse: if self.inverse:
assert datas is not None and self.indice_key is not None assert datas is not None and self.indice_key is not None
_, outids, indice_pairs, indice_pair_num, out_spatial_shape = datas _, outids, indice_pairs, indice_pair_num, out_spatial_shape = datas
assert indice_pairs.shape[0] == np.prod( assert indice_pair_num.shape[0] == np.prod(
self.kernel_size self.kernel_size
), "inverse conv must have same kernel size as its couple conv" ), "inverse conv must have same kernel size as its couple conv"
else: else:
......
...@@ -77,14 +77,15 @@ def get_indice_pairs(indices, ...@@ -77,14 +77,15 @@ def get_indice_pairs(indices,
else: else:
out_shape = get_conv_output_size(spatial_shape, ksize, stride, out_shape = get_conv_output_size(spatial_shape, ksize, stride,
padding, dilation) padding, dilation)
else: else:
out_shape = spatial_shape out_shape = spatial_shape
if grid is None: if grid is None:
res = torch.ops.spconv.get_indice_pairs_v2(indices, batch_size, out_shape, res = torch.ops.spconv.get_indice_pairs_v2(indices, batch_size,
spatial_shape, ksize, stride, padding, out_shape, spatial_shape,
dilation, out_padding, int(subm), ksize, stride, padding,
int(transpose), int(use_hash)) dilation, out_padding,
int(subm), int(transpose),
int(use_hash))
return res return res
else: else:
if ndim == 2: if ndim == 2:
...@@ -106,22 +107,17 @@ def indice_conv(features, ...@@ -106,22 +107,17 @@ def indice_conv(features,
num_activate_out, num_activate_out,
inverse=False, inverse=False,
subm=False): subm=False):
return torch.ops.spconv.indice_conv_v2(features, filters, indice_pairs, return torch.ops.spconv.indice_conv(features, filters, indice_pairs,
indice_pair_num, num_activate_out, indice_pair_num, num_activate_out,
int(inverse), int(subm)) int(inverse), int(subm))
def fused_indice_conv(features, filters, bias, indice_pairs, indice_pair_num, def fused_indice_conv(features, filters, bias, indice_pairs, indice_pair_num,
num_activate_out, inverse, subm): num_activate_out, inverse, subm):
if features.dtype == torch.half: return torch.ops.spconv.fused_indice_conv_bn(features, filters, bias,
func = torch.ops.spconv.fused_indice_conv_half indice_pairs, indice_pair_num,
elif filters.dtype == torch.float32: num_activate_out,
func = torch.ops.spconv.fused_indice_conv_fp32 int(inverse), int(subm))
else:
raise NotImplementedError
return func(features, filters, bias, indice_pairs, indice_pair_num,
num_activate_out, int(inverse), int(subm))
def indice_conv_backward(features, def indice_conv_backward(features,
...@@ -137,28 +133,15 @@ def indice_conv_backward(features, ...@@ -137,28 +133,15 @@ def indice_conv_backward(features,
def indice_maxpool(features, indice_pairs, indice_pair_num, num_activate_out): def indice_maxpool(features, indice_pairs, indice_pair_num, num_activate_out):
if features.dtype == torch.float32: return torch.ops.spconv.indice_maxpool(features, indice_pairs,
return torch.ops.spconv.indice_maxpool_fp32(features, indice_pairs, indice_pair_num, num_activate_out)
indice_pair_num,
num_activate_out)
elif features.dtype == torch.half:
return torch.ops.spconv.indice_maxpool_half(features, indice_pairs,
indice_pair_num,
num_activate_out)
else:
raise NotImplementedError
def indice_maxpool_backward(features, out_features, out_bp, indice_pairs, def indice_maxpool_backward(features, out_features, out_bp, indice_pairs,
indice_pair_num): indice_pair_num):
if features.dtype == torch.float32: return torch.ops.spconv.indice_maxpool_backward(features, out_features,
return torch.ops.spconv.indice_maxpool_backward_fp32( out_bp, indice_pairs,
features, out_features, out_bp, indice_pairs, indice_pair_num) indice_pair_num)
elif features.dtype == torch.half:
return torch.ops.spconv.indice_maxpool_backward_half(
features, out_features, out_bp, indice_pairs, indice_pair_num)
else:
raise NotImplementedError
def nms(boxes, scores, pre_max_size, post_max_size, thresh, eps): def nms(boxes, scores, pre_max_size, post_max_size, thresh, eps):
......
set(ALL_FILES all.cc indice.cc reordering.cc maxpool.cc nms.cc spconv_ops.cc) set(ALL_FILES all.cc indice.cc reordering.cc maxpool.cc nms.cc spconv_ops.cc pool_ops.cc)
if (SPCONV_BuildCUDA) if (SPCONV_BuildCUDA)
set(ALL_FILES ${ALL_FILES} indice.cu reordering.cu maxpool.cu pillar_scatter.cu) set(ALL_FILES ${ALL_FILES} indice.cu reordering.cu maxpool.cu pillar_scatter.cu)
endif() endif()
......
...@@ -12,12 +12,12 @@ ...@@ -12,12 +12,12 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include <spconv/fused_spconv_ops.h>
#include <spconv/nms_ops.h> #include <spconv/nms_ops.h>
#include <spconv/pillar_scatter_ops.h> #include <spconv/pillar_scatter_ops.h>
#include <spconv/pool_ops.h> #include <spconv/pool_ops.h>
#include <spconv/spconv_ops.h> #include <spconv/spconv_ops.h>
#include <torch/script.h> #include <torch/script.h>
#include <spconv/fused_spconv_ops.h>
static auto registry = static auto registry =
torch::RegisterOperators() torch::RegisterOperators()
...@@ -31,14 +31,9 @@ static auto registry = ...@@ -31,14 +31,9 @@ static auto registry =
&spconv::getIndicePairPreGrid<3>) &spconv::getIndicePairPreGrid<3>)
.op("spconv::indice_conv", &spconv::indiceConv) .op("spconv::indice_conv", &spconv::indiceConv)
.op("spconv::indice_conv_backward", &spconv::indiceConvBackward) .op("spconv::indice_conv_backward", &spconv::indiceConvBackward)
.op("spconv::fused_indice_conv_bn", .op("spconv::fused_indice_conv_bn", &spconv::fusedIndiceConvBatchNorm)
&spconv::fusedIndiceConvBatchNorm) .op("spconv::indice_maxpool", &spconv::indiceMaxPool)
.op("spconv::indice_maxpool_fp32", &spconv::indiceMaxPool<float>) .op("spconv::indice_maxpool_backward", &spconv::indiceMaxPoolBackward)
.op("spconv::indice_maxpool_backward_fp32",
&spconv::indiceMaxPoolBackward<float>)
.op("spconv::indice_maxpool_half", &spconv::indiceMaxPool<at::Half>)
.op("spconv::indice_maxpool_backward_half",
&spconv::indiceMaxPoolBackward<at::Half>)
.op("spconv::nms", &spconv::nonMaxSuppression<float>) .op("spconv::nms", &spconv::nonMaxSuppression<float>)
.op("spconv::pillar_scatter_float", &spconv::pointPillarScatter<float>) .op("spconv::pillar_scatter_float", &spconv::pointPillarScatter<float>)
.op("spconv::pillar_scatter_half", .op("spconv::pillar_scatter_half",
......
...@@ -72,8 +72,8 @@ Index getIndicePairsConv(tv::TensorView<const Index> indicesIn, ...@@ -72,8 +72,8 @@ Index getIndicePairsConv(tv::TensorView<const Index> indicesIn,
hashval = iter->second; hashval = iter->second;
} }
// indicePairs: [K, 2, L] // indicePairs: [K, 2, L]
indicePairs(offset, 0, indiceNum[offset]) = j; indicePairs(0, offset, indiceNum[offset]) = j;
indicePairs(offset, 1, indiceNum[offset]++) = hashval; indicePairs(1, offset, indiceNum[offset]++) = hashval;
} }
} }
return numAct; return numAct;
...@@ -130,8 +130,8 @@ Index getIndicePairsDeConv(tv::TensorView<const Index> indicesIn, ...@@ -130,8 +130,8 @@ Index getIndicePairsDeConv(tv::TensorView<const Index> indicesIn,
hashval = iter->second; hashval = iter->second;
} }
// indicePairs: [K, 2, L] // indicePairs: [K, 2, L]
indicePairs(offset, 0, indiceNum[offset]) = j; indicePairs(0, offset, indiceNum[offset]) = j;
indicePairs(offset, 1, indiceNum[offset]++) = hashval; indicePairs(1, offset, indiceNum[offset]++) = hashval;
} }
} }
return numAct; return numAct;
...@@ -189,8 +189,8 @@ Index getIndicePairsSubM(tv::TensorView<const Index> indicesIn, ...@@ -189,8 +189,8 @@ Index getIndicePairsSubM(tv::TensorView<const Index> indicesIn,
if (iter != hash.end()) { if (iter != hash.end()) {
#pragma omp atomic capture #pragma omp atomic capture
oldOffset = indiceNum[offset]++; oldOffset = indiceNum[offset]++;
indicePairs(offset, 0, oldOffset) = j; indicePairs(0, offset, oldOffset) = j;
indicePairs(offset, 1, oldOffset) = iter->second; indicePairs(1, offset, oldOffset) = iter->second;
} }
} }
} }
...@@ -245,8 +245,8 @@ Index getIndicePairsSubM(tv::TensorView<const Index> indicesIn, ...@@ -245,8 +245,8 @@ Index getIndicePairsSubM(tv::TensorView<const Index> indicesIn,
spatialVolume * indicesIn(j, 0); spatialVolume * indicesIn(j, 0);
auto iter = hash.find(index); auto iter = hash.find(index);
if (iter != hash.end()) { if (iter != hash.end()) {
indicePairs(offset, 0, indiceNum[offset]) = j; indicePairs(0, offset, indiceNum[offset]) = j;
indicePairs(offset, 1, indiceNum[offset]++) = iter->second; indicePairs(1, offset, indiceNum[offset]++) = iter->second;
} }
} }
} }
...@@ -264,7 +264,7 @@ int create_conv_indice_pair_cpu( ...@@ -264,7 +264,7 @@ int create_conv_indice_pair_cpu(
auto ndim = outSpatialShape.size(); auto ndim = outSpatialShape.size();
auto numActIn = indicesIn.size(0); auto numActIn = indicesIn.size(0);
int batchSize = gridsOut.size(0); int batchSize = gridsOut.size(0);
auto kernelVolume = indicePairs.size(0); auto kernelVolume = indiceNum.size(0);
if (numActIn == 0) if (numActIn == 0)
return 0; return 0;
tv::dispatch_torch<int32_t, int64_t>(indicesIn.scalar_type(), [&](auto V) { tv::dispatch_torch<int32_t, int64_t>(indicesIn.scalar_type(), [&](auto V) {
...@@ -304,7 +304,7 @@ int create_submconv_indice_pair_cpu( ...@@ -304,7 +304,7 @@ int create_submconv_indice_pair_cpu(
auto ndim = outSpatialShape.size(); auto ndim = outSpatialShape.size();
auto numActIn = indicesIn.size(0); auto numActIn = indicesIn.size(0);
int batchSize = gridsOut.size(0); int batchSize = gridsOut.size(0);
auto kernelVolume = indicePairs.size(0); auto kernelVolume = indiceNum.size(0);
if (numActIn == 0) if (numActIn == 0)
return 0; return 0;
tv::dispatch_torch<int32_t, int64_t>(indicesIn.scalar_type(), [&](auto V) { tv::dispatch_torch<int32_t, int64_t>(indicesIn.scalar_type(), [&](auto V) {
......
...@@ -29,6 +29,8 @@ ...@@ -29,6 +29,8 @@
namespace spconv { namespace spconv {
using max_kernel_vol_t = tv::mp_list_c<int, 16, 32, 256, 4096>;
int create_conv_indice_pair_p1_cuda( int create_conv_indice_pair_p1_cuda(
torch::Tensor indicesIn, torch::Tensor indicePairs, torch::Tensor indiceNum, torch::Tensor indicesIn, torch::Tensor indicePairs, torch::Tensor indiceNum,
torch::Tensor indicePairUnique, std::vector<int64_t> kernelSize, torch::Tensor indicePairUnique, std::vector<int64_t> kernelSize,
...@@ -38,7 +40,7 @@ int create_conv_indice_pair_p1_cuda( ...@@ -38,7 +40,7 @@ int create_conv_indice_pair_p1_cuda(
auto stream = at::cuda::getCurrentCUDAStream(); auto stream = at::cuda::getCurrentCUDAStream();
auto ndim = kernelSize.size(); auto ndim = kernelSize.size();
auto numActIn = indicesIn.size(0); auto numActIn = indicesIn.size(0);
auto kernelVolume = indicePairs.size(0); auto kernelVolume = indiceNum.size(0);
if (numActIn == 0) if (numActIn == 0)
return 0; return 0;
tv::dispatch_torch<int32_t>(indicesIn.scalar_type(), [&](auto IndexValue) { tv::dispatch_torch<int32_t>(indicesIn.scalar_type(), [&](auto IndexValue) {
...@@ -52,7 +54,7 @@ int create_conv_indice_pair_p1_cuda( ...@@ -52,7 +54,7 @@ int create_conv_indice_pair_p1_cuda(
tv::SimpleVector<Index, NDim> di(dilation.begin(), dilation.end()); tv::SimpleVector<Index, NDim> di(dilation.begin(), dilation.end());
tv::SimpleVector<Index, NDim> ou(outSpatialShape.begin(), tv::SimpleVector<Index, NDim> ou(outSpatialShape.begin(),
outSpatialShape.end()); outSpatialShape.end());
tv::dispatch_int<16, 32, 256, 4096>( tv::DispatchInt<max_kernel_vol_t>()(
kernelVolume, std::less_equal<int>(), [&](auto I2) { kernelVolume, std::less_equal<int>(), [&](auto I2) {
constexpr int MaxKernelVolume = decltype(I2)::value; constexpr int MaxKernelVolume = decltype(I2)::value;
if (transpose) { if (transpose) {
...@@ -91,7 +93,7 @@ int create_conv_indice_pair_p2_cuda( ...@@ -91,7 +93,7 @@ int create_conv_indice_pair_p2_cuda(
int batchSize = gridsOut.size(0); int batchSize = gridsOut.size(0);
int numAct = indicePairUnique.size(0) - 1; int numAct = indicePairUnique.size(0) - 1;
auto kernelVolume = indicePairs.size(0); auto kernelVolume = indiceNum.size(0);
if (numActIn == 0) if (numActIn == 0)
return 0; return 0;
tv::dispatch_torch<int32_t>(indicesIn.scalar_type(), [&](auto IndexValue) { tv::dispatch_torch<int32_t>(indicesIn.scalar_type(), [&](auto IndexValue) {
...@@ -113,10 +115,9 @@ int create_conv_indice_pair_p2_cuda( ...@@ -113,10 +115,9 @@ int create_conv_indice_pair_p2_cuda(
<<<tv::cuda::getBlocks(numAct), tv::cuda::CUDA_NUM_THREADS, 0, <<<tv::cuda::getBlocks(numAct), tv::cuda::CUDA_NUM_THREADS, 0,
stream>>>(d_values, numAct); stream>>>(d_values, numAct);
TV_CHECK_CUDA_ERR_V2("arangeKernel failed"); TV_CHECK_CUDA_ERR_V2("arangeKernel failed");
bool res = bool res = table.Build(
table.Build(numAct, numAct,
reinterpret_cast<unsigned *>( reinterpret_cast<unsigned *>(indicePairUnique.data_ptr<Index>()),
tv::torch2tv<Index>(indicePairUnique).data()),
d_values); d_values);
cudaFree(d_values); cudaFree(d_values);
TV_CHECK_CUDA_ERR_V2("cudaFree failed"); TV_CHECK_CUDA_ERR_V2("cudaFree failed");
...@@ -182,7 +183,7 @@ int create_submconv_indice_pair_cuda( ...@@ -182,7 +183,7 @@ int create_submconv_indice_pair_cuda(
auto numActIn = indicesIn.size(0); auto numActIn = indicesIn.size(0);
int batchSize = gridsOut.size(0); int batchSize = gridsOut.size(0);
auto kernelVolume = indicePairs.size(0); auto kernelVolume = indiceNum.size(0);
if (numActIn == 0) if (numActIn == 0)
return 0; return 0;
tv::dispatch_torch<int32_t>(indicesIn.scalar_type(), [&](auto IndexValue) { tv::dispatch_torch<int32_t>(indicesIn.scalar_type(), [&](auto IndexValue) {
...@@ -212,7 +213,7 @@ int create_submconv_indice_pair_cuda( ...@@ -212,7 +213,7 @@ int create_submconv_indice_pair_cuda(
bool res = bool res =
table.Build(numActIn, reinterpret_cast<unsigned *>(d_keyvalues), table.Build(numActIn, reinterpret_cast<unsigned *>(d_keyvalues),
reinterpret_cast<unsigned *>(d_values)); reinterpret_cast<unsigned *>(d_values));
cudaFree(d_values); cudaFree(d_keyvalues);
TV_CHECK_CUDA_ERR_V2("cudaFree failed"); TV_CHECK_CUDA_ERR_V2("cudaFree failed");
if (!res) { if (!res) {
return -1; // use -1 to tell outside use CPU implementation return -1; // use -1 to tell outside use CPU implementation
...@@ -222,7 +223,7 @@ int create_submconv_indice_pair_cuda( ...@@ -222,7 +223,7 @@ int create_submconv_indice_pair_cuda(
auto constants = table.get_constants_4(); auto constants = table.get_constants_4();
auto stash_constants = table.get_stash_constants(); auto stash_constants = table.get_stash_constants();
auto stash_count = table.get_stash_count(); auto stash_count = table.get_stash_count();
tv::dispatch_int<16, 32, 256, 4096>( tv::DispatchInt<max_kernel_vol_t>()(
kernelVolume, std::less_equal<int>(), [&](auto I2) { kernelVolume, std::less_equal<int>(), [&](auto I2) {
constexpr int MaxKernelVolume = decltype(I2)::value; constexpr int MaxKernelVolume = decltype(I2)::value;
getSubMIndicePairsHashKernel<Index, NDim, MaxKernelVolume> getSubMIndicePairsHashKernel<Index, NDim, MaxKernelVolume>
...@@ -240,7 +241,7 @@ int create_submconv_indice_pair_cuda( ...@@ -240,7 +241,7 @@ int create_submconv_indice_pair_cuda(
stream>>>(tv::torch2tv<Index>(indicesIn), stream>>>(tv::torch2tv<Index>(indicesIn),
tv::torch2tv<IndexGrid>(gridsOut), ou); tv::torch2tv<IndexGrid>(gridsOut), ou);
TV_CHECK_CUDA_ERR_V2("prepareSubMGridKernel failed"); TV_CHECK_CUDA_ERR_V2("prepareSubMGridKernel failed");
tv::dispatch_int<16, 32, 256, 4096>( tv::DispatchInt<max_kernel_vol_t>()(
ndim, std::less_equal<int>(), [&](auto I2) { ndim, std::less_equal<int>(), [&](auto I2) {
constexpr int MaxKernelVolume = decltype(I2)::value; constexpr int MaxKernelVolume = decltype(I2)::value;
getSubMIndicePairsKernel<Index, IndexGrid, NDim, MaxKernelVolume> getSubMIndicePairsKernel<Index, IndexGrid, NDim, MaxKernelVolume>
...@@ -315,7 +316,7 @@ struct CreateConvIndicePairFunctorP2<tv::GPU, Index, IndexGrid, NDim> { ...@@ -315,7 +316,7 @@ struct CreateConvIndicePairFunctorP2<tv::GPU, Index, IndexGrid, NDim> {
const tv::SimpleVector<Index, NDim> outSpatialShape, const tv::SimpleVector<Index, NDim> outSpatialShape,
bool transpose, bool resetGrid, bool useHash) { bool transpose, bool resetGrid, bool useHash) {
Index batchSize = gridsOut.dim(0); Index batchSize = gridsOut.dim(0);
auto kernelVolume = indicePairs.dim(0); auto kernelVolume = indiceNum.dim(0);
auto numActIn = indicesIn.dim(0); auto numActIn = indicesIn.dim(0);
if (numActIn == 0) if (numActIn == 0)
return 0; return 0;
......
...@@ -17,66 +17,66 @@ ...@@ -17,66 +17,66 @@
namespace spconv { namespace spconv {
namespace functor { using float_types_t = tv::mp_list<float, double, at::Half>;
template <typename T, typename Index> using int_types_t = tv::mp_list<int32_t, int64_t>;
struct SparseMaxPoolForwardFunctor<tv::CPU, T, Index> {
void operator()(const tv::CPU &d, tv::TensorView<T> outFeatures, void maxpool_fwd_cpu(torch::Tensor outFeatures, torch::Tensor inFeatures,
tv::TensorView<const T> inFeatures, torch::Tensor indicesIn, torch::Tensor indicesOut,
tv::TensorView<const Index> indices, int size) { int size) {
int stride = outFeatures.dim(1); if (size <= 0)
auto outFeaturesData = outFeatures.data(); return;
auto inFeaturesData = inFeatures.data(); int stride = inFeatures.size(1);
auto indicesIn = indices.subview(0).data(); auto dtype = inFeatures.scalar_type();
auto indicesOut = indices.subview(1).data(); auto int_dtype = indicesIn.scalar_type();
tv::DispatchTorch<float_types_t>()(dtype, [&](auto TValue) {
using T = decltype(TValue);
tv::DispatchTorch<int_types_t>()(int_dtype, [&](auto IndexValue) {
using Index = decltype(IndexValue);
auto outFeaturesData = outFeatures.data_ptr<T>();
auto inFeaturesData = inFeatures.data_ptr<T>();
auto indicesInData = indicesIn.data_ptr<Index>();
auto indicesOutData = indicesOut.data_ptr<Index>();
Index idxi, idxo; Index idxi, idxo;
for (int row = 0; row < size; row++) { for (int row = 0; row < size; row++) {
idxi = indicesIn[row] * stride; idxi = indicesInData[row] * stride;
idxo = indicesOut[row] * stride; idxo = indicesOutData[row] * stride;
for (int plane = 0; plane < stride; ++plane) for (int plane = 0; plane < stride; ++plane)
if (outFeaturesData[idxo + plane] < inFeaturesData[idxi + plane]) if (outFeaturesData[idxo + plane] < inFeaturesData[idxi + plane])
outFeaturesData[idxo + plane] = inFeaturesData[idxi + plane]; outFeaturesData[idxo + plane] = inFeaturesData[idxi + plane];
} }
} });
}; });
}
template <typename T, typename Index> void maxpool_bwd_cpu(torch::Tensor outFeatures, torch::Tensor inFeatures,
struct SparseMaxPoolBackwardFunctor<tv::CPU, T, Index> { torch::Tensor dout, torch::Tensor din,
void operator()(const tv::CPU &d, tv::TensorView<const T> outFeatures, torch::Tensor indicesIn, torch::Tensor indicesOut,
tv::TensorView<const T> inFeatures, int size) {
tv::TensorView<const T> dout, tv::TensorView<T> din, if (size <= 0)
tv::TensorView<const Index> indices, int size) { return;
int stride = outFeatures.dim(1); int stride = inFeatures.size(1);
auto outFeaturesData = outFeatures.data(); auto dtype = inFeatures.scalar_type();
auto inFeaturesData = inFeatures.data(); auto int_dtype = indicesIn.scalar_type();
auto doutData = dout.data(); tv::DispatchTorch<float_types_t>()(dtype, [&](auto TValue) {
auto dinData = din.data(); using T = decltype(TValue);
auto indicesIn = indices.subview(0).data(); tv::DispatchTorch<int_types_t>()(int_dtype, [&](auto IndexValue) {
auto indicesOut = indices.subview(1).data(); using Index = decltype(IndexValue);
auto outFeaturesData = outFeatures.data_ptr<T>();
auto inFeaturesData = inFeatures.data_ptr<T>();
auto doutData = dout.data_ptr<T>();
auto dinData = din.data_ptr<T>();
auto indicesInData = indicesIn.data_ptr<Index>();
auto indicesOutData = indicesOut.data_ptr<Index>();
Index idxi, idxo; Index idxi, idxo;
for (int row = 0; row < size; row++) { for (int row = 0; row < size; row++) {
idxi = indicesIn[row] * stride; idxi = indicesInData[row] * stride;
idxo = indicesOut[row] * stride; idxo = indicesOutData[row] * stride;
for (int plane = 0; plane < stride; ++plane) for (int plane = 0; plane < stride; ++plane)
if (outFeaturesData[idxo + plane] == inFeaturesData[idxi + plane]) if (outFeaturesData[idxo + plane] == inFeaturesData[idxi + plane])
dinData[idxi + plane] += doutData[idxo + plane]; dinData[idxi + plane] += doutData[idxo + plane];
} }
} });
}; });
} // namespace functor }
#define DECLARE_CPU_SPECS_T_INDEX(T, Index) \
template struct functor::SparseMaxPoolForwardFunctor<tv::CPU, T, Index>; \
template struct functor::SparseMaxPoolBackwardFunctor<tv::CPU, T, Index>;
#define DECLARE_CPU_SPECS(T) \
DECLARE_CPU_SPECS_T_INDEX(T, int); \
DECLARE_CPU_SPECS_T_INDEX(T, long);
DECLARE_CPU_SPECS(float);
DECLARE_CPU_SPECS(double);
DECLARE_CPU_SPECS(at::Half);
#undef DECLARE_CPU_SPECS
#undef DECLARE_CPU_SPECS_T_INDEX
} // namespace spconv } // namespace spconv
...@@ -306,22 +306,31 @@ maxPoolBwdGenericKernel(const T *outFeatures, const T *inFeatures, ...@@ -306,22 +306,31 @@ maxPoolBwdGenericKernel(const T *outFeatures, const T *inFeatures,
} }
} }
namespace functor { using float_types_t = tv::mp_list<float, double, at::Half>;
template <typename T, typename Index> using int_types_t = tv::mp_list<int32_t, int64_t>;
struct SparseMaxPoolForwardFunctor<tv::GPU, T, Index> {
void maxpool_fwd_cuda(torch::Tensor outFeatures, torch::Tensor inFeatures,
torch::Tensor indicesIn, torch::Tensor indicesOut,
int size) {
if (size <= 0)
return;
int numPlanes = inFeatures.size(1);
auto dtype = inFeatures.scalar_type();
auto int_dtype = indicesIn.scalar_type();
auto stream = at::cuda::getCurrentCUDAStream();
tv::DispatchTorch<float_types_t>()(dtype, [&](auto TValue) {
using T = decltype(TValue);
using vecload_type_t = using vecload_type_t =
std::conditional_t<std::is_same<T, at::Half>::value, int2, int4>; std::conditional_t<std::is_same<T, at::Half>::value, int2, int4>;
using kernel_block_t = tv::mp_list_c<int, 64, 32, 16>; using kernel_block_t = tv::mp_list_c<int, 64, 32, 16>;
void operator()(const tv::GPU &d, tv::TensorView<T> outFeatures,
tv::TensorView<const T> inFeatures, tv::DispatchTorch<int_types_t>()(int_dtype, [&](auto IndexValue) {
tv::TensorView<const Index> indices, int size) { using Index = decltype(IndexValue);
if (size <= 0)
return;
int numPlanes = inFeatures.dim(1);
bool notFound = true; bool notFound = true;
constexpr int vecloadFactor = sizeof(vecload_type_t) / sizeof(T); constexpr int vecloadFactor = sizeof(vecload_type_t) / sizeof(T);
tv::mp_for_each<kernel_block_t>([=, &outFeatures, &inFeatures, &indices, tv::mp_for_each<kernel_block_t>([=, &outFeatures, &inFeatures, &indicesIn,
&notFound](auto NumTLP) { &indicesOut, &notFound](auto NumTLP) {
constexpr int NumILP = NumTLP / 4; constexpr int NumILP = NumTLP / 4;
int numHotBlock = (size / NumTLP) * NumTLP; int numHotBlock = (size / NumTLP) * NumTLP;
...@@ -332,19 +341,20 @@ struct SparseMaxPoolForwardFunctor<tv::GPU, T, Index> { ...@@ -332,19 +341,20 @@ struct SparseMaxPoolForwardFunctor<tv::GPU, T, Index> {
vecload_type_t> vecload_type_t>
<<<dim3(std::min(size / NumTLP, 512), numPlanes / NumTLP), <<<dim3(std::min(size / NumTLP, 512), numPlanes / NumTLP),
dim3(NumTLP / vecloadFactor, NumTLP / NumILP), 0, dim3(NumTLP / vecloadFactor, NumTLP / NumILP), 0,
d.getStream()>>>(outFeatures.data(), inFeatures.data(), stream>>>(
indices.subview(0).data(), outFeatures.data_ptr<T>(), inFeatures.data_ptr<T>(),
indices.subview(1).data(), numHotBlock, indicesIn.data_ptr<Index>(), indicesOut.data_ptr<Index>(),
numPlanes / vecloadFactor); numHotBlock, numPlanes / vecloadFactor);
TV_CHECK_CUDA_ERR(); TV_CHECK_CUDA_ERR();
} }
if (size > numHotBlock) { if (size > numHotBlock) {
maxPoolFwdGenericKernel<T, Index, int(NumTLP), NumILP> maxPoolFwdGenericKernel<T, Index, int(NumTLP), NumILP>
<<<dim3(1, numPlanes / NumTLP), dim3(NumTLP / NumILP, NumTLP), <<<dim3(1, numPlanes / NumTLP), dim3(NumTLP / NumILP, NumTLP),
0, d.getStream()>>>(outFeatures.data(), inFeatures.data(), 0, stream>>>(outFeatures.data_ptr<T>(),
indices.subview(0).data() + numHotBlock, inFeatures.data_ptr<T>(),
indices.subview(1).data() + numHotBlock, indicesIn.data_ptr<Index>() + numHotBlock,
indicesOut.data_ptr<Index>() + numHotBlock,
size - numHotBlock, numPlanes); size - numHotBlock, numPlanes);
TV_CHECK_CUDA_ERR(); TV_CHECK_CUDA_ERR();
} }
...@@ -360,9 +370,9 @@ struct SparseMaxPoolForwardFunctor<tv::GPU, T, Index> { ...@@ -360,9 +370,9 @@ struct SparseMaxPoolForwardFunctor<tv::GPU, T, Index> {
if (numHotBlock >= NumTLP) { if (numHotBlock >= NumTLP) {
maxPoolFwdGenericBlockKernel<T, Index, NumTLP, NumILP> maxPoolFwdGenericBlockKernel<T, Index, NumTLP, NumILP>
<<<dim3(size / NumTLP, tv::cuda::DivUp(numPlanes, NumTLP)), <<<dim3(size / NumTLP, tv::cuda::DivUp(numPlanes, NumTLP)),
dim3(NumTLP / NumILP, NumTLP), 0, d.getStream()>>>( dim3(NumTLP / NumILP, NumTLP), 0, stream>>>(
outFeatures.data(), inFeatures.data(), outFeatures.data_ptr<T>(), inFeatures.data_ptr<T>(),
indices.subview(0).data(), indices.subview(1).data(), indicesIn.data_ptr<Index>(), indicesOut.data_ptr<Index>(),
numHotBlock, numPlanes); numHotBlock, numPlanes);
TV_CHECK_CUDA_ERR(); TV_CHECK_CUDA_ERR();
} }
...@@ -370,35 +380,42 @@ struct SparseMaxPoolForwardFunctor<tv::GPU, T, Index> { ...@@ -370,35 +380,42 @@ struct SparseMaxPoolForwardFunctor<tv::GPU, T, Index> {
if (size > numHotBlock) { if (size > numHotBlock) {
maxPoolFwdGenericKernel<T, Index, NumTLP, NumILP> maxPoolFwdGenericKernel<T, Index, NumTLP, NumILP>
<<<dim3(1, tv::cuda::DivUp(numPlanes, NumTLP)), <<<dim3(1, tv::cuda::DivUp(numPlanes, NumTLP)),
dim3(NumTLP / NumILP, NumTLP), 0, d.getStream()>>>( dim3(NumTLP / NumILP, NumTLP), 0, stream>>>(
outFeatures.data(), inFeatures.data(), outFeatures.data_ptr<T>(), inFeatures.data_ptr<T>(),
indices.subview(0).data() + numHotBlock, indicesIn.data_ptr<Index>() + numHotBlock,
indices.subview(1).data() + numHotBlock, size - numHotBlock, indicesOut.data_ptr<Index>() + numHotBlock,
numPlanes); size - numHotBlock, numPlanes);
TV_CHECK_CUDA_ERR(); TV_CHECK_CUDA_ERR();
} }
} }
} });
}; });
}
void maxpool_bwd_cuda(torch::Tensor outFeatures, torch::Tensor inFeatures,
torch::Tensor dout, torch::Tensor din,
torch::Tensor indicesIn, torch::Tensor indicesOut,
int size) {
if (size <= 0)
return;
int numPlanes = inFeatures.size(1);
auto dtype = inFeatures.scalar_type();
auto int_dtype = indicesIn.scalar_type();
auto stream = at::cuda::getCurrentCUDAStream();
template <typename T, typename Index> tv::DispatchTorch<float_types_t>()(dtype, [&](auto TValue) {
struct SparseMaxPoolBackwardFunctor<tv::GPU, T, Index> { using T = decltype(TValue);
using vecload_type_t = using vecload_type_t =
std::conditional_t<std::is_same<T, at::Half>::value, int2, int4>; std::conditional_t<std::is_same<T, at::Half>::value, int2, int4>;
using kernel_block_t = tv::mp_list_c<int, 64, 32, 16>; using kernel_block_t = tv::mp_list_c<int, 64, 32, 16>;
void operator()(const tv::GPU &d, tv::TensorView<const T> outFeatures, tv::DispatchTorch<int_types_t>()(int_dtype, [&](auto IndexValue) {
tv::TensorView<const T> inFeatures, using Index = decltype(IndexValue);
tv::TensorView<const T> dout, tv::TensorView<T> din,
tv::TensorView<const Index> indices, int size) {
if (size <= 0)
return;
int numPlanes = inFeatures.dim(1);
bool notFound = true; bool notFound = true;
constexpr int vecloadFactor = sizeof(vecload_type_t) / sizeof(T); constexpr int vecloadFactor = sizeof(vecload_type_t) / sizeof(T);
tv::mp_for_each<kernel_block_t>([=, &outFeatures, &inFeatures, &dout, &din, tv::mp_for_each<kernel_block_t>([=, &outFeatures, &inFeatures, &dout,
&indices, &notFound](auto NumTLP) { &din, &indicesIn, &indicesOut,
&notFound](auto NumTLP) {
constexpr int NumILP = NumTLP / 4; constexpr int NumILP = NumTLP / 4;
int numHotBlock = (size / NumTLP) * NumTLP; int numHotBlock = (size / NumTLP) * NumTLP;
if (notFound) { if (notFound) {
if (numPlanes % NumTLP == 0) { if (numPlanes % NumTLP == 0) {
...@@ -407,10 +424,10 @@ struct SparseMaxPoolBackwardFunctor<tv::GPU, T, Index> { ...@@ -407,10 +424,10 @@ struct SparseMaxPoolBackwardFunctor<tv::GPU, T, Index> {
vecload_type_t> vecload_type_t>
<<<dim3(std::min(size / NumTLP, 512), numPlanes / NumTLP), <<<dim3(std::min(size / NumTLP, 512), numPlanes / NumTLP),
dim3(NumTLP / vecloadFactor, NumTLP / NumILP), 0, dim3(NumTLP / vecloadFactor, NumTLP / NumILP), 0,
d.getStream()>>>(outFeatures.data(), inFeatures.data(), stream>>>(outFeatures.data_ptr<T>(),
dout.data(), din.data(), inFeatures.data_ptr<T>(), dout.data_ptr<T>(),
indices.subview(0).data(), din.data_ptr<T>(), indicesIn.data_ptr<Index>(),
indices.subview(1).data(), numHotBlock, indicesOut.data_ptr<Index>(), numHotBlock,
numPlanes / vecloadFactor); numPlanes / vecloadFactor);
TV_CHECK_CUDA_ERR(); TV_CHECK_CUDA_ERR();
} }
...@@ -418,10 +435,11 @@ struct SparseMaxPoolBackwardFunctor<tv::GPU, T, Index> { ...@@ -418,10 +435,11 @@ struct SparseMaxPoolBackwardFunctor<tv::GPU, T, Index> {
if (size > numHotBlock) { if (size > numHotBlock) {
maxPoolBwdGenericKernel<T, Index, int(NumTLP), NumILP> maxPoolBwdGenericKernel<T, Index, int(NumTLP), NumILP>
<<<dim3(1, numPlanes / NumTLP), dim3(NumTLP / NumILP, NumTLP), <<<dim3(1, numPlanes / NumTLP), dim3(NumTLP / NumILP, NumTLP),
0, d.getStream()>>>(outFeatures.data(), inFeatures.data(), 0, stream>>>(outFeatures.data_ptr<T>(),
dout.data(), din.data(), inFeatures.data_ptr<T>(), dout.data_ptr<T>(),
indices.subview(0).data() + numHotBlock, din.data_ptr<T>(),
indices.subview(1).data() + numHotBlock, indicesIn.data_ptr<Index>() + numHotBlock,
indicesOut.data_ptr<Index>() + numHotBlock,
size - numHotBlock, numPlanes); size - numHotBlock, numPlanes);
TV_CHECK_CUDA_ERR(); TV_CHECK_CUDA_ERR();
} }
...@@ -437,9 +455,10 @@ struct SparseMaxPoolBackwardFunctor<tv::GPU, T, Index> { ...@@ -437,9 +455,10 @@ struct SparseMaxPoolBackwardFunctor<tv::GPU, T, Index> {
if (numHotBlock >= NumTLP) { if (numHotBlock >= NumTLP) {
maxPoolBwdGenericBlockKernel<T, Index, NumTLP, NumILP> maxPoolBwdGenericBlockKernel<T, Index, NumTLP, NumILP>
<<<dim3(size / NumTLP, tv::cuda::DivUp(numPlanes, NumTLP)), <<<dim3(size / NumTLP, tv::cuda::DivUp(numPlanes, NumTLP)),
dim3(NumTLP / NumILP, NumTLP), 0, d.getStream()>>>( dim3(NumTLP / NumILP, NumTLP), 0, stream>>>(
outFeatures.data(), inFeatures.data(), dout.data(), din.data(), outFeatures.data_ptr<T>(), inFeatures.data_ptr<T>(),
indices.subview(0).data(), indices.subview(1).data(), dout.data_ptr<T>(), din.data_ptr<T>(),
indicesIn.data_ptr<Index>(), indicesOut.data_ptr<Index>(),
numHotBlock, numPlanes); numHotBlock, numPlanes);
TV_CHECK_CUDA_ERR(); TV_CHECK_CUDA_ERR();
} }
...@@ -447,29 +466,17 @@ struct SparseMaxPoolBackwardFunctor<tv::GPU, T, Index> { ...@@ -447,29 +466,17 @@ struct SparseMaxPoolBackwardFunctor<tv::GPU, T, Index> {
if (size > numHotBlock) { if (size > numHotBlock) {
maxPoolBwdGenericKernel<T, Index, NumTLP, NumILP> maxPoolBwdGenericKernel<T, Index, NumTLP, NumILP>
<<<dim3(1, tv::cuda::DivUp(numPlanes, NumTLP)), <<<dim3(1, tv::cuda::DivUp(numPlanes, NumTLP)),
dim3(NumTLP / NumILP, NumTLP), 0, d.getStream()>>>( dim3(NumTLP / NumILP, NumTLP), 0, stream>>>(
outFeatures.data(), inFeatures.data(), dout.data(), din.data(), outFeatures.data_ptr<T>(), inFeatures.data_ptr<T>(),
indices.subview(0).data() + numHotBlock, dout.data_ptr<T>(), din.data_ptr<T>(),
indices.subview(1).data() + numHotBlock, size - numHotBlock, indicesIn.data_ptr<Index>() + numHotBlock,
numPlanes); indicesOut.data_ptr<Index>() + numHotBlock,
size - numHotBlock, numPlanes);
TV_CHECK_CUDA_ERR(); TV_CHECK_CUDA_ERR();
} }
} }
} });
}; });
}
} // namespace functor
#define DECLARE_GPU_SPECS_T_INDEX(T, Index) \
template struct functor::SparseMaxPoolForwardFunctor<tv::GPU, T, Index>; \
template struct functor::SparseMaxPoolBackwardFunctor<tv::GPU, T, Index>;
#define DECLARE_GPU_SPECS(T) DECLARE_GPU_SPECS_T_INDEX(T, int);
DECLARE_GPU_SPECS(float);
DECLARE_GPU_SPECS(double);
DECLARE_GPU_SPECS(at::Half);
#undef DECLARE_GPU_SPECS
#undef DECLARE_GPU_SPECS_T_INDEX
} // namespace spconv } // namespace spconv
\ No newline at end of file
#include <spconv/pool_ops.h>
namespace spconv {
torch::Tensor indiceMaxPool(torch::Tensor features, torch::Tensor indicePairs,
torch::Tensor indiceNum, int64_t numAct) {
auto device = features.device().type();
auto kernelVolume = indiceNum.size(0);
auto numInPlanes = features.size(1);
auto indicePairNumCpu = indiceNum.to({torch::kCPU});
auto options =
torch::TensorOptions().dtype(features.dtype()).device(features.device());
torch::Tensor output = torch::zeros({numAct, numInPlanes}, options);
double totalTime = 0;
for (int i = 0; i < kernelVolume; ++i) {
auto nHot = indicePairNumCpu.data_ptr<int>()[i];
if (nHot <= 0) {
continue;
}
// auto timer = spconv::CudaContextTimer<>();
if (device == torch::kCPU) {
maxpool_fwd_cpu(output, features, indicePairs[0][i], indicePairs[1][i],
nHot);
}
#ifdef TV_CUDA
else if (device == torch::kCUDA) {
maxpool_fwd_cuda(output, features, indicePairs[0][i], indicePairs[1][i],
nHot);
}
#endif
else {
TV_ASSERT_INVALID_ARG(false, "unknown device type");
}
// totalTime += timer.report() / 1000.0;
}
// std::cout << "maxpool forward time " << totalTime << std::endl;
return output;
}
torch::Tensor indiceMaxPoolBackward(torch::Tensor features,
torch::Tensor outFeatures,
torch::Tensor outGrad,
torch::Tensor indicePairs,
torch::Tensor indiceNum) {
auto device = features.device().type();
auto numInPlanes = features.size(1);
auto indicePairNumCpu = indiceNum.to({torch::kCPU});
auto options =
torch::TensorOptions().dtype(features.dtype()).device(features.device());
torch::Tensor inputGrad = torch::zeros(features.sizes(), options);
auto kernelVolume = indiceNum.size(0);
for (int i = 0; i < kernelVolume; ++i) {
auto nHot = indicePairNumCpu.data_ptr<int>()[i];
if (nHot <= 0) {
continue;
}
if (device == torch::kCPU) {
maxpool_bwd_cpu(outFeatures, features, outGrad, inputGrad,
indicePairs[0][i], indicePairs[1][i], nHot);
}
#ifdef TV_CUDA
else if (device == torch::kCUDA) {
maxpool_bwd_cuda(outFeatures, features, outGrad, inputGrad,
indicePairs[0][i], indicePairs[1][i], nHot);
}
#endif
else {
TV_ASSERT_INVALID_ARG(false, "unknown device type");
}
}
return inputGrad;
}
} // namespace spconv
\ No newline at end of file
...@@ -28,27 +28,28 @@ ...@@ -28,27 +28,28 @@
namespace spconv { namespace spconv {
using float_types_t = tv::mp_list<float, double, at::Half>;
using int_types_t = tv::mp_list<int32_t, int64_t>;
void sparse_gather_cuda(torch::Tensor buffer, torch::Tensor features, void sparse_gather_cuda(torch::Tensor buffer, torch::Tensor features,
torch::Tensor indices, int size) { torch::Tensor indices, int size) {
if (size <= 0) if (size <= 0)
return; return;
int numPlanes = features.size(1); int numPlanes = features.size(1);
auto stream = at::cuda::getCurrentCUDAStream(); auto stream = at::cuda::getCurrentCUDAStream();
auto dtype = features.scalar_type();
tv::dispatch_torch<float, double, auto inds_dtype = indices.scalar_type();
at::Half>(features.scalar_type(), [&](auto TValue) { tv::DispatchTorch<float_types_t>()(dtype, [&](auto TValue) {
using T = decltype(TValue); using T = decltype(TValue);
using vecload_type_t = using vecload_type_t =
std::conditional_t<std::is_same<T, at::Half>::value, int2, int4>; std::conditional_t<std::is_same<T, at::Half>::value, int2, int4>;
using kernel_block_t = tv::mp_list_c<int, 64, 32, 16>; using kernel_block_t = tv::mp_list_c<int, 64, 32, 16>;
tv::DispatchTorch<int_types_t>()(inds_dtype, [&](auto IndexValue) {
tv::dispatch_torch<int32_t, int64_t>(
indices.scalar_type(), [&](auto IndexValue) {
using Index = decltype(IndexValue); using Index = decltype(IndexValue);
bool notFound = true; bool notFound = true;
constexpr int vecloadFactor = sizeof(vecload_type_t) / sizeof(T); constexpr int vecloadFactor = sizeof(vecload_type_t) / sizeof(T);
tv::mp_for_each<kernel_block_t>([=, &buffer, &features, &indices, tv::mp_for_each<kernel_block_t>(
&notFound](auto NumTLP) { [=, &buffer, &features, &indices, &notFound](auto NumTLP) {
constexpr int NumILP = NumTLP / 4; constexpr int NumILP = NumTLP / 4;
// constexpr int NumILP = NumTLP / (64 / (NumTLP / vecloadFactor)); // constexpr int NumILP = NumTLP / (64 / (NumTLP / vecloadFactor));
int nHotBlock = (size / NumTLP) * NumTLP; int nHotBlock = (size / NumTLP) * NumTLP;
...@@ -101,22 +102,21 @@ void sparse_scatter_add_cuda(torch::Tensor buffer, torch::Tensor outFeatures, ...@@ -101,22 +102,21 @@ void sparse_scatter_add_cuda(torch::Tensor buffer, torch::Tensor outFeatures,
return; return;
int numPlanes = outFeatures.size(1); int numPlanes = outFeatures.size(1);
auto stream = at::cuda::getCurrentCUDAStream(); auto stream = at::cuda::getCurrentCUDAStream();
auto dtype = outFeatures.scalar_type();
auto inds_dtype = indices.scalar_type();
tv::dispatch_torch<float, double, at::Half>( tv::DispatchTorch<float_types_t>()(dtype, [&](auto TValue) {
outFeatures.scalar_type(), [&](auto TValue) {
using T = decltype(TValue); using T = decltype(TValue);
using vecload_type_t = using vecload_type_t =
std::conditional_t<std::is_same<T, at::Half>::value, int2, int4>; std::conditional_t<std::is_same<T, at::Half>::value, int2, int4>;
using kernel_block_t = tv::mp_list_c<int, 64, 32, 16>; using kernel_block_t = tv::mp_list_c<int, 64, 32, 16>;
tv::DispatchTorch<int_types_t>()(inds_dtype, [&](auto IndexValue) {
tv::dispatch_torch<int32_t, int64_t>(
indices.scalar_type(), [&](auto IndexValue) {
using Index = decltype(IndexValue); using Index = decltype(IndexValue);
bool notFound = true; bool notFound = true;
constexpr int vecloadFactor = constexpr int vecloadFactor =
sizeof(vecload_type_t) / sizeof(T); // important for half. sizeof(vecload_type_t) / sizeof(T); // important for half.
tv::mp_for_each<kernel_block_t>( tv::mp_for_each<kernel_block_t>([=, &outFeatures, &buffer, &indices,
[=, &outFeatures, &buffer, &indices, &notFound](auto NumTLP) { &notFound](auto NumTLP) {
// constexpr int NumILP = NumTLP / (64 / (NumTLP / // constexpr int NumILP = NumTLP / (64 / (NumTLP /
// vecloadFactor)); // vecloadFactor));
constexpr int NumILP = NumTLP / 4; constexpr int NumILP = NumTLP / 4;
...@@ -124,22 +124,19 @@ void sparse_scatter_add_cuda(torch::Tensor buffer, torch::Tensor outFeatures, ...@@ -124,22 +124,19 @@ void sparse_scatter_add_cuda(torch::Tensor buffer, torch::Tensor outFeatures,
if (notFound) { if (notFound) {
if (numPlanes % NumTLP == 0) { if (numPlanes % NumTLP == 0) {
if (nHotBlock >= NumTLP) { if (nHotBlock >= NumTLP) {
scatterAddVecBlockKernel<T, Index, int(NumTLP), scatterAddVecBlockKernel<T, Index, int(NumTLP), NumILP,
NumILP, vecload_type_t> vecload_type_t>
<<<dim3(numPlanes / NumTLP, size / NumTLP), <<<dim3(numPlanes / NumTLP, size / NumTLP),
dim3(NumTLP / vecloadFactor, NumTLP / NumILP), dim3(NumTLP / vecloadFactor, NumTLP / NumILP), 0,
0, stream>>>(outFeatures.data_ptr<T>(), stream>>>(outFeatures.data_ptr<T>(), buffer.data_ptr<T>(),
buffer.data_ptr<T>(), indices.data_ptr<Index>(), nHotBlock,
indices.data_ptr<Index>(),
nHotBlock,
numPlanes / vecloadFactor); numPlanes / vecloadFactor);
TV_CHECK_CUDA_ERR(); TV_CHECK_CUDA_ERR();
} }
if (size - nHotBlock > 0) { if (size - nHotBlock > 0) {
scatterAddGenericKernel<T, Index, int(NumTLP), NumILP> scatterAddGenericKernel<T, Index, int(NumTLP), NumILP>
<<<dim3(1, numPlanes / NumTLP), <<<dim3(1, numPlanes / NumTLP), dim3(NumTLP / NumILP, NumTLP),
dim3(NumTLP / NumILP, NumTLP), 0, stream>>>( 0, stream>>>(outFeatures.data_ptr<T>(),
outFeatures.data_ptr<T>(),
buffer.data_ptr<T>() + nHotBlock * numPlanes, buffer.data_ptr<T>() + nHotBlock * numPlanes,
indices.data_ptr<Index>() + nHotBlock, indices.data_ptr<Index>() + nHotBlock,
size - nHotBlock, numPlanes); size - nHotBlock, numPlanes);
......
...@@ -39,7 +39,7 @@ getIndicePairV2(torch::Tensor indices, int64_t batchSize, ...@@ -39,7 +39,7 @@ getIndicePairV2(torch::Tensor indices, int64_t batchSize,
TV_ASSERT_RT_ERR(batchSize * outputVolume < std::numeric_limits<int>::max(), TV_ASSERT_RT_ERR(batchSize * outputVolume < std::numeric_limits<int>::max(),
msg); msg);
torch::Tensor indicePairs = torch::Tensor indicePairs =
torch::full({kernelVolume, 2, numAct}, -1, torch::full({2, kernelVolume, numAct}, -1,
torch::dtype(torch::kInt32).device(indices.device())); torch::dtype(torch::kInt32).device(indices.device()));
torch::Tensor indiceNum = torch::zeros( torch::Tensor indiceNum = torch::zeros(
{kernelVolume}, torch::dtype(torch::kInt32).device(indices.device())); {kernelVolume}, torch::dtype(torch::kInt32).device(indices.device()));
...@@ -68,6 +68,18 @@ getIndicePairV2(torch::Tensor indices, int64_t batchSize, ...@@ -68,6 +68,18 @@ getIndicePairV2(torch::Tensor indices, int64_t batchSize,
numActOut = create_submconv_indice_pair_cuda( numActOut = create_submconv_indice_pair_cuda(
indices, gridOut, indicePairs, indiceNum, kernelSize, stride, padding, indices, gridOut, indicePairs, indiceNum, kernelSize, stride, padding,
dilation, outSpatialShape, transpose, false, useHash); dilation, outSpatialShape, transpose, false, useHash);
if (numActOut == -1) {
auto device = indices.device();
indicePairs = indicePairs.to({torch::kCPU});
indiceNum = indiceNum.to({torch::kCPU});
indices = indices.to({torch::kCPU});
numActOut = create_submconv_indice_pair_cpu(
indices, gridOut, indicePairs, indiceNum, kernelSize, stride,
padding, dilation, outSpatialShape, transpose, false, useHash);
return {indices.to(device), indicePairs.to(device),
indiceNum.to(device)};
}
} }
#endif #endif
else { else {
...@@ -97,6 +109,20 @@ getIndicePairV2(torch::Tensor indices, int64_t batchSize, ...@@ -97,6 +109,20 @@ getIndicePairV2(torch::Tensor indices, int64_t batchSize,
numActOut = create_conv_indice_pair_p2_cuda( numActOut = create_conv_indice_pair_p2_cuda(
indices, outInds, gridOut, indicePairs, indiceNum, indicePairUnique, indices, outInds, gridOut, indicePairs, indiceNum, indicePairUnique,
outSpatialShape, transpose, false, useHash); outSpatialShape, transpose, false, useHash);
if (numActOut == -1) {
auto device = indices.device();
outInds = outInds.to({torch::kCPU});
indicePairs = indicePairs.to({torch::kCPU});
indiceNum = indiceNum.to({torch::kCPU});
indices = indices.to({torch::kCPU});
numActOut = create_conv_indice_pair_cpu(
indices, outInds, gridOut, indicePairs, indiceNum, kernelSize,
stride, padding, dilation, outSpatialShape, transpose, false,
useHash);
return {outInds.to(device).slice(0, 0, numActOut),
indicePairs.to(device), indiceNum.to(device)};
}
} }
} }
#endif #endif
...@@ -114,7 +140,7 @@ torch::Tensor indiceConv(torch::Tensor features, torch::Tensor filters, ...@@ -114,7 +140,7 @@ torch::Tensor indiceConv(torch::Tensor features, torch::Tensor filters,
bool inverse = _inverse != 0; bool inverse = _inverse != 0;
auto device = features.device().type(); auto device = features.device().type();
auto ndim = filters.dim() - 2; auto ndim = filters.dim() - 2;
auto kernelVolume = indicePairs.size(0); auto kernelVolume = indiceNum.size(0);
auto numInPlanes = features.size(1); auto numInPlanes = features.size(1);
auto numOutPlanes = filters.size(ndim + 1); auto numOutPlanes = filters.size(ndim + 1);
auto indicePairNumCpu = indiceNum.to({torch::kCPU}); auto indicePairNumCpu = indiceNum.to({torch::kCPU});
...@@ -125,21 +151,8 @@ torch::Tensor indiceConv(torch::Tensor features, torch::Tensor filters, ...@@ -125,21 +151,8 @@ torch::Tensor indiceConv(torch::Tensor features, torch::Tensor filters,
indicePairMaxSizeIter - indicePairNumCpu.data_ptr<int>(); indicePairMaxSizeIter - indicePairNumCpu.data_ptr<int>();
int indicePairMaxSize = *indicePairMaxSizeIter; int indicePairMaxSize = *indicePairMaxSizeIter;
/*if (_subM){
std::vector<int> indicePairNumVec(indicePairNumCpu.data_ptr<int>(),
indicePairNumCpu.data_ptr<int>() + kernelVolume);
indicePairNumVec.erase(indicePairNumVec.begin() + indicePairMaxOffset);
auto indicePairVecMaxSizeIter = std::max_element(
indicePairNumVec.begin(), indicePairNumVec.end());
indicePairMaxSize = *indicePairVecMaxSizeIter;
}*/
auto options = auto options =
torch::TensorOptions().dtype(features.dtype()).device(features.device()); torch::TensorOptions().dtype(features.dtype()).device(features.device());
// auto indicePairOptions =
// torch::TensorOptions().dtype(torch::kInt64).device(indicePairs.device());
torch::Tensor output = torch::zeros({numActOut, numOutPlanes}, options); torch::Tensor output = torch::zeros({numActOut, numOutPlanes}, options);
torch::Tensor inputBuffer = torch::Tensor inputBuffer =
torch::zeros({indicePairMaxSize, numInPlanes}, options); torch::zeros({indicePairMaxSize, numInPlanes}, options);
...@@ -159,17 +172,17 @@ torch::Tensor indiceConv(torch::Tensor features, torch::Tensor filters, ...@@ -159,17 +172,17 @@ torch::Tensor indiceConv(torch::Tensor features, torch::Tensor filters,
continue; continue;
} }
// auto timer = spconv::CudaContextTimer<>(); // auto timer = spconv::CudaContextTimer<>();
auto outputBufferBlob = torch::from_blob( auto outputBufferBlob = torch::from_blob(outputBuffer.data_ptr(),
outputBuffer.data_ptr(), {nHot, numOutPlanes}, options); {nHot, numOutPlanes}, options);
auto inputBufferBlob = torch::from_blob(inputBuffer.data_ptr(), auto inputBufferBlob =
{nHot, numInPlanes}, options); torch::from_blob(inputBuffer.data_ptr(), {nHot, numInPlanes}, options);
if (device == torch::kCPU) { if (device == torch::kCPU) {
sparse_gather_cpu(inputBuffer, features, indicePairs[i][inverse], nHot); sparse_gather_cpu(inputBuffer, features, indicePairs[inverse][i], nHot);
} }
#ifdef TV_CUDA #ifdef TV_CUDA
else if (device == torch::kCUDA) { else if (device == torch::kCUDA) {
sparse_gather_cuda(inputBuffer, features, indicePairs[i][inverse], nHot); sparse_gather_cuda(inputBuffer, features, indicePairs[inverse][i], nHot);
/* slower than SparseGatherFunctor, may due to int->long conversion /* slower than SparseGatherFunctor, may due to int->long conversion
auto indicePairLong = indicePairs[i][inverse].to(torch::kInt64); auto indicePairLong = indicePairs[i][inverse].to(torch::kInt64);
auto indicePairBlob = torch::from_blob(indicePairLong.data<long>(), auto indicePairBlob = torch::from_blob(indicePairLong.data<long>(),
...@@ -180,32 +193,24 @@ torch::Tensor indiceConv(torch::Tensor features, torch::Tensor filters, ...@@ -180,32 +193,24 @@ torch::Tensor indiceConv(torch::Tensor features, torch::Tensor filters,
else { else {
TV_ASSERT_INVALID_ARG(false, "unknown device type"); TV_ASSERT_INVALID_ARG(false, "unknown device type");
} }
// totalGatherTime += timer.report() / 1000.0;
torch::mm_out(outputBufferBlob, inputBufferBlob, filters[i]); torch::mm_out(outputBufferBlob, inputBufferBlob, filters[i]);
// totalGEMMTime += timer.report() / 1000.0;
if (device == torch::kCPU) { if (device == torch::kCPU) {
sparse_scatter_add_cpu(outputBuffer, output, indicePairs[i][!inverse], nHot); sparse_scatter_add_cpu(outputBuffer, output, indicePairs[!inverse][i],
nHot);
} }
#ifdef TV_CUDA #ifdef TV_CUDA
else if (device == torch::kCUDA) { else if (device == torch::kCUDA) {
sparse_scatter_add_cuda(outputBuffer, output, indicePairs[i][!inverse], nHot); sparse_scatter_add_cuda(outputBuffer, output, indicePairs[!inverse][i],
nHot);
} }
#endif #endif
else { else {
TV_ASSERT_INVALID_ARG(false, "unknown device type"); TV_ASSERT_INVALID_ARG(false, "unknown device type");
} }
// totalSAddTime += timer.report() / 1000.0;
} }
// std::cout << "gather time " << totalGatherTime << std::endl;
// std::cout << "gemm time " << totalGEMMTime << std::endl;
// std::cout << "scatteradd time " << totalSAddTime << std::endl;
return output; return output;
} }
std::vector<torch::Tensor> std::vector<torch::Tensor>
indiceConvBackward(torch::Tensor features, torch::Tensor filters, indiceConvBackward(torch::Tensor features, torch::Tensor filters,
torch::Tensor outGrad, torch::Tensor indicePairs, torch::Tensor outGrad, torch::Tensor indicePairs,
...@@ -215,7 +220,7 @@ indiceConvBackward(torch::Tensor features, torch::Tensor filters, ...@@ -215,7 +220,7 @@ indiceConvBackward(torch::Tensor features, torch::Tensor filters,
auto device = features.device().type(); auto device = features.device().type();
auto ndim = filters.dim() - 2; auto ndim = filters.dim() - 2;
auto kernelVolume = indicePairs.size(0); auto kernelVolume = indiceNum.size(0);
auto numInPlanes = features.size(1); auto numInPlanes = features.size(1);
auto numOutPlanes = filters.size(ndim + 1); auto numOutPlanes = filters.size(ndim + 1);
auto indicePairNumCpu = indiceNum.to({torch::kCPU}); auto indicePairNumCpu = indiceNum.to({torch::kCPU});
...@@ -248,13 +253,13 @@ indiceConvBackward(torch::Tensor features, torch::Tensor filters, ...@@ -248,13 +253,13 @@ indiceConvBackward(torch::Tensor features, torch::Tensor filters,
continue; continue;
} }
if (device == torch::kCPU) { if (device == torch::kCPU) {
sparse_gather_cpu(inputBuffer, features, indicePairs[i][inverse], nHot); sparse_gather_cpu(inputBuffer, features, indicePairs[inverse][i], nHot);
sparse_gather_cpu(outputBuffer, outGrad, indicePairs[i][!inverse], nHot); sparse_gather_cpu(outputBuffer, outGrad, indicePairs[!inverse][i], nHot);
} }
#ifdef TV_CUDA #ifdef TV_CUDA
else if (device == torch::kCUDA) { else if (device == torch::kCUDA) {
sparse_gather_cuda(inputBuffer, features, indicePairs[i][inverse], nHot); sparse_gather_cuda(inputBuffer, features, indicePairs[inverse][i], nHot);
sparse_gather_cuda(outputBuffer, outGrad, indicePairs[i][!inverse], nHot); sparse_gather_cuda(outputBuffer, outGrad, indicePairs[!inverse][i], nHot);
} }
#endif #endif
else { else {
...@@ -264,17 +269,19 @@ indiceConvBackward(torch::Tensor features, torch::Tensor filters, ...@@ -264,17 +269,19 @@ indiceConvBackward(torch::Tensor features, torch::Tensor filters,
auto filterGradSub = filtersGrad[i]; auto filterGradSub = filtersGrad[i];
auto outputBufferBlob = torch::from_blob(outputBuffer.data_ptr(), auto outputBufferBlob = torch::from_blob(outputBuffer.data_ptr(),
{nHot, numOutPlanes}, options); {nHot, numOutPlanes}, options);
auto inputBufferBlob = torch::from_blob(inputBuffer.data_ptr(), auto inputBufferBlob =
{nHot, numInPlanes}, options); torch::from_blob(inputBuffer.data_ptr(), {nHot, numInPlanes}, options);
torch::mm_out(filterGradSub, inputBufferBlob.t(), outputBufferBlob); torch::mm_out(filterGradSub, inputBufferBlob.t(), outputBufferBlob);
torch::mm_out(inputBufferBlob, outputBufferBlob, filters[i].t()); torch::mm_out(inputBufferBlob, outputBufferBlob, filters[i].t());
if (device == torch::kCPU) { if (device == torch::kCPU) {
sparse_scatter_add_cpu(inputBuffer, inputGrad, indicePairs[i][inverse], nHot); sparse_scatter_add_cpu(inputBuffer, inputGrad, indicePairs[inverse][i],
nHot);
} }
#ifdef TV_CUDA #ifdef TV_CUDA
else if (device == torch::kCUDA) { else if (device == torch::kCUDA) {
sparse_scatter_add_cuda(inputBuffer, inputGrad, indicePairs[i][inverse], nHot); sparse_scatter_add_cuda(inputBuffer, inputGrad, indicePairs[inverse][i],
nHot);
} }
#endif #endif
else { else {
...@@ -284,5 +291,4 @@ indiceConvBackward(torch::Tensor features, torch::Tensor filters, ...@@ -284,5 +291,4 @@ indiceConvBackward(torch::Tensor features, torch::Tensor filters,
return {inputGrad, filtersGrad.view(filterShape)}; return {inputGrad, filtersGrad.view(filterShape)};
} }
} // namespace spconv } // namespace spconv
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment