Commit 8cbb7d3c authored by yanyan's avatar yanyan
Browse files

1.2 release

parent 105b3892
# Changelog
## [1.2.0] - 2020-05-28
### Added
- add batch gemm support. small performance increasement but more gpu memory usage. you can use algo=spconv.ConvAlgo.Batch to use it.
### Changed
- replace most of 'functor' with c++14 dispatch in c++ code.
### Fixed
- change gather/scatterAdd kernel parameter to support large points.
...@@ -11,6 +11,7 @@ endif() ...@@ -11,6 +11,7 @@ endif()
if(WIN32) # true if windows (32 and 64 bit) if(WIN32) # true if windows (32 and 64 bit)
add_compile_definitions(TV_WINDOWS) add_compile_definitions(TV_WINDOWS)
endif() endif()
add_compile_definitions(PYTORCH_VERSION=${PYTORCH_VERSION})
set(CMAKE_CXX_EXTENSIONS OFF) # avoid gnu++11 be added to CXX flags set(CMAKE_CXX_EXTENSIONS OFF) # avoid gnu++11 be added to CXX flags
if(CMAKE_BUILD_TYPE STREQUAL "Debug") if(CMAKE_BUILD_TYPE STREQUAL "Debug")
......
...@@ -15,7 +15,15 @@ ...@@ -15,7 +15,15 @@
#ifndef REORDERING_CU_H_ #ifndef REORDERING_CU_H_
#define REORDERING_CU_H_ #define REORDERING_CU_H_
#include <THC/THCAtomics.cuh> #include <THC/THCAtomics.cuh>
#include <THC/THCNumerics.cuh>
#include <tensorview/kernel_utils.h> #include <tensorview/kernel_utils.h>
#if PYTORCH_VERSION < 10500
#define TH_ATOMIC_ADD atomicAdd
#else
#define TH_ATOMIC_ADD gpuAtomicAdd
#endif
// see http://www.nvidia.com/content/GTC-2010/pdfs/2238_GTC2010.pdf. // see http://www.nvidia.com/content/GTC-2010/pdfs/2238_GTC2010.pdf.
namespace spconv { namespace spconv {
...@@ -78,21 +86,21 @@ template <typename T, typename Index, int NumTLP, int NumILP, ...@@ -78,21 +86,21 @@ template <typename T, typename Index, int NumTLP, int NumILP,
__global__ void gatherVecBlockKernel(T *buffer, const T *features, __global__ void gatherVecBlockKernel(T *buffer, const T *features,
const Index *indices, int size, const Index *indices, int size,
int numPlanes) { int numPlanes) {
int ILPStrideY[NumILP]; int ILPStrideX[NumILP];
#pragma unroll #pragma unroll
for (int ilp = 0; ilp < NumILP; ilp++) for (int ilp = 0; ilp < NumILP; ilp++)
ILPStrideY[ilp] = ilp * gridDim.y * blockDim.y; ILPStrideX[ilp] = ilp * gridDim.x * blockDim.x;
features += blockIdx.x * NumTLP; features += blockIdx.y * NumTLP;
buffer += blockIdx.x * NumTLP; buffer += blockIdx.y * NumTLP;
for (int iy : tv::KernelLoopY<int, NumILP>(size)) { for (int ix : tv::KernelLoopX<int, NumILP>(size)) {
#pragma unroll #pragma unroll
for (int ilp = 0; ilp < NumILP; ++ilp) { for (int ilp = 0; ilp < NumILP; ++ilp) {
reinterpret_cast<VecType *>( reinterpret_cast<VecType *>(
buffer)[(iy + ILPStrideY[ilp]) * numPlanes + threadIdx.x] = buffer)[(ix + ILPStrideX[ilp]) * numPlanes + threadIdx.y] =
reinterpret_cast<const VecType *>( reinterpret_cast<const VecType *>(
features)[indices[iy + ILPStrideY[ilp]] * numPlanes + features)[indices[ix + ILPStrideX[ilp]] * numPlanes +
threadIdx.x]; threadIdx.y];
} }
} }
} }
...@@ -124,22 +132,33 @@ __global__ void batchGatherGenericKernel(T *buffer, const T *features, ...@@ -124,22 +132,33 @@ __global__ void batchGatherGenericKernel(T *buffer, const T *features,
for (int iy : tv::KernelLoopY<int>(numPlanes)) { for (int iy : tv::KernelLoopY<int>(numPlanes)) {
#pragma unroll #pragma unroll
for (int ilp = 0; ilp < NumILP; ++ilp) { for (int ilp = 0; ilp < NumILP; ++ilp) {
if (ix + ILPStrideX[ilp] < size && inds[ilp] != -1) if (ix + ILPStrideX[ilp] < size) {
buffer[(ix + ILPStrideX[ilp]) * numPlanes + iy] = if (inds[ilp] != -1) {
features[inds[ilp] * numPlanes + iy]; buffer[(ix + ILPStrideX[ilp]) * numPlanes + iy] =
features[inds[ilp] * numPlanes + iy];
} else {
buffer[(ix + ILPStrideX[ilp]) * numPlanes + iy] = T(0);
}
}
} }
} }
} }
} }
template <typename T, typename Index, int NumTLP, int NumILP, typename VecType> template <typename T, typename Index, int NumTLP, int NumILP, typename VecType>
__global__ void batchGatherVecKernel(T *buffer, const T *features, __global__ void
const Index *indices, int size, batchGatherVecKernel(T *buffer, const T *features, const Index *indices,
int feature_offset, int size, int feature_offset, int numPlanes,
int numPlanes, int indice_batch_stride, int indice_batch_stride, int feature_batch_stride) {
int feature_batch_stride) {
int ILPStrideX[NumILP]; int ILPStrideX[NumILP];
Index inds[NumILP]; Index inds[NumILP];
Index zero[sizeof(VecType) / sizeof(T)];
#pragma unroll
for (int i = 0; i < sizeof(VecType) / sizeof(T); ++i) {
zero[i] = T(0);
}
Index inds_elem; Index inds_elem;
#pragma unroll #pragma unroll
for (int ilp = 0; ilp < NumILP; ilp++) for (int ilp = 0; ilp < NumILP; ilp++)
...@@ -158,11 +177,19 @@ __global__ void batchGatherVecKernel(T *buffer, const T *features, ...@@ -158,11 +177,19 @@ __global__ void batchGatherVecKernel(T *buffer, const T *features,
for (int iy : tv::KernelLoopY<int>(numPlanes)) { for (int iy : tv::KernelLoopY<int>(numPlanes)) {
#pragma unroll #pragma unroll
for (int ilp = 0; ilp < NumILP; ++ilp) { for (int ilp = 0; ilp < NumILP; ++ilp) {
if (ix + ILPStrideX[ilp] < size && inds[ilp] != -1) if (ix + ILPStrideX[ilp] < size) {
reinterpret_cast<VecType *>( if (inds[ilp] != -1) {
buffer)[(ix + ILPStrideX[ilp]) * numPlanes + iy] = reinterpret_cast<VecType *>(
reinterpret_cast<const VecType *>( buffer)[(ix + ILPStrideX[ilp]) * numPlanes + iy] =
features)[inds[ilp] * numPlanes + iy]; reinterpret_cast<const VecType *>(
features)[inds[ilp] * numPlanes + iy];
} else {
reinterpret_cast<VecType *>(
buffer)[(ix + ILPStrideX[ilp]) * numPlanes + iy] =
reinterpret_cast<const VecType *>(&zero)[0];
}
}
} }
} }
} }
...@@ -174,29 +201,38 @@ __global__ void ...@@ -174,29 +201,38 @@ __global__ void
batchGatherVecBlockKernel(T *buffer, const T *features, const Index *indices, batchGatherVecBlockKernel(T *buffer, const T *features, const Index *indices,
int size, int numPlanes, int indice_batch_stride, int size, int numPlanes, int indice_batch_stride,
int feature_batch_stride) { int feature_batch_stride) {
int ILPStrideY[NumILP]; int ILPStrideX[NumILP];
Index inds; Index inds;
#pragma unroll #pragma unroll
for (int ilp = 0; ilp < NumILP; ilp++) for (int ilp = 0; ilp < NumILP; ilp++)
ILPStrideY[ilp] = ilp * gridDim.y * blockDim.y; ILPStrideX[ilp] = ilp * gridDim.x * blockDim.x;
features += blockIdx.x * NumTLP; features += blockIdx.y * NumTLP;
buffer += blockIdx.x * NumTLP; buffer += blockIdx.y * NumTLP;
Index inds_elem; Index inds_elem;
Index zero[sizeof(VecType) / sizeof(T)];
#pragma unroll
for (int i = 0; i < sizeof(VecType) / sizeof(T); ++i) {
zero[i] = T(0);
}
for (int iy : tv::KernelLoopY<int, NumILP>(size)) { for (int ix : tv::KernelLoopX<int, NumILP>(size)) {
#pragma unroll #pragma unroll
for (int ilp = 0; ilp < NumILP; ++ilp) { for (int ilp = 0; ilp < NumILP; ++ilp) {
inds_elem = iy + ILPStrideY[ilp]; inds_elem = ix + ILPStrideX[ilp];
inds = indices[(inds_elem / feature_batch_stride) * indice_batch_stride + inds = indices[(inds_elem / feature_batch_stride) * indice_batch_stride +
inds_elem % feature_batch_stride]; inds_elem % feature_batch_stride];
if (inds != -1) { if (inds != -1) {
reinterpret_cast<VecType *>( reinterpret_cast<VecType *>(
buffer)[(iy + ILPStrideY[ilp]) * numPlanes + threadIdx.x] = buffer)[(ix + ILPStrideX[ilp]) * numPlanes + threadIdx.y] =
reinterpret_cast<const VecType *>( reinterpret_cast<const VecType *>(
features)[inds * numPlanes + threadIdx.x]; features)[inds * numPlanes + threadIdx.y];
} else {
reinterpret_cast<VecType *>(
buffer)[(ix + ILPStrideX[ilp]) * numPlanes + threadIdx.y] =
reinterpret_cast<const VecType *>(&zero)[0];
} }
} }
} }
...@@ -234,24 +270,24 @@ template <typename T, typename Index, int NumTLP, int NumILP, ...@@ -234,24 +270,24 @@ template <typename T, typename Index, int NumTLP, int NumILP,
__global__ void scatterAddVecBlockKernel(T *outFeatures, const T *buffer, __global__ void scatterAddVecBlockKernel(T *outFeatures, const T *buffer,
const Index *indices, int size, const Index *indices, int size,
int numPlanes) { int numPlanes) {
int ILPStrideY[NumILP]; int ILPStrideX[NumILP];
constexpr int vecloadFactor = sizeof(VecType) / sizeof(T); constexpr int vecloadFactor = sizeof(VecType) / sizeof(T);
#pragma unroll #pragma unroll
for (int ilp = 0; ilp < NumILP; ilp++) for (int ilp = 0; ilp < NumILP; ilp++)
ILPStrideY[ilp] = ilp * gridDim.y * blockDim.y; ILPStrideX[ilp] = ilp * gridDim.x * blockDim.x;
outFeatures += blockIdx.x * NumTLP; outFeatures += blockIdx.y * NumTLP;
buffer += blockIdx.x * NumTLP; buffer += blockIdx.y * NumTLP;
T buf[vecloadFactor]; T buf[vecloadFactor];
T buf2[vecloadFactor]; T buf2[vecloadFactor];
Index idx; Index idx;
for (int iy : tv::KernelLoopY<int, NumILP>(size)) { for (int ix : tv::KernelLoopX<int, NumILP>(size)) {
#pragma unroll #pragma unroll
for (int ilp = 0; ilp < NumILP; ++ilp) { for (int ilp = 0; ilp < NumILP; ++ilp) {
idx = indices[iy + ILPStrideY[ilp]] * numPlanes + threadIdx.x; idx = indices[ix + ILPStrideX[ilp]] * numPlanes + threadIdx.y;
reinterpret_cast<VecType *>(buf)[0] = reinterpret_cast<VecType *>(buf)[0] =
reinterpret_cast<VecType *>(outFeatures)[idx]; reinterpret_cast<VecType *>(outFeatures)[idx];
reinterpret_cast<VecType *>(buf2)[0] = reinterpret_cast<const VecType *>( reinterpret_cast<VecType *>(buf2)[0] = reinterpret_cast<const VecType *>(
buffer)[(iy + ILPStrideY[ilp]) * numPlanes + threadIdx.x]; buffer)[(ix + ILPStrideX[ilp]) * numPlanes + threadIdx.y];
#pragma unroll #pragma unroll
for (int i = 0; i < vecloadFactor; i++) { for (int i = 0; i < vecloadFactor; i++) {
buf[i] += buf2[i]; buf[i] += buf2[i];
...@@ -268,6 +304,10 @@ __global__ void batchScatterAddGenericKernel(T *outFeatures, const T *buffer, ...@@ -268,6 +304,10 @@ __global__ void batchScatterAddGenericKernel(T *outFeatures, const T *buffer,
int feature_offset, int numPlanes, int feature_offset, int numPlanes,
int indice_batch_stride, int indice_batch_stride,
int feature_batch_stride) { int feature_batch_stride) {
// batch scatter add is greatly slower than native scatter when the number of
// points is large. this may due to atomicAdd?
// batch scatter add is greatly faster than native when the number of points
// is small.
int ILPStrideX[NumILP]; int ILPStrideX[NumILP];
Index inds[NumILP]; Index inds[NumILP];
Index inds_elem; Index inds_elem;
...@@ -288,8 +328,8 @@ __global__ void batchScatterAddGenericKernel(T *outFeatures, const T *buffer, ...@@ -288,8 +328,8 @@ __global__ void batchScatterAddGenericKernel(T *outFeatures, const T *buffer,
#pragma unroll #pragma unroll
for (int ilp = 0; ilp < NumILP; ++ilp) { for (int ilp = 0; ilp < NumILP; ++ilp) {
if (ix + ILPStrideX[ilp] < size && inds[ilp] != -1) { if (ix + ILPStrideX[ilp] < size && inds[ilp] != -1) {
gpuAtomicAdd(outFeatures + inds[ilp] * numPlanes + iy, TH_ATOMIC_ADD(outFeatures + inds[ilp] * numPlanes + iy,
buffer[(ix + ILPStrideX[ilp]) * numPlanes + iy]); buffer[(ix + ILPStrideX[ilp]) * numPlanes + iy]);
} }
} }
} }
...@@ -301,22 +341,22 @@ __global__ void ...@@ -301,22 +341,22 @@ __global__ void
batchScatterAddBlockKernel(T *outFeatures, const T *buffer, batchScatterAddBlockKernel(T *outFeatures, const T *buffer,
const Index *indices, int size, int numPlanes, const Index *indices, int size, int numPlanes,
int indice_batch_stride, int feature_batch_stride) { int indice_batch_stride, int feature_batch_stride) {
int ILPStrideY[NumILP]; int ILPStrideX[NumILP];
#pragma unroll #pragma unroll
for (int ilp = 0; ilp < NumILP; ilp++) for (int ilp = 0; ilp < NumILP; ilp++)
ILPStrideY[ilp] = ilp * gridDim.y * blockDim.y; ILPStrideX[ilp] = ilp * gridDim.x * blockDim.x;
outFeatures += blockIdx.x * NumTLP; outFeatures += blockIdx.y * NumTLP;
buffer += blockIdx.x * NumTLP; buffer += blockIdx.y * NumTLP;
Index inds, inds_elem; Index inds, inds_elem;
for (int iy : tv::KernelLoopY<int, NumILP>(size)) { for (int ix : tv::KernelLoopX<int, NumILP>(size)) {
#pragma unroll #pragma unroll
for (int ilp = 0; ilp < NumILP; ++ilp) { for (int ilp = 0; ilp < NumILP; ++ilp) {
inds_elem = iy + ILPStrideY[ilp]; inds_elem = ix + ILPStrideX[ilp];
inds = indices[(inds_elem / feature_batch_stride) * indice_batch_stride + inds = indices[(inds_elem / feature_batch_stride) * indice_batch_stride +
inds_elem % feature_batch_stride]; inds_elem % feature_batch_stride];
if (inds != -1) { if (inds != -1) {
gpuAtomicAdd(outFeatures + inds * numPlanes + threadIdx.x, TH_ATOMIC_ADD(outFeatures + inds * numPlanes + threadIdx.y,
buffer[(iy + ILPStrideY[ilp]) * numPlanes + threadIdx.x]); buffer[(ix + ILPStrideX[ilp]) * numPlanes + threadIdx.y]);
} }
} }
} }
...@@ -324,4 +364,6 @@ batchScatterAddBlockKernel(T *outFeatures, const T *buffer, ...@@ -324,4 +364,6 @@ batchScatterAddBlockKernel(T *outFeatures, const T *buffer,
} // namespace spconv } // namespace spconv
#undef TH_ATOMIC_ADD
#endif #endif
\ No newline at end of file
...@@ -20,10 +20,10 @@ ...@@ -20,10 +20,10 @@
namespace spconv { namespace spconv {
void batch_sparse_gather_cuda(torch::Tensor buffer, torch::Tensor features, void batch_sparse_gather_cuda(torch::Tensor buffer, torch::Tensor features,
torch::Tensor indices, int size); torch::Tensor indices, int size);
void batch_sparse_scatter_add_cuda(torch::Tensor buffer, torch::Tensor outFeatures, void batch_sparse_scatter_add_cuda(torch::Tensor buffer,
torch::Tensor indices, int size); torch::Tensor outFeatures,
torch::Tensor indices, int size);
void sparse_gather_cuda(torch::Tensor buffer, torch::Tensor features, void sparse_gather_cuda(torch::Tensor buffer, torch::Tensor features,
torch::Tensor indices, int size); torch::Tensor indices, int size);
......
...@@ -23,10 +23,7 @@ ...@@ -23,10 +23,7 @@
namespace spconv { namespace spconv {
enum ConvAlgo { enum ConvAlgo { kNative = 0, kBatch = 1, kBatchGemmGather = 2 };
kNative = 0,
kBatchGemm = 1
};
// torch.jit's doc says only support int64, so we need to convert to int32. // torch.jit's doc says only support int64, so we need to convert to int32.
template <unsigned NDim> template <unsigned NDim>
...@@ -345,8 +342,10 @@ std::vector<torch::Tensor> getIndicePairPreGrid( ...@@ -345,8 +342,10 @@ std::vector<torch::Tensor> getIndicePairPreGrid(
} }
} }
torch::Tensor indiceConvBatch(torch::Tensor features, torch::Tensor filters, torch::Tensor indiceConvBatch(torch::Tensor features, torch::Tensor filters,
torch::Tensor indicePairs, torch::Tensor indiceNum, torch::Tensor indicePairs,
int64_t numActOut, int64_t _inverse, int64_t _subM); torch::Tensor indiceNum, int64_t numActOut,
int64_t _inverse, int64_t _subM,
bool batchScatter);
torch::Tensor indiceConv(torch::Tensor features, torch::Tensor filters, torch::Tensor indiceConv(torch::Tensor features, torch::Tensor filters,
torch::Tensor indicePairs, torch::Tensor indiceNum, torch::Tensor indicePairs, torch::Tensor indiceNum,
...@@ -355,13 +354,14 @@ torch::Tensor indiceConv(torch::Tensor features, torch::Tensor filters, ...@@ -355,13 +354,14 @@ torch::Tensor indiceConv(torch::Tensor features, torch::Tensor filters,
std::vector<torch::Tensor> std::vector<torch::Tensor>
indiceConvBackward(torch::Tensor features, torch::Tensor filters, indiceConvBackward(torch::Tensor features, torch::Tensor filters,
torch::Tensor outGrad, torch::Tensor indicePairs, torch::Tensor outGrad, torch::Tensor indicePairs,
torch::Tensor indiceNum, int64_t _inverse, int64_t _subM, int64_t algo); torch::Tensor indiceNum, int64_t _inverse, int64_t _subM,
int64_t algo);
std::vector<torch::Tensor> std::vector<torch::Tensor>
indiceConvBackwardBatch(torch::Tensor features, torch::Tensor filters, indiceConvBackwardBatch(torch::Tensor features, torch::Tensor filters,
torch::Tensor outGrad, torch::Tensor indicePairs, torch::Tensor outGrad, torch::Tensor indicePairs,
torch::Tensor indiceNum, int64_t _inverse, torch::Tensor indiceNum, int64_t _inverse,
int64_t _subM); int64_t _subM, bool batchScatter);
} // namespace spconv } // namespace spconv
#endif #endif
\ No newline at end of file
...@@ -80,6 +80,8 @@ public: ...@@ -80,6 +80,8 @@ public:
} }
} else { } else {
#ifdef TV_CUDA #ifdef TV_CUDA
// we should select device in external
/*
int deviceCount; int deviceCount;
cudaGetDeviceCount(&deviceCount); cudaGetDeviceCount(&deviceCount);
if (device >= deviceCount) { if (device >= deviceCount) {
...@@ -87,6 +89,7 @@ public: ...@@ -87,6 +89,7 @@ public:
" but you only have ", deviceCount, " device."); " but you only have ", deviceCount, " device.");
} }
cudaSetDevice(device); cudaSetDevice(device);
*/
if (managed) { if (managed) {
checkCudaErrors(cudaMallocManaged(&this->mPtr, size * sizeof(T))); checkCudaErrors(cudaMallocManaged(&this->mPtr, size * sizeof(T)));
} else { } else {
......
...@@ -125,6 +125,21 @@ TensorView<T, Rank, PtrTraits, Tindex> torch2tv(const torch::Tensor &tensor) { ...@@ -125,6 +125,21 @@ TensorView<T, Rank, PtrTraits, Tindex> torch2tv(const torch::Tensor &tensor) {
return tv::TensorView<T, Rank, PtrTraits, Tindex>( return tv::TensorView<T, Rank, PtrTraits, Tindex>(
tensor.data_ptr<std::remove_const_t<T>>(), shape); tensor.data_ptr<std::remove_const_t<T>>(), shape);
} }
template <typename T>
torch::Tensor torch_slice_first_axis(torch::Tensor tensor, T start, T end) {
// only torch >= 1.5 have tensor slice.
torch::Tensor res;
auto tensor_shape = tensor.sizes();
std::vector<int64_t> shape(tensor_shape.begin(), tensor_shape.end());
shape[0] = end - start;
auto dtype = tensor.scalar_type();
uint8_t *ptr = reinterpret_cast<uint8_t *>(tensor.data_ptr());
res = torch::from_blob(ptr + start * tensor.stride(0) * tensor.itemsize(),
torch::IntArrayRef(shape), tensor.options());
return res;
}
namespace detail { namespace detail {
template <> struct TypeToString<at::Half> { template <> struct TypeToString<at::Half> {
static constexpr const char *value = "half"; static constexpr const char *value = "half";
......
...@@ -18,6 +18,8 @@ LIBTORCH_ROOT = str(Path(torch.__file__).parent) ...@@ -18,6 +18,8 @@ LIBTORCH_ROOT = str(Path(torch.__file__).parent)
SPCONV_FORCE_BUILD_CUDA = os.getenv("SPCONV_FORCE_BUILD_CUDA") SPCONV_FORCE_BUILD_CUDA = os.getenv("SPCONV_FORCE_BUILD_CUDA")
PYTHON_VERSION = "{}.{}".format(sys.version_info.major, sys.version_info.minor) PYTHON_VERSION = "{}.{}".format(sys.version_info.major, sys.version_info.minor)
PYTORCH_VERSION = list(map(int, torch.__version__.split(".")))
PYTORCH_VERSION_NUMBER = PYTORCH_VERSION[0] * 10000 + PYTORCH_VERSION[1] * 100 + PYTORCH_VERSION[2]
class CMakeExtension(Extension): class CMakeExtension(Extension):
def __init__(self, name, sourcedir='', library_dirs=[]): def __init__(self, name, sourcedir='', library_dirs=[]):
...@@ -47,6 +49,7 @@ class CMakeBuild(build_ext): ...@@ -47,6 +49,7 @@ class CMakeBuild(build_ext):
'-DCMAKE_PREFIX_PATH={}'.format(LIBTORCH_ROOT), '-DCMAKE_PREFIX_PATH={}'.format(LIBTORCH_ROOT),
'-DPYBIND11_PYTHON_VERSION={}'.format(PYTHON_VERSION), '-DPYBIND11_PYTHON_VERSION={}'.format(PYTHON_VERSION),
'-DSPCONV_BuildTests=OFF', '-DSPCONV_BuildTests=OFF',
'-DPYTORCH_VERSION={}'.format(PYTORCH_VERSION_NUMBER)
] # -arch=sm_61 ] # -arch=sm_61
if not torch.cuda.is_available() and SPCONV_FORCE_BUILD_CUDA is None: if not torch.cuda.is_available() and SPCONV_FORCE_BUILD_CUDA is None:
cmake_args += ['-DSPCONV_BuildCUDA=OFF'] cmake_args += ['-DSPCONV_BuildCUDA=OFF']
......
...@@ -19,12 +19,12 @@ import numpy as np ...@@ -19,12 +19,12 @@ import numpy as np
import torch import torch
from spconv import ops, utils from spconv import ops, utils
from spconv.ops import ConvAlgo
from spconv.conv import (SparseConv2d, SparseConv3d, SparseConvTranspose2d, from spconv.conv import (SparseConv2d, SparseConv3d, SparseConvTranspose2d,
SparseConvTranspose3d, SparseInverseConv2d, SparseConvTranspose3d, SparseInverseConv2d,
SparseInverseConv3d, SubMConv2d, SubMConv3d) SparseInverseConv3d, SubMConv2d, SubMConv3d)
from spconv.identity import Identity from spconv.identity import Identity
from spconv.modules import SparseModule, SparseSequential from spconv.modules import SparseModule, SparseSequential
from spconv.ops import ConvAlgo
from spconv.pool import SparseMaxPool2d, SparseMaxPool3d from spconv.pool import SparseMaxPool2d, SparseMaxPool3d
from spconv.tables import AddTable, ConcatTable, JoinTable from spconv.tables import AddTable, ConcatTable, JoinTable
...@@ -62,7 +62,7 @@ class SparseConvTensor(object): ...@@ -62,7 +62,7 @@ class SparseConvTensor(object):
self.features = features self.features = features
self.indices = indices self.indices = indices
if self.indices.dtype != torch.int32: if self.indices.dtype != torch.int32:
self.indices.int() self.indices = self.indices.int()
self.spatial_shape = spatial_shape self.spatial_shape = spatial_shape
self.batch_size = batch_size self.batch_size = batch_size
self.indice_dict = {} self.indice_dict = {}
...@@ -82,7 +82,8 @@ class SparseConvTensor(object): ...@@ -82,7 +82,8 @@ class SparseConvTensor(object):
def dense(self, channels_first=True): def dense(self, channels_first=True):
output_shape = [self.batch_size] + list( output_shape = [self.batch_size] + list(
self.spatial_shape) + [self.features.shape[1]] self.spatial_shape) + [self.features.shape[1]]
res = scatter_nd(self.indices.long().to(self.features.device), self.features, output_shape) res = scatter_nd(self.indices.long().to(self.features.device),
self.features, output_shape)
if not channels_first: if not channels_first:
return res return res
ndim = len(self.spatial_shape) ndim = len(self.spatial_shape)
......
...@@ -25,16 +25,25 @@ class SparseConvFunction(Function): ...@@ -25,16 +25,25 @@ class SparseConvFunction(Function):
num_activate_out, algo): num_activate_out, algo):
ctx.save_for_backward(indice_pairs, indice_pair_num, features, filters) ctx.save_for_backward(indice_pairs, indice_pair_num, features, filters)
ctx.algo = algo ctx.algo = algo
return ops.indice_conv(features, filters, indice_pairs, return ops.indice_conv(features,
indice_pair_num, num_activate_out, False, algo=algo) filters,
indice_pairs,
indice_pair_num,
num_activate_out,
False,
algo=algo)
@staticmethod @staticmethod
def backward(ctx, grad_output): def backward(ctx, grad_output):
indice_pairs, indice_pair_num, features, filters = ctx.saved_tensors indice_pairs, indice_pair_num, features, filters = ctx.saved_tensors
input_bp, filters_bp = ops.indice_conv_backward( input_bp, filters_bp = ops.indice_conv_backward(features,
features, filters, grad_output, indice_pairs, indice_pair_num, filters,
False, algo=ctx.algo) grad_output,
indice_pairs,
indice_pair_num,
False,
algo=ctx.algo)
return input_bp, filters_bp, None, None, None, None return input_bp, filters_bp, None, None, None, None
...@@ -45,15 +54,26 @@ class SparseInverseConvFunction(Function): ...@@ -45,15 +54,26 @@ class SparseInverseConvFunction(Function):
num_activate_out, algo): num_activate_out, algo):
ctx.save_for_backward(indice_pairs, indice_pair_num, features, filters) ctx.save_for_backward(indice_pairs, indice_pair_num, features, filters)
ctx.algo = algo ctx.algo = algo
return ops.indice_conv(features, filters, indice_pairs, return ops.indice_conv(features,
indice_pair_num, num_activate_out, True, False, algo=algo) filters,
indice_pairs,
indice_pair_num,
num_activate_out,
True,
False,
algo=algo)
@staticmethod @staticmethod
def backward(ctx, grad_output): def backward(ctx, grad_output):
indice_pairs, indice_pair_num, features, filters = ctx.saved_tensors indice_pairs, indice_pair_num, features, filters = ctx.saved_tensors
input_bp, filters_bp = ops.indice_conv_backward( input_bp, filters_bp = ops.indice_conv_backward(features,
features, filters, grad_output, indice_pairs, indice_pair_num, filters,
True, False, algo=ctx.algo) grad_output,
indice_pairs,
indice_pair_num,
True,
False,
algo=ctx.algo)
return input_bp, filters_bp, None, None, None, None return input_bp, filters_bp, None, None, None, None
...@@ -64,15 +84,26 @@ class SubMConvFunction(Function): ...@@ -64,15 +84,26 @@ class SubMConvFunction(Function):
num_activate_out, algo): num_activate_out, algo):
ctx.save_for_backward(indice_pairs, indice_pair_num, features, filters) ctx.save_for_backward(indice_pairs, indice_pair_num, features, filters)
ctx.algo = algo ctx.algo = algo
return ops.indice_conv(features, filters, indice_pairs, return ops.indice_conv(features,
indice_pair_num, num_activate_out, False, True, algo=algo) filters,
indice_pairs,
indice_pair_num,
num_activate_out,
False,
True,
algo=algo)
@staticmethod @staticmethod
def backward(ctx, grad_output): def backward(ctx, grad_output):
indice_pairs, indice_pair_num, features, filters = ctx.saved_tensors indice_pairs, indice_pair_num, features, filters = ctx.saved_tensors
input_bp, filters_bp = ops.indice_conv_backward( input_bp, filters_bp = ops.indice_conv_backward(features,
features, filters, grad_output, indice_pairs, indice_pair_num, filters,
False, True, algo=ctx.algo) grad_output,
indice_pairs,
indice_pair_num,
False,
True,
algo=ctx.algo)
return input_bp, filters_bp, None, None, None, None return input_bp, filters_bp, None, None, None, None
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import sys
import time import time
from collections import OrderedDict from collections import OrderedDict
......
...@@ -12,15 +12,18 @@ ...@@ -12,15 +12,18 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from enum import Enum
import torch import torch
import spconv import spconv
from enum import Enum
class ConvAlgo(Enum): class ConvAlgo(Enum):
Native = 0 Native = 0 # small memory cost, faster when number of points is large.
BatchGemm = 1 Batch = 1 # high memory cost, faster when number of points is small (< 50000)
BatchGemmGather = 2 # high memory cost, faster when number of points medium
def get_conv_output_size(input_size, kernel_size, stride, padding, dilation): def get_conv_output_size(input_size, kernel_size, stride, padding, dilation):
ndim = len(input_size) ndim = len(input_size)
...@@ -59,7 +62,7 @@ def get_indice_pairs(indices, ...@@ -59,7 +62,7 @@ def get_indice_pairs(indices,
subm=False, subm=False,
transpose=False, transpose=False,
grid=None, grid=None,
use_hash=True): use_hash=False):
ndim = indices.shape[1] - 1 ndim = indices.shape[1] - 1
if not isinstance(ksize, (list, tuple)): if not isinstance(ksize, (list, tuple)):
ksize = [ksize] * ndim ksize = [ksize] * ndim
...@@ -133,7 +136,7 @@ def indice_conv_backward(features, ...@@ -133,7 +136,7 @@ def indice_conv_backward(features,
indice_pair_num, indice_pair_num,
inverse=False, inverse=False,
subm=False, subm=False,
algo=ConvAlgo.Native.value): algo=ConvAlgo.Native.value):
return torch.ops.spconv.indice_conv_backward(features, filters, out_bp, return torch.ops.spconv.indice_conv_backward(features, filters, out_bp,
indice_pairs, indice_pair_num, indice_pairs, indice_pair_num,
int(inverse), int(subm), algo) int(inverse), int(subm), algo)
......
...@@ -25,7 +25,6 @@ ...@@ -25,7 +25,6 @@
#include <tensorview/torch_utils.h> #include <tensorview/torch_utils.h>
#include <type_traits> #include <type_traits>
#include <utility/timer.h> #include <utility/timer.h>
namespace spconv { namespace spconv {
using float_types_t = tv::mp_list<float, double, at::Half>; using float_types_t = tv::mp_list<float, double, at::Half>;
...@@ -48,7 +47,7 @@ void sparse_gather_cuda(torch::Tensor buffer, torch::Tensor features, ...@@ -48,7 +47,7 @@ void sparse_gather_cuda(torch::Tensor buffer, torch::Tensor features,
using Index = decltype(IndexValue); using Index = decltype(IndexValue);
bool notFound = true; bool notFound = true;
constexpr int vecloadFactor = sizeof(vecload_type_t) / sizeof(T); constexpr int vecloadFactor = sizeof(vecload_type_t) / sizeof(T);
tv::mp_for_each<kernel_block_t>( tv::mp_for_each<kernel_block_t>(
[=, &buffer, &features, &indices, &notFound](auto NumTLP) { [=, &buffer, &features, &indices, &notFound](auto NumTLP) {
constexpr int NumILP = NumTLP / 4; constexpr int NumILP = NumTLP / 4;
...@@ -59,8 +58,8 @@ void sparse_gather_cuda(torch::Tensor buffer, torch::Tensor features, ...@@ -59,8 +58,8 @@ void sparse_gather_cuda(torch::Tensor buffer, torch::Tensor features,
if (nHotBlock >= NumTLP) { if (nHotBlock >= NumTLP) {
gatherVecBlockKernel<T, Index, int(NumTLP), NumILP, gatherVecBlockKernel<T, Index, int(NumTLP), NumILP,
vecload_type_t> vecload_type_t>
<<<dim3(numPlanes / NumTLP, size / NumTLP), <<<dim3(size / NumTLP, numPlanes / NumTLP),
dim3(NumTLP / vecloadFactor, NumTLP / NumILP), 0, dim3(NumTLP / NumILP, NumTLP / vecloadFactor), 0,
stream>>>(buffer.data_ptr<T>(), features.data_ptr<T>(), stream>>>(buffer.data_ptr<T>(), features.data_ptr<T>(),
indices.data_ptr<Index>(), nHotBlock, indices.data_ptr<Index>(), nHotBlock,
numPlanes / vecloadFactor); numPlanes / vecloadFactor);
...@@ -115,7 +114,7 @@ void sparse_scatter_add_cuda(torch::Tensor buffer, torch::Tensor outFeatures, ...@@ -115,7 +114,7 @@ void sparse_scatter_add_cuda(torch::Tensor buffer, torch::Tensor outFeatures,
bool notFound = true; bool notFound = true;
constexpr int vecloadFactor = constexpr int vecloadFactor =
sizeof(vecload_type_t) / sizeof(T); // important for half. sizeof(vecload_type_t) / sizeof(T); // important for half.
tv::mp_for_each<kernel_block_t>([=, &outFeatures, &buffer, &indices, tv::mp_for_each<kernel_block_t>([=, &outFeatures, &buffer, &indices,
&notFound](auto NumTLP) { &notFound](auto NumTLP) {
// constexpr int NumILP = NumTLP / (64 / (NumTLP / // constexpr int NumILP = NumTLP / (64 / (NumTLP /
...@@ -127,8 +126,8 @@ void sparse_scatter_add_cuda(torch::Tensor buffer, torch::Tensor outFeatures, ...@@ -127,8 +126,8 @@ void sparse_scatter_add_cuda(torch::Tensor buffer, torch::Tensor outFeatures,
if (nHotBlock >= NumTLP) { if (nHotBlock >= NumTLP) {
scatterAddVecBlockKernel<T, Index, int(NumTLP), NumILP, scatterAddVecBlockKernel<T, Index, int(NumTLP), NumILP,
vecload_type_t> vecload_type_t>
<<<dim3(numPlanes / NumTLP, size / NumTLP), <<<dim3(size / NumTLP, numPlanes / NumTLP),
dim3(NumTLP / vecloadFactor, NumTLP / NumILP), 0, dim3(NumTLP / NumILP, NumTLP / vecloadFactor), 0,
stream>>>(outFeatures.data_ptr<T>(), buffer.data_ptr<T>(), stream>>>(outFeatures.data_ptr<T>(), buffer.data_ptr<T>(),
indices.data_ptr<Index>(), nHotBlock, indices.data_ptr<Index>(), nHotBlock,
numPlanes / vecloadFactor); numPlanes / vecloadFactor);
...@@ -194,31 +193,31 @@ void batch_sparse_gather_cuda(torch::Tensor buffer, torch::Tensor features, ...@@ -194,31 +193,31 @@ void batch_sparse_gather_cuda(torch::Tensor buffer, torch::Tensor features,
if (nHotBlock >= NumTLP) { if (nHotBlock >= NumTLP) {
batchGatherVecBlockKernel<T, Index, int(NumTLP), NumILP, batchGatherVecBlockKernel<T, Index, int(NumTLP), NumILP,
vecload_type_t> vecload_type_t>
<<<dim3(numPlanes / NumTLP, size / NumTLP), <<<dim3(size / NumTLP, numPlanes / NumTLP),
dim3(NumTLP / vecloadFactor, NumTLP / NumILP), 0, dim3(NumTLP / NumILP, NumTLP / vecloadFactor), 0,
stream>>>(buffer.data_ptr<T>(), features.data_ptr<T>(), stream>>>(buffer.data_ptr<T>(), features.data_ptr<T>(),
indices.data_ptr<Index>(), nHotBlock, indices.data_ptr<Index>(), nHotBlock,
numPlanes / vecloadFactor, inds_stride, numPlanes / vecloadFactor, inds_stride,
feature_stride); feature_stride);
TV_CHECK_CUDA_ERR_V2("batchGatherVecBlockKernel");
TV_CHECK_CUDA_ERR();
} }
if (size - nHotBlock > 0) { if (size - nHotBlock > 0) {
batchGatherVecKernel<T, Index, int(NumTLP), NumILP, vecload_type_t> batchGatherVecKernel<T, Index, int(NumTLP), NumILP,
vecload_type_t>
<<<dim3(1, numPlanes / NumTLP), <<<dim3(1, numPlanes / NumTLP),
dim3(NumTLP / NumILP, NumTLP / vecloadFactor), 0, dim3(NumTLP / NumILP, NumTLP / vecloadFactor), 0,
stream>>>(buffer.data_ptr<T>() + nHotBlock * numPlanes, stream>>>(buffer.data_ptr<T>() + nHotBlock * numPlanes,
features.data_ptr<T>(), features.data_ptr<T>(),
indices.data_ptr<Index>(), indices.data_ptr<Index>(), size - nHotBlock,
size - nHotBlock, nHotBlock, numPlanes / vecloadFactor, nHotBlock, numPlanes / vecloadFactor,
inds_stride, feature_stride); inds_stride, feature_stride);
TV_CHECK_CUDA_ERR(); TV_CHECK_CUDA_ERR_V2("batchGatherVecKernel");
} }
notFound = false; notFound = false;
} }
} }
}); });
if (notFound) { if (notFound) {
constexpr int NumTLP = 64; constexpr int NumTLP = 64;
constexpr int NumILP = NumTLP / 4; constexpr int NumILP = NumTLP / 4;
...@@ -259,7 +258,7 @@ void batch_sparse_scatter_add_cuda(torch::Tensor buffer, ...@@ -259,7 +258,7 @@ void batch_sparse_scatter_add_cuda(torch::Tensor buffer,
using Index = decltype(IndexValue); using Index = decltype(IndexValue);
bool notFound = true; bool notFound = true;
constexpr int vecloadFactor = 1; // important for half. constexpr int vecloadFactor = 1; // important for half.
tv::mp_for_each<kernel_block_t>([=, &outFeatures, &buffer, &indices, tv::mp_for_each<kernel_block_t>([=, &outFeatures, &buffer, &indices,
&notFound](auto NumTLP) { &notFound](auto NumTLP) {
// constexpr int NumILP = NumTLP / (64 / (NumTLP / // constexpr int NumILP = NumTLP / (64 / (NumTLP /
...@@ -270,12 +269,12 @@ void batch_sparse_scatter_add_cuda(torch::Tensor buffer, ...@@ -270,12 +269,12 @@ void batch_sparse_scatter_add_cuda(torch::Tensor buffer,
if (numPlanes % NumTLP == 0) { if (numPlanes % NumTLP == 0) {
if (nHotBlock >= NumTLP) { if (nHotBlock >= NumTLP) {
batchScatterAddBlockKernel<T, Index, int(NumTLP), NumILP> batchScatterAddBlockKernel<T, Index, int(NumTLP), NumILP>
<<<dim3(numPlanes / NumTLP, size / NumTLP), <<<dim3(size / NumTLP, numPlanes / NumTLP),
dim3(NumTLP / vecloadFactor, NumTLP / NumILP), 0, dim3(NumTLP / NumILP, NumTLP / vecloadFactor), 0,
stream>>>(outFeatures.data_ptr<T>(), buffer.data_ptr<T>(), stream>>>(outFeatures.data_ptr<T>(), buffer.data_ptr<T>(),
indices.data_ptr<Index>(), nHotBlock, indices.data_ptr<Index>(), nHotBlock,
numPlanes / vecloadFactor, inds_stride, numPlanes / vecloadFactor, inds_stride,
feature_stride); feature_stride);
TV_CHECK_CUDA_ERR(); TV_CHECK_CUDA_ERR();
} }
if (size - nHotBlock > 0) { if (size - nHotBlock > 0) {
...@@ -283,8 +282,8 @@ void batch_sparse_scatter_add_cuda(torch::Tensor buffer, ...@@ -283,8 +282,8 @@ void batch_sparse_scatter_add_cuda(torch::Tensor buffer,
<<<dim3(1, numPlanes / NumTLP), dim3(NumTLP / NumILP, NumTLP), <<<dim3(1, numPlanes / NumTLP), dim3(NumTLP / NumILP, NumTLP),
0, stream>>>(outFeatures.data_ptr<T>(), 0, stream>>>(outFeatures.data_ptr<T>(),
buffer.data_ptr<T>() + nHotBlock * numPlanes, buffer.data_ptr<T>() + nHotBlock * numPlanes,
indices.data_ptr<Index>(), indices.data_ptr<Index>(), size - nHotBlock,
size - nHotBlock, nHotBlock, numPlanes, inds_stride, nHotBlock, numPlanes, inds_stride,
feature_stride); feature_stride);
TV_CHECK_CUDA_ERR(); TV_CHECK_CUDA_ERR();
} }
...@@ -292,7 +291,7 @@ void batch_sparse_scatter_add_cuda(torch::Tensor buffer, ...@@ -292,7 +291,7 @@ void batch_sparse_scatter_add_cuda(torch::Tensor buffer,
} }
} }
}); });
if (notFound) { if (notFound) {
constexpr int NumTLP = 64; constexpr int NumTLP = 64;
constexpr int NumILP = NumTLP / 4; constexpr int NumILP = NumTLP / 4;
...@@ -309,4 +308,4 @@ void batch_sparse_scatter_add_cuda(torch::Tensor buffer, ...@@ -309,4 +308,4 @@ void batch_sparse_scatter_add_cuda(torch::Tensor buffer,
}); });
} }
} // namespace spconv } // namespace spconv
\ No newline at end of file
...@@ -139,10 +139,12 @@ torch::Tensor indiceConv(torch::Tensor features, torch::Tensor filters, ...@@ -139,10 +139,12 @@ torch::Tensor indiceConv(torch::Tensor features, torch::Tensor filters,
int64_t algo) { int64_t algo) {
auto kernelVolume = indiceNum.size(0); auto kernelVolume = indiceNum.size(0);
switch (algo) { switch (algo) {
case kBatchGemm: { case kBatchGemmGather:
case kBatch: {
if (kernelVolume != 1) { if (kernelVolume != 1) {
return indiceConvBatch(features, filters, indicePairs, indiceNum, return indiceConvBatch(features, filters, indicePairs, indiceNum,
numActOut, _inverse, _subM); numActOut, _inverse, _subM,
algo != kBatchGemmGather);
} else { } else {
break; break;
} }
...@@ -152,6 +154,8 @@ torch::Tensor indiceConv(torch::Tensor features, torch::Tensor filters, ...@@ -152,6 +154,8 @@ torch::Tensor indiceConv(torch::Tensor features, torch::Tensor filters,
default: default:
TV_THROW_RT_ERR("unknown algo"); TV_THROW_RT_ERR("unknown algo");
} }
// auto timer = spconv::CudaContextTimer<>();
bool subM = _subM != 0; bool subM = _subM != 0;
bool inverse = _inverse != 0; bool inverse = _inverse != 0;
auto device = features.device().type(); auto device = features.device().type();
...@@ -170,10 +174,11 @@ torch::Tensor indiceConv(torch::Tensor features, torch::Tensor filters, ...@@ -170,10 +174,11 @@ torch::Tensor indiceConv(torch::Tensor features, torch::Tensor filters,
torch::TensorOptions().dtype(features.dtype()).device(features.device()); torch::TensorOptions().dtype(features.dtype()).device(features.device());
torch::Tensor output = torch::zeros({numActOut, numOutPlanes}, options); torch::Tensor output = torch::zeros({numActOut, numOutPlanes}, options);
torch::Tensor inputBuffer = torch::Tensor inputBuffer =
torch::zeros({indicePairMaxSize, numInPlanes}, options); torch::empty({indicePairMaxSize, numInPlanes}, options);
torch::Tensor outputBuffer = torch::Tensor outputBuffer =
torch::empty({indicePairMaxSize, numOutPlanes}, options); torch::empty({indicePairMaxSize, numOutPlanes}, options);
filters = filters.view({-1, numInPlanes, numOutPlanes}); filters = filters.view({-1, numInPlanes, numOutPlanes});
if (subM) { // the center index of subm conv don't need gather and scatter if (subM) { // the center index of subm conv don't need gather and scatter
// add. // add.
torch::mm_out(output, features, filters[indicePairMaxOffset]); torch::mm_out(output, features, filters[indicePairMaxOffset]);
...@@ -181,12 +186,13 @@ torch::Tensor indiceConv(torch::Tensor features, torch::Tensor filters, ...@@ -181,12 +186,13 @@ torch::Tensor indiceConv(torch::Tensor features, torch::Tensor filters,
double totalGatherTime = 0; double totalGatherTime = 0;
double totalGEMMTime = 0; double totalGEMMTime = 0;
double totalSAddTime = 0; double totalSAddTime = 0;
// tv::ssprint("first subm gemm time", timer.report() / 1000.0);
for (int i = 0; i < kernelVolume; ++i) { for (int i = 0; i < kernelVolume; ++i) {
auto nHot = indicePairNumCpu.data_ptr<int>()[i]; auto nHot = indicePairNumCpu.data_ptr<int>()[i];
if (nHot <= 0 || (subM && i == indicePairMaxOffset)) { if (nHot <= 0 || (subM && i == indicePairMaxOffset)) {
continue; continue;
} }
// auto timer = spconv::CudaContextTimer<>();
auto outputBufferBlob = torch::from_blob(outputBuffer.data_ptr(), auto outputBufferBlob = torch::from_blob(outputBuffer.data_ptr(),
{nHot, numOutPlanes}, options); {nHot, numOutPlanes}, options);
auto inputBufferBlob = auto inputBufferBlob =
...@@ -208,7 +214,10 @@ torch::Tensor indiceConv(torch::Tensor features, torch::Tensor filters, ...@@ -208,7 +214,10 @@ torch::Tensor indiceConv(torch::Tensor features, torch::Tensor filters,
else { else {
TV_THROW_INVALID_ARG("unknown device type"); TV_THROW_INVALID_ARG("unknown device type");
} }
// totalGatherTime += timer.report() / 1000.0;
torch::mm_out(outputBufferBlob, inputBufferBlob, filters[i]); torch::mm_out(outputBufferBlob, inputBufferBlob, filters[i]);
// totalGEMMTime += timer.report() / 1000.0;
if (device == torch::kCPU) { if (device == torch::kCPU) {
sparse_scatter_add_cpu(outputBuffer, output, indicePairs[!inverse][i], sparse_scatter_add_cpu(outputBuffer, output, indicePairs[!inverse][i],
nHot); nHot);
...@@ -222,14 +231,17 @@ torch::Tensor indiceConv(torch::Tensor features, torch::Tensor filters, ...@@ -222,14 +231,17 @@ torch::Tensor indiceConv(torch::Tensor features, torch::Tensor filters,
else { else {
TV_THROW_INVALID_ARG("unknown device type"); TV_THROW_INVALID_ARG("unknown device type");
} }
// totalSAddTime += timer.report() / 1000.0;
} }
// tv::ssprint(totalGatherTime, totalGEMMTime, totalSAddTime);
return output; return output;
} }
torch::Tensor indiceConvBatch(torch::Tensor features, torch::Tensor filters, torch::Tensor indiceConvBatch(torch::Tensor features, torch::Tensor filters,
torch::Tensor indicePairs, torch::Tensor indicePairs,
torch::Tensor indiceNum, int64_t numActOut, torch::Tensor indiceNum, int64_t numActOut,
int64_t _inverse, int64_t _subM) { int64_t _inverse, int64_t _subM,
bool batchScatter) {
bool subM = _subM != 0; bool subM = _subM != 0;
bool inverse = _inverse != 0; bool inverse = _inverse != 0;
auto device = features.device().type(); auto device = features.device().type();
...@@ -238,6 +250,7 @@ torch::Tensor indiceConvBatch(torch::Tensor features, torch::Tensor filters, ...@@ -238,6 +250,7 @@ torch::Tensor indiceConvBatch(torch::Tensor features, torch::Tensor filters,
TV_ASSERT_INVALID_ARG(kernelVolume > 1, "error"); TV_ASSERT_INVALID_ARG(kernelVolume > 1, "error");
auto numInPlanes = features.size(1); auto numInPlanes = features.size(1);
auto numOutPlanes = filters.size(ndim + 1); auto numOutPlanes = filters.size(ndim + 1);
// auto timer = spconv::CudaContextTimer<>();
auto indicePairNumCpu = indiceNum.to({torch::kCPU}); auto indicePairNumCpu = indiceNum.to({torch::kCPU});
auto indicePairNumVec = auto indicePairNumVec =
std::vector<int>(indicePairNumCpu.data_ptr<int>(), std::vector<int>(indicePairNumCpu.data_ptr<int>(),
...@@ -257,85 +270,98 @@ torch::Tensor indiceConvBatch(torch::Tensor features, torch::Tensor filters, ...@@ -257,85 +270,98 @@ torch::Tensor indiceConvBatch(torch::Tensor features, torch::Tensor filters,
// number of indice in the center of filter is much more than other // number of indice in the center of filter is much more than other
// filter location. // filter location.
// so we first use top2 indice num to do batch conv, then // so we first use top2 indice num to do batch conv, then
// do native conv in center. // do native conv (gemm) in center.
int bufferSize = subM ? indicePairTop2Size : indicePairMaxSize; int bufferSize = subM ? indicePairTop2Size : indicePairMaxSize;
torch::Tensor inputBuffer = int maxKernelVolumePart = kernelVolume;
torch::zeros({kernelVolume, bufferSize, numInPlanes}, options); std::vector<std::pair<int, int>> part_ranges = {{0, kernelVolume}};
torch::Tensor outputBuffer =
torch::empty({kernelVolume, bufferSize, numOutPlanes}, options);
filters = filters.view({kernelVolume, numInPlanes, numOutPlanes}); filters = filters.view({kernelVolume, numInPlanes, numOutPlanes});
int64_t size = kernelVolume * bufferSize;
if (device == torch::kCPU) {
TV_THROW_INVALID_ARG("unknown device type");
}
#ifdef TV_CUDA
else if (device == torch::kCUDA) {
batch_sparse_gather_cuda(inputBuffer, features, indicePairs[inverse], size);
}
#endif
else {
TV_THROW_INVALID_ARG("unknown device type");
}
torch::bmm_out(outputBuffer, inputBuffer, filters);
if (device == torch::kCPU) {
TV_THROW_INVALID_ARG("unknown device type");
}
#ifdef TV_CUDA
else if (device == torch::kCUDA) {
batch_sparse_scatter_add_cuda(outputBuffer, output, indicePairs[!inverse],
size);
}
#endif
else {
TV_THROW_INVALID_ARG("unknown device type");
}
if (subM) { if (subM) {
auto remain_size = indicePairMaxSize - indicePairTop2Size; maxKernelVolumePart = std::max(indicePairMaxOffset,
if (remain_size <= 0) { int(kernelVolume - indicePairMaxOffset - 1));
part_ranges = {{0, indicePairMaxOffset},
{indicePairMaxOffset + 1, kernelVolume}};
torch::mm_out(output, features, filters[indicePairMaxOffset]);
if (indicePairTop2Size == 0) {
return output; return output;
} }
inputBuffer = torch::empty({remain_size, numInPlanes}, options); }
outputBuffer = torch::empty({remain_size, numOutPlanes}, options); // tv::ssprint("first subm gemm time", timer.report() / 1000.0);
double totalGatherTime = 0;
double totalGEMMTime = 0;
double totalSAddTime = 0;
torch::Tensor inputBuffer =
torch::empty({maxKernelVolumePart, bufferSize, numInPlanes}, options);
torch::Tensor outputBuffer =
torch::empty({maxKernelVolumePart, bufferSize, numOutPlanes}, options);
for (auto &range : part_ranges) {
int start = range.first;
int end = range.second;
int length = end - start;
int64_t size = length * bufferSize;
auto inputBufferPart = tv::torch_slice_first_axis(inputBuffer, 0, length);
auto outputBufferPart = tv::torch_slice_first_axis(outputBuffer, 0, length);
auto indicePairs1Part =
tv::torch_slice_first_axis(indicePairs[inverse], start, end);
auto indicePairs2Part =
tv::torch_slice_first_axis(indicePairs[!inverse], start, end);
auto filtersPart = tv::torch_slice_first_axis(filters, start, end);
if (device == torch::kCPU) { if (device == torch::kCPU) {
TV_THROW_INVALID_ARG("unknown device type"); TV_THROW_INVALID_ARG("unknown device type");
} }
#ifdef TV_CUDA #ifdef TV_CUDA
else if (device == torch::kCUDA) { else if (device == torch::kCUDA) {
tv::dispatch_torch<int32_t, int64_t>(indice_dtype, [&](auto I) { batch_sparse_gather_cuda(inputBufferPart, features, indicePairs1Part,
using Index = decltype(I); size);
auto indicePairsRemain = torch::from_blob(
indicePairs[inverse][indicePairMaxOffset].data_ptr<Index>() +
indicePairTop2Size,
{remain_size}, indicePairs.options());
sparse_gather_cuda(inputBuffer, features, indicePairsRemain,
remain_size);
});
} }
#endif #endif
else { else {
TV_THROW_INVALID_ARG("unknown device type"); TV_THROW_INVALID_ARG("unknown device type");
} }
torch::mm_out(outputBuffer, inputBuffer, filters[indicePairMaxOffset]); // totalGatherTime += timer.report() / 1000.0;
if (device == torch::kCPU) {
TV_THROW_INVALID_ARG("unknown device type"); torch::bmm_out(outputBufferPart, inputBufferPart, filtersPart);
} // totalGEMMTime += timer.report() / 1000.0;
if (batchScatter) {
if (device == torch::kCPU) {
TV_THROW_INVALID_ARG("unknown device type");
}
#ifdef TV_CUDA #ifdef TV_CUDA
else if (device == torch::kCUDA) { else if (device == torch::kCUDA) {
tv::dispatch_torch<int32_t, int64_t>(indice_dtype, [&](auto I) { batch_sparse_scatter_add_cuda(outputBufferPart, output,
using Index = decltype(I); indicePairs2Part, size);
auto indicePairsRemain = torch::from_blob( }
indicePairs[!inverse][indicePairMaxOffset].data_ptr<Index>() +
indicePairTop2Size,
{remain_size}, indicePairs.options());
sparse_scatter_add_cuda(outputBuffer, output, indicePairsRemain,
remain_size);
});
}
#endif #endif
else { else {
TV_THROW_INVALID_ARG("unknown device type"); TV_THROW_INVALID_ARG("unknown device type");
}
} else {
for (int i = 0; i < length; ++i) {
auto nHot = indicePairNumCpu.data_ptr<int>()[i + start];
if (nHot <= 0) {
continue;
}
if (device == torch::kCPU) {
sparse_scatter_add_cpu(outputBufferPart[i], output,
indicePairs2Part[i], nHot);
}
#ifdef TV_CUDA
else if (device == torch::kCUDA) {
sparse_scatter_add_cuda(outputBufferPart[i], output,
indicePairs2Part[i], nHot);
}
#endif
else {
TV_THROW_INVALID_ARG("unknown device type");
}
}
} }
// totalSAddTime += timer.report() / 1000.0;
} }
// tv::ssprint(totalGatherTime, totalGEMMTime, totalSAddTime);
return output; return output;
} }
...@@ -346,10 +372,12 @@ indiceConvBackward(torch::Tensor features, torch::Tensor filters, ...@@ -346,10 +372,12 @@ indiceConvBackward(torch::Tensor features, torch::Tensor filters,
int64_t algo) { int64_t algo) {
auto kernelVolume = indiceNum.size(0); auto kernelVolume = indiceNum.size(0);
switch (algo) { switch (algo) {
case kBatchGemm: { case kBatchGemmGather:
case kBatch: {
if (kernelVolume != 1) { if (kernelVolume != 1) {
return indiceConvBackwardBatch(features, filters, outGrad, indicePairs, return indiceConvBackwardBatch(features, filters, outGrad, indicePairs,
indiceNum, _inverse, _subM); indiceNum, _inverse, _subM,
algo != kBatchGemmGather);
} else { } else {
break; break;
} }
...@@ -439,7 +467,7 @@ std::vector<torch::Tensor> ...@@ -439,7 +467,7 @@ std::vector<torch::Tensor>
indiceConvBackwardBatch(torch::Tensor features, torch::Tensor filters, indiceConvBackwardBatch(torch::Tensor features, torch::Tensor filters,
torch::Tensor outGrad, torch::Tensor indicePairs, torch::Tensor outGrad, torch::Tensor indicePairs,
torch::Tensor indiceNum, int64_t _inverse, torch::Tensor indiceNum, int64_t _inverse,
int64_t _subM) { int64_t _subM, bool batchScatter) {
bool subM = _subM != 0; bool subM = _subM != 0;
bool inverse = _inverse != 0; bool inverse = _inverse != 0;
...@@ -467,101 +495,99 @@ indiceConvBackwardBatch(torch::Tensor features, torch::Tensor filters, ...@@ -467,101 +495,99 @@ indiceConvBackwardBatch(torch::Tensor features, torch::Tensor filters,
torch::Tensor inputGrad = torch::zeros(features.sizes(), options); torch::Tensor inputGrad = torch::zeros(features.sizes(), options);
torch::Tensor filtersGrad = torch::zeros(filterShape, options); torch::Tensor filtersGrad = torch::zeros(filterShape, options);
int bufferSize = subM ? indicePairTop2Size : indicePairMaxSize; int bufferSize = subM ? indicePairTop2Size : indicePairMaxSize;
torch::Tensor inputBuffer =
torch::zeros({kernelVolume, bufferSize, numInPlanes}, options);
torch::Tensor outputBuffer =
torch::zeros({kernelVolume, bufferSize, numOutPlanes}, options);
filters = filters.view({-1, numInPlanes, numOutPlanes}); filters = filters.view({-1, numInPlanes, numOutPlanes});
filtersGrad = filtersGrad.view({-1, numInPlanes, numOutPlanes}); filtersGrad = filtersGrad.view({-1, numInPlanes, numOutPlanes});
int64_t size = kernelVolume * bufferSize;
if (device == torch::kCPU) {
TV_THROW_INVALID_ARG("unknown device type");
}
#ifdef TV_CUDA
else if (device == torch::kCUDA) {
batch_sparse_gather_cuda(inputBuffer, features, indicePairs[inverse], size);
batch_sparse_gather_cuda(outputBuffer, outGrad, indicePairs[!inverse],
size);
}
#endif
else {
TV_THROW_INVALID_ARG("unknown device type");
}
// filters: KV, I, O, inputBuffer: [KV, buffer, I]
// outputBuffer: [KV, buffer, O]
torch::bmm_out(filtersGrad, inputBuffer.permute({0, 2, 1}), outputBuffer);
torch::bmm_out(inputBuffer, outputBuffer, filters.permute({0, 2, 1}));
if (device == torch::kCPU) {
TV_THROW_INVALID_ARG("unknown device type");
}
#ifdef TV_CUDA
else if (device == torch::kCUDA) {
batch_sparse_scatter_add_cuda(inputBuffer, inputGrad, indicePairs[inverse],
size);
}
#endif
else {
TV_THROW_INVALID_ARG("unknown device type");
}
std::vector<std::pair<int, int>> part_ranges = {{0, kernelVolume}};
int maxKernelVolumePart = kernelVolume;
if (subM) { if (subM) {
auto remain_size = indicePairMaxSize - indicePairTop2Size; maxKernelVolumePart = std::max(indicePairMaxOffset,
if (remain_size <= 0) { int(kernelVolume - indicePairMaxOffset - 1));
part_ranges = {{0, indicePairMaxOffset},
{indicePairMaxOffset + 1, kernelVolume}};
auto filtersGradSub = filtersGrad[indicePairMaxOffset];
auto filtersSub = filters[indicePairMaxOffset];
torch::mm_out(filtersGradSub, features.t(), outGrad);
torch::mm_out(inputGrad, outGrad, filtersSub.t());
if (indicePairTop2Size == 0) {
return {inputGrad, filtersGrad.view(filterShape)}; return {inputGrad, filtersGrad.view(filterShape)};
} }
inputBuffer = torch::zeros({remain_size, numInPlanes}, options); }
outputBuffer = torch::zeros({remain_size, numOutPlanes}, options); torch::Tensor inputBuffer =
torch::zeros({maxKernelVolumePart, bufferSize, numInPlanes}, options);
torch::Tensor outputBuffer =
torch::zeros({maxKernelVolumePart, bufferSize, numOutPlanes}, options);
for (auto &range : part_ranges) {
int start = range.first;
int end = range.second;
int length = end - start;
int64_t size = length * bufferSize;
auto inputBufferPart = tv::torch_slice_first_axis(inputBuffer, 0, length);
auto outputBufferPart = tv::torch_slice_first_axis(outputBuffer, 0, length);
auto indicePairs1Part =
tv::torch_slice_first_axis(indicePairs[inverse], start, end);
auto indicePairs2Part =
tv::torch_slice_first_axis(indicePairs[!inverse], start, end);
auto filtersPart = tv::torch_slice_first_axis(filters, start, end);
auto filtersGradPart = tv::torch_slice_first_axis(filtersGrad, start, end);
if (device == torch::kCPU) { if (device == torch::kCPU) {
TV_THROW_INVALID_ARG("unknown device type"); TV_THROW_INVALID_ARG("unknown device type");
} }
#ifdef TV_CUDA #ifdef TV_CUDA
else if (device == torch::kCUDA) { else if (device == torch::kCUDA) {
tv::dispatch_torch<int32_t, int64_t>(indice_dtype, [&](auto I) { batch_sparse_gather_cuda(inputBufferPart, features, indicePairs1Part,
using Index = decltype(I); size);
auto indicePairsRemain = torch::from_blob( batch_sparse_gather_cuda(outputBufferPart, outGrad, indicePairs2Part,
indicePairs[inverse][indicePairMaxOffset].data_ptr<Index>() + size);
indicePairTop2Size,
{remain_size}, indicePairs.options());
auto indicePairsRemain2 = torch::from_blob(
indicePairs[!inverse][indicePairMaxOffset].data_ptr<Index>() +
indicePairTop2Size,
{remain_size}, indicePairs.options());
batch_sparse_gather_cuda(inputBuffer, features, indicePairsRemain,
remain_size);
batch_sparse_gather_cuda(outputBuffer, outGrad, indicePairsRemain2,
remain_size);
});
} }
#endif #endif
else { else {
TV_THROW_INVALID_ARG("unknown device type"); TV_THROW_INVALID_ARG("unknown device type");
} }
torch::mm_out(filtersGrad, inputBuffer.t(), outputBuffer); // filters: KV, I, O, inputBuffer: [KV, buffer, I]
torch::mm_out(inputBuffer, outputBuffer, filters[indicePairMaxOffset].t()); // outputBuffer: [KV, buffer, O]
if (device == torch::kCPU) { torch::bmm_out(filtersGradPart, inputBufferPart.permute({0, 2, 1}),
TV_THROW_INVALID_ARG("unknown device type"); outputBufferPart);
} torch::bmm_out(inputBuffer, outputBufferPart,
filtersPart.permute({0, 2, 1}));
if (batchScatter) {
if (device == torch::kCPU) {
TV_THROW_INVALID_ARG("unknown device type");
}
#ifdef TV_CUDA #ifdef TV_CUDA
else if (device == torch::kCUDA) { else if (device == torch::kCUDA) {
tv::dispatch_torch<int32_t, int64_t>(indice_dtype, [&](auto I) { batch_sparse_scatter_add_cuda(inputBufferPart, inputGrad,
using Index = decltype(I); indicePairs1Part, size);
auto indicePairsRemain2 = torch::from_blob( }
indicePairs[!inverse][indicePairMaxOffset].data_ptr<Index>() +
indicePairTop2Size,
{remain_size}, indicePairs.options());
batch_sparse_scatter_add_cuda(inputBuffer, inputGrad,
indicePairsRemain2, remain_size);
});
}
#endif #endif
else { else {
TV_THROW_INVALID_ARG("unknown device type"); TV_THROW_INVALID_ARG("unknown device type");
}
} else {
for (int i = 0; i < length; ++i) {
auto nHot = indicePairNumCpu.data_ptr<int>()[i + start];
if (nHot <= 0) {
continue;
}
if (device == torch::kCPU) {
sparse_scatter_add_cpu(inputBufferPart[i], inputGrad,
indicePairs1Part[i], nHot);
}
#ifdef TV_CUDA
else if (device == torch::kCUDA) {
sparse_scatter_add_cuda(inputBufferPart[i], inputGrad,
indicePairs1Part[i], nHot);
}
#endif
else {
TV_THROW_INVALID_ARG("unknown device type");
}
}
} }
} }
return {inputGrad, filtersGrad.view(filterShape)}; return {inputGrad, filtersGrad.view(filterShape)};
} }
......
...@@ -27,11 +27,18 @@ from spconv.test_utils import TestCase, generate_sparse_data, params_grid ...@@ -27,11 +27,18 @@ from spconv.test_utils import TestCase, generate_sparse_data, params_grid
class SparseConv3dTestTorch(nn.Module): class SparseConv3dTestTorch(nn.Module):
def __init__(self, num_layers, ndim, shape, in_channels, out_channels, def __init__(self,
kernel_size, stride, padding, dilation): num_layers,
ndim,
shape,
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
algo=spconv.ConvAlgo.BatchGemmGather):
super().__init__() super().__init__()
algo = spconv.ConvAlgo.BatchGemm
layers = [ layers = [
spconv.SparseConv3d(in_channels, spconv.SparseConv3d(in_channels,
out_channels, out_channels,
...@@ -67,8 +74,17 @@ class SparseConv3dTestTorch(nn.Module): ...@@ -67,8 +74,17 @@ class SparseConv3dTestTorch(nn.Module):
class SubMConv3dTestTorch(nn.Module): class SubMConv3dTestTorch(nn.Module):
def __init__(self, num_layers, ndim, shape, in_channels, out_channels, def __init__(self,
kernel_size, stride, padding, dilation, algo=spconv.ConvAlgo.Native): num_layers,
ndim,
shape,
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
algo=spconv.ConvAlgo.Native):
super().__init__() super().__init__()
layers = [ layers = [
spconv.SubMConv3d(in_channels, spconv.SubMConv3d(in_channels,
...@@ -89,14 +105,14 @@ class SubMConv3dTestTorch(nn.Module): ...@@ -89,14 +105,14 @@ class SubMConv3dTestTorch(nn.Module):
padding=padding, padding=padding,
dilation=dilation, dilation=dilation,
bias=False, bias=False,
algo=algo)) algo=algo))
self.net = spconv.SparseSequential(*layers, ) self.net = spconv.SparseSequential(*layers, )
# self.grid = torch.full([3, *shape], -1, dtype=torch.int32).cuda() # self.grid = torch.full([3, *shape], -1, dtype=torch.int32).cuda()
self.grid = None self.grid = None
self.shape = shape self.shape = shape
def forward(self, features, coors, batch_size): def forward(self, features, coors, batch_size):
coors = coors.int()# .cpu() coors = coors.int() # .cpu()
x = spconv.SparseConvTensor(features, coors, self.shape, batch_size, x = spconv.SparseConvTensor(features, coors, self.shape, batch_size,
self.grid) self.grid)
return self.net(x) # .dense() return self.net(x) # .dense()
...@@ -599,13 +615,13 @@ class TestSpConv(TestCase): ...@@ -599,13 +615,13 @@ class TestSpConv(TestCase):
self.assertAllClose(din_np, din_sparse_np, atol=1e-4) self.assertAllClose(din_np, din_sparse_np, atol=1e-4)
def main(): def main(algo=spconv.ConvAlgo.Native):
# function for develop. # function for develop.
np.random.seed(484) np.random.seed(484)
# devices = ["cuda:0"] # devices = ["cuda:0"]
devices = ["cuda:0"] devices = ["cuda:0"]
shapes = [[50, 30, 30]] shapes = [[400, 400, 15]]
batchsizes = [2] batchsizes = [1]
in_channels = [32] in_channels = [32]
out_channels = [64] out_channels = [64]
...@@ -620,7 +636,7 @@ def main(): ...@@ -620,7 +636,7 @@ def main():
if all([s > 1, d > 1]): if all([s > 1, d > 1]):
continue continue
device = torch.device(dev) device = torch.device(dev)
num_points = [500] * bs num_points = [30000] * bs
sparse_dict = generate_sparse_data(shape, num_points, IC) sparse_dict = generate_sparse_data(shape, num_points, IC)
...@@ -636,8 +652,8 @@ def main(): ...@@ -636,8 +652,8 @@ def main():
features_t = torch.from_numpy(features).to(device).float() features_t = torch.from_numpy(features).to(device).float()
features_dense_t = torch.from_numpy(features_dense).to(device).float() features_dense_t = torch.from_numpy(features_dense).to(device).float()
net = SparseConv3dTestTorch(1, 3, shape, IC, OC, k, s, p, net = SparseConv3dTestTorch(1, 3, shape, IC, OC, k, s, p, d,
d).to(device).float() algo=algo).to(device).float()
net_ref = Conv3dTestTorch(1, 3, shape, IC, OC, k, s, p, net_ref = Conv3dTestTorch(1, 3, shape, IC, OC, k, s, p,
d).to(device).float() d).to(device).float()
filters_t = torch.from_numpy(filters).to(device).float() filters_t = torch.from_numpy(filters).to(device).float()
...@@ -662,7 +678,8 @@ def main(): ...@@ -662,7 +678,8 @@ def main():
print( print(
np.linalg.norm(out.detach().cpu().numpy() - np.linalg.norm(out.detach().cpu().numpy() -
out_ref.detach().cpu().numpy())) out_ref.detach().cpu().numpy()))
print(out_numpy.min(), out_numpy.max(), out_numpy.mean(), out_numpy.sum()) print(out_numpy.min(), out_numpy.max(), out_numpy.mean(),
out_numpy.sum())
def main_subm(algo): def main_subm(algo):
...@@ -671,7 +688,7 @@ def main_subm(algo): ...@@ -671,7 +688,7 @@ def main_subm(algo):
torch.manual_seed(50051) torch.manual_seed(50051)
# devices = ["cuda:0"] # devices = ["cuda:0"]
devices = ["cuda:0"] devices = ["cuda:0"]
shapes = [[50, 30, 30]] shapes = [[400, 400, 15]]
batchsizes = [2] batchsizes = [2]
in_channels = [32] in_channels = [32]
...@@ -686,7 +703,7 @@ def main_subm(algo): ...@@ -686,7 +703,7 @@ def main_subm(algo):
if all([s > 1, d > 1]): if all([s > 1, d > 1]):
continue continue
device = torch.device(dev) device = torch.device(dev)
num_points = [1000] * bs num_points = [240000] * bs
sparse_dict = generate_sparse_data(shape, num_points, IC) sparse_dict = generate_sparse_data(shape, num_points, IC)
...@@ -702,8 +719,8 @@ def main_subm(algo): ...@@ -702,8 +719,8 @@ def main_subm(algo):
features_t = torch.from_numpy(features).to(device).float() features_t = torch.from_numpy(features).to(device).float()
features_dense_t = torch.from_numpy(features_dense).to(device).float() features_dense_t = torch.from_numpy(features_dense).to(device).float()
net = SubMConv3dTestTorch(1, 3, shape, IC, OC, k, s, p, net = SubMConv3dTestTorch(1, 3, shape, IC, OC, k, s, p, d,
d, algo=algo).to(device).float() algo=algo).to(device).float()
net_ref = Conv3dTestTorch(1, 3, shape, IC, OC, k, s, p, net_ref = Conv3dTestTorch(1, 3, shape, IC, OC, k, s, p,
d).to(device).float() d).to(device).float()
filters_t = torch.from_numpy(filters).to(device).float() filters_t = torch.from_numpy(filters).to(device).float()
...@@ -712,7 +729,7 @@ def main_subm(algo): ...@@ -712,7 +729,7 @@ def main_subm(algo):
net.net[0].weight[:] = filters_t net.net[0].weight[:] = filters_t
out_ref = net_ref(features_dense_t) out_ref = net_ref(features_dense_t)
times = [] times = []
for i in range(100): for i in range(20):
t = time.time() t = time.time()
out = net(features_t, indices_t, bs) out = net(features_t, indices_t, bs)
torch.cuda.synchronize() torch.cuda.synchronize()
...@@ -727,11 +744,13 @@ def main_subm(algo): ...@@ -727,11 +744,13 @@ def main_subm(algo):
print( print(
np.linalg.norm(out.detach().cpu().numpy() - np.linalg.norm(out.detach().cpu().numpy() -
out_ref.detach().cpu().numpy())) out_ref.detach().cpu().numpy()))
print(out_numpy.min(), out_numpy.max(), out_numpy.mean(), out_numpy.sum()) print(out_numpy.min(), out_numpy.max(), out_numpy.mean(),
out_numpy.sum())
return out_numpy return out_numpy
if __name__ == '__main__': if __name__ == '__main__':
# out_my = main_subm(algo=spconv.ConvAlgo.BatchGemm) # main_subm(algo=spconv.ConvAlgo.BatchGemmGather)
# out_ref = main_subm(algo=spconv.ConvAlgo.Native) # out_ref = main_subm(algo=spconv.ConvAlgo.Native)
# TestCase().assertAllClose(out_my, out_ref) # TestCase().assertAllClose(out_my, out_ref)
# unittest.main() # unittest.main()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment