Commit 105b3892 authored by Yan Yan's avatar Yan Yan
Browse files

batch indice conv alpha release

parent f08268fc
......@@ -290,8 +290,6 @@ __global__ void batchScatterAddGenericKernel(T *outFeatures, const T *buffer,
if (ix + ILPStrideX[ilp] < size && inds[ilp] != -1) {
gpuAtomicAdd(outFeatures + inds[ilp] * numPlanes + iy,
buffer[(ix + ILPStrideX[ilp]) * numPlanes + iy]);
// outFeatures[inds[ilp] * numPlanes + iy] +=
// buffer[(ix + ILPStrideX[ilp]) * numPlanes + iy];
}
}
}
......
......@@ -22,6 +22,12 @@
#include <utility/timer.h>
namespace spconv {
enum ConvAlgo {
kNative = 0,
kBatchGemm = 1
};
// torch.jit's doc says only support int64, so we need to convert to int32.
template <unsigned NDim>
std::vector<torch::Tensor>
......@@ -344,12 +350,18 @@ torch::Tensor indiceConvBatch(torch::Tensor features, torch::Tensor filters,
torch::Tensor indiceConv(torch::Tensor features, torch::Tensor filters,
torch::Tensor indicePairs, torch::Tensor indiceNum,
int64_t numActOut, int64_t _inverse, int64_t _subM);
int64_t numActOut, int64_t _inverse, int64_t _subM,
int64_t algo);
std::vector<torch::Tensor>
indiceConvBackward(torch::Tensor features, torch::Tensor filters,
torch::Tensor outGrad, torch::Tensor indicePairs,
torch::Tensor indiceNum, int64_t _inverse, int64_t _subM);
torch::Tensor indiceNum, int64_t _inverse, int64_t _subM, int64_t algo);
std::vector<torch::Tensor>
indiceConvBackwardBatch(torch::Tensor features, torch::Tensor filters,
torch::Tensor outGrad, torch::Tensor indicePairs,
torch::Tensor indiceNum, int64_t _inverse,
int64_t _subM);
} // namespace spconv
#endif
\ No newline at end of file
......@@ -88,12 +88,12 @@ class CMakeBuild(build_ext):
packages = find_packages(exclude=('tools', 'tools.*'))
setup(
name='spconv',
version='1.1',
version='1.2',
author='Yan Yan',
author_email='scrin@foxmail.com',
description='spatial sparse convolution for pytorch',
long_description='',
setup_requires = ['torch>=1.0.0'],
setup_requires = ['torch>=1.3.0'],
packages=packages,
package_dir = {'spconv': 'spconv'},
ext_modules=[CMakeExtension('spconv', library_dirs=[])],
......
......@@ -19,6 +19,7 @@ import numpy as np
import torch
from spconv import ops, utils
from spconv.ops import ConvAlgo
from spconv.conv import (SparseConv2d, SparseConv3d, SparseConvTranspose2d,
SparseConvTranspose3d, SparseInverseConv2d,
SparseInverseConv3d, SubMConv2d, SubMConv3d)
......@@ -81,7 +82,7 @@ class SparseConvTensor(object):
def dense(self, channels_first=True):
output_shape = [self.batch_size] + list(
self.spatial_shape) + [self.features.shape[1]]
res = scatter_nd(self.indices.long(), self.features, output_shape)
res = scatter_nd(self.indices.long().to(self.features.device), self.features, output_shape)
if not channels_first:
return res
ndim = len(self.spatial_shape)
......
......@@ -71,7 +71,8 @@ class SparseConvolution(SparseModule):
inverse=False,
indice_key=None,
fused_bn=False,
use_hash=False):
use_hash=False,
algo=ops.ConvAlgo.Native):
super(SparseConvolution, self).__init__()
assert groups == 1
if not isinstance(kernel_size, (list, tuple)):
......@@ -104,6 +105,7 @@ class SparseConvolution(SparseModule):
self.indice_key = indice_key
self.fused_bn = fused_bn
self.use_hash = use_hash
self.algo = algo.value
self.weight = Parameter(
torch.Tensor(*kernel_size, in_channels, out_channels))
......@@ -194,17 +196,17 @@ class SparseConvolution(SparseModule):
out_features = Fsp.indice_subm_conv(features, self.weight,
indice_pairs.to(device),
indice_pair_num,
outids.shape[0])
outids.shape[0], self.algo)
else:
if self.inverse:
out_features = Fsp.indice_inverse_conv(
features, self.weight, indice_pairs.to(device),
indice_pair_num, outids.shape[0])
indice_pair_num, outids.shape[0], self.algo)
else:
out_features = Fsp.indice_conv(features, self.weight,
indice_pairs.to(device),
indice_pair_num,
outids.shape[0])
outids.shape[0], self.algo)
if self.bias is not None:
out_features += self.bias
......@@ -226,7 +228,8 @@ class SparseConv2d(SparseConvolution):
groups=1,
bias=True,
indice_key=None,
use_hash=False):
use_hash=False,
algo=ops.ConvAlgo.Native):
super(SparseConv2d, self).__init__(2,
in_channels,
out_channels,
......@@ -237,7 +240,8 @@ class SparseConv2d(SparseConvolution):
groups,
bias,
indice_key=indice_key,
use_hash=use_hash)
use_hash=use_hash,
algo=algo)
class SparseConv3d(SparseConvolution):
......@@ -251,7 +255,8 @@ class SparseConv3d(SparseConvolution):
groups=1,
bias=True,
indice_key=None,
use_hash=False):
use_hash=False,
algo=ops.ConvAlgo.Native):
super(SparseConv3d, self).__init__(3,
in_channels,
out_channels,
......@@ -262,7 +267,8 @@ class SparseConv3d(SparseConvolution):
groups,
bias,
indice_key=indice_key,
use_hash=use_hash)
use_hash=use_hash,
algo=algo)
class SparseConv4d(SparseConvolution):
......@@ -276,7 +282,8 @@ class SparseConv4d(SparseConvolution):
groups=1,
bias=True,
indice_key=None,
use_hash=False):
use_hash=False,
algo=ops.ConvAlgo.Native):
super(SparseConv4d, self).__init__(4,
in_channels,
out_channels,
......@@ -287,7 +294,8 @@ class SparseConv4d(SparseConvolution):
groups,
bias,
indice_key=indice_key,
use_hash=use_hash)
use_hash=use_hash,
algo=algo)
class SparseConvTranspose2d(SparseConvolution):
......@@ -301,7 +309,8 @@ class SparseConvTranspose2d(SparseConvolution):
groups=1,
bias=True,
indice_key=None,
use_hash=False):
use_hash=False,
algo=ops.ConvAlgo.Native):
super(SparseConvTranspose2d, self).__init__(2,
in_channels,
out_channels,
......@@ -313,7 +322,8 @@ class SparseConvTranspose2d(SparseConvolution):
bias,
transposed=True,
indice_key=indice_key,
use_hash=use_hash)
use_hash=use_hash,
algo=algo)
class SparseConvTranspose3d(SparseConvolution):
......@@ -327,7 +337,8 @@ class SparseConvTranspose3d(SparseConvolution):
groups=1,
bias=True,
indice_key=None,
use_hash=False):
use_hash=False,
algo=ops.ConvAlgo.Native):
super(SparseConvTranspose3d, self).__init__(3,
in_channels,
out_channels,
......@@ -339,7 +350,8 @@ class SparseConvTranspose3d(SparseConvolution):
bias,
transposed=True,
indice_key=indice_key,
use_hash=use_hash)
use_hash=use_hash,
algo=algo)
class SparseInverseConv2d(SparseConvolution):
......@@ -348,14 +360,16 @@ class SparseInverseConv2d(SparseConvolution):
out_channels,
kernel_size,
indice_key,
bias=True):
bias=True,
algo=ops.ConvAlgo.Native):
super(SparseInverseConv2d, self).__init__(2,
in_channels,
out_channels,
kernel_size,
bias=bias,
inverse=True,
indice_key=indice_key)
indice_key=indice_key,
algo=algo)
class SparseInverseConv3d(SparseConvolution):
......@@ -364,14 +378,16 @@ class SparseInverseConv3d(SparseConvolution):
out_channels,
kernel_size,
indice_key,
bias=True):
bias=True,
algo=ops.ConvAlgo.Native):
super(SparseInverseConv3d, self).__init__(3,
in_channels,
out_channels,
kernel_size,
bias=bias,
inverse=True,
indice_key=indice_key)
indice_key=indice_key,
algo=algo)
class SubMConv2d(SparseConvolution):
......@@ -385,7 +401,8 @@ class SubMConv2d(SparseConvolution):
groups=1,
bias=True,
indice_key=None,
use_hash=False):
use_hash=False,
algo=ops.ConvAlgo.Native):
super(SubMConv2d, self).__init__(2,
in_channels,
out_channels,
......@@ -397,7 +414,8 @@ class SubMConv2d(SparseConvolution):
bias,
True,
indice_key=indice_key,
use_hash=use_hash)
use_hash=use_hash,
algo=algo)
class SubMConv3d(SparseConvolution):
......@@ -411,7 +429,8 @@ class SubMConv3d(SparseConvolution):
groups=1,
bias=True,
indice_key=None,
use_hash=False):
use_hash=False,
algo=ops.ConvAlgo.Native):
super(SubMConv3d, self).__init__(3,
in_channels,
out_channels,
......@@ -423,7 +442,8 @@ class SubMConv3d(SparseConvolution):
bias,
True,
indice_key=indice_key,
use_hash=use_hash)
use_hash=use_hash,
algo=algo)
class SubMConv4d(SparseConvolution):
......@@ -437,7 +457,8 @@ class SubMConv4d(SparseConvolution):
groups=1,
bias=True,
indice_key=None,
use_hash=False):
use_hash=False,
algo=ops.ConvAlgo.Native):
super(SubMConv4d, self).__init__(4,
in_channels,
out_channels,
......@@ -449,4 +470,5 @@ class SubMConv4d(SparseConvolution):
bias,
True,
indice_key=indice_key,
use_hash=use_hash)
use_hash=use_hash,
algo=algo)
......@@ -22,55 +22,59 @@ import spconv.ops as ops
class SparseConvFunction(Function):
@staticmethod
def forward(ctx, features, filters, indice_pairs, indice_pair_num,
num_activate_out):
num_activate_out, algo):
ctx.save_for_backward(indice_pairs, indice_pair_num, features, filters)
ctx.algo = algo
return ops.indice_conv(features, filters, indice_pairs,
indice_pair_num, num_activate_out, False)
indice_pair_num, num_activate_out, False, algo=algo)
@staticmethod
def backward(ctx, grad_output):
indice_pairs, indice_pair_num, features, filters = ctx.saved_tensors
input_bp, filters_bp = ops.indice_conv_backward(
features, filters, grad_output, indice_pairs, indice_pair_num,
False)
False, algo=ctx.algo)
return input_bp, filters_bp, None, None, None
return input_bp, filters_bp, None, None, None, None
class SparseInverseConvFunction(Function):
@staticmethod
def forward(ctx, features, filters, indice_pairs, indice_pair_num,
num_activate_out):
num_activate_out, algo):
ctx.save_for_backward(indice_pairs, indice_pair_num, features, filters)
ctx.algo = algo
return ops.indice_conv(features, filters, indice_pairs,
indice_pair_num, num_activate_out, True, False)
indice_pair_num, num_activate_out, True, False, algo=algo)
@staticmethod
def backward(ctx, grad_output):
indice_pairs, indice_pair_num, features, filters = ctx.saved_tensors
input_bp, filters_bp = ops.indice_conv_backward(
features, filters, grad_output, indice_pairs, indice_pair_num,
True, False)
True, False, algo=ctx.algo)
return input_bp, filters_bp, None, None, None
return input_bp, filters_bp, None, None, None, None
class SubMConvFunction(Function):
@staticmethod
def forward(ctx, features, filters, indice_pairs, indice_pair_num,
num_activate_out):
num_activate_out, algo):
ctx.save_for_backward(indice_pairs, indice_pair_num, features, filters)
ctx.algo = algo
return ops.indice_conv(features, filters, indice_pairs,
indice_pair_num, num_activate_out, False, True)
indice_pair_num, num_activate_out, False, True, algo=algo)
@staticmethod
def backward(ctx, grad_output):
indice_pairs, indice_pair_num, features, filters = ctx.saved_tensors
input_bp, filters_bp = ops.indice_conv_backward(
features, filters, grad_output, indice_pairs, indice_pair_num,
False, True)
False, True, algo=ctx.algo)
return input_bp, filters_bp, None, None, None
return input_bp, filters_bp, None, None, None, None
class SparseMaxPoolFunction(Function):
......
......@@ -16,6 +16,11 @@ import torch
import spconv
from enum import Enum
class ConvAlgo(Enum):
Native = 0
BatchGemm = 1
def get_conv_output_size(input_size, kernel_size, stride, padding, dilation):
ndim = len(input_size)
......@@ -106,10 +111,11 @@ def indice_conv(features,
indice_pair_num,
num_activate_out,
inverse=False,
subm=False):
return torch.ops.spconv.indice_conv_batch(features, filters, indice_pairs,
subm=False,
algo=ConvAlgo.Native.value):
return torch.ops.spconv.indice_conv(features, filters, indice_pairs,
indice_pair_num, num_activate_out,
int(inverse), int(subm))
int(inverse), int(subm), algo)
def fused_indice_conv(features, filters, bias, indice_pairs, indice_pair_num,
......@@ -126,10 +132,11 @@ def indice_conv_backward(features,
indice_pairs,
indice_pair_num,
inverse=False,
subm=False):
subm=False,
algo=ConvAlgo.Native.value):
return torch.ops.spconv.indice_conv_backward(features, filters, out_bp,
indice_pairs, indice_pair_num,
int(inverse), int(subm))
int(inverse), int(subm), algo)
def indice_maxpool(features, indice_pairs, indice_pair_num, num_activate_out):
......
......@@ -48,6 +48,7 @@ void sparse_gather_cuda(torch::Tensor buffer, torch::Tensor features,
using Index = decltype(IndexValue);
bool notFound = true;
constexpr int vecloadFactor = sizeof(vecload_type_t) / sizeof(T);
tv::mp_for_each<kernel_block_t>(
[=, &buffer, &features, &indices, &notFound](auto NumTLP) {
constexpr int NumILP = NumTLP / 4;
......@@ -80,7 +81,6 @@ void sparse_gather_cuda(torch::Tensor buffer, torch::Tensor features,
}
}
});
if (notFound) {
constexpr int NumTLP = 64;
constexpr int NumILP = NumTLP / 4;
......@@ -115,6 +115,7 @@ void sparse_scatter_add_cuda(torch::Tensor buffer, torch::Tensor outFeatures,
bool notFound = true;
constexpr int vecloadFactor =
sizeof(vecload_type_t) / sizeof(T); // important for half.
tv::mp_for_each<kernel_block_t>([=, &outFeatures, &buffer, &indices,
&notFound](auto NumTLP) {
// constexpr int NumILP = NumTLP / (64 / (NumTLP /
......@@ -183,7 +184,6 @@ void batch_sparse_gather_cuda(torch::Tensor buffer, torch::Tensor features,
using Index = decltype(IndexValue);
bool notFound = true;
constexpr int vecloadFactor = sizeof(vecload_type_t) / sizeof(T);
tv::mp_for_each<kernel_block_t>(
[=, &buffer, &features, &indices, &notFound](auto NumTLP) {
constexpr int NumILP = NumTLP / 4;
......@@ -204,13 +204,12 @@ void batch_sparse_gather_cuda(torch::Tensor buffer, torch::Tensor features,
TV_CHECK_CUDA_ERR();
}
if (size - nHotBlock > 0) {
auto indices_offset = (nHotBlock / feature_stride) * inds_stride + nHotBlock % feature_stride;
batchGatherVecKernel<T, Index, int(NumTLP), NumILP, vecload_type_t>
<<<dim3(1, numPlanes / NumTLP),
dim3(NumTLP / NumILP, NumTLP / vecloadFactor), 0,
stream>>>(buffer.data_ptr<T>() + nHotBlock * numPlanes,
features.data_ptr<T>(),
indices.data_ptr<Index>() + indices_offset,
indices.data_ptr<Index>(),
size - nHotBlock, nHotBlock, numPlanes / vecloadFactor,
inds_stride, feature_stride);
TV_CHECK_CUDA_ERR();
......@@ -280,7 +279,6 @@ void batch_sparse_scatter_add_cuda(torch::Tensor buffer,
TV_CHECK_CUDA_ERR();
}
if (size - nHotBlock > 0) {
// int indices_offset = (nHotBlock / feature_stride) * inds_stride + nHotBlock % feature_stride;
batchScatterAddGenericKernel<T, Index, int(NumTLP), NumILP>
<<<dim3(1, numPlanes / NumTLP), dim3(NumTLP / NumILP, NumTLP),
0, stream>>>(outFeatures.data_ptr<T>(),
......
......@@ -83,7 +83,7 @@ getIndicePairV2(torch::Tensor indices, int64_t batchSize,
}
#endif
else {
TV_ASSERT_INVALID_ARG(false, "unknown device type");
TV_THROW_INVALID_ARG("unknown device type");
}
return {indices, indicePairs, indiceNum};
} else {
......@@ -127,7 +127,7 @@ getIndicePairV2(torch::Tensor indices, int64_t batchSize,
}
#endif
else {
TV_ASSERT_INVALID_ARG(false, "unknown device type");
TV_THROW_INVALID_ARG("unknown device type");
}
return {outInds.slice(0, 0, numActOut), indicePairs, indiceNum};
}
......@@ -135,12 +135,27 @@ getIndicePairV2(torch::Tensor indices, int64_t batchSize,
torch::Tensor indiceConv(torch::Tensor features, torch::Tensor filters,
torch::Tensor indicePairs, torch::Tensor indiceNum,
int64_t numActOut, int64_t _inverse, int64_t _subM) {
int64_t numActOut, int64_t _inverse, int64_t _subM,
int64_t algo) {
auto kernelVolume = indiceNum.size(0);
switch (algo) {
case kBatchGemm: {
if (kernelVolume != 1) {
return indiceConvBatch(features, filters, indicePairs, indiceNum,
numActOut, _inverse, _subM);
} else {
break;
}
}
case kNative:
break;
default:
TV_THROW_RT_ERR("unknown algo");
}
bool subM = _subM != 0;
bool inverse = _inverse != 0;
auto device = features.device().type();
auto ndim = filters.dim() - 2;
auto kernelVolume = indiceNum.size(0);
auto numInPlanes = features.size(1);
auto numOutPlanes = filters.size(ndim + 1);
auto indicePairNumCpu = indiceNum.to({torch::kCPU});
......@@ -157,7 +172,7 @@ torch::Tensor indiceConv(torch::Tensor features, torch::Tensor filters,
torch::Tensor inputBuffer =
torch::zeros({indicePairMaxSize, numInPlanes}, options);
torch::Tensor outputBuffer =
torch::zeros({indicePairMaxSize, numOutPlanes}, options);
torch::empty({indicePairMaxSize, numOutPlanes}, options);
filters = filters.view({-1, numInPlanes, numOutPlanes});
if (subM) { // the center index of subm conv don't need gather and scatter
// add.
......@@ -191,7 +206,7 @@ torch::Tensor indiceConv(torch::Tensor features, torch::Tensor filters,
}
#endif
else {
TV_ASSERT_INVALID_ARG(false, "unknown device type");
TV_THROW_INVALID_ARG("unknown device type");
}
torch::mm_out(outputBufferBlob, inputBufferBlob, filters[i]);
if (device == torch::kCPU) {
......@@ -205,7 +220,7 @@ torch::Tensor indiceConv(torch::Tensor features, torch::Tensor filters,
}
#endif
else {
TV_ASSERT_INVALID_ARG(false, "unknown device type");
TV_THROW_INVALID_ARG("unknown device type");
}
}
return output;
......@@ -220,33 +235,38 @@ torch::Tensor indiceConvBatch(torch::Tensor features, torch::Tensor filters,
auto device = features.device().type();
auto ndim = filters.dim() - 2;
auto kernelVolume = indiceNum.size(0);
TV_ASSERT_INVALID_ARG(kernelVolume > 1, "error");
auto numInPlanes = features.size(1);
auto numOutPlanes = filters.size(ndim + 1);
auto indicePairNumCpu = indiceNum.to({torch::kCPU});
auto indicePairMaxSizeIter =
std::max_element(indicePairNumCpu.data_ptr<int>(),
auto indicePairNumVec =
std::vector<int>(indicePairNumCpu.data_ptr<int>(),
indicePairNumCpu.data_ptr<int>() + kernelVolume);
int indicePairMaxOffset =
indicePairMaxSizeIter - indicePairNumCpu.data_ptr<int>();
auto indicePairMaxSizeIter =
std::max_element(indicePairNumVec.begin(), indicePairNumVec.end());
int indicePairMaxOffset = indicePairMaxSizeIter - indicePairNumVec.begin();
int indicePairMaxSize = *indicePairMaxSizeIter;
std::nth_element(indicePairNumVec.begin(), indicePairNumVec.begin() + 1,
indicePairNumVec.end(), std::greater<int>());
int indicePairTop2Size = indicePairNumVec[1];
auto options =
torch::TensorOptions().dtype(features.dtype()).device(features.device());
auto indice_dtype = indicePairs.scalar_type();
torch::Tensor output = torch::zeros({numActOut, numOutPlanes}, options);
// we cant use batch conv in subm directly because
// number of indice in the center of filter is much more than other
// filter location.
// so we first use top2 indice num to do batch conv, then
// do native conv in center.
int bufferSize = subM ? indicePairTop2Size : indicePairMaxSize;
torch::Tensor inputBuffer =
torch::zeros({kernelVolume, indicePairMaxSize, numInPlanes}, options);
torch::zeros({kernelVolume, bufferSize, numInPlanes}, options);
torch::Tensor outputBuffer =
torch::zeros({kernelVolume, indicePairMaxSize, numOutPlanes}, options);
torch::empty({kernelVolume, bufferSize, numOutPlanes}, options);
filters = filters.view({kernelVolume, numInPlanes, numOutPlanes});
if (subM) { // the center index of subm conv don't need gather and scatter
// add.
torch::mm_out(output, features, filters[indicePairMaxOffset]);
}
double totalGatherTime = 0;
double totalGEMMTime = 0;
double totalSAddTime = 0;
auto size = kernelVolume * indicePairMaxSize;
int64_t size = kernelVolume * bufferSize;
if (device == torch::kCPU) {
TV_ASSERT_INVALID_ARG(false, "unknown device type");
TV_THROW_INVALID_ARG("unknown device type");
}
#ifdef TV_CUDA
else if (device == torch::kCUDA) {
......@@ -254,11 +274,11 @@ torch::Tensor indiceConvBatch(torch::Tensor features, torch::Tensor filters,
}
#endif
else {
TV_ASSERT_INVALID_ARG(false, "unknown device type");
TV_THROW_INVALID_ARG("unknown device type");
}
torch::bmm_out(outputBuffer, inputBuffer, filters);
if (device == torch::kCPU) {
TV_ASSERT_INVALID_ARG(false, "unknown device type");
TV_THROW_INVALID_ARG("unknown device type");
}
#ifdef TV_CUDA
else if (device == torch::kCUDA) {
......@@ -267,7 +287,54 @@ torch::Tensor indiceConvBatch(torch::Tensor features, torch::Tensor filters,
}
#endif
else {
TV_ASSERT_INVALID_ARG(false, "unknown device type");
TV_THROW_INVALID_ARG("unknown device type");
}
if (subM) {
auto remain_size = indicePairMaxSize - indicePairTop2Size;
if (remain_size <= 0) {
return output;
}
inputBuffer = torch::empty({remain_size, numInPlanes}, options);
outputBuffer = torch::empty({remain_size, numOutPlanes}, options);
if (device == torch::kCPU) {
TV_THROW_INVALID_ARG("unknown device type");
}
#ifdef TV_CUDA
else if (device == torch::kCUDA) {
tv::dispatch_torch<int32_t, int64_t>(indice_dtype, [&](auto I) {
using Index = decltype(I);
auto indicePairsRemain = torch::from_blob(
indicePairs[inverse][indicePairMaxOffset].data_ptr<Index>() +
indicePairTop2Size,
{remain_size}, indicePairs.options());
sparse_gather_cuda(inputBuffer, features, indicePairsRemain,
remain_size);
});
}
#endif
else {
TV_THROW_INVALID_ARG("unknown device type");
}
torch::mm_out(outputBuffer, inputBuffer, filters[indicePairMaxOffset]);
if (device == torch::kCPU) {
TV_THROW_INVALID_ARG("unknown device type");
}
#ifdef TV_CUDA
else if (device == torch::kCUDA) {
tv::dispatch_torch<int32_t, int64_t>(indice_dtype, [&](auto I) {
using Index = decltype(I);
auto indicePairsRemain = torch::from_blob(
indicePairs[!inverse][indicePairMaxOffset].data_ptr<Index>() +
indicePairTop2Size,
{remain_size}, indicePairs.options());
sparse_scatter_add_cuda(outputBuffer, output, indicePairsRemain,
remain_size);
});
}
#endif
else {
TV_THROW_INVALID_ARG("unknown device type");
}
}
return output;
}
......@@ -275,13 +342,29 @@ torch::Tensor indiceConvBatch(torch::Tensor features, torch::Tensor filters,
std::vector<torch::Tensor>
indiceConvBackward(torch::Tensor features, torch::Tensor filters,
torch::Tensor outGrad, torch::Tensor indicePairs,
torch::Tensor indiceNum, int64_t _inverse, int64_t _subM) {
torch::Tensor indiceNum, int64_t _inverse, int64_t _subM,
int64_t algo) {
auto kernelVolume = indiceNum.size(0);
switch (algo) {
case kBatchGemm: {
if (kernelVolume != 1) {
return indiceConvBackwardBatch(features, filters, outGrad, indicePairs,
indiceNum, _inverse, _subM);
} else {
break;
}
}
case kNative:
break;
default:
TV_THROW_RT_ERR("unknown algo");
}
bool subM = _subM != 0;
bool inverse = _inverse != 0;
auto device = features.device().type();
auto ndim = filters.dim() - 2;
auto kernelVolume = indiceNum.size(0);
auto numInPlanes = features.size(1);
auto numOutPlanes = filters.size(ndim + 1);
auto indicePairNumCpu = indiceNum.to({torch::kCPU});
......@@ -324,7 +407,7 @@ indiceConvBackward(torch::Tensor features, torch::Tensor filters,
}
#endif
else {
TV_ASSERT_INVALID_ARG(false, "unknown device type");
TV_THROW_INVALID_ARG("unknown device type");
}
auto filterGradSub = filtersGrad[i];
......@@ -346,9 +429,139 @@ indiceConvBackward(torch::Tensor features, torch::Tensor filters,
}
#endif
else {
TV_ASSERT_INVALID_ARG(false, "unknown device type");
TV_THROW_INVALID_ARG("unknown device type");
}
}
return {inputGrad, filtersGrad.view(filterShape)};
}
std::vector<torch::Tensor>
indiceConvBackwardBatch(torch::Tensor features, torch::Tensor filters,
torch::Tensor outGrad, torch::Tensor indicePairs,
torch::Tensor indiceNum, int64_t _inverse,
int64_t _subM) {
bool subM = _subM != 0;
bool inverse = _inverse != 0;
auto device = features.device().type();
auto ndim = filters.dim() - 2;
auto kernelVolume = indiceNum.size(0);
TV_ASSERT_INVALID_ARG(kernelVolume > 1, "error");
auto numInPlanes = features.size(1);
auto numOutPlanes = filters.size(ndim + 1);
auto indicePairNumCpu = indiceNum.to({torch::kCPU});
auto indicePairNumVec =
std::vector<int>(indicePairNumCpu.data_ptr<int>(),
indicePairNumCpu.data_ptr<int>() + kernelVolume);
auto indicePairMaxSizeIter =
std::max_element(indicePairNumVec.begin(), indicePairNumVec.end());
int indicePairMaxOffset = indicePairMaxSizeIter - indicePairNumVec.begin();
int indicePairMaxSize = *indicePairMaxSizeIter;
std::nth_element(indicePairNumVec.begin(), indicePairNumVec.begin() + 1,
indicePairNumVec.end(), std::greater<int>());
int indicePairTop2Size = indicePairNumVec[1];
auto options =
torch::TensorOptions().dtype(features.dtype()).device(features.device());
auto indice_dtype = indicePairs.scalar_type();
auto filterShape = filters.sizes();
torch::Tensor inputGrad = torch::zeros(features.sizes(), options);
torch::Tensor filtersGrad = torch::zeros(filterShape, options);
int bufferSize = subM ? indicePairTop2Size : indicePairMaxSize;
torch::Tensor inputBuffer =
torch::zeros({kernelVolume, bufferSize, numInPlanes}, options);
torch::Tensor outputBuffer =
torch::zeros({kernelVolume, bufferSize, numOutPlanes}, options);
filters = filters.view({-1, numInPlanes, numOutPlanes});
filtersGrad = filtersGrad.view({-1, numInPlanes, numOutPlanes});
int64_t size = kernelVolume * bufferSize;
if (device == torch::kCPU) {
TV_THROW_INVALID_ARG("unknown device type");
}
#ifdef TV_CUDA
else if (device == torch::kCUDA) {
batch_sparse_gather_cuda(inputBuffer, features, indicePairs[inverse], size);
batch_sparse_gather_cuda(outputBuffer, outGrad, indicePairs[!inverse],
size);
}
#endif
else {
TV_THROW_INVALID_ARG("unknown device type");
}
// filters: KV, I, O, inputBuffer: [KV, buffer, I]
// outputBuffer: [KV, buffer, O]
torch::bmm_out(filtersGrad, inputBuffer.permute({0, 2, 1}), outputBuffer);
torch::bmm_out(inputBuffer, outputBuffer, filters.permute({0, 2, 1}));
if (device == torch::kCPU) {
TV_THROW_INVALID_ARG("unknown device type");
}
#ifdef TV_CUDA
else if (device == torch::kCUDA) {
batch_sparse_scatter_add_cuda(inputBuffer, inputGrad, indicePairs[inverse],
size);
}
#endif
else {
TV_THROW_INVALID_ARG("unknown device type");
}
if (subM) {
auto remain_size = indicePairMaxSize - indicePairTop2Size;
if (remain_size <= 0) {
return {inputGrad, filtersGrad.view(filterShape)};
}
inputBuffer = torch::zeros({remain_size, numInPlanes}, options);
outputBuffer = torch::zeros({remain_size, numOutPlanes}, options);
if (device == torch::kCPU) {
TV_THROW_INVALID_ARG("unknown device type");
}
#ifdef TV_CUDA
else if (device == torch::kCUDA) {
tv::dispatch_torch<int32_t, int64_t>(indice_dtype, [&](auto I) {
using Index = decltype(I);
auto indicePairsRemain = torch::from_blob(
indicePairs[inverse][indicePairMaxOffset].data_ptr<Index>() +
indicePairTop2Size,
{remain_size}, indicePairs.options());
auto indicePairsRemain2 = torch::from_blob(
indicePairs[!inverse][indicePairMaxOffset].data_ptr<Index>() +
indicePairTop2Size,
{remain_size}, indicePairs.options());
batch_sparse_gather_cuda(inputBuffer, features, indicePairsRemain,
remain_size);
batch_sparse_gather_cuda(outputBuffer, outGrad, indicePairsRemain2,
remain_size);
});
}
#endif
else {
TV_THROW_INVALID_ARG("unknown device type");
}
torch::mm_out(filtersGrad, inputBuffer.t(), outputBuffer);
torch::mm_out(inputBuffer, outputBuffer, filters[indicePairMaxOffset].t());
if (device == torch::kCPU) {
TV_THROW_INVALID_ARG("unknown device type");
}
#ifdef TV_CUDA
else if (device == torch::kCUDA) {
tv::dispatch_torch<int32_t, int64_t>(indice_dtype, [&](auto I) {
using Index = decltype(I);
auto indicePairsRemain2 = torch::from_blob(
indicePairs[!inverse][indicePairMaxOffset].data_ptr<Index>() +
indicePairTop2Size,
{remain_size}, indicePairs.options());
batch_sparse_scatter_add_cuda(inputBuffer, inputGrad,
indicePairsRemain2, remain_size);
});
}
#endif
else {
TV_THROW_INVALID_ARG("unknown device type");
}
}
return {inputGrad, filtersGrad.view(filterShape)};
}
......
......@@ -30,6 +30,8 @@ class SparseConv3dTestTorch(nn.Module):
def __init__(self, num_layers, ndim, shape, in_channels, out_channels,
kernel_size, stride, padding, dilation):
super().__init__()
algo = spconv.ConvAlgo.BatchGemm
layers = [
spconv.SparseConv3d(in_channels,
out_channels,
......@@ -38,7 +40,8 @@ class SparseConv3dTestTorch(nn.Module):
padding=padding,
dilation=dilation,
bias=False,
use_hash=False)
use_hash=False,
algo=algo)
]
for i in range(1, num_layers):
layers.append(
......@@ -49,7 +52,8 @@ class SparseConv3dTestTorch(nn.Module):
padding=padding,
dilation=dilation,
bias=False,
use_hash=False))
use_hash=False,
algo=algo))
self.net = spconv.SparseSequential(*layers, )
# self.grid = torch.full([3, *shape], -1, dtype=torch.int32).cuda()
self.grid = None
......@@ -64,7 +68,7 @@ class SparseConv3dTestTorch(nn.Module):
class SubMConv3dTestTorch(nn.Module):
def __init__(self, num_layers, ndim, shape, in_channels, out_channels,
kernel_size, stride, padding, dilation):
kernel_size, stride, padding, dilation, algo=spconv.ConvAlgo.Native):
super().__init__()
layers = [
spconv.SubMConv3d(in_channels,
......@@ -73,7 +77,8 @@ class SubMConv3dTestTorch(nn.Module):
stride,
padding=padding,
dilation=dilation,
bias=False)
bias=False,
algo=algo)
]
for i in range(1, num_layers):
layers.append(
......@@ -83,14 +88,15 @@ class SubMConv3dTestTorch(nn.Module):
stride,
padding=padding,
dilation=dilation,
bias=False))
bias=False,
algo=algo))
self.net = spconv.SparseSequential(*layers, )
# self.grid = torch.full([3, *shape], -1, dtype=torch.int32).cuda()
self.grid = None
self.shape = shape
def forward(self, features, coors, batch_size):
coors = coors.int()
coors = coors.int()# .cpu()
x = spconv.SparseConvTensor(features, coors, self.shape, batch_size,
self.grid)
return self.net(x) # .dense()
......@@ -327,7 +333,7 @@ def scatter_nd(indices, updates, shape):
class TestSpConv(TestCase):
def testSpConv3d(self):
np.random.seed(484)
devices = ["cuda:0", "cpu:0"]
devices = ["cuda:0"]
shapes = [[19, 18, 17]]
batchsizes = [1, 2]
......@@ -651,14 +657,18 @@ def main():
out = net(features_t, indices_t, bs)
# print(out.indices)
out = out.dense()
out_numpy = out.detach().cpu().numpy()
print(
np.linalg.norm(out.detach().cpu().numpy() -
out_ref.detach().cpu().numpy()))
print(out_numpy.min(), out_numpy.max(), out_numpy.mean(), out_numpy.sum())
def main_subm():
def main_subm(algo):
# function for develop.
np.random.seed(484)
torch.manual_seed(50051)
# devices = ["cuda:0"]
devices = ["cuda:0"]
shapes = [[50, 30, 30]]
......@@ -670,14 +680,13 @@ def main_subm():
strides = [1]
paddings = [1]
dilations = [1]
for dev, shape, bs, IC, OC, k, s, p, d in params_grid(
devices, shapes, batchsizes, in_channels, out_channels, ksizes,
strides, paddings, dilations):
if all([s > 1, d > 1]):
continue
device = torch.device(dev)
num_points = [500] * bs
num_points = [1000] * bs
sparse_dict = generate_sparse_data(shape, num_points, IC)
......@@ -694,7 +703,7 @@ def main_subm():
features_dense_t = torch.from_numpy(features_dense).to(device).float()
net = SubMConv3dTestTorch(1, 3, shape, IC, OC, k, s, p,
d).to(device).float()
d, algo=algo).to(device).float()
net_ref = Conv3dTestTorch(1, 3, shape, IC, OC, k, s, p,
d).to(device).float()
filters_t = torch.from_numpy(filters).to(device).float()
......@@ -714,12 +723,16 @@ def main_subm():
out = net(features_t, indices_t, bs)
# print(out.indices)
out = out.dense()
out_numpy = out.detach().cpu().numpy()
print(
np.linalg.norm(out.detach().cpu().numpy() -
out_ref.detach().cpu().numpy()))
print(out_numpy.min(), out_numpy.max(), out_numpy.mean(), out_numpy.sum())
return out_numpy
if __name__ == '__main__':
main()
# out_my = main_subm(algo=spconv.ConvAlgo.BatchGemm)
# out_ref = main_subm(algo=spconv.ConvAlgo.Native)
# TestCase().assertAllClose(out_my, out_ref)
# unittest.main()
# TestSpConv().testSpConv3d()
TestSpConv().testSpConv3d()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment