Commit 105b3892 authored by Yan Yan's avatar Yan Yan
Browse files

batch indice conv alpha release

parent f08268fc
...@@ -290,8 +290,6 @@ __global__ void batchScatterAddGenericKernel(T *outFeatures, const T *buffer, ...@@ -290,8 +290,6 @@ __global__ void batchScatterAddGenericKernel(T *outFeatures, const T *buffer,
if (ix + ILPStrideX[ilp] < size && inds[ilp] != -1) { if (ix + ILPStrideX[ilp] < size && inds[ilp] != -1) {
gpuAtomicAdd(outFeatures + inds[ilp] * numPlanes + iy, gpuAtomicAdd(outFeatures + inds[ilp] * numPlanes + iy,
buffer[(ix + ILPStrideX[ilp]) * numPlanes + iy]); buffer[(ix + ILPStrideX[ilp]) * numPlanes + iy]);
// outFeatures[inds[ilp] * numPlanes + iy] +=
// buffer[(ix + ILPStrideX[ilp]) * numPlanes + iy];
} }
} }
} }
......
...@@ -22,6 +22,12 @@ ...@@ -22,6 +22,12 @@
#include <utility/timer.h> #include <utility/timer.h>
namespace spconv { namespace spconv {
enum ConvAlgo {
kNative = 0,
kBatchGemm = 1
};
// torch.jit's doc says only support int64, so we need to convert to int32. // torch.jit's doc says only support int64, so we need to convert to int32.
template <unsigned NDim> template <unsigned NDim>
std::vector<torch::Tensor> std::vector<torch::Tensor>
...@@ -344,12 +350,18 @@ torch::Tensor indiceConvBatch(torch::Tensor features, torch::Tensor filters, ...@@ -344,12 +350,18 @@ torch::Tensor indiceConvBatch(torch::Tensor features, torch::Tensor filters,
torch::Tensor indiceConv(torch::Tensor features, torch::Tensor filters, torch::Tensor indiceConv(torch::Tensor features, torch::Tensor filters,
torch::Tensor indicePairs, torch::Tensor indiceNum, torch::Tensor indicePairs, torch::Tensor indiceNum,
int64_t numActOut, int64_t _inverse, int64_t _subM); int64_t numActOut, int64_t _inverse, int64_t _subM,
int64_t algo);
std::vector<torch::Tensor> std::vector<torch::Tensor>
indiceConvBackward(torch::Tensor features, torch::Tensor filters, indiceConvBackward(torch::Tensor features, torch::Tensor filters,
torch::Tensor outGrad, torch::Tensor indicePairs, torch::Tensor outGrad, torch::Tensor indicePairs,
torch::Tensor indiceNum, int64_t _inverse, int64_t _subM); torch::Tensor indiceNum, int64_t _inverse, int64_t _subM, int64_t algo);
std::vector<torch::Tensor>
indiceConvBackwardBatch(torch::Tensor features, torch::Tensor filters,
torch::Tensor outGrad, torch::Tensor indicePairs,
torch::Tensor indiceNum, int64_t _inverse,
int64_t _subM);
} // namespace spconv } // namespace spconv
#endif #endif
\ No newline at end of file
...@@ -88,12 +88,12 @@ class CMakeBuild(build_ext): ...@@ -88,12 +88,12 @@ class CMakeBuild(build_ext):
packages = find_packages(exclude=('tools', 'tools.*')) packages = find_packages(exclude=('tools', 'tools.*'))
setup( setup(
name='spconv', name='spconv',
version='1.1', version='1.2',
author='Yan Yan', author='Yan Yan',
author_email='scrin@foxmail.com', author_email='scrin@foxmail.com',
description='spatial sparse convolution for pytorch', description='spatial sparse convolution for pytorch',
long_description='', long_description='',
setup_requires = ['torch>=1.0.0'], setup_requires = ['torch>=1.3.0'],
packages=packages, packages=packages,
package_dir = {'spconv': 'spconv'}, package_dir = {'spconv': 'spconv'},
ext_modules=[CMakeExtension('spconv', library_dirs=[])], ext_modules=[CMakeExtension('spconv', library_dirs=[])],
......
...@@ -19,6 +19,7 @@ import numpy as np ...@@ -19,6 +19,7 @@ import numpy as np
import torch import torch
from spconv import ops, utils from spconv import ops, utils
from spconv.ops import ConvAlgo
from spconv.conv import (SparseConv2d, SparseConv3d, SparseConvTranspose2d, from spconv.conv import (SparseConv2d, SparseConv3d, SparseConvTranspose2d,
SparseConvTranspose3d, SparseInverseConv2d, SparseConvTranspose3d, SparseInverseConv2d,
SparseInverseConv3d, SubMConv2d, SubMConv3d) SparseInverseConv3d, SubMConv2d, SubMConv3d)
...@@ -81,7 +82,7 @@ class SparseConvTensor(object): ...@@ -81,7 +82,7 @@ class SparseConvTensor(object):
def dense(self, channels_first=True): def dense(self, channels_first=True):
output_shape = [self.batch_size] + list( output_shape = [self.batch_size] + list(
self.spatial_shape) + [self.features.shape[1]] self.spatial_shape) + [self.features.shape[1]]
res = scatter_nd(self.indices.long(), self.features, output_shape) res = scatter_nd(self.indices.long().to(self.features.device), self.features, output_shape)
if not channels_first: if not channels_first:
return res return res
ndim = len(self.spatial_shape) ndim = len(self.spatial_shape)
......
...@@ -71,7 +71,8 @@ class SparseConvolution(SparseModule): ...@@ -71,7 +71,8 @@ class SparseConvolution(SparseModule):
inverse=False, inverse=False,
indice_key=None, indice_key=None,
fused_bn=False, fused_bn=False,
use_hash=False): use_hash=False,
algo=ops.ConvAlgo.Native):
super(SparseConvolution, self).__init__() super(SparseConvolution, self).__init__()
assert groups == 1 assert groups == 1
if not isinstance(kernel_size, (list, tuple)): if not isinstance(kernel_size, (list, tuple)):
...@@ -104,6 +105,7 @@ class SparseConvolution(SparseModule): ...@@ -104,6 +105,7 @@ class SparseConvolution(SparseModule):
self.indice_key = indice_key self.indice_key = indice_key
self.fused_bn = fused_bn self.fused_bn = fused_bn
self.use_hash = use_hash self.use_hash = use_hash
self.algo = algo.value
self.weight = Parameter( self.weight = Parameter(
torch.Tensor(*kernel_size, in_channels, out_channels)) torch.Tensor(*kernel_size, in_channels, out_channels))
...@@ -194,17 +196,17 @@ class SparseConvolution(SparseModule): ...@@ -194,17 +196,17 @@ class SparseConvolution(SparseModule):
out_features = Fsp.indice_subm_conv(features, self.weight, out_features = Fsp.indice_subm_conv(features, self.weight,
indice_pairs.to(device), indice_pairs.to(device),
indice_pair_num, indice_pair_num,
outids.shape[0]) outids.shape[0], self.algo)
else: else:
if self.inverse: if self.inverse:
out_features = Fsp.indice_inverse_conv( out_features = Fsp.indice_inverse_conv(
features, self.weight, indice_pairs.to(device), features, self.weight, indice_pairs.to(device),
indice_pair_num, outids.shape[0]) indice_pair_num, outids.shape[0], self.algo)
else: else:
out_features = Fsp.indice_conv(features, self.weight, out_features = Fsp.indice_conv(features, self.weight,
indice_pairs.to(device), indice_pairs.to(device),
indice_pair_num, indice_pair_num,
outids.shape[0]) outids.shape[0], self.algo)
if self.bias is not None: if self.bias is not None:
out_features += self.bias out_features += self.bias
...@@ -226,7 +228,8 @@ class SparseConv2d(SparseConvolution): ...@@ -226,7 +228,8 @@ class SparseConv2d(SparseConvolution):
groups=1, groups=1,
bias=True, bias=True,
indice_key=None, indice_key=None,
use_hash=False): use_hash=False,
algo=ops.ConvAlgo.Native):
super(SparseConv2d, self).__init__(2, super(SparseConv2d, self).__init__(2,
in_channels, in_channels,
out_channels, out_channels,
...@@ -237,7 +240,8 @@ class SparseConv2d(SparseConvolution): ...@@ -237,7 +240,8 @@ class SparseConv2d(SparseConvolution):
groups, groups,
bias, bias,
indice_key=indice_key, indice_key=indice_key,
use_hash=use_hash) use_hash=use_hash,
algo=algo)
class SparseConv3d(SparseConvolution): class SparseConv3d(SparseConvolution):
...@@ -251,7 +255,8 @@ class SparseConv3d(SparseConvolution): ...@@ -251,7 +255,8 @@ class SparseConv3d(SparseConvolution):
groups=1, groups=1,
bias=True, bias=True,
indice_key=None, indice_key=None,
use_hash=False): use_hash=False,
algo=ops.ConvAlgo.Native):
super(SparseConv3d, self).__init__(3, super(SparseConv3d, self).__init__(3,
in_channels, in_channels,
out_channels, out_channels,
...@@ -262,7 +267,8 @@ class SparseConv3d(SparseConvolution): ...@@ -262,7 +267,8 @@ class SparseConv3d(SparseConvolution):
groups, groups,
bias, bias,
indice_key=indice_key, indice_key=indice_key,
use_hash=use_hash) use_hash=use_hash,
algo=algo)
class SparseConv4d(SparseConvolution): class SparseConv4d(SparseConvolution):
...@@ -276,7 +282,8 @@ class SparseConv4d(SparseConvolution): ...@@ -276,7 +282,8 @@ class SparseConv4d(SparseConvolution):
groups=1, groups=1,
bias=True, bias=True,
indice_key=None, indice_key=None,
use_hash=False): use_hash=False,
algo=ops.ConvAlgo.Native):
super(SparseConv4d, self).__init__(4, super(SparseConv4d, self).__init__(4,
in_channels, in_channels,
out_channels, out_channels,
...@@ -287,7 +294,8 @@ class SparseConv4d(SparseConvolution): ...@@ -287,7 +294,8 @@ class SparseConv4d(SparseConvolution):
groups, groups,
bias, bias,
indice_key=indice_key, indice_key=indice_key,
use_hash=use_hash) use_hash=use_hash,
algo=algo)
class SparseConvTranspose2d(SparseConvolution): class SparseConvTranspose2d(SparseConvolution):
...@@ -301,7 +309,8 @@ class SparseConvTranspose2d(SparseConvolution): ...@@ -301,7 +309,8 @@ class SparseConvTranspose2d(SparseConvolution):
groups=1, groups=1,
bias=True, bias=True,
indice_key=None, indice_key=None,
use_hash=False): use_hash=False,
algo=ops.ConvAlgo.Native):
super(SparseConvTranspose2d, self).__init__(2, super(SparseConvTranspose2d, self).__init__(2,
in_channels, in_channels,
out_channels, out_channels,
...@@ -313,7 +322,8 @@ class SparseConvTranspose2d(SparseConvolution): ...@@ -313,7 +322,8 @@ class SparseConvTranspose2d(SparseConvolution):
bias, bias,
transposed=True, transposed=True,
indice_key=indice_key, indice_key=indice_key,
use_hash=use_hash) use_hash=use_hash,
algo=algo)
class SparseConvTranspose3d(SparseConvolution): class SparseConvTranspose3d(SparseConvolution):
...@@ -327,7 +337,8 @@ class SparseConvTranspose3d(SparseConvolution): ...@@ -327,7 +337,8 @@ class SparseConvTranspose3d(SparseConvolution):
groups=1, groups=1,
bias=True, bias=True,
indice_key=None, indice_key=None,
use_hash=False): use_hash=False,
algo=ops.ConvAlgo.Native):
super(SparseConvTranspose3d, self).__init__(3, super(SparseConvTranspose3d, self).__init__(3,
in_channels, in_channels,
out_channels, out_channels,
...@@ -339,7 +350,8 @@ class SparseConvTranspose3d(SparseConvolution): ...@@ -339,7 +350,8 @@ class SparseConvTranspose3d(SparseConvolution):
bias, bias,
transposed=True, transposed=True,
indice_key=indice_key, indice_key=indice_key,
use_hash=use_hash) use_hash=use_hash,
algo=algo)
class SparseInverseConv2d(SparseConvolution): class SparseInverseConv2d(SparseConvolution):
...@@ -348,14 +360,16 @@ class SparseInverseConv2d(SparseConvolution): ...@@ -348,14 +360,16 @@ class SparseInverseConv2d(SparseConvolution):
out_channels, out_channels,
kernel_size, kernel_size,
indice_key, indice_key,
bias=True): bias=True,
algo=ops.ConvAlgo.Native):
super(SparseInverseConv2d, self).__init__(2, super(SparseInverseConv2d, self).__init__(2,
in_channels, in_channels,
out_channels, out_channels,
kernel_size, kernel_size,
bias=bias, bias=bias,
inverse=True, inverse=True,
indice_key=indice_key) indice_key=indice_key,
algo=algo)
class SparseInverseConv3d(SparseConvolution): class SparseInverseConv3d(SparseConvolution):
...@@ -364,14 +378,16 @@ class SparseInverseConv3d(SparseConvolution): ...@@ -364,14 +378,16 @@ class SparseInverseConv3d(SparseConvolution):
out_channels, out_channels,
kernel_size, kernel_size,
indice_key, indice_key,
bias=True): bias=True,
algo=ops.ConvAlgo.Native):
super(SparseInverseConv3d, self).__init__(3, super(SparseInverseConv3d, self).__init__(3,
in_channels, in_channels,
out_channels, out_channels,
kernel_size, kernel_size,
bias=bias, bias=bias,
inverse=True, inverse=True,
indice_key=indice_key) indice_key=indice_key,
algo=algo)
class SubMConv2d(SparseConvolution): class SubMConv2d(SparseConvolution):
...@@ -385,7 +401,8 @@ class SubMConv2d(SparseConvolution): ...@@ -385,7 +401,8 @@ class SubMConv2d(SparseConvolution):
groups=1, groups=1,
bias=True, bias=True,
indice_key=None, indice_key=None,
use_hash=False): use_hash=False,
algo=ops.ConvAlgo.Native):
super(SubMConv2d, self).__init__(2, super(SubMConv2d, self).__init__(2,
in_channels, in_channels,
out_channels, out_channels,
...@@ -397,7 +414,8 @@ class SubMConv2d(SparseConvolution): ...@@ -397,7 +414,8 @@ class SubMConv2d(SparseConvolution):
bias, bias,
True, True,
indice_key=indice_key, indice_key=indice_key,
use_hash=use_hash) use_hash=use_hash,
algo=algo)
class SubMConv3d(SparseConvolution): class SubMConv3d(SparseConvolution):
...@@ -411,7 +429,8 @@ class SubMConv3d(SparseConvolution): ...@@ -411,7 +429,8 @@ class SubMConv3d(SparseConvolution):
groups=1, groups=1,
bias=True, bias=True,
indice_key=None, indice_key=None,
use_hash=False): use_hash=False,
algo=ops.ConvAlgo.Native):
super(SubMConv3d, self).__init__(3, super(SubMConv3d, self).__init__(3,
in_channels, in_channels,
out_channels, out_channels,
...@@ -423,7 +442,8 @@ class SubMConv3d(SparseConvolution): ...@@ -423,7 +442,8 @@ class SubMConv3d(SparseConvolution):
bias, bias,
True, True,
indice_key=indice_key, indice_key=indice_key,
use_hash=use_hash) use_hash=use_hash,
algo=algo)
class SubMConv4d(SparseConvolution): class SubMConv4d(SparseConvolution):
...@@ -437,7 +457,8 @@ class SubMConv4d(SparseConvolution): ...@@ -437,7 +457,8 @@ class SubMConv4d(SparseConvolution):
groups=1, groups=1,
bias=True, bias=True,
indice_key=None, indice_key=None,
use_hash=False): use_hash=False,
algo=ops.ConvAlgo.Native):
super(SubMConv4d, self).__init__(4, super(SubMConv4d, self).__init__(4,
in_channels, in_channels,
out_channels, out_channels,
...@@ -449,4 +470,5 @@ class SubMConv4d(SparseConvolution): ...@@ -449,4 +470,5 @@ class SubMConv4d(SparseConvolution):
bias, bias,
True, True,
indice_key=indice_key, indice_key=indice_key,
use_hash=use_hash) use_hash=use_hash,
algo=algo)
...@@ -22,55 +22,59 @@ import spconv.ops as ops ...@@ -22,55 +22,59 @@ import spconv.ops as ops
class SparseConvFunction(Function): class SparseConvFunction(Function):
@staticmethod @staticmethod
def forward(ctx, features, filters, indice_pairs, indice_pair_num, def forward(ctx, features, filters, indice_pairs, indice_pair_num,
num_activate_out): num_activate_out, algo):
ctx.save_for_backward(indice_pairs, indice_pair_num, features, filters) ctx.save_for_backward(indice_pairs, indice_pair_num, features, filters)
ctx.algo = algo
return ops.indice_conv(features, filters, indice_pairs, return ops.indice_conv(features, filters, indice_pairs,
indice_pair_num, num_activate_out, False) indice_pair_num, num_activate_out, False, algo=algo)
@staticmethod @staticmethod
def backward(ctx, grad_output): def backward(ctx, grad_output):
indice_pairs, indice_pair_num, features, filters = ctx.saved_tensors indice_pairs, indice_pair_num, features, filters = ctx.saved_tensors
input_bp, filters_bp = ops.indice_conv_backward( input_bp, filters_bp = ops.indice_conv_backward(
features, filters, grad_output, indice_pairs, indice_pair_num, features, filters, grad_output, indice_pairs, indice_pair_num,
False) False, algo=ctx.algo)
return input_bp, filters_bp, None, None, None return input_bp, filters_bp, None, None, None, None
class SparseInverseConvFunction(Function): class SparseInverseConvFunction(Function):
@staticmethod @staticmethod
def forward(ctx, features, filters, indice_pairs, indice_pair_num, def forward(ctx, features, filters, indice_pairs, indice_pair_num,
num_activate_out): num_activate_out, algo):
ctx.save_for_backward(indice_pairs, indice_pair_num, features, filters) ctx.save_for_backward(indice_pairs, indice_pair_num, features, filters)
ctx.algo = algo
return ops.indice_conv(features, filters, indice_pairs, return ops.indice_conv(features, filters, indice_pairs,
indice_pair_num, num_activate_out, True, False) indice_pair_num, num_activate_out, True, False, algo=algo)
@staticmethod @staticmethod
def backward(ctx, grad_output): def backward(ctx, grad_output):
indice_pairs, indice_pair_num, features, filters = ctx.saved_tensors indice_pairs, indice_pair_num, features, filters = ctx.saved_tensors
input_bp, filters_bp = ops.indice_conv_backward( input_bp, filters_bp = ops.indice_conv_backward(
features, filters, grad_output, indice_pairs, indice_pair_num, features, filters, grad_output, indice_pairs, indice_pair_num,
True, False) True, False, algo=ctx.algo)
return input_bp, filters_bp, None, None, None return input_bp, filters_bp, None, None, None, None
class SubMConvFunction(Function): class SubMConvFunction(Function):
@staticmethod @staticmethod
def forward(ctx, features, filters, indice_pairs, indice_pair_num, def forward(ctx, features, filters, indice_pairs, indice_pair_num,
num_activate_out): num_activate_out, algo):
ctx.save_for_backward(indice_pairs, indice_pair_num, features, filters) ctx.save_for_backward(indice_pairs, indice_pair_num, features, filters)
ctx.algo = algo
return ops.indice_conv(features, filters, indice_pairs, return ops.indice_conv(features, filters, indice_pairs,
indice_pair_num, num_activate_out, False, True) indice_pair_num, num_activate_out, False, True, algo=algo)
@staticmethod @staticmethod
def backward(ctx, grad_output): def backward(ctx, grad_output):
indice_pairs, indice_pair_num, features, filters = ctx.saved_tensors indice_pairs, indice_pair_num, features, filters = ctx.saved_tensors
input_bp, filters_bp = ops.indice_conv_backward( input_bp, filters_bp = ops.indice_conv_backward(
features, filters, grad_output, indice_pairs, indice_pair_num, features, filters, grad_output, indice_pairs, indice_pair_num,
False, True) False, True, algo=ctx.algo)
return input_bp, filters_bp, None, None, None return input_bp, filters_bp, None, None, None, None
class SparseMaxPoolFunction(Function): class SparseMaxPoolFunction(Function):
......
...@@ -16,6 +16,11 @@ import torch ...@@ -16,6 +16,11 @@ import torch
import spconv import spconv
from enum import Enum
class ConvAlgo(Enum):
Native = 0
BatchGemm = 1
def get_conv_output_size(input_size, kernel_size, stride, padding, dilation): def get_conv_output_size(input_size, kernel_size, stride, padding, dilation):
ndim = len(input_size) ndim = len(input_size)
...@@ -106,10 +111,11 @@ def indice_conv(features, ...@@ -106,10 +111,11 @@ def indice_conv(features,
indice_pair_num, indice_pair_num,
num_activate_out, num_activate_out,
inverse=False, inverse=False,
subm=False): subm=False,
return torch.ops.spconv.indice_conv_batch(features, filters, indice_pairs, algo=ConvAlgo.Native.value):
return torch.ops.spconv.indice_conv(features, filters, indice_pairs,
indice_pair_num, num_activate_out, indice_pair_num, num_activate_out,
int(inverse), int(subm)) int(inverse), int(subm), algo)
def fused_indice_conv(features, filters, bias, indice_pairs, indice_pair_num, def fused_indice_conv(features, filters, bias, indice_pairs, indice_pair_num,
...@@ -126,10 +132,11 @@ def indice_conv_backward(features, ...@@ -126,10 +132,11 @@ def indice_conv_backward(features,
indice_pairs, indice_pairs,
indice_pair_num, indice_pair_num,
inverse=False, inverse=False,
subm=False): subm=False,
algo=ConvAlgo.Native.value):
return torch.ops.spconv.indice_conv_backward(features, filters, out_bp, return torch.ops.spconv.indice_conv_backward(features, filters, out_bp,
indice_pairs, indice_pair_num, indice_pairs, indice_pair_num,
int(inverse), int(subm)) int(inverse), int(subm), algo)
def indice_maxpool(features, indice_pairs, indice_pair_num, num_activate_out): def indice_maxpool(features, indice_pairs, indice_pair_num, num_activate_out):
......
...@@ -48,6 +48,7 @@ void sparse_gather_cuda(torch::Tensor buffer, torch::Tensor features, ...@@ -48,6 +48,7 @@ void sparse_gather_cuda(torch::Tensor buffer, torch::Tensor features,
using Index = decltype(IndexValue); using Index = decltype(IndexValue);
bool notFound = true; bool notFound = true;
constexpr int vecloadFactor = sizeof(vecload_type_t) / sizeof(T); constexpr int vecloadFactor = sizeof(vecload_type_t) / sizeof(T);
tv::mp_for_each<kernel_block_t>( tv::mp_for_each<kernel_block_t>(
[=, &buffer, &features, &indices, &notFound](auto NumTLP) { [=, &buffer, &features, &indices, &notFound](auto NumTLP) {
constexpr int NumILP = NumTLP / 4; constexpr int NumILP = NumTLP / 4;
...@@ -80,7 +81,6 @@ void sparse_gather_cuda(torch::Tensor buffer, torch::Tensor features, ...@@ -80,7 +81,6 @@ void sparse_gather_cuda(torch::Tensor buffer, torch::Tensor features,
} }
} }
}); });
if (notFound) { if (notFound) {
constexpr int NumTLP = 64; constexpr int NumTLP = 64;
constexpr int NumILP = NumTLP / 4; constexpr int NumILP = NumTLP / 4;
...@@ -115,6 +115,7 @@ void sparse_scatter_add_cuda(torch::Tensor buffer, torch::Tensor outFeatures, ...@@ -115,6 +115,7 @@ void sparse_scatter_add_cuda(torch::Tensor buffer, torch::Tensor outFeatures,
bool notFound = true; bool notFound = true;
constexpr int vecloadFactor = constexpr int vecloadFactor =
sizeof(vecload_type_t) / sizeof(T); // important for half. sizeof(vecload_type_t) / sizeof(T); // important for half.
tv::mp_for_each<kernel_block_t>([=, &outFeatures, &buffer, &indices, tv::mp_for_each<kernel_block_t>([=, &outFeatures, &buffer, &indices,
&notFound](auto NumTLP) { &notFound](auto NumTLP) {
// constexpr int NumILP = NumTLP / (64 / (NumTLP / // constexpr int NumILP = NumTLP / (64 / (NumTLP /
...@@ -183,7 +184,6 @@ void batch_sparse_gather_cuda(torch::Tensor buffer, torch::Tensor features, ...@@ -183,7 +184,6 @@ void batch_sparse_gather_cuda(torch::Tensor buffer, torch::Tensor features,
using Index = decltype(IndexValue); using Index = decltype(IndexValue);
bool notFound = true; bool notFound = true;
constexpr int vecloadFactor = sizeof(vecload_type_t) / sizeof(T); constexpr int vecloadFactor = sizeof(vecload_type_t) / sizeof(T);
tv::mp_for_each<kernel_block_t>( tv::mp_for_each<kernel_block_t>(
[=, &buffer, &features, &indices, &notFound](auto NumTLP) { [=, &buffer, &features, &indices, &notFound](auto NumTLP) {
constexpr int NumILP = NumTLP / 4; constexpr int NumILP = NumTLP / 4;
...@@ -204,13 +204,12 @@ void batch_sparse_gather_cuda(torch::Tensor buffer, torch::Tensor features, ...@@ -204,13 +204,12 @@ void batch_sparse_gather_cuda(torch::Tensor buffer, torch::Tensor features,
TV_CHECK_CUDA_ERR(); TV_CHECK_CUDA_ERR();
} }
if (size - nHotBlock > 0) { if (size - nHotBlock > 0) {
auto indices_offset = (nHotBlock / feature_stride) * inds_stride + nHotBlock % feature_stride;
batchGatherVecKernel<T, Index, int(NumTLP), NumILP, vecload_type_t> batchGatherVecKernel<T, Index, int(NumTLP), NumILP, vecload_type_t>
<<<dim3(1, numPlanes / NumTLP), <<<dim3(1, numPlanes / NumTLP),
dim3(NumTLP / NumILP, NumTLP / vecloadFactor), 0, dim3(NumTLP / NumILP, NumTLP / vecloadFactor), 0,
stream>>>(buffer.data_ptr<T>() + nHotBlock * numPlanes, stream>>>(buffer.data_ptr<T>() + nHotBlock * numPlanes,
features.data_ptr<T>(), features.data_ptr<T>(),
indices.data_ptr<Index>() + indices_offset, indices.data_ptr<Index>(),
size - nHotBlock, nHotBlock, numPlanes / vecloadFactor, size - nHotBlock, nHotBlock, numPlanes / vecloadFactor,
inds_stride, feature_stride); inds_stride, feature_stride);
TV_CHECK_CUDA_ERR(); TV_CHECK_CUDA_ERR();
...@@ -280,7 +279,6 @@ void batch_sparse_scatter_add_cuda(torch::Tensor buffer, ...@@ -280,7 +279,6 @@ void batch_sparse_scatter_add_cuda(torch::Tensor buffer,
TV_CHECK_CUDA_ERR(); TV_CHECK_CUDA_ERR();
} }
if (size - nHotBlock > 0) { if (size - nHotBlock > 0) {
// int indices_offset = (nHotBlock / feature_stride) * inds_stride + nHotBlock % feature_stride;
batchScatterAddGenericKernel<T, Index, int(NumTLP), NumILP> batchScatterAddGenericKernel<T, Index, int(NumTLP), NumILP>
<<<dim3(1, numPlanes / NumTLP), dim3(NumTLP / NumILP, NumTLP), <<<dim3(1, numPlanes / NumTLP), dim3(NumTLP / NumILP, NumTLP),
0, stream>>>(outFeatures.data_ptr<T>(), 0, stream>>>(outFeatures.data_ptr<T>(),
......
...@@ -83,7 +83,7 @@ getIndicePairV2(torch::Tensor indices, int64_t batchSize, ...@@ -83,7 +83,7 @@ getIndicePairV2(torch::Tensor indices, int64_t batchSize,
} }
#endif #endif
else { else {
TV_ASSERT_INVALID_ARG(false, "unknown device type"); TV_THROW_INVALID_ARG("unknown device type");
} }
return {indices, indicePairs, indiceNum}; return {indices, indicePairs, indiceNum};
} else { } else {
...@@ -127,7 +127,7 @@ getIndicePairV2(torch::Tensor indices, int64_t batchSize, ...@@ -127,7 +127,7 @@ getIndicePairV2(torch::Tensor indices, int64_t batchSize,
} }
#endif #endif
else { else {
TV_ASSERT_INVALID_ARG(false, "unknown device type"); TV_THROW_INVALID_ARG("unknown device type");
} }
return {outInds.slice(0, 0, numActOut), indicePairs, indiceNum}; return {outInds.slice(0, 0, numActOut), indicePairs, indiceNum};
} }
...@@ -135,12 +135,27 @@ getIndicePairV2(torch::Tensor indices, int64_t batchSize, ...@@ -135,12 +135,27 @@ getIndicePairV2(torch::Tensor indices, int64_t batchSize,
torch::Tensor indiceConv(torch::Tensor features, torch::Tensor filters, torch::Tensor indiceConv(torch::Tensor features, torch::Tensor filters,
torch::Tensor indicePairs, torch::Tensor indiceNum, torch::Tensor indicePairs, torch::Tensor indiceNum,
int64_t numActOut, int64_t _inverse, int64_t _subM) { int64_t numActOut, int64_t _inverse, int64_t _subM,
int64_t algo) {
auto kernelVolume = indiceNum.size(0);
switch (algo) {
case kBatchGemm: {
if (kernelVolume != 1) {
return indiceConvBatch(features, filters, indicePairs, indiceNum,
numActOut, _inverse, _subM);
} else {
break;
}
}
case kNative:
break;
default:
TV_THROW_RT_ERR("unknown algo");
}
bool subM = _subM != 0; bool subM = _subM != 0;
bool inverse = _inverse != 0; bool inverse = _inverse != 0;
auto device = features.device().type(); auto device = features.device().type();
auto ndim = filters.dim() - 2; auto ndim = filters.dim() - 2;
auto kernelVolume = indiceNum.size(0);
auto numInPlanes = features.size(1); auto numInPlanes = features.size(1);
auto numOutPlanes = filters.size(ndim + 1); auto numOutPlanes = filters.size(ndim + 1);
auto indicePairNumCpu = indiceNum.to({torch::kCPU}); auto indicePairNumCpu = indiceNum.to({torch::kCPU});
...@@ -157,7 +172,7 @@ torch::Tensor indiceConv(torch::Tensor features, torch::Tensor filters, ...@@ -157,7 +172,7 @@ torch::Tensor indiceConv(torch::Tensor features, torch::Tensor filters,
torch::Tensor inputBuffer = torch::Tensor inputBuffer =
torch::zeros({indicePairMaxSize, numInPlanes}, options); torch::zeros({indicePairMaxSize, numInPlanes}, options);
torch::Tensor outputBuffer = torch::Tensor outputBuffer =
torch::zeros({indicePairMaxSize, numOutPlanes}, options); torch::empty({indicePairMaxSize, numOutPlanes}, options);
filters = filters.view({-1, numInPlanes, numOutPlanes}); filters = filters.view({-1, numInPlanes, numOutPlanes});
if (subM) { // the center index of subm conv don't need gather and scatter if (subM) { // the center index of subm conv don't need gather and scatter
// add. // add.
...@@ -191,7 +206,7 @@ torch::Tensor indiceConv(torch::Tensor features, torch::Tensor filters, ...@@ -191,7 +206,7 @@ torch::Tensor indiceConv(torch::Tensor features, torch::Tensor filters,
} }
#endif #endif
else { else {
TV_ASSERT_INVALID_ARG(false, "unknown device type"); TV_THROW_INVALID_ARG("unknown device type");
} }
torch::mm_out(outputBufferBlob, inputBufferBlob, filters[i]); torch::mm_out(outputBufferBlob, inputBufferBlob, filters[i]);
if (device == torch::kCPU) { if (device == torch::kCPU) {
...@@ -205,7 +220,7 @@ torch::Tensor indiceConv(torch::Tensor features, torch::Tensor filters, ...@@ -205,7 +220,7 @@ torch::Tensor indiceConv(torch::Tensor features, torch::Tensor filters,
} }
#endif #endif
else { else {
TV_ASSERT_INVALID_ARG(false, "unknown device type"); TV_THROW_INVALID_ARG("unknown device type");
} }
} }
return output; return output;
...@@ -220,33 +235,38 @@ torch::Tensor indiceConvBatch(torch::Tensor features, torch::Tensor filters, ...@@ -220,33 +235,38 @@ torch::Tensor indiceConvBatch(torch::Tensor features, torch::Tensor filters,
auto device = features.device().type(); auto device = features.device().type();
auto ndim = filters.dim() - 2; auto ndim = filters.dim() - 2;
auto kernelVolume = indiceNum.size(0); auto kernelVolume = indiceNum.size(0);
TV_ASSERT_INVALID_ARG(kernelVolume > 1, "error");
auto numInPlanes = features.size(1); auto numInPlanes = features.size(1);
auto numOutPlanes = filters.size(ndim + 1); auto numOutPlanes = filters.size(ndim + 1);
auto indicePairNumCpu = indiceNum.to({torch::kCPU}); auto indicePairNumCpu = indiceNum.to({torch::kCPU});
auto indicePairMaxSizeIter = auto indicePairNumVec =
std::max_element(indicePairNumCpu.data_ptr<int>(), std::vector<int>(indicePairNumCpu.data_ptr<int>(),
indicePairNumCpu.data_ptr<int>() + kernelVolume); indicePairNumCpu.data_ptr<int>() + kernelVolume);
int indicePairMaxOffset = auto indicePairMaxSizeIter =
indicePairMaxSizeIter - indicePairNumCpu.data_ptr<int>(); std::max_element(indicePairNumVec.begin(), indicePairNumVec.end());
int indicePairMaxOffset = indicePairMaxSizeIter - indicePairNumVec.begin();
int indicePairMaxSize = *indicePairMaxSizeIter; int indicePairMaxSize = *indicePairMaxSizeIter;
std::nth_element(indicePairNumVec.begin(), indicePairNumVec.begin() + 1,
indicePairNumVec.end(), std::greater<int>());
int indicePairTop2Size = indicePairNumVec[1];
auto options = auto options =
torch::TensorOptions().dtype(features.dtype()).device(features.device()); torch::TensorOptions().dtype(features.dtype()).device(features.device());
auto indice_dtype = indicePairs.scalar_type();
torch::Tensor output = torch::zeros({numActOut, numOutPlanes}, options); torch::Tensor output = torch::zeros({numActOut, numOutPlanes}, options);
// we cant use batch conv in subm directly because
// number of indice in the center of filter is much more than other
// filter location.
// so we first use top2 indice num to do batch conv, then
// do native conv in center.
int bufferSize = subM ? indicePairTop2Size : indicePairMaxSize;
torch::Tensor inputBuffer = torch::Tensor inputBuffer =
torch::zeros({kernelVolume, indicePairMaxSize, numInPlanes}, options); torch::zeros({kernelVolume, bufferSize, numInPlanes}, options);
torch::Tensor outputBuffer = torch::Tensor outputBuffer =
torch::zeros({kernelVolume, indicePairMaxSize, numOutPlanes}, options); torch::empty({kernelVolume, bufferSize, numOutPlanes}, options);
filters = filters.view({kernelVolume, numInPlanes, numOutPlanes}); filters = filters.view({kernelVolume, numInPlanes, numOutPlanes});
if (subM) { // the center index of subm conv don't need gather and scatter int64_t size = kernelVolume * bufferSize;
// add.
torch::mm_out(output, features, filters[indicePairMaxOffset]);
}
double totalGatherTime = 0;
double totalGEMMTime = 0;
double totalSAddTime = 0;
auto size = kernelVolume * indicePairMaxSize;
if (device == torch::kCPU) { if (device == torch::kCPU) {
TV_ASSERT_INVALID_ARG(false, "unknown device type"); TV_THROW_INVALID_ARG("unknown device type");
} }
#ifdef TV_CUDA #ifdef TV_CUDA
else if (device == torch::kCUDA) { else if (device == torch::kCUDA) {
...@@ -254,11 +274,11 @@ torch::Tensor indiceConvBatch(torch::Tensor features, torch::Tensor filters, ...@@ -254,11 +274,11 @@ torch::Tensor indiceConvBatch(torch::Tensor features, torch::Tensor filters,
} }
#endif #endif
else { else {
TV_ASSERT_INVALID_ARG(false, "unknown device type"); TV_THROW_INVALID_ARG("unknown device type");
} }
torch::bmm_out(outputBuffer, inputBuffer, filters); torch::bmm_out(outputBuffer, inputBuffer, filters);
if (device == torch::kCPU) { if (device == torch::kCPU) {
TV_ASSERT_INVALID_ARG(false, "unknown device type"); TV_THROW_INVALID_ARG("unknown device type");
} }
#ifdef TV_CUDA #ifdef TV_CUDA
else if (device == torch::kCUDA) { else if (device == torch::kCUDA) {
...@@ -267,7 +287,54 @@ torch::Tensor indiceConvBatch(torch::Tensor features, torch::Tensor filters, ...@@ -267,7 +287,54 @@ torch::Tensor indiceConvBatch(torch::Tensor features, torch::Tensor filters,
} }
#endif #endif
else { else {
TV_ASSERT_INVALID_ARG(false, "unknown device type"); TV_THROW_INVALID_ARG("unknown device type");
}
if (subM) {
auto remain_size = indicePairMaxSize - indicePairTop2Size;
if (remain_size <= 0) {
return output;
}
inputBuffer = torch::empty({remain_size, numInPlanes}, options);
outputBuffer = torch::empty({remain_size, numOutPlanes}, options);
if (device == torch::kCPU) {
TV_THROW_INVALID_ARG("unknown device type");
}
#ifdef TV_CUDA
else if (device == torch::kCUDA) {
tv::dispatch_torch<int32_t, int64_t>(indice_dtype, [&](auto I) {
using Index = decltype(I);
auto indicePairsRemain = torch::from_blob(
indicePairs[inverse][indicePairMaxOffset].data_ptr<Index>() +
indicePairTop2Size,
{remain_size}, indicePairs.options());
sparse_gather_cuda(inputBuffer, features, indicePairsRemain,
remain_size);
});
}
#endif
else {
TV_THROW_INVALID_ARG("unknown device type");
}
torch::mm_out(outputBuffer, inputBuffer, filters[indicePairMaxOffset]);
if (device == torch::kCPU) {
TV_THROW_INVALID_ARG("unknown device type");
}
#ifdef TV_CUDA
else if (device == torch::kCUDA) {
tv::dispatch_torch<int32_t, int64_t>(indice_dtype, [&](auto I) {
using Index = decltype(I);
auto indicePairsRemain = torch::from_blob(
indicePairs[!inverse][indicePairMaxOffset].data_ptr<Index>() +
indicePairTop2Size,
{remain_size}, indicePairs.options());
sparse_scatter_add_cuda(outputBuffer, output, indicePairsRemain,
remain_size);
});
}
#endif
else {
TV_THROW_INVALID_ARG("unknown device type");
}
} }
return output; return output;
} }
...@@ -275,13 +342,29 @@ torch::Tensor indiceConvBatch(torch::Tensor features, torch::Tensor filters, ...@@ -275,13 +342,29 @@ torch::Tensor indiceConvBatch(torch::Tensor features, torch::Tensor filters,
std::vector<torch::Tensor> std::vector<torch::Tensor>
indiceConvBackward(torch::Tensor features, torch::Tensor filters, indiceConvBackward(torch::Tensor features, torch::Tensor filters,
torch::Tensor outGrad, torch::Tensor indicePairs, torch::Tensor outGrad, torch::Tensor indicePairs,
torch::Tensor indiceNum, int64_t _inverse, int64_t _subM) { torch::Tensor indiceNum, int64_t _inverse, int64_t _subM,
int64_t algo) {
auto kernelVolume = indiceNum.size(0);
switch (algo) {
case kBatchGemm: {
if (kernelVolume != 1) {
return indiceConvBackwardBatch(features, filters, outGrad, indicePairs,
indiceNum, _inverse, _subM);
} else {
break;
}
}
case kNative:
break;
default:
TV_THROW_RT_ERR("unknown algo");
}
bool subM = _subM != 0; bool subM = _subM != 0;
bool inverse = _inverse != 0; bool inverse = _inverse != 0;
auto device = features.device().type(); auto device = features.device().type();
auto ndim = filters.dim() - 2; auto ndim = filters.dim() - 2;
auto kernelVolume = indiceNum.size(0);
auto numInPlanes = features.size(1); auto numInPlanes = features.size(1);
auto numOutPlanes = filters.size(ndim + 1); auto numOutPlanes = filters.size(ndim + 1);
auto indicePairNumCpu = indiceNum.to({torch::kCPU}); auto indicePairNumCpu = indiceNum.to({torch::kCPU});
...@@ -324,7 +407,7 @@ indiceConvBackward(torch::Tensor features, torch::Tensor filters, ...@@ -324,7 +407,7 @@ indiceConvBackward(torch::Tensor features, torch::Tensor filters,
} }
#endif #endif
else { else {
TV_ASSERT_INVALID_ARG(false, "unknown device type"); TV_THROW_INVALID_ARG("unknown device type");
} }
auto filterGradSub = filtersGrad[i]; auto filterGradSub = filtersGrad[i];
...@@ -346,9 +429,139 @@ indiceConvBackward(torch::Tensor features, torch::Tensor filters, ...@@ -346,9 +429,139 @@ indiceConvBackward(torch::Tensor features, torch::Tensor filters,
} }
#endif #endif
else { else {
TV_ASSERT_INVALID_ARG(false, "unknown device type"); TV_THROW_INVALID_ARG("unknown device type");
}
}
return {inputGrad, filtersGrad.view(filterShape)};
}
std::vector<torch::Tensor>
indiceConvBackwardBatch(torch::Tensor features, torch::Tensor filters,
torch::Tensor outGrad, torch::Tensor indicePairs,
torch::Tensor indiceNum, int64_t _inverse,
int64_t _subM) {
bool subM = _subM != 0;
bool inverse = _inverse != 0;
auto device = features.device().type();
auto ndim = filters.dim() - 2;
auto kernelVolume = indiceNum.size(0);
TV_ASSERT_INVALID_ARG(kernelVolume > 1, "error");
auto numInPlanes = features.size(1);
auto numOutPlanes = filters.size(ndim + 1);
auto indicePairNumCpu = indiceNum.to({torch::kCPU});
auto indicePairNumVec =
std::vector<int>(indicePairNumCpu.data_ptr<int>(),
indicePairNumCpu.data_ptr<int>() + kernelVolume);
auto indicePairMaxSizeIter =
std::max_element(indicePairNumVec.begin(), indicePairNumVec.end());
int indicePairMaxOffset = indicePairMaxSizeIter - indicePairNumVec.begin();
int indicePairMaxSize = *indicePairMaxSizeIter;
std::nth_element(indicePairNumVec.begin(), indicePairNumVec.begin() + 1,
indicePairNumVec.end(), std::greater<int>());
int indicePairTop2Size = indicePairNumVec[1];
auto options =
torch::TensorOptions().dtype(features.dtype()).device(features.device());
auto indice_dtype = indicePairs.scalar_type();
auto filterShape = filters.sizes();
torch::Tensor inputGrad = torch::zeros(features.sizes(), options);
torch::Tensor filtersGrad = torch::zeros(filterShape, options);
int bufferSize = subM ? indicePairTop2Size : indicePairMaxSize;
torch::Tensor inputBuffer =
torch::zeros({kernelVolume, bufferSize, numInPlanes}, options);
torch::Tensor outputBuffer =
torch::zeros({kernelVolume, bufferSize, numOutPlanes}, options);
filters = filters.view({-1, numInPlanes, numOutPlanes});
filtersGrad = filtersGrad.view({-1, numInPlanes, numOutPlanes});
int64_t size = kernelVolume * bufferSize;
if (device == torch::kCPU) {
TV_THROW_INVALID_ARG("unknown device type");
}
#ifdef TV_CUDA
else if (device == torch::kCUDA) {
batch_sparse_gather_cuda(inputBuffer, features, indicePairs[inverse], size);
batch_sparse_gather_cuda(outputBuffer, outGrad, indicePairs[!inverse],
size);
}
#endif
else {
TV_THROW_INVALID_ARG("unknown device type");
}
// filters: KV, I, O, inputBuffer: [KV, buffer, I]
// outputBuffer: [KV, buffer, O]
torch::bmm_out(filtersGrad, inputBuffer.permute({0, 2, 1}), outputBuffer);
torch::bmm_out(inputBuffer, outputBuffer, filters.permute({0, 2, 1}));
if (device == torch::kCPU) {
TV_THROW_INVALID_ARG("unknown device type");
}
#ifdef TV_CUDA
else if (device == torch::kCUDA) {
batch_sparse_scatter_add_cuda(inputBuffer, inputGrad, indicePairs[inverse],
size);
}
#endif
else {
TV_THROW_INVALID_ARG("unknown device type");
} }
if (subM) {
auto remain_size = indicePairMaxSize - indicePairTop2Size;
if (remain_size <= 0) {
return {inputGrad, filtersGrad.view(filterShape)};
}
inputBuffer = torch::zeros({remain_size, numInPlanes}, options);
outputBuffer = torch::zeros({remain_size, numOutPlanes}, options);
if (device == torch::kCPU) {
TV_THROW_INVALID_ARG("unknown device type");
} }
#ifdef TV_CUDA
else if (device == torch::kCUDA) {
tv::dispatch_torch<int32_t, int64_t>(indice_dtype, [&](auto I) {
using Index = decltype(I);
auto indicePairsRemain = torch::from_blob(
indicePairs[inverse][indicePairMaxOffset].data_ptr<Index>() +
indicePairTop2Size,
{remain_size}, indicePairs.options());
auto indicePairsRemain2 = torch::from_blob(
indicePairs[!inverse][indicePairMaxOffset].data_ptr<Index>() +
indicePairTop2Size,
{remain_size}, indicePairs.options());
batch_sparse_gather_cuda(inputBuffer, features, indicePairsRemain,
remain_size);
batch_sparse_gather_cuda(outputBuffer, outGrad, indicePairsRemain2,
remain_size);
});
}
#endif
else {
TV_THROW_INVALID_ARG("unknown device type");
}
torch::mm_out(filtersGrad, inputBuffer.t(), outputBuffer);
torch::mm_out(inputBuffer, outputBuffer, filters[indicePairMaxOffset].t());
if (device == torch::kCPU) {
TV_THROW_INVALID_ARG("unknown device type");
}
#ifdef TV_CUDA
else if (device == torch::kCUDA) {
tv::dispatch_torch<int32_t, int64_t>(indice_dtype, [&](auto I) {
using Index = decltype(I);
auto indicePairsRemain2 = torch::from_blob(
indicePairs[!inverse][indicePairMaxOffset].data_ptr<Index>() +
indicePairTop2Size,
{remain_size}, indicePairs.options());
batch_sparse_scatter_add_cuda(inputBuffer, inputGrad,
indicePairsRemain2, remain_size);
});
}
#endif
else {
TV_THROW_INVALID_ARG("unknown device type");
}
}
return {inputGrad, filtersGrad.view(filterShape)}; return {inputGrad, filtersGrad.view(filterShape)};
} }
......
...@@ -30,6 +30,8 @@ class SparseConv3dTestTorch(nn.Module): ...@@ -30,6 +30,8 @@ class SparseConv3dTestTorch(nn.Module):
def __init__(self, num_layers, ndim, shape, in_channels, out_channels, def __init__(self, num_layers, ndim, shape, in_channels, out_channels,
kernel_size, stride, padding, dilation): kernel_size, stride, padding, dilation):
super().__init__() super().__init__()
algo = spconv.ConvAlgo.BatchGemm
layers = [ layers = [
spconv.SparseConv3d(in_channels, spconv.SparseConv3d(in_channels,
out_channels, out_channels,
...@@ -38,7 +40,8 @@ class SparseConv3dTestTorch(nn.Module): ...@@ -38,7 +40,8 @@ class SparseConv3dTestTorch(nn.Module):
padding=padding, padding=padding,
dilation=dilation, dilation=dilation,
bias=False, bias=False,
use_hash=False) use_hash=False,
algo=algo)
] ]
for i in range(1, num_layers): for i in range(1, num_layers):
layers.append( layers.append(
...@@ -49,7 +52,8 @@ class SparseConv3dTestTorch(nn.Module): ...@@ -49,7 +52,8 @@ class SparseConv3dTestTorch(nn.Module):
padding=padding, padding=padding,
dilation=dilation, dilation=dilation,
bias=False, bias=False,
use_hash=False)) use_hash=False,
algo=algo))
self.net = spconv.SparseSequential(*layers, ) self.net = spconv.SparseSequential(*layers, )
# self.grid = torch.full([3, *shape], -1, dtype=torch.int32).cuda() # self.grid = torch.full([3, *shape], -1, dtype=torch.int32).cuda()
self.grid = None self.grid = None
...@@ -64,7 +68,7 @@ class SparseConv3dTestTorch(nn.Module): ...@@ -64,7 +68,7 @@ class SparseConv3dTestTorch(nn.Module):
class SubMConv3dTestTorch(nn.Module): class SubMConv3dTestTorch(nn.Module):
def __init__(self, num_layers, ndim, shape, in_channels, out_channels, def __init__(self, num_layers, ndim, shape, in_channels, out_channels,
kernel_size, stride, padding, dilation): kernel_size, stride, padding, dilation, algo=spconv.ConvAlgo.Native):
super().__init__() super().__init__()
layers = [ layers = [
spconv.SubMConv3d(in_channels, spconv.SubMConv3d(in_channels,
...@@ -73,7 +77,8 @@ class SubMConv3dTestTorch(nn.Module): ...@@ -73,7 +77,8 @@ class SubMConv3dTestTorch(nn.Module):
stride, stride,
padding=padding, padding=padding,
dilation=dilation, dilation=dilation,
bias=False) bias=False,
algo=algo)
] ]
for i in range(1, num_layers): for i in range(1, num_layers):
layers.append( layers.append(
...@@ -83,14 +88,15 @@ class SubMConv3dTestTorch(nn.Module): ...@@ -83,14 +88,15 @@ class SubMConv3dTestTorch(nn.Module):
stride, stride,
padding=padding, padding=padding,
dilation=dilation, dilation=dilation,
bias=False)) bias=False,
algo=algo))
self.net = spconv.SparseSequential(*layers, ) self.net = spconv.SparseSequential(*layers, )
# self.grid = torch.full([3, *shape], -1, dtype=torch.int32).cuda() # self.grid = torch.full([3, *shape], -1, dtype=torch.int32).cuda()
self.grid = None self.grid = None
self.shape = shape self.shape = shape
def forward(self, features, coors, batch_size): def forward(self, features, coors, batch_size):
coors = coors.int() coors = coors.int()# .cpu()
x = spconv.SparseConvTensor(features, coors, self.shape, batch_size, x = spconv.SparseConvTensor(features, coors, self.shape, batch_size,
self.grid) self.grid)
return self.net(x) # .dense() return self.net(x) # .dense()
...@@ -327,7 +333,7 @@ def scatter_nd(indices, updates, shape): ...@@ -327,7 +333,7 @@ def scatter_nd(indices, updates, shape):
class TestSpConv(TestCase): class TestSpConv(TestCase):
def testSpConv3d(self): def testSpConv3d(self):
np.random.seed(484) np.random.seed(484)
devices = ["cuda:0", "cpu:0"] devices = ["cuda:0"]
shapes = [[19, 18, 17]] shapes = [[19, 18, 17]]
batchsizes = [1, 2] batchsizes = [1, 2]
...@@ -651,14 +657,18 @@ def main(): ...@@ -651,14 +657,18 @@ def main():
out = net(features_t, indices_t, bs) out = net(features_t, indices_t, bs)
# print(out.indices) # print(out.indices)
out = out.dense() out = out.dense()
out_numpy = out.detach().cpu().numpy()
print( print(
np.linalg.norm(out.detach().cpu().numpy() - np.linalg.norm(out.detach().cpu().numpy() -
out_ref.detach().cpu().numpy())) out_ref.detach().cpu().numpy()))
print(out_numpy.min(), out_numpy.max(), out_numpy.mean(), out_numpy.sum())
def main_subm(): def main_subm(algo):
# function for develop. # function for develop.
np.random.seed(484) np.random.seed(484)
torch.manual_seed(50051)
# devices = ["cuda:0"] # devices = ["cuda:0"]
devices = ["cuda:0"] devices = ["cuda:0"]
shapes = [[50, 30, 30]] shapes = [[50, 30, 30]]
...@@ -670,14 +680,13 @@ def main_subm(): ...@@ -670,14 +680,13 @@ def main_subm():
strides = [1] strides = [1]
paddings = [1] paddings = [1]
dilations = [1] dilations = [1]
for dev, shape, bs, IC, OC, k, s, p, d in params_grid( for dev, shape, bs, IC, OC, k, s, p, d in params_grid(
devices, shapes, batchsizes, in_channels, out_channels, ksizes, devices, shapes, batchsizes, in_channels, out_channels, ksizes,
strides, paddings, dilations): strides, paddings, dilations):
if all([s > 1, d > 1]): if all([s > 1, d > 1]):
continue continue
device = torch.device(dev) device = torch.device(dev)
num_points = [500] * bs num_points = [1000] * bs
sparse_dict = generate_sparse_data(shape, num_points, IC) sparse_dict = generate_sparse_data(shape, num_points, IC)
...@@ -694,7 +703,7 @@ def main_subm(): ...@@ -694,7 +703,7 @@ def main_subm():
features_dense_t = torch.from_numpy(features_dense).to(device).float() features_dense_t = torch.from_numpy(features_dense).to(device).float()
net = SubMConv3dTestTorch(1, 3, shape, IC, OC, k, s, p, net = SubMConv3dTestTorch(1, 3, shape, IC, OC, k, s, p,
d).to(device).float() d, algo=algo).to(device).float()
net_ref = Conv3dTestTorch(1, 3, shape, IC, OC, k, s, p, net_ref = Conv3dTestTorch(1, 3, shape, IC, OC, k, s, p,
d).to(device).float() d).to(device).float()
filters_t = torch.from_numpy(filters).to(device).float() filters_t = torch.from_numpy(filters).to(device).float()
...@@ -714,12 +723,16 @@ def main_subm(): ...@@ -714,12 +723,16 @@ def main_subm():
out = net(features_t, indices_t, bs) out = net(features_t, indices_t, bs)
# print(out.indices) # print(out.indices)
out = out.dense() out = out.dense()
out_numpy = out.detach().cpu().numpy()
print( print(
np.linalg.norm(out.detach().cpu().numpy() - np.linalg.norm(out.detach().cpu().numpy() -
out_ref.detach().cpu().numpy())) out_ref.detach().cpu().numpy()))
print(out_numpy.min(), out_numpy.max(), out_numpy.mean(), out_numpy.sum())
return out_numpy
if __name__ == '__main__': if __name__ == '__main__':
main() # out_my = main_subm(algo=spconv.ConvAlgo.BatchGemm)
# out_ref = main_subm(algo=spconv.ConvAlgo.Native)
# TestCase().assertAllClose(out_my, out_ref)
# unittest.main() # unittest.main()
# TestSpConv().testSpConv3d() TestSpConv().testSpConv3d()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment