Commit e2df774f authored by yan.yan's avatar yan.yan
Browse files

fix #532 overflow in huge dim

parent 1f5ce924
...@@ -116,7 +116,7 @@ jobs: ...@@ -116,7 +116,7 @@ jobs:
strategy: strategy:
matrix: matrix:
python-version: ['3.7', '3.8', '3.9', '3.10', '3.11'] # this version is only used for upload. python-version: ['3.7', '3.8', '3.9', '3.10', '3.11'] # this version is only used for upload.
cuda-version: ['102', '113', '114', '116', '117', '118'] cuda-version: ['102', '113', '114', '116', '117', '118', '']
steps: steps:
- uses: actions/checkout@master - uses: actions/checkout@master
......
# Changelog # Changelog
## [2.2.5] - 2022-11-05
### Fixed
- Fix overflow when shape is too large
## [2.2.4] - 2022-10-13 ## [2.2.4] - 2022-10-13
### Added ### Added
- Add prebuilt for CUDA 11.8 (RTX 4090 and H100) and CUDA 11.6. - Add prebuilt for CUDA 11.8 (RTX 4090 and H100) and CUDA 11.6.
......
...@@ -41,8 +41,8 @@ ...@@ -41,8 +41,8 @@
[pypi-url-118]: https://pypi.org/project/spconv-cu118/ [pypi-url-118]: https://pypi.org/project/spconv-cu118/
[pypi-download-118]: https://img.shields.io/pypi/dm/spconv-cu118 [pypi-download-118]: https://img.shields.io/pypi/dm/spconv-cu118
[pypi-url-116]: https://pypi.org/project/spconv-cu118/ [pypi-url-116]: https://pypi.org/project/spconv-cu116/
[pypi-download-116]: https://img.shields.io/pypi/dm/spconv-cu118 [pypi-download-116]: https://img.shields.io/pypi/dm/spconv-cu116
# SpConv: Spatially Sparse Convolution Library # SpConv: Spatially Sparse Convolution Library
[![Build Status](https://github.com/traveller59/spconv/workflows/build/badge.svg)](https://github.com/traveller59/spconv/actions?query=workflow%3Abuild) [![Build Status](https://github.com/traveller59/spconv/workflows/build/badge.svg)](https://github.com/traveller59/spconv/actions?query=workflow%3Abuild)
...@@ -57,7 +57,9 @@ ...@@ -57,7 +57,9 @@
| CUDA 11.4 | [![PyPI Version][pypi-ver-114]][pypi-url-114] | ```pip install spconv-cu114```| [![pypi monthly download][pypi-download-114]][pypi-url-114]| | CUDA 11.4 | [![PyPI Version][pypi-ver-114]][pypi-url-114] | ```pip install spconv-cu114```| [![pypi monthly download][pypi-download-114]][pypi-url-114]|
| CUDA 11.6 | [![PyPI Version][pypi-ver-116]][pypi-url-116] | ```pip install spconv-cu116```| [![pypi monthly download][pypi-download-116]][pypi-url-116]| | CUDA 11.6 | [![PyPI Version][pypi-ver-116]][pypi-url-116] | ```pip install spconv-cu116```| [![pypi monthly download][pypi-download-116]][pypi-url-116]|
| CUDA 11.7 | [![PyPI Version][pypi-ver-117]][pypi-url-117] | ```pip install spconv-cu117```| [![pypi monthly download][pypi-download-117]][pypi-url-117]| | CUDA 11.7 | [![PyPI Version][pypi-ver-117]][pypi-url-117] | ```pip install spconv-cu117```| [![pypi monthly download][pypi-download-117]][pypi-url-117]|
| CUDA 11.8 | [![PyPI Version][pypi-ver-118]][pypi-url-118] | ```pip install spconv-cu118```| [![pypi monthly download][pypi-download-118]][pypi-url-118]| | CUDA 11.8* | [![PyPI Version][pypi-ver-118]][pypi-url-118] | ```pip install spconv-cu118```| [![pypi monthly download][pypi-download-118]][pypi-url-118]|
*: sm_89 and sm_90 is added in CUDA 11.8. If you use RTX 4090 or H100, you should use this version.
<!-- | CUDA 12.0 | [![PyPI Version][pypi-ver-120]][pypi-url-120] | ```pip install spconv-cu120```| [![pypi monthly download][pypi-download-120]][pypi-url-120]| --> <!-- | CUDA 12.0 | [![PyPI Version][pypi-ver-120]][pypi-url-120] | ```pip install spconv-cu120```| [![pypi monthly download][pypi-download-120]][pypi-url-120]| -->
......
[build-system] [build-system]
requires = ["setuptools>=41.0", "wheel", "pccm>=0.4.0", "cumm>=0.3.5"] requires = ["setuptools>=41.0", "wheel", "pccm>=0.4.0", "cumm>=0.3.7"]
# requires = ["setuptools>=41.0", "wheel", "pccm>=0.4.0", "cumm @ file:///io/dist/cumm_cu118-0.3.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl"] # requires = ["setuptools>=41.0", "wheel", "pccm>=0.4.0", "cumm @ file:///io/dist/cumm_cu118-0.3.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl"]
build-backend = "setuptools.build_meta" build-backend = "setuptools.build_meta"
...@@ -39,9 +39,9 @@ if cuda_ver: ...@@ -39,9 +39,9 @@ if cuda_ver:
cuda_ver_str = cuda_ver.replace(".", "") # 10.2 to 102 cuda_ver_str = cuda_ver.replace(".", "") # 10.2 to 102
RELEASE_NAME += "-cu{}".format(cuda_ver_str) RELEASE_NAME += "-cu{}".format(cuda_ver_str)
deps = ["cumm-cu{}>=0.3.4".format(cuda_ver_str)] deps = ["cumm-cu{}>=0.3.7".format(cuda_ver_str)]
else: else:
deps = ["cumm>=0.3.4"] deps = ["cumm>=0.3.7"]
......
...@@ -618,7 +618,6 @@ class SimpleConv: ...@@ -618,7 +618,6 @@ class SimpleConv:
] ]
self.prebuilt_desps = prebuilt_desps self.prebuilt_desps = prebuilt_desps
self.prebuilt_desp_names = {str(d) for d in prebuilt_desps} self.prebuilt_desp_names = {str(d) for d in prebuilt_desps}
self.prebuilt_desp_names.clear()
self.lock = Lock() self.lock = Lock()
self.static_key_to_desps = group_by(self.get_static_key, all_desps) self.static_key_to_desps = group_by(self.get_static_key, all_desps)
......
...@@ -1677,7 +1677,7 @@ class SpconvOps(pccm.Class): ...@@ -1677,7 +1677,7 @@ class SpconvOps(pccm.Class):
}} }}
std::vector<int64_t> output_dims_i64(out_shape.begin(), out_shape.end()); std::vector<int64_t> output_dims_i64(out_shape.begin(), out_shape.end());
int64_t out_spatial_volume = std::accumulate(output_dims_i64.begin(), int64_t out_spatial_volume = std::accumulate(output_dims_i64.begin(),
output_dims_i64.end(), int64_t(1), std::multiplies<int64_t>()); output_dims_i64.end(), int64_t(1), std::multiplies<int64_t>()) * batch_size;
bool use_int64_hash_k = out_spatial_volume >= int64_t(std::numeric_limits<int>::max()); bool use_int64_hash_k = out_spatial_volume >= int64_t(std::numeric_limits<int>::max());
tv::DType indice_uniq_dtype = use_int64_hash_k ? tv::int64 : tv::int32; tv::DType indice_uniq_dtype = use_int64_hash_k ? tv::int64 : tv::int32;
TV_ASSERT_RT_ERR(conv_algo == tv::gemm::SparseConvAlgo::kMaskImplicitGemm || TV_ASSERT_RT_ERR(conv_algo == tv::gemm::SparseConvAlgo::kMaskImplicitGemm ||
...@@ -2022,7 +2022,7 @@ Your Conv Params: )" << "\\n"; ...@@ -2022,7 +2022,7 @@ Your Conv Params: )" << "\\n";
}} }}
std::vector<int64_t> output_dims_i64(out_shape.begin(), out_shape.end()); std::vector<int64_t> output_dims_i64(out_shape.begin(), out_shape.end());
int64_t out_spatial_volume = std::accumulate(output_dims_i64.begin(), int64_t out_spatial_volume = std::accumulate(output_dims_i64.begin(),
output_dims_i64.end(), int64_t(1), std::multiplies<int64_t>()); output_dims_i64.end(), int64_t(1), std::multiplies<int64_t>()) * batch_size;
bool use_int64_hash_k = out_spatial_volume >= int64_t(std::numeric_limits<int>::max()); bool use_int64_hash_k = out_spatial_volume >= int64_t(std::numeric_limits<int>::max());
tv::DType indice_uniq_dtype = use_int64_hash_k ? tv::int64 : tv::int32; tv::DType indice_uniq_dtype = use_int64_hash_k ? tv::int64 : tv::int32;
......
...@@ -76,11 +76,14 @@ class CudaCommonKernel(pccm.ParameterizedClass): ...@@ -76,11 +76,14 @@ class CudaCommonKernel(pccm.ParameterizedClass):
class ConvOutLocIter(pccm.ParameterizedClass): class ConvOutLocIter(pccm.ParameterizedClass):
def __init__(self, problem: ConvProblem): def __init__(self, problem: ConvProblem, use_i64: bool = False):
super().__init__() super().__init__()
self.add_dependency(TensorView) self.add_dependency(TensorView)
self.add_param_class("lociter", problem, "ConvProblem") self.add_param_class("lociter", problem, "ConvProblem")
layout_npq = TensorGeneric(problem.ndim + 1, False) if use_i64:
layout_npq = TensorGeneric(problem.ndim + 1, False, dtypes.int64)
else:
layout_npq = TensorGeneric(problem.ndim + 1, False)
layout_rs = TensorGeneric(problem.ndim, False) layout_rs = TensorGeneric(problem.ndim, False)
self.add_param_class("lociter", layout_npq, "LayoutNPQ") self.add_param_class("lociter", layout_npq, "LayoutNPQ")
...@@ -271,7 +274,10 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -271,7 +274,10 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
super().__init__() super().__init__()
self.add_dependency(TensorView, TensorViewKernel, TensorViewHashKernel) self.add_dependency(TensorView, TensorViewKernel, TensorViewHashKernel)
self.loc_iter = ConvOutLocIter(problem) self.loc_iter = ConvOutLocIter(problem)
self.loc_iter_64 = ConvOutLocIter(problem, True)
self.add_param_class("spinds", self.loc_iter, "ConvLocIter") self.add_param_class("spinds", self.loc_iter, "ConvLocIter")
self.add_param_class("spinds64", self.loc_iter_64, "ConvLocIter64")
self.add_param_class("spinds", problem, "ConvProblem") self.add_param_class("spinds", problem, "ConvProblem")
self.add_param_class("cudakers", CudaCommonKernel()) self.add_param_class("cudakers", CudaCommonKernel())
self.add_include("tensorview/hash/ops.h") self.add_include("tensorview/hash/ops.h")
...@@ -285,8 +291,8 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -285,8 +291,8 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
def calc_conv_indices_stage1(self): def calc_conv_indices_stage1(self):
code = pccm.FunctionCode() code = pccm.FunctionCode()
code.targ("TIndiceUniq") code.targ("TIndiceUniq")
code.arg("loc_iter", f"ConvLocIter") # [N, ndim + 1] code.targ("TConvLocIter")
code.arg("loc_iter", f"TConvLocIter") # [N, ndim + 1]
code.arg("indices_in", f"const int*") # [N, ndim + 1] code.arg("indices_in", f"const int*") # [N, ndim + 1]
code.arg("indice_pairs", code.arg("indice_pairs",
f"{self.dtype_indices}*") # [2, kernelProd, MaxSize] f"{self.dtype_indices}*") # [2, kernelProd, MaxSize]
...@@ -330,15 +336,14 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -330,15 +336,14 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
def build_conv_hash_table(self): def build_conv_hash_table(self):
code = pccm.FunctionCode() code = pccm.FunctionCode()
code.targ("TTable") code.targ("TTable")
code.targ("TLayoutNPQ")
code.arg("table", f"TTable") # [N, ndim + 1] code.arg("table", f"TTable") # [N, ndim + 1]
code.arg("indices_out", f"int*") # [N, ndim + 1] code.arg("indices_out", f"int*") # [N, ndim + 1]
code.arg( code.arg(
"indice_pairs_for_uniq", "indice_pairs_for_uniq",
f"const typename TTable::key_type*") # [2, kernelProd, MaxSize] f"const typename TTable::key_type*") # [2, kernelProd, MaxSize]
code.arg("layout_npq", f"TLayoutNPQ") # [N, ndim + 1]
code.arg("layout_npq",
f"spinds::LayoutNPQ") # [2, kernelProd, MaxSize]
code.arg("num_indices", "int") code.arg("num_indices", "int")
...@@ -355,13 +360,13 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -355,13 +360,13 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
def arange_hash_table_and_assign_out(self): def arange_hash_table_and_assign_out(self):
code = pccm.FunctionCode() code = pccm.FunctionCode()
code.targ("TTable") code.targ("TTable")
code.targ("TLayoutNPQ")
code.arg("table", f"TTable") # [N, ndim + 1] code.arg("table", f"TTable") # [N, ndim + 1]
code.arg("indices_out", f"int*") # [N, ndim + 1] code.arg("indices_out", f"int*") # [N, ndim + 1]
code.arg("count", f"int*") # [N, ndim + 1] code.arg("count", f"int*") # [N, ndim + 1]
code.arg("limit", f"int") # [N, ndim + 1] code.arg("limit", f"int") # [N, ndim + 1]
code.arg("layout_npq", f"TLayoutNPQ") # [N, ndim + 1]
code.arg("layout_npq",
f"spinds::LayoutNPQ") # [2, kernelProd, MaxSize]
code.raw(f""" code.raw(f"""
auto key_ptr = table.key_ptr(); auto key_ptr = table.key_ptr();
...@@ -387,7 +392,8 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -387,7 +392,8 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
code = pccm.FunctionCode() code = pccm.FunctionCode()
code.targ("TTable") code.targ("TTable")
code.arg("table", f"TTable") # [N, ndim + 1] code.arg("table", f"TTable") # [N, ndim + 1]
code.arg("out_indices_offset", f"typename TTable::key_type *") # [N, ndim + 1] code.arg("out_indices_offset",
f"typename TTable::key_type *") # [N, ndim + 1]
code.arg("count", f"int*") # [N, ndim + 1] code.arg("count", f"int*") # [N, ndim + 1]
code.arg("limit", f"int") # [N, ndim + 1] code.arg("limit", f"int") # [N, ndim + 1]
...@@ -411,12 +417,11 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -411,12 +417,11 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
def assign_out_indices(self): def assign_out_indices(self):
code = pccm.FunctionCode() code = pccm.FunctionCode()
code.targ("T") code.targ("T")
code.targ("TLayoutNPQ")
code.arg("indices_out", f"int*") # [N, ndim + 1] code.arg("indices_out", f"int*") # [N, ndim + 1]
code.arg("out_indices_offset", f"const T*") # [N, ndim + 1] code.arg("out_indices_offset", f"const T*") # [N, ndim + 1]
code.arg("layout_npq", code.arg("layout_npq", f"TLayoutNPQ") # [N, ndim + 1]
f"spinds::LayoutNPQ") # [2, kernelProd, MaxSize]
code.arg("size", f"int") # [N, ndim + 1] code.arg("size", f"int") # [N, ndim + 1]
code.raw(f""" code.raw(f"""
for (auto i : tv::KernelLoopX<int>(size)) {{ for (auto i : tv::KernelLoopX<int>(size)) {{
layout_npq.inverse(out_indices_offset[i], indices_out + {self.ndim + 1} * i); layout_npq.inverse(out_indices_offset[i], indices_out + {self.ndim + 1} * i);
...@@ -424,7 +429,6 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -424,7 +429,6 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
""") """)
return code return code
@pccm.cuda.cuda_global_function @pccm.cuda.cuda_global_function
def calc_conv_indices_stage2(self): def calc_conv_indices_stage2(self):
code = pccm.FunctionCode() code = pccm.FunctionCode()
...@@ -497,9 +501,9 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -497,9 +501,9 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
def calc_conv_indices_stage1_mask(self): def calc_conv_indices_stage1_mask(self):
code = pccm.FunctionCode() code = pccm.FunctionCode()
code.targ("TIndiceUniq") code.targ("TIndiceUniq")
code.targ("TConvLocIter")
code.arg("loc_iter", f"ConvLocIter") # [N, ndim + 1] code.arg("loc_iter", f"TConvLocIter") # [N, ndim + 1]
code.arg("indices_in", f"const int*") # [N, ndim + 1] code.arg("indices_in", f"const int*") # [N, ndim + 1]
code.arg("indice_pairs_bwd", code.arg("indice_pairs_bwd",
f"{self.dtype_indices}*") # [kernelProd, MaxSize] f"{self.dtype_indices}*") # [kernelProd, MaxSize]
...@@ -545,9 +549,10 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -545,9 +549,10 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
code = pccm.FunctionCode() code = pccm.FunctionCode()
code.targ("TIndiceUniq") code.targ("TIndiceUniq")
code.targ("TTable") code.targ("TTable")
code.arg("table", f"TTable") # [N, ndim + 1] code.targ("TConvLocIter")
code.arg("loc_iter", f"ConvLocIter") # [N, ndim + 1] code.arg("table", f"TTable") # [N, ndim + 1]
code.arg("loc_iter", f"TConvLocIter") # [N, ndim + 1]
code.arg("indices_in", f"const int*") # [N, ndim + 1] code.arg("indices_in", f"const int*") # [N, ndim + 1]
code.arg("indice_pairs_bwd", code.arg("indice_pairs_bwd",
...@@ -710,10 +715,12 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -710,10 +715,12 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
def build_subm_conv_hash_table(self): def build_subm_conv_hash_table(self):
code = pccm.FunctionCode() code = pccm.FunctionCode()
code.targ("TTable") code.targ("TTable")
code.targ("TLayoutNPQ")
code.arg("table", f"TTable") # [N, ndim + 1] code.arg("table", f"TTable") # [N, ndim + 1]
code.arg("indices_in", f"const int*") # [N, ndim + 1] code.arg("indices_in", f"const int*") # [N, ndim + 1]
code.arg("layout_npq", f"spinds::LayoutNPQ") code.arg("layout_npq", f"TLayoutNPQ")
code.arg("num_indices", "int") code.arg("num_indices", "int")
...@@ -741,7 +748,8 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -741,7 +748,8 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
def calc_subm_conv_indices(self): def calc_subm_conv_indices(self):
code = pccm.FunctionCode() code = pccm.FunctionCode()
code.targ("TTable") code.targ("TTable")
code.arg("loc_iter", f"ConvLocIter") # [N, ndim + 1] code.targ("TConvLocIter")
code.arg("loc_iter", f"TConvLocIter") # [N, ndim + 1]
code.arg("table", f"TTable") # [N, ndim + 1] code.arg("table", f"TTable") # [N, ndim + 1]
code.arg("indices_in", f"const int*") # [N, ndim + 1] code.arg("indices_in", f"const int*") # [N, ndim + 1]
...@@ -790,7 +798,8 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -790,7 +798,8 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
def calc_subm_conv_indices_mask(self): def calc_subm_conv_indices_mask(self):
code = pccm.FunctionCode() code = pccm.FunctionCode()
code.targ("TTable") code.targ("TTable")
code.arg("loc_iter", f"ConvLocIter") # [N, ndim + 1] code.targ("TConvLocIter")
code.arg("loc_iter", f"TConvLocIter") # [N, ndim + 1]
code.arg("table", f"TTable") # [N, ndim + 1] code.arg("table", f"TTable") # [N, ndim + 1]
code.arg("indices_in", f"const int*") # [N, ndim + 1] code.arg("indices_in", f"const int*") # [N, ndim + 1]
...@@ -857,7 +866,9 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -857,7 +866,9 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
def calc_subm_conv_indices_split_mask(self): def calc_subm_conv_indices_split_mask(self):
code = pccm.FunctionCode() code = pccm.FunctionCode()
code.targ("TTable") code.targ("TTable")
code.arg("loc_iter", f"ConvLocIter") # [N, ndim + 1] code.targ("TConvLocIter")
code.arg("loc_iter", f"TConvLocIter") # [N, ndim + 1]
code.arg("table", f"TTable") # [N, ndim + 1] code.arg("table", f"TTable") # [N, ndim + 1]
code.arg("indices_in", f"const int*") # [N, ndim + 1] code.arg("indices_in", f"const int*") # [N, ndim + 1]
...@@ -952,20 +963,24 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -952,20 +963,24 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
// tv::cuda::Launch launcher_num_act_in_2(indices.dim(0)); // tv::cuda::Launch launcher_num_act_in_2(indices.dim(0));
launcher_num_act_in.blocks.y = kv; launcher_num_act_in.blocks.y = kv;
ConvProblem problem(batch_size, 1, 1, input_dims, output_dims, ksize, padding, stride, dilation); ConvProblem problem(batch_size, 1, 1, input_dims, output_dims, ksize, padding, stride, dilation);
ConvLocIter loc_iter(problem); bool use_int32 = problem.check_npq_not_overflow();
tv::cuda::Launch launcher_clean_uniq(uniq_size, reinterpret_cast<cudaStream_t>(stream_int)); tv::cuda::Launch launcher_clean_uniq(uniq_size, reinterpret_cast<cudaStream_t>(stream_int));
tv::dispatch<int32_t, int64_t>(indice_pairs_uniq.dtype(), [&](auto I){{
using T = TV_DECLTYPE(I);
TV_ASSERT_RT_ERR(input_dims.op<tv::arrayops::prod>() < std::numeric_limits<T>::max(),
"kernel volume must smaller than max value of T");
launcher_clean_uniq(clean_indices_uniq<T>, indice_pairs_uniq.data_ptr<T>(), uniq_size);
launcher_num_act_in(calc_conv_indices_stage1<T>, loc_iter, indices.data_ptr<const int>(),
indice_pairs.data_ptr<{self.dtype_indices}>(),
indice_pairs_uniq.data_ptr<T>(), indice_num_per_loc.data_ptr<int>(), indices.dim(0),
indice_pairs.dim(2), kv, transposed);
}});
""") """)
for x in codeops.dispatch_ints(code, [0, 1], "int(use_int32)"):
loc_type = "ConvLocIter" if x == 1 else "ConvLocIter64"
code.raw(f"""
{loc_type} loc_iter(problem);
tv::dispatch<int32_t, int64_t>(indice_pairs_uniq.dtype(), [&](auto I){{
using T = TV_DECLTYPE(I);
TV_ASSERT_RT_ERR(input_dims.op<tv::arrayops::prod>() < std::numeric_limits<T>::max(),
"kernel volume must smaller than max value of T");
launcher_clean_uniq(clean_indices_uniq<T>, indice_pairs_uniq.data_ptr<T>(), uniq_size);
launcher_num_act_in(calc_conv_indices_stage1<T, {loc_type}>, loc_iter, indices.data_ptr<const int>(),
indice_pairs.data_ptr<{self.dtype_indices}>(),
indice_pairs_uniq.data_ptr<T>(), indice_num_per_loc.data_ptr<int>(), indices.dim(0),
indice_pairs.dim(2), kv, transposed);
}});
""")
return code # .ret("int") return code # .ret("int")
@pccm.cuda.static_function @pccm.cuda.static_function
...@@ -1029,12 +1044,17 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -1029,12 +1044,17 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
tv::cuda::Launch launcher_num_act_in(indices.dim(0), custream); tv::cuda::Launch launcher_num_act_in(indices.dim(0), custream);
launcher_num_act_in.blocks.y = kv; launcher_num_act_in.blocks.y = kv;
ConvProblem problem(batch_size, 1, 1, input_dims, output_dims, ksize, padding, stride, dilation); ConvProblem problem(batch_size, 1, 1, input_dims, output_dims, ksize, padding, stride, dilation);
ConvLocIter loc_iter(problem); bool use_int32 = problem.check_npq_not_overflow();
// TODO handle invalid num_out_act // TODO handle invalid num_out_act
indice_pairs_uniq = indice_pairs_uniq.slice_first_axis(0, num_out_act); indice_pairs_uniq = indice_pairs_uniq.slice_first_axis(0, num_out_act);
tv::cuda::Launch lanucher_build_hash(num_out_act, custream); tv::cuda::Launch lanucher_build_hash(num_out_act, custream);
tv::dispatch<int32_t, int64_t>(hashdata_k.dtype(), [&](auto I){{ """)
with code.block(
"",
"tv::dispatch<int32_t, int64_t>(hashdata_k.dtype(), [&](auto I){",
"});"):
code.raw(f"""
using V = {self.dtype_indices}; using V = {self.dtype_indices};
using K = TV_DECLTYPE(I); using K = TV_DECLTYPE(I);
using table_t = using table_t =
...@@ -1044,9 +1064,17 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -1044,9 +1064,17 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
table_t hash = table_t(hashdata_k.data_ptr<K>(), hashdata_v.data_ptr<V>(), hashdata_k.dim(0)); table_t hash = table_t(hashdata_k.data_ptr<K>(), hashdata_v.data_ptr<V>(), hashdata_k.dim(0));
tv::hash::clear_map_split(hash, custream); tv::hash::clear_map_split(hash, custream);
// hash.clear(custream); // hash.clear(custream);
lanucher_build_hash(build_conv_hash_table<table_t>, hash, """)
out_inds.data_ptr<int>(), indice_pairs_uniq.data_ptr<const K>(), for x in codeops.dispatch_ints(code, [0, 1],
loc_iter.layout_npq, num_out_act); "int(use_int32)"):
loc_type = "ConvLocIter" if x == 1 else "ConvLocIter64"
code.raw(f"""
{loc_type} loc_iter(problem);
lanucher_build_hash(build_conv_hash_table<table_t, std::decay_t<decltype(loc_iter.layout_npq)>>, hash,
out_inds.data_ptr<int>(), indice_pairs_uniq.data_ptr<const K>(),
loc_iter.layout_npq, num_out_act);
""")
code.raw(f"""
if (!use_bound_algo){{ if (!use_bound_algo){{
launcher_num_act_in(calc_conv_indices_stage2<table_t>, hash, launcher_num_act_in(calc_conv_indices_stage2<table_t>, hash,
indice_pairs_uniq_before_sort.data_ptr<const K>(), indice_pairs_uniq_before_sort.data_ptr<const K>(),
...@@ -1070,7 +1098,8 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -1070,7 +1098,8 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
indices.dim(0), indices.dim(0),
indice_pairs.dim(2)); indice_pairs.dim(2));
}} }}
}}); """)
code.raw(f"""
return num_out_act; return num_out_act;
""") """)
return code.ret("int") return code.ret("int")
...@@ -1108,28 +1137,32 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -1108,28 +1137,32 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
// tv::cuda::Launch launcher_num_act_in_2(indices.dim(0)); // tv::cuda::Launch launcher_num_act_in_2(indices.dim(0));
launcher_num_act_in.blocks.y = kv; launcher_num_act_in.blocks.y = kv;
ConvProblem problem(batch_size, 1, 1, input_dims, output_dims, ksize, padding, stride, dilation); ConvProblem problem(batch_size, 1, 1, input_dims, output_dims, ksize, padding, stride, dilation);
ConvLocIter loc_iter(problem); bool use_int32 = problem.check_npq_not_overflow();
tv::cuda::Launch launcher_clean_uniq(uniq_size, reinterpret_cast<cudaStream_t>(stream_int)); tv::cuda::Launch launcher_clean_uniq(uniq_size, reinterpret_cast<cudaStream_t>(stream_int));
tv::dispatch<int32_t, int64_t>(indice_pairs_uniq.dtype(), [&](auto I){{
using T = TV_DECLTYPE(I);
TV_ASSERT_RT_ERR(input_dims.op<tv::arrayops::prod>() < std::numeric_limits<T>::max(),
"kernel volume must smaller than max value of T");
launcher_clean_uniq(clean_indices_uniq<T>, indice_pairs_uniq.data_ptr<T>(), uniq_size);
launcher_num_act_in(calc_conv_indices_stage1_mask<T>, loc_iter, indices.data_ptr<const int>(),
indice_pairs_bwd.data_ptr<{self.dtype_indices}>(),
indice_pairs_uniq.data_ptr<T>(), indice_num_per_loc.data_ptr<int>(), indices.dim(0),
kv, transposed);
}});
""") """)
return code # .ret("int")
for x in codeops.dispatch_ints(code, [0, 1], "int(use_int32)"):
loc_type = "ConvLocIter" if x == 1 else "ConvLocIter64"
code.raw(f"""
{loc_type} loc_iter(problem);
tv::dispatch<int32_t, int64_t>(indice_pairs_uniq.dtype(), [&](auto I){{
using T = TV_DECLTYPE(I);
TV_ASSERT_RT_ERR(input_dims.op<tv::arrayops::prod>() < std::numeric_limits<T>::max(),
"kernel volume must smaller than max value of T");
launcher_clean_uniq(clean_indices_uniq<T>, indice_pairs_uniq.data_ptr<T>(), uniq_size);
launcher_num_act_in(calc_conv_indices_stage1_mask<T, {loc_type}>, loc_iter, indices.data_ptr<const int>(),
indice_pairs_bwd.data_ptr<{self.dtype_indices}>(),
indice_pairs_uniq.data_ptr<T>(), indice_num_per_loc.data_ptr<int>(), indices.dim(0),
kv, transposed);
}});
""")
return code # .ret("int")
@pccm.cuda.static_function @pccm.cuda.static_function
def generate_conv_inds_mask_stage1_direct_table(self): def generate_conv_inds_mask_stage1_direct_table(self):
code = pccm.FunctionCode() code = pccm.FunctionCode()
code.arg("indices, hashdata_k, hashdata_v", "tv::Tensor") code.arg("indices, hashdata_k, hashdata_v", "tv::Tensor")
code.arg("indice_pairs_bwd, indice_pairs_uniq", code.arg("indice_pairs_bwd, indice_pairs_uniq", "tv::Tensor")
"tv::Tensor")
code.arg("indice_num_per_loc", "tv::Tensor") code.arg("indice_num_per_loc", "tv::Tensor")
code.arg("batch_size", "int") code.arg("batch_size", "int")
code.arg("output_dims, input_dims", f"tv::array<int, {self.ndim}>") code.arg("output_dims, input_dims", f"tv::array<int, {self.ndim}>")
...@@ -1158,9 +1191,15 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -1158,9 +1191,15 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
// tv::cuda::Launch launcher_num_act_in_2(indices.dim(0)); // tv::cuda::Launch launcher_num_act_in_2(indices.dim(0));
launcher_num_act_in.blocks.y = kv; launcher_num_act_in.blocks.y = kv;
ConvProblem problem(batch_size, 1, 1, input_dims, output_dims, ksize, padding, stride, dilation); ConvProblem problem(batch_size, 1, 1, input_dims, output_dims, ksize, padding, stride, dilation);
ConvLocIter loc_iter(problem);
tv::cuda::Launch launcher_clean_uniq(uniq_size, reinterpret_cast<cudaStream_t>(stream_int)); tv::cuda::Launch launcher_clean_uniq(uniq_size, reinterpret_cast<cudaStream_t>(stream_int));
tv::dispatch<int32_t, int64_t>(indice_pairs_uniq.dtype(), [&](auto I){{ bool use_int32 = problem.check_npq_not_overflow();
""")
with code.block(
"",
"tv::dispatch<int32_t, int64_t>(indice_pairs_uniq.dtype(), [&](auto I){",
"});"):
code.raw(f"""
using V = {self.dtype_indices}; using V = {self.dtype_indices};
using K = TV_DECLTYPE(I); using K = TV_DECLTYPE(I);
using table_t = using table_t =
...@@ -1172,17 +1211,21 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -1172,17 +1211,21 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
TV_ASSERT_RT_ERR(input_dims.op<tv::arrayops::prod>() < std::numeric_limits<T>::max(), TV_ASSERT_RT_ERR(input_dims.op<tv::arrayops::prod>() < std::numeric_limits<T>::max(),
"kernel volume must smaller than max value of T"); "kernel volume must smaller than max value of T");
launcher_clean_uniq(clean_indices_uniq<T>, indice_pairs_uniq.data_ptr<T>(), uniq_size); launcher_clean_uniq(clean_indices_uniq<T>, indice_pairs_uniq.data_ptr<T>(), uniq_size);
launcher_num_act_in(calc_conv_indices_stage1_mask_direct_table<T, table_t>, table, """)
loc_iter, indices.data_ptr<const int>(), for x in codeops.dispatch_ints(code, [0, 1],
indice_pairs_bwd.data_ptr<{self.dtype_indices}>(), "int(use_int32)"):
indice_pairs_uniq.data_ptr<T>(), indice_num_per_loc.data_ptr<int>(), loc_type = "ConvLocIter" if x == 1 else "ConvLocIter64"
indices.dim(0), code.raw(f"""
kv, transposed); {loc_type} loc_iter(problem);
}}); launcher_num_act_in(calc_conv_indices_stage1_mask_direct_table<T, table_t, {loc_type}>, table,
""") loc_iter, indices.data_ptr<const int>(),
indice_pairs_bwd.data_ptr<{self.dtype_indices}>(),
indice_pairs_uniq.data_ptr<T>(), indice_num_per_loc.data_ptr<int>(),
indices.dim(0),
kv, transposed);
""")
return code return code
def generate_conv_inds_stage2_mask_template(self, is_direct_table: bool): def generate_conv_inds_stage2_mask_template(self, is_direct_table: bool):
"""here indice_pairs_uniq may be bounded, some """here indice_pairs_uniq may be bounded, some
points may be dropped. points may be dropped.
...@@ -1233,8 +1276,9 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -1233,8 +1276,9 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
tv::cuda::Launch launcher_num_act_in_no_y(num_act_in, custream); tv::cuda::Launch launcher_num_act_in_no_y(num_act_in, custream);
ConvProblem problem(batch_size, 1, 1, input_dims, output_dims, ksize, padding, stride, dilation); ConvProblem problem(batch_size, 1, 1, input_dims, output_dims, ksize, padding, stride, dilation);
ConvLocIter loc_iter(problem);
tv::cuda::Launch lanucher_build_hash(num_out_act, custream); tv::cuda::Launch lanucher_build_hash(num_out_act, custream);
bool use_int32 = problem.check_npq_not_overflow();
// TODO handle invalid num_out_act // TODO handle invalid num_out_act
""") """)
...@@ -1242,8 +1286,11 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -1242,8 +1286,11 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
code.raw(f""" code.raw(f"""
indice_pairs_uniq = indice_pairs_uniq.slice_first_axis(0, num_out_act); indice_pairs_uniq = indice_pairs_uniq.slice_first_axis(0, num_out_act);
""") """)
with code.block("", start="tv::dispatch<int32_t, int64_t>(hashdata_k.dtype(), [&](auto I){", with code.block(
end="});"): "",
start=
"tv::dispatch<int32_t, int64_t>(hashdata_k.dtype(), [&](auto I){",
end="});"):
code.raw(f""" code.raw(f"""
using V = {self.dtype_indices}; using V = {self.dtype_indices};
using K = TV_DECLTYPE(I); using K = TV_DECLTYPE(I);
...@@ -1254,13 +1301,19 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -1254,13 +1301,19 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
table_t hash = table_t(hashdata_k.data_ptr<K>(), hashdata_v.data_ptr<V>(), hashdata_k.dim(0)); table_t hash = table_t(hashdata_k.data_ptr<K>(), hashdata_v.data_ptr<V>(), hashdata_k.dim(0));
""") """)
if not is_direct_table: if not is_direct_table:
# direct table built in stage 1.
code.raw(f""" code.raw(f"""
tv::hash::clear_map_split(hash, custream); tv::hash::clear_map_split(hash, custream);
lanucher_build_hash(build_conv_hash_table<table_t>, hash,
out_inds.data_ptr<int>(), indice_pairs_uniq.data_ptr<const K>(),
loc_iter.layout_npq, num_out_act);
""") """)
# direct table built in stage 1.
for x in codeops.dispatch_ints(code, [0, 1],
"int(use_int32)"):
loc_type = "ConvLocIter" if x == 1 else "ConvLocIter64"
code.raw(f"""
{loc_type} loc_iter(problem);
lanucher_build_hash(build_conv_hash_table<table_t, std::decay_t<decltype(loc_iter.layout_npq)>>, hash,
out_inds.data_ptr<int>(), indice_pairs_uniq.data_ptr<const K>(),
loc_iter.layout_npq, num_out_act);
""")
code.raw(f""" code.raw(f"""
if (!mask_bwd.empty()){{ if (!mask_bwd.empty()){{
launcher_num_act_in(calc_conv_indices_stage2_mask<table_t, {pccm.literal(is_direct_table)}>, hash, launcher_num_act_in(calc_conv_indices_stage2_mask<table_t, {pccm.literal(is_direct_table)}>, hash,
...@@ -1293,14 +1346,14 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -1293,14 +1346,14 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
return num_out_act; return num_out_act;
""") """)
return code.ret("int") return code.ret("int")
@pccm.cuda.static_function @pccm.cuda.static_function
def generate_conv_inds_stage2_mask(self): def generate_conv_inds_stage2_mask(self):
"""here indice_pairs_uniq may be bounded, some """here indice_pairs_uniq may be bounded, some
points may be dropped. points may be dropped.
""" """
return self.generate_conv_inds_stage2_mask_template(False) return self.generate_conv_inds_stage2_mask_template(False)
@pccm.cuda.static_function @pccm.cuda.static_function
def generate_conv_inds_stage2_mask_direct_table(self): def generate_conv_inds_stage2_mask_direct_table(self):
"""here indice_pairs_uniq may be bounded, some """here indice_pairs_uniq may be bounded, some
...@@ -1314,9 +1367,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -1314,9 +1367,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
""" """
code = pccm.FunctionCode() code = pccm.FunctionCode()
code.arg("hashdata_k, hashdata_v, uniq_cnt", "tv::Tensor") code.arg("hashdata_k, hashdata_v, uniq_cnt", "tv::Tensor")
code.arg( code.arg("out_inds", "tv::Tensor")
"out_inds",
"tv::Tensor")
code.arg("num_out_bound", "int") code.arg("num_out_bound", "int")
code.arg("batch_size", "int") code.arg("batch_size", "int")
code.arg("output_dims, input_dims", f"tv::array<int, {self.ndim}>") code.arg("output_dims, input_dims", f"tv::array<int, {self.ndim}>")
...@@ -1328,23 +1379,30 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -1328,23 +1379,30 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
auto custream = reinterpret_cast<cudaStream_t>(stream_int); auto custream = reinterpret_cast<cudaStream_t>(stream_int);
tv::cuda::Launch lanucher_build_hash(hashdata_k.size(), custream); tv::cuda::Launch lanucher_build_hash(hashdata_k.size(), custream);
ConvProblem problem(batch_size, 1, 1, input_dims, output_dims, ksize, padding, stride, dilation); ConvProblem problem(batch_size, 1, 1, input_dims, output_dims, ksize, padding, stride, dilation);
ConvLocIter loc_iter(problem); bool use_int32 = problem.check_npq_not_overflow();
auto tvctx = tv::Context(); auto tvctx = tv::Context();
tvctx.set_cuda_stream(reinterpret_cast<cudaStream_t>(stream_int)); tvctx.set_cuda_stream(reinterpret_cast<cudaStream_t>(stream_int));
if (num_out_bound <= 0){{ if (num_out_bound <= 0){{
num_out_bound = hashdata_k.size(); num_out_bound = hashdata_k.size();
}} }}
tv::dispatch<int32_t, int64_t>(hashdata_k.dtype(), [&](auto I){{ """)
using V = {self.dtype_indices}; for x in codeops.dispatch_ints(code, [0, 1], "int(use_int32)"):
using K = TV_DECLTYPE(I); loc_type = "ConvLocIter" if x == 1 else "ConvLocIter64"
using table_t = code.raw(f"""
tv::hash::LinearHashTableSplit<K, V, tv::hash::Murmur3Hash<K>, {loc_type} loc_iter(problem);
tv::hash::default_empty_key_v<K>, false>; tv::dispatch<int32_t, int64_t>(hashdata_k.dtype(), [&](auto I){{
table_t table = table_t(hashdata_k.data_ptr<K>(), hashdata_v.data_ptr<V>(), hashdata_k.dim(0)); using V = {self.dtype_indices};
lanucher_build_hash(arange_hash_table_and_assign_out<table_t>, table, using K = TV_DECLTYPE(I);
out_inds.data_ptr<int>(), uniq_cnt.data_ptr<int>(), num_out_bound, using table_t =
loc_iter.layout_npq); tv::hash::LinearHashTableSplit<K, V, tv::hash::Murmur3Hash<K>,
}}); tv::hash::default_empty_key_v<K>, false>;
table_t table = table_t(hashdata_k.data_ptr<K>(), hashdata_v.data_ptr<V>(), hashdata_k.dim(0));
lanucher_build_hash(arange_hash_table_and_assign_out<table_t, std::decay_t<decltype(loc_iter.layout_npq)>>, table,
out_inds.data_ptr<int>(), uniq_cnt.data_ptr<int>(), num_out_bound,
loc_iter.layout_npq);
}});
""")
code.raw(f"""
auto uniq_cnt_cpu = uniq_cnt.cpu(tvctx); auto uniq_cnt_cpu = uniq_cnt.cpu(tvctx);
return std::min(uniq_cnt_cpu.data_ptr<int>()[0], num_out_bound); return std::min(uniq_cnt_cpu.data_ptr<int>()[0], num_out_bound);
""") """)
...@@ -1355,7 +1413,8 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -1355,7 +1413,8 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
"""unique by hash """unique by hash
""" """
code = pccm.FunctionCode() code = pccm.FunctionCode()
code.arg("hashdata_k, hashdata_v, uniq_cnt, out_indices_offset", "tv::Tensor") code.arg("hashdata_k, hashdata_v, uniq_cnt, out_indices_offset",
"tv::Tensor")
code.arg("num_out_bound", "int") code.arg("num_out_bound", "int")
code.arg("stream_int", f"std::uintptr_t", "0") code.arg("stream_int", f"std::uintptr_t", "0")
code.raw(f""" code.raw(f"""
...@@ -1400,16 +1459,22 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -1400,16 +1459,22 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
tv::cuda::Launch lanucher_build_hash(out_inds.dim(0), custream); tv::cuda::Launch lanucher_build_hash(out_inds.dim(0), custream);
TV_ASSERT_RT_ERR(out_indices_offset.dim(0) >= out_inds.dim(0), "error"); TV_ASSERT_RT_ERR(out_indices_offset.dim(0) >= out_inds.dim(0), "error");
ConvProblem problem(batch_size, 1, 1, input_dims, output_dims, ksize, padding, stride, dilation); ConvProblem problem(batch_size, 1, 1, input_dims, output_dims, ksize, padding, stride, dilation);
ConvLocIter loc_iter(problem); bool use_int32 = problem.check_npq_not_overflow();
auto tvctx = tv::Context(); auto tvctx = tv::Context();
tvctx.set_cuda_stream(reinterpret_cast<cudaStream_t>(stream_int)); tvctx.set_cuda_stream(reinterpret_cast<cudaStream_t>(stream_int));
tv::dispatch<int32_t, int64_t>(out_indices_offset.dtype(), [&](auto I){{
using K = TV_DECLTYPE(I);
lanucher_build_hash(assign_out_indices<K>, out_inds.data_ptr<int>(),
out_indices_offset.data_ptr<const K>(),
loc_iter.layout_npq, out_inds.dim(0));
}});
""") """)
for x in codeops.dispatch_ints(code, [0, 1], "int(use_int32)"):
loc_type = "ConvLocIter" if x == 1 else "ConvLocIter64"
code.raw(f"""
{loc_type} loc_iter(problem);
tv::dispatch<int32_t, int64_t>(out_indices_offset.dtype(), [&](auto I){{
using K = TV_DECLTYPE(I);
lanucher_build_hash(assign_out_indices<K, std::decay_t<decltype(loc_iter.layout_npq)>>, out_inds.data_ptr<int>(),
out_indices_offset.data_ptr<const K>(),
loc_iter.layout_npq, out_inds.dim(0));
}});
""")
return code return code
@pccm.cuda.static_function @pccm.cuda.static_function
...@@ -1451,57 +1516,61 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -1451,57 +1516,61 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
launcher_num_act_in.blocks.y = (kv / 2) + 1; launcher_num_act_in.blocks.y = (kv / 2) + 1;
// launcher_num_act_in.blocks.y = kv; // launcher_num_act_in.blocks.y = kv;
ConvProblem problem(batch_size, 1, 1, input_dims, input_dims, ksize, padding, stride, dilation); ConvProblem problem(batch_size, 1, 1, input_dims, input_dims, ksize, padding, stride, dilation);
ConvLocIter loc_iter(problem); bool use_int32 = problem.check_npq_not_overflow();
tv::cuda::Launch lanucher_build_hash(num_act_in_real, custream); tv::cuda::Launch lanucher_build_hash(num_act_in_real, custream);
tv::dispatch<int32_t, int64_t>(hashdata_k.dtype(), [&](auto I){{ """)
using V = {self.dtype_indices}; for x in codeops.dispatch_ints(code, [0, 1], "int(use_int32)"):
using K = TV_DECLTYPE(I); loc_type = "ConvLocIter" if x == 1 else "ConvLocIter64"
TV_ASSERT_RT_ERR(input_dims.op<tv::arrayops::prod>() < std::numeric_limits<K>::max(), code.raw(f"""
"kernel volume must smaller than max value of K"); {loc_type} loc_iter(problem);
tv::dispatch<int32_t, int64_t>(hashdata_k.dtype(), [&](auto I){{
using table_t = using V = {self.dtype_indices};
tv::hash::LinearHashTableSplit<K, V, tv::hash::Murmur3Hash<K>, using K = TV_DECLTYPE(I);
tv::hash::default_empty_key_v<K>, false>; TV_ASSERT_RT_ERR(input_dims.op<tv::arrayops::prod>() < std::numeric_limits<K>::max(),
TV_ASSERT_RT_ERR(hashdata_k.dim(0) >= num_act_in_real, "hash size not enough"); "kernel volume must smaller than max value of K");
table_t hash = table_t(hashdata_k.data_ptr<K>(), hashdata_v.data_ptr<V>(), hashdata_k.dim(0));
tv::hash::clear_map_split(hash, custream); using table_t =
tv::hash::LinearHashTableSplit<K, V, tv::hash::Murmur3Hash<K>,
tv::hash::default_empty_key_v<K>, false>;
TV_ASSERT_RT_ERR(hashdata_k.dim(0) >= num_act_in_real, "hash size not enough");
table_t hash = table_t(hashdata_k.data_ptr<K>(), hashdata_v.data_ptr<V>(), hashdata_k.dim(0));
tv::hash::clear_map_split(hash, custream);
lanucher_build_hash(build_subm_conv_hash_table<table_t, std::decay_t<decltype(loc_iter.layout_npq)>>, hash, indices.data_ptr<const int>(),
loc_iter.layout_npq, num_act_in_real);
if (!indice_pair_mask.empty()){{
TV_ASSERT_RT_ERR(indice_pairs.ndim() == 3, "error");
TV_ASSERT_RT_ERR(indice_pairs.dim(0) == (is_train ? 2 : 1), "error");
TV_ASSERT_INVALID_ARG(indice_pair_mask.ndim() == 2, "error");
// indice_pair_mask: [mask_split_count, num_act_in]
if (indice_pair_mask.dim(0) == 2){{
auto mask_0 = indice_pair_mask[0].slice_first_axis(0, num_act_in_real);
auto mask_1 = indice_pair_mask[1].slice_first_axis(0, num_act_in_real);
tv::cuda::Launch lanucher_fill(num_act_in_real, custream);
lanucher_fill(cudakers::fill_kernel<uint32_t>, mask_0.data_ptr<uint32_t>(), (1 << (kv / 2)), indices.dim(0));
mask_1.zero_(ctx);
auto kernel = &calc_subm_conv_indices_split_mask<table_t, {loc_type}>;
launcher_num_act_in(kernel, loc_iter, hash,
indices.data_ptr<const int>(), indice_pairs.data_ptr<int>(),
mask_0.data_ptr<uint32_t>(), mask_1.data_ptr<uint32_t>(),
indices.dim(0), indice_pairs.dim(2), kv, is_train);
lanucher_build_hash(build_subm_conv_hash_table<table_t>, hash, indices.data_ptr<const int>(), }}else{{
loc_iter.layout_npq, num_act_in_real); // indice_pair_mask: [1, num_act_in]
if (!indice_pair_mask.empty()){{ tv::cuda::Launch lanucher_fill(num_act_in_real, custream);
TV_ASSERT_RT_ERR(indice_pairs.ndim() == 3, "error"); lanucher_fill(cudakers::fill_kernel<uint32_t>, indice_pair_mask.data_ptr<uint32_t>(), (1 << (kv / 2)), indices.dim(0));
TV_ASSERT_RT_ERR(indice_pairs.dim(0) == (is_train ? 2 : 1), "error"); TV_ASSERT_RT_ERR(indice_pair_mask.dim(0) == 1, "error");
TV_ASSERT_INVALID_ARG(indice_pair_mask.ndim() == 2, "error"); launcher_num_act_in(calc_subm_conv_indices_mask<table_t, {loc_type}>, loc_iter, hash,
// indice_pair_mask: [mask_split_count, num_act_in] indices.data_ptr<const int>(), indice_pairs.data_ptr<int>(),
if (indice_pair_mask.dim(0) == 2){{ indice_pair_mask.data_ptr<uint32_t>(), indices.dim(0), indice_pairs.dim(2), kv, is_train);
auto mask_0 = indice_pair_mask[0].slice_first_axis(0, num_act_in_real); }}
auto mask_1 = indice_pair_mask[1].slice_first_axis(0, num_act_in_real);
tv::cuda::Launch lanucher_fill(num_act_in_real, custream);
lanucher_fill(cudakers::fill_kernel<uint32_t>, mask_0.data_ptr<uint32_t>(), (1 << (kv / 2)), indices.dim(0));
mask_1.zero_(ctx);
auto kernel = &calc_subm_conv_indices_split_mask<table_t>;
launcher_num_act_in(kernel, loc_iter, hash,
indices.data_ptr<const int>(), indice_pairs.data_ptr<int>(),
mask_0.data_ptr<uint32_t>(), mask_1.data_ptr<uint32_t>(),
indices.dim(0), indice_pairs.dim(2), kv, is_train);
}}else{{ }}else{{
// indice_pair_mask: [1, num_act_in] TV_ASSERT_RT_ERR(indice_pairs.ndim() == 3, "error");
tv::cuda::Launch lanucher_fill(num_act_in_real, custream); TV_ASSERT_RT_ERR(indice_pairs.dim(0) == 2, "error");
lanucher_fill(cudakers::fill_kernel<uint32_t>, indice_pair_mask.data_ptr<uint32_t>(), (1 << (kv / 2)), indices.dim(0)); launcher_num_act_in(calc_subm_conv_indices<table_t, {loc_type}>, loc_iter, hash, indices.data_ptr<const int>(),
TV_ASSERT_RT_ERR(indice_pair_mask.dim(0) == 1, "error"); indice_pairs.data_ptr<int>(),
launcher_num_act_in(calc_subm_conv_indices_mask<table_t>, loc_iter, hash, indice_num_per_loc.data_ptr<int>(), indices.dim(0), indice_pairs.dim(2), kv);
indices.data_ptr<const int>(), indice_pairs.data_ptr<int>(),
indice_pair_mask.data_ptr<uint32_t>(), indices.dim(0), indice_pairs.dim(2), kv, is_train);
}} }}
}}else{{ }});
TV_ASSERT_RT_ERR(indice_pairs.ndim() == 3, "error");
TV_ASSERT_RT_ERR(indice_pairs.dim(0) == 2, "error");
launcher_num_act_in(calc_subm_conv_indices<table_t>, loc_iter, hash, indices.data_ptr<const int>(),
indice_pairs.data_ptr<int>(),
indice_num_per_loc.data_ptr<int>(), indices.dim(0), indice_pairs.dim(2), kv);
}}
}});
return indices.dim(0); return indices.dim(0);
""") """)
...@@ -1515,7 +1584,9 @@ class SparseConvIndicesCPU(pccm.ParameterizedClass): ...@@ -1515,7 +1584,9 @@ class SparseConvIndicesCPU(pccm.ParameterizedClass):
self.add_dependency(TensorView) self.add_dependency(TensorView)
self.add_include("unordered_map") self.add_include("unordered_map")
self.loc_iter = ConvOutLocIter(problem) self.loc_iter = ConvOutLocIter(problem)
self.loc_iter_64 = ConvOutLocIter(problem, True)
self.add_param_class("spinds", self.loc_iter, "ConvLocIter") self.add_param_class("spinds", self.loc_iter, "ConvLocIter")
self.add_param_class("spinds64", self.loc_iter_64, "ConvLocIter64")
self.add_param_class("spinds", problem, "ConvProblem") self.add_param_class("spinds", problem, "ConvProblem")
self.ndim = problem.ndim self.ndim = problem.ndim
...@@ -1532,7 +1603,6 @@ class SparseConvIndicesCPU(pccm.ParameterizedClass): ...@@ -1532,7 +1603,6 @@ class SparseConvIndicesCPU(pccm.ParameterizedClass):
code.arg("batch_size", "int") code.arg("batch_size", "int")
code.arg("input_dims", f"tv::array<int, {self.ndim}>") code.arg("input_dims", f"tv::array<int, {self.ndim}>")
code.arg("ksize, dilation", f"tv::array<int, {self.ndim}>") code.arg("ksize, dilation", f"tv::array<int, {self.ndim}>")
code.raw(f""" code.raw(f"""
tv::array<int, {self.ndim}> stride, padding; tv::array<int, {self.ndim}> stride, padding;
for (int i = 0; i < {self.ndim}; ++i){{ for (int i = 0; i < {self.ndim}; ++i){{
...@@ -1544,47 +1614,54 @@ class SparseConvIndicesCPU(pccm.ParameterizedClass): ...@@ -1544,47 +1614,54 @@ class SparseConvIndicesCPU(pccm.ParameterizedClass):
TV_ASSERT_RT_ERR(input_dims.op<tv::arrayops::prod>() < std::numeric_limits<{self.dtype_indices}>::max(), TV_ASSERT_RT_ERR(input_dims.op<tv::arrayops::prod>() < std::numeric_limits<{self.dtype_indices}>::max(),
"kernel volume must smaller than max value of {self.dtype_indices}"); "kernel volume must smaller than max value of {self.dtype_indices}");
ConvProblem problem(batch_size, 1, 1, input_dims, input_dims, ksize, padding, stride, dilation); ConvProblem problem(batch_size, 1, 1, input_dims, input_dims, ksize, padding, stride, dilation);
ConvLocIter loc_iter(problem); bool use_int32 = problem.check_npq_not_overflow();
int indices_pair_size = indice_pairs.dim(2); """)
int indices_pair_size_mul_RS = indices_pair_size * kv; for x in codeops.dispatch_ints(code, [0, 1], "int(use_int32)"):
auto indice_pairs_ptr = indice_pairs.data_ptr<{self.dtype_indices}>(); loc_type = "ConvLocIter" if x == 1 else "ConvLocIter64"
std::unordered_map<{self.dtype_indices}, {self.dtype_indices}> hash; code.raw(f"""
auto indices_ptr = indices.data_ptr<const {self.dtype_indices}>(); {loc_type} loc_iter(problem);
int indice_in_num = indices.dim(0); int indices_pair_size = indice_pairs.dim(2);
for (int i = 0; i < indice_in_num; ++i){{ int indices_pair_size_mul_RS = indices_pair_size * kv;
{self.dtype_indices} index = loc_iter.layout_npq(indices_ptr); auto indice_pairs_ptr = indice_pairs.data_ptr<{self.dtype_indices}>();
hash.insert({{index, i}}); std::unordered_map<{self.dtype_indices}, {self.dtype_indices}> hash;
indices_ptr += {self.ndim + 1}; auto indices_ptr = indices.data_ptr<const {self.dtype_indices}>();
}} int indice_in_num = indices.dim(0);
for (int filter_offset = 0; filter_offset < (kv / 2 + 1); ++filter_offset){{ for (int i = 0; i < indice_in_num; ++i){{
int filter_offset_mul_indices_pair_size = filter_offset * indices_pair_size; {self.dtype_indices} index = loc_iter.layout_npq(indices_ptr);
int filter_offset_mul_indices_pair_size_1 = (kv - 1 - filter_offset) * indices_pair_size; hash.insert({{index, i}});
if (filter_offset == kv / 2){{ indices_ptr += {self.ndim + 1};
for (int i = 0; i < indice_in_num; ++i){{ }}
indice_pairs_ptr[filter_offset_mul_indices_pair_size + i] = i; for (int filter_offset = 0; filter_offset < (kv / 2 + 1); ++filter_offset){{
indice_pairs_ptr[indices_pair_size_mul_RS + filter_offset_mul_indices_pair_size + i] = i; int filter_offset_mul_indices_pair_size = filter_offset * indices_pair_size;
}} int filter_offset_mul_indices_pair_size_1 = (kv - 1 - filter_offset) * indices_pair_size;
}}else{{ if (filter_offset == kv / 2){{
indices_ptr = indices.data_ptr<const {self.dtype_indices}>(); for (int i = 0; i < indice_in_num; ++i){{
auto indice_num_per_loc_ptr = indice_num_per_loc.data_ptr<{self.dtype_indices}>() + filter_offset; indice_pairs_ptr[filter_offset_mul_indices_pair_size + i] = i;
for (int i = 0; i < indice_in_num; ++i){{ indice_pairs_ptr[indices_pair_size_mul_RS + filter_offset_mul_indices_pair_size + i] = i;
tv::array<int, {self.ndim + 1}> npq_offset; }}
if (loc_iter.query_npq_no_stride(indices_ptr, npq_offset)){{ }}else{{
auto index = loc_iter.layout_npq(npq_offset); indices_ptr = indices.data_ptr<const {self.dtype_indices}>();
auto iter = hash.find(index); auto indice_num_per_loc_ptr = indice_num_per_loc.data_ptr<{self.dtype_indices}>() + filter_offset;
if (iter != hash.end()){{ for (int i = 0; i < indice_in_num; ++i){{
auto old_num = indice_num_per_loc_ptr[0]++; tv::array<int, {self.ndim + 1}> npq_offset;
indice_pairs_ptr[filter_offset_mul_indices_pair_size + old_num] = i; if (loc_iter.query_npq_no_stride(indices_ptr, npq_offset)){{
indice_pairs_ptr[indices_pair_size_mul_RS + filter_offset_mul_indices_pair_size + old_num] = iter->second; auto index = loc_iter.layout_npq(npq_offset);
indice_pairs_ptr[filter_offset_mul_indices_pair_size_1 + old_num] = iter->second; auto iter = hash.find(index);
indice_pairs_ptr[indices_pair_size_mul_RS + filter_offset_mul_indices_pair_size_1 + old_num] = i; if (iter != hash.end()){{
auto old_num = indice_num_per_loc_ptr[0]++;
indice_pairs_ptr[filter_offset_mul_indices_pair_size + old_num] = i;
indice_pairs_ptr[indices_pair_size_mul_RS + filter_offset_mul_indices_pair_size + old_num] = iter->second;
indice_pairs_ptr[filter_offset_mul_indices_pair_size_1 + old_num] = iter->second;
indice_pairs_ptr[indices_pair_size_mul_RS + filter_offset_mul_indices_pair_size_1 + old_num] = i;
}}
}} }}
indices_ptr += {self.ndim + 1};
}} }}
indices_ptr += {self.ndim + 1};
}} }}
++loc_iter;
}} }}
++loc_iter; """)
}} code.raw(f"""
return indices.dim(0); return indices.dim(0);
""") """)
return code.ret("int") return code.ret("int")
...@@ -1602,51 +1679,59 @@ class SparseConvIndicesCPU(pccm.ParameterizedClass): ...@@ -1602,51 +1679,59 @@ class SparseConvIndicesCPU(pccm.ParameterizedClass):
code.raw(f""" code.raw(f"""
int kv = ksize.op<tv::arrayops::prod>(); int kv = ksize.op<tv::arrayops::prod>();
ConvProblem problem(batch_size, 1, 1, input_dims, output_dims, ksize, padding, stride, dilation); ConvProblem problem(batch_size, 1, 1, input_dims, output_dims, ksize, padding, stride, dilation);
ConvLocIter loc_iter(problem); bool use_int32 = problem.check_npq_not_overflow();
int indices_pair_size = indice_pairs.dim(2);
int indices_pair_size_mul_RS = indices_pair_size * kv;
auto indice_pairs_ptr = indice_pairs.data_ptr<{self.dtype_indices}>();
std::unordered_map<{self.dtype_indices}, {self.dtype_indices}> hash;
auto indices_ptr = indices.data_ptr<const {self.dtype_indices}>();
auto out_inds_ptr = out_inds.data_ptr<{self.dtype_indices}>();
TV_ASSERT_RT_ERR(input_dims.op<tv::arrayops::prod>() < std::numeric_limits<{self.dtype_indices}>::max(),
"kernel volume must smaller than max value of {self.dtype_indices}");
int indice_in_num = indices.dim(0);
int num_act = 0; int num_act = 0;
{self.dtype_indices} hashval;
for (int filter_offset = 0; filter_offset < kv; ++filter_offset){{ """)
int filter_offset_mul_indices_pair_size = filter_offset * indices_pair_size; for x in codeops.dispatch_ints(code, [0, 1], "int(use_int32)"):
indices_ptr = indices.data_ptr<const {self.dtype_indices}>(); loc_type = "ConvLocIter" if x == 1 else "ConvLocIter64"
auto indice_num_per_loc_ptr = indice_num_per_loc.data_ptr<{self.dtype_indices}>() + filter_offset; code.raw(f"""
for (int i = 0; i < indice_in_num; ++i){{ {loc_type} loc_iter(problem);
tv::array<int, {self.ndim + 1}> npq_offset;
bool valid; int indices_pair_size = indice_pairs.dim(2);
if (transposed){{ int indices_pair_size_mul_RS = indices_pair_size * kv;
valid = loc_iter.query_nhw_out(indices_ptr, npq_offset); auto indice_pairs_ptr = indice_pairs.data_ptr<{self.dtype_indices}>();
}}else{{ std::unordered_map<{self.dtype_indices}, {self.dtype_indices}> hash;
valid = loc_iter.query_npq(indices_ptr, npq_offset); auto indices_ptr = indices.data_ptr<const {self.dtype_indices}>();
}} auto out_inds_ptr = out_inds.data_ptr<{self.dtype_indices}>();
if (valid){{ TV_ASSERT_RT_ERR(input_dims.op<tv::arrayops::prod>() < std::numeric_limits<{self.dtype_indices}>::max(),
auto index = loc_iter.layout_npq(npq_offset); "kernel volume must smaller than max value of {self.dtype_indices}");
auto iter = hash.find(index); int indice_in_num = indices.dim(0);
if (iter == hash.end()){{ {self.dtype_indices} hashval;
hashval = num_act++; for (int filter_offset = 0; filter_offset < kv; ++filter_offset){{
hash.insert({{index, hashval}}); int filter_offset_mul_indices_pair_size = filter_offset * indices_pair_size;
for (int k = 0; k < {self.ndim + 1}; ++k){{ indices_ptr = indices.data_ptr<const {self.dtype_indices}>();
out_inds_ptr[k] = npq_offset[k]; auto indice_num_per_loc_ptr = indice_num_per_loc.data_ptr<{self.dtype_indices}>() + filter_offset;
}} for (int i = 0; i < indice_in_num; ++i){{
out_inds_ptr += {self.ndim + 1}; tv::array<int, {self.ndim + 1}> npq_offset;
bool valid;
if (transposed){{
valid = loc_iter.query_nhw_out(indices_ptr, npq_offset);
}}else{{ }}else{{
hashval = iter->second; valid = loc_iter.query_npq(indices_ptr, npq_offset);
}}
if (valid){{
auto index = loc_iter.layout_npq(npq_offset);
auto iter = hash.find(index);
if (iter == hash.end()){{
hashval = num_act++;
hash.insert({{index, hashval}});
for (int k = 0; k < {self.ndim + 1}; ++k){{
out_inds_ptr[k] = npq_offset[k];
}}
out_inds_ptr += {self.ndim + 1};
}}else{{
hashval = iter->second;
}}
indice_pairs_ptr[filter_offset_mul_indices_pair_size + indice_num_per_loc_ptr[0]] = i;
indice_pairs_ptr[indices_pair_size_mul_RS + filter_offset_mul_indices_pair_size + indice_num_per_loc_ptr[0]++] = hashval;
}} }}
indice_pairs_ptr[filter_offset_mul_indices_pair_size + indice_num_per_loc_ptr[0]] = i; indices_ptr += {self.ndim + 1};
indice_pairs_ptr[indices_pair_size_mul_RS + filter_offset_mul_indices_pair_size + indice_num_per_loc_ptr[0]++] = hashval;
}} }}
indices_ptr += {self.ndim + 1}; ++loc_iter;
}} }}
++loc_iter; """)
}} code.raw(f"""
return num_act; return num_act;
""") """)
return code.ret("int") return code.ret("int")
...@@ -185,7 +185,7 @@ def get_indice_pairs(indices: torch.Tensor, ...@@ -185,7 +185,7 @@ def get_indice_pairs(indices: torch.Tensor,
) )
assert algo == ConvAlgo.Native, "TODO" assert algo == ConvAlgo.Native, "TODO"
# indices = indices.cpu() # indices = indices.cpu()
spatial_volume = functools.reduce(lambda x, y: x * y, out_shape, 1) spatial_volume = functools.reduce(lambda x, y: x * y, out_shape, 1) * batch_size
use_int64_hash_k = spatial_volume >= INT32_MAX or DEBUG_INT64_HASH_K use_int64_hash_k = spatial_volume >= INT32_MAX or DEBUG_INT64_HASH_K
indice_dtype = torch.int64 if use_int64_hash_k else indices.dtype indice_dtype = torch.int64 if use_int64_hash_k else indices.dtype
pair = torch.full((2, kv, indices.shape[0]), pair = torch.full((2, kv, indices.shape[0]),
...@@ -457,7 +457,7 @@ def get_indice_pairs_implicit_gemm( ...@@ -457,7 +457,7 @@ def get_indice_pairs_implicit_gemm(
raise ValueError( raise ValueError(
f"your out spatial shape {out_shape} reach zero!!! input shape: {spatial_shape}" f"your out spatial shape {out_shape} reach zero!!! input shape: {spatial_shape}"
) )
spatial_volume = functools.reduce(lambda x, y: x * y, spatial_shape, 1) spatial_volume = functools.reduce(lambda x, y: x * y, spatial_shape, 1) * batch_size
use_int64_hash_k = spatial_volume >= INT32_MAX or DEBUG_INT64_HASH_K use_int64_hash_k = spatial_volume >= INT32_MAX or DEBUG_INT64_HASH_K
indice_dtype = torch.int64 if use_int64_hash_k else indices.dtype indice_dtype = torch.int64 if use_int64_hash_k else indices.dtype
assert algo == ConvAlgo.MaskImplicitGemm or algo == ConvAlgo.MaskSplitImplicitGemm, "TODO" assert algo == ConvAlgo.MaskImplicitGemm or algo == ConvAlgo.MaskSplitImplicitGemm, "TODO"
......
...@@ -145,7 +145,8 @@ def generate_sparse_data(shape, ...@@ -145,7 +145,8 @@ def generate_sparse_data(shape,
integer=False, integer=False,
data_range=(-1, 1), data_range=(-1, 1),
with_dense=True, with_dense=True,
dtype=np.float32): dtype=np.float32,
shape_scale = 1):
dense_shape = shape dense_shape = shape
ndim = len(dense_shape) ndim = len(dense_shape)
# num_points = np.random.randint(10, 100, size=[batch_size, ndim]) # num_points = np.random.randint(10, 100, size=[batch_size, ndim])
...@@ -153,9 +154,9 @@ def generate_sparse_data(shape, ...@@ -153,9 +154,9 @@ def generate_sparse_data(shape,
# num_points = np.array([3, 2]) # num_points = np.array([3, 2])
batch_size = len(num_points) batch_size = len(num_points)
batch_indices = [] batch_indices = []
coors_total = np.stack(np.meshgrid(*[np.arange(0, s) for s in shape]), coors_total = np.stack(np.meshgrid(*[np.arange(0, s // shape_scale) for s in shape]),
axis=-1) axis=-1)
coors_total = coors_total.reshape(-1, ndim) coors_total = coors_total.reshape(-1, ndim) * shape_scale
for i in range(batch_size): for i in range(batch_size):
np.random.shuffle(coors_total) np.random.shuffle(coors_total)
inds_total = coors_total[:num_points[i]] inds_total = coors_total[:num_points[i]]
......
import spconv import spconv.pytorch as spconv
from spconv.core import ConvAlgo
import spconv.pytorch as spconv
from spconv.test_utils import TestCase, generate_sparse_data, params_grid
from spconv.core_cc.cumm.common import CompileInfo import torch
if __name__ == "__main__": import numpy as np
print(CompileInfo.arch_is_compatible_gemm((9, 0)), CompileInfo.arch_is_compiled_gemm((9, 0))) class SparseMaxPool2dTestTorch(torch.nn.Module):
print(CompileInfo.arch_is_compatible_gemm((8, 6)), CompileInfo.arch_is_compiled_gemm((8, 6))) def __init__(self, num_layers, ndim, shape, kernel_size, stride, padding,
\ No newline at end of file dilation, algo):
super().__init__()
self.algo = algo
layers = [
spconv.SparseMaxPool2d(kernel_size, stride, padding, dilation, algo=algo)
]
for i in range(1, num_layers):
layers.append(
spconv.SparseMaxPool2d(kernel_size, stride, padding, dilation, algo=algo))
self.net = spconv.SparseSequential(*layers, )
self.shape = shape
def forward(self, features, coors, batch_size):
coors = coors.int()
x = spconv.SparseConvTensor(features, coors, self.shape, batch_size)
return self.net(x) # .dense()
shapes = [[65536, 65536]]
batchsizes = [32]
in_channels = [32]
out_channels = [32]
ksizes = [2]
strides = [2]
paddings = [0]
dilations = [1]
algos = [
# ConvAlgo.Native,
ConvAlgo.MaskImplicitGemm,
# ConvAlgo.MaskSplitImplicitGemm
]
devices = ["cuda:0"]
for dev, shape, bs, IC, OC, k, s, p, d, al in params_grid(
devices, shapes, batchsizes, in_channels, out_channels, ksizes,
strides, paddings, dilations, algos):
device = torch.device(dev)
num_points = [1000] * bs
print(1)
sparse_dict = generate_sparse_data(shape,
num_points,
IC,
with_dense=False,
data_range=[0.1, 1],
shape_scale = 64)
print(2)
net = SparseMaxPool2dTestTorch(1, 2, shape, k, s, p, d, al).to(device)
features = np.ascontiguousarray(sparse_dict["features"]).astype(
np.float32)
indices = np.ascontiguousarray(
sparse_dict["indices"][:, [2, 0, 1]]).astype(np.int32)
print(indices.max(0))
indices_t = torch.from_numpy(indices).int().to(device)
features_t = torch.from_numpy(features).to(device)
features_t.requires_grad = True
out = net(features_t, indices_t, bs)
print(out.indices.min(0))
...@@ -916,8 +916,8 @@ def _test_native_conv_cuda(subm: bool): ...@@ -916,8 +916,8 @@ def _test_native_conv_cuda(subm: bool):
def test_all_algo_unit(): def test_all_algo_unit():
# for i in range(5): # for i in range(5):
# _test_impgemm_conv_cuda(True) _test_impgemm_conv_cuda(True)
# _test_impgemm_conv_cuda(False) _test_impgemm_conv_cuda(False)
_test_native_conv_cuda(True) _test_native_conv_cuda(True)
_test_native_conv_cuda(False) _test_native_conv_cuda(False)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment