Unverified Commit 4d54f765 authored by FindDefinition's avatar FindDefinition Committed by GitHub
Browse files

Merge pull request #363 from traveller59/feature/v2.1

v2.1
parents fa995a4f eae6a3bd
from typing import overload, Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
from pccm.stubs import EnumValue, EnumClassValue
from ...cumm.gemm.main import GemmAlgoDesp
from cumm.tensorview import Tensor
class ConvAlgoDesp(GemmAlgoDesp):
ndim: int
op_type: int
iter_algo: int
layout_i: int
layout_w: int
layout_o: int
interleave_i: int
interleave_w: int
interleave_o: int
mask_sparse: bool
increment_k_first: bool
def __init__(self, ndim: int, op_type: int) -> None:
"""
Args:
ndim:
op_type:
"""
...
def __repr__(self) -> str: ...
@staticmethod
def conv_iwo_012_to_abc(op_type: int) -> List[int]:
"""
Args:
op_type:
"""
...
@staticmethod
def gemm_abc_012_to_iwo(op_type: int) -> List[int]:
"""
Args:
op_type:
"""
...
@property
def dtype_input(self) -> int: ...
@property
def dtype_weight(self) -> int: ...
@property
def dtype_output(self) -> int: ...
def supported(self, m: int, n: int, k: int, C: int, K: int, mask_width: int) -> bool:
"""
Args:
m:
n:
k:
C:
K:
mask_width:
"""
...
def query_conv_workspace_size(self, m: int, n: int, k: int, split_k_slices: int, kv: int) -> int:
"""
Args:
m:
n:
k:
split_k_slices:
kv:
"""
...
def supported_ldx_conv(self, ldi: int, ldw: int, ldo: int) -> bool:
"""
Args:
ldi:
ldw:
ldo:
"""
...
class ConvParams:
conv_algo_desp: Any
input: Tensor
weight: Tensor
output: Tensor
split_k_slices: int
padding: List[int]
stride: List[int]
dilation: List[int]
alpha: float
beta: float
mask_width: int
mask_filter: int
reverse_mask: bool
verbose: bool
workspace: Tensor = Tensor()
mask: Tensor = Tensor()
mask_argsort: Tensor = Tensor()
indices: Tensor = Tensor()
mask_output: Tensor = Tensor()
stream: int
def __init__(self, ndim: int, op_type: int) -> None:
"""
Args:
ndim:
op_type:
"""
...
class ConvMainUnitTest:
@staticmethod
def extract_mnk(op_type: int, N: int, C: int, K: int, kernel_volume: int, in_prod: int, out_prod: int, mask_sparse: bool) -> List[int]:
"""
Args:
op_type:
N:
C:
K:
kernel_volume:
in_prod:
out_prod:
mask_sparse:
"""
...
@staticmethod
def implicit_gemm2(params: ConvParams) -> None:
"""
Args:
params:
"""
...
@staticmethod
def get_all_conv_algo_desp() -> List[ConvAlgoDesp]: ...
from typing import overload, Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
from pccm.stubs import EnumValue, EnumClassValue
from cumm.tensorview import Tensor
class ScatterAll:
def __init__(self) -> None: ...
@staticmethod
def get_all_scatter_params() -> List[Tuple[int, int, int, int]]: ...
def supported_scatter(self, tile_m: int, tile_k_bytes: int, bytes_per_access: int, num_threads: int, channel_size: int, dtype: int) -> bool:
"""
Args:
tile_m:
tile_k_bytes:
bytes_per_access:
num_threads:
channel_size:
dtype:
"""
...
@staticmethod
def stream_synchronize(stream: int = 0) -> None:
"""
Args:
stream:
"""
...
def scatter(self, output: Tensor, input: Tensor, indices: Tensor, tile_m: int, tile_k_bytes: int, bytes_per_access: int, num_threads: int, stream: int = 0) -> None:
"""
Args:
output:
input:
indices:
tile_m:
tile_k_bytes:
bytes_per_access:
num_threads:
stream:
"""
...
def scatter2(self, output: Tensor, input: Tensor, indices: Tensor, size: int, stream: int = 0) -> None:
"""
Args:
output:
input:
indices:
size:
stream:
"""
...
class GatherAll:
def __init__(self) -> None: ...
@staticmethod
def get_all_gather_params() -> List[Tuple[int, int, int, int]]: ...
@staticmethod
def supported(bytes_per_access: int, channel_size: int, dtype: int) -> bool:
"""
Args:
bytes_per_access:
channel_size:
dtype:
"""
...
@staticmethod
def stream_synchronize(stream: int = 0) -> None:
"""
Args:
stream:
"""
...
def gather(self, output: Tensor, input: Tensor, indices: Tensor, tile_m: int, tile_k_bytes: int, bytes_per_access: int, num_threads: int, stream: int = 0) -> None:
"""
Args:
output:
input:
indices:
tile_m:
tile_k_bytes:
bytes_per_access:
num_threads:
stream:
"""
...
def gather2(self, output: Tensor, input: Tensor, indices: Tensor, size: int, stream: int = 0) -> None:
"""
Args:
output:
input:
indices:
size:
stream:
"""
...
...@@ -18,6 +18,7 @@ class GemmAlgoDesp: ...@@ -18,6 +18,7 @@ class GemmAlgoDesp:
element_per_access_a: int element_per_access_a: int
element_per_access_b: int element_per_access_b: int
element_per_access_c: int element_per_access_c: int
access_per_vector: int
def __init__(self) -> None: ... def __init__(self) -> None: ...
def __repr__(self) -> str: ... def __repr__(self) -> str: ...
@property @property
......
...@@ -12,33 +12,77 @@ ...@@ -12,33 +12,77 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from cumm.common import TensorViewKernel, ThrustLib from cumm.common import TensorView, TensorViewCPU, TensorViewKernel, ThrustLib
from cumm.conv.bases import ConvOpType, NHWC from cumm.conv.bases import ConvOpType, NHWC
from cumm.conv.params import ConvProblem from cumm.conv.params import ConvProblem
from cumm import dtypes from cumm import dtypes
from cumm.constants import CUMM_CPU_ONLY_BUILD
import pccm import pccm
from ccimport import compat from ccimport import compat
from .pointops import Point2Voxel, Point2VoxelCPU from .pointops import Point2Voxel, Point2VoxelCPU
from .indices import SparseConvIndicesKernel, CudaCommonKernel from .indices import SparseConvIndicesKernel, CudaCommonKernel, SparseConvIndicesCPU
from .maxpool import IndiceMaxPool from .maxpool import IndiceMaxPool, IndiceMaxPoolCPU
from .gather import GatherCPU
class CustomThrustLib(pccm.Class):
def __init__(self):
super().__init__()
self.add_dependency(ThrustLib)
# https://github.com/NVIDIA/thrust/issues/1401#issuecomment-806403746
self.build_meta.add_cflags("nvcc", "-Xcompiler", "-fno-gnu-unique")
class ThrustCustomAllocatorV2(pccm.Class, pccm.pybind.PybindClassMixin):
def __init__(self):
super().__init__()
self.add_dependency(TensorView)
self.add_include("functional", "memory")
self.add_pybind_member("alloc_func", "std::function<std::uintptr_t(std::size_t)>", pyanno="Callable[[int], int]")
self.add_typedef("value_type", "char")
@pccm.member_function
def allocate(self):
code = pccm.FunctionCode()
code.arg("num_bytes", "std::ptrdiff_t")
code.ret("char*")
code.raw(f"""
if (alloc_func){{
char* result = reinterpret_cast<char*>(alloc_func(num_bytes));
return result;
}}
else{{
TV_THROW_RT_ERR("set alloc function first.");
}}
""")
return code
@pccm.member_function
def deallocate(self):
code = pccm.FunctionCode()
code.arg("ptr", "char *")
code.arg("num_bytes", "size_t")
return code
class SpconvOps(pccm.Class): class SpconvOps(pccm.Class):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.add_dependency(ThrustCustomAllocatorV2)
self.ndims = [1, 2, 3, 4] self.ndims = [1, 2, 3, 4]
for ndim in self.ndims: for ndim in self.ndims:
p2v = Point2Voxel(dtypes.float32, ndim) p2v = Point2Voxel(dtypes.float32, ndim)
p2v_cpu = Point2VoxelCPU(dtypes.float32, ndim) p2v_cpu = Point2VoxelCPU(dtypes.float32, ndim)
self.add_param_class(f"ops{ndim}d", p2v, f"Point2Voxel{ndim}D")
self.add_param_class(f"ops_cpu{ndim}d", p2v_cpu, f"Point2Voxel{ndim}DCPU") self.add_param_class(f"ops_cpu{ndim}d", p2v_cpu, f"Point2Voxel{ndim}DCPU")
problem = ConvProblem(ndim, ConvOpType.kForward, NHWC, NHWC, NHWC) problem = ConvProblem(ndim, ConvOpType.kForward, NHWC, NHWC, NHWC)
indices = SparseConvIndicesKernel(problem, dtypes.int32) indices = SparseConvIndicesKernel(problem, dtypes.int32)
indices_cpu = SparseConvIndicesCPU(problem, dtypes.int32)
self.add_param_class(f"ops_cpu{ndim}d", indices_cpu, f"SpconvIndicesCPU{ndim}D")
# self.add_param_class("ops", indices, "SpconvIndices") # self.add_param_class("ops", indices, "SpconvIndices")
cuda_funcs = [self.generate_subm_conv_inds, if not CUMM_CPU_ONLY_BUILD:
self.generate_conv_inds_stage1, self.generate_conv_inds_stage1_5, self.generate_conv_inds_stage2, self.sort_1d_by_key] self.add_param_class(f"ops{ndim}d", p2v, f"Point2Voxel{ndim}D")
self.add_impl_only_param_class(cuda_funcs, f"ops{ndim}d", indices, f"SpconvIndices{ndim}D") cuda_funcs = [self.generate_subm_conv_inds,
self.generate_conv_inds_stage1, self.generate_conv_inds_stage1_5, self.generate_conv_inds_stage2, self.sort_1d_by_key,
self.generate_conv_inds_mask_stage1, self.generate_conv_inds_mask_stage2]
self.add_impl_only_param_class(cuda_funcs, f"ops{ndim}d", indices, f"SpconvIndices{ndim}D")
@pccm.pybind.mark @pccm.pybind.mark
@pccm.cuda.static_function @pccm.cuda.static_function
...@@ -52,6 +96,14 @@ class SpconvOps(pccm.Class): ...@@ -52,6 +96,14 @@ class SpconvOps(pccm.Class):
code.arg("transposed", f"bool", "false") code.arg("transposed", f"bool", "false")
code.arg("stream_int", f"std::uintptr_t", "0", pyanno="int") code.arg("stream_int", f"std::uintptr_t", "0", pyanno="int")
if CUMM_CPU_ONLY_BUILD:
return code.make_invalid()
if CUMM_CPU_ONLY_BUILD:
code.raw(f"""
TV_THROW_RT_ERR("CPU ONLY build, don't support cuda algorithm.");
""")
return code
code.raw(f""" code.raw(f"""
int ndim = indices.dim(1) - 1; int ndim = indices.dim(1) - 1;
TV_ASSERT_RT_ERR(output_dims.size() == ndim && input_dims.size() == ndim && TV_ASSERT_RT_ERR(output_dims.size() == ndim && input_dims.size() == ndim &&
...@@ -90,6 +142,15 @@ class SpconvOps(pccm.Class): ...@@ -90,6 +142,15 @@ class SpconvOps(pccm.Class):
code.arg("ndim", "int") code.arg("ndim", "int")
code.arg("uniq_size", "int64_t") code.arg("uniq_size", "int64_t")
code.arg("stream_int", f"std::uintptr_t", "0", pyanno="int") code.arg("stream_int", f"std::uintptr_t", "0", pyanno="int")
if CUMM_CPU_ONLY_BUILD:
return code.make_invalid()
if CUMM_CPU_ONLY_BUILD:
code.raw(f"""
TV_THROW_RT_ERR("CPU ONLY build, don't support cuda algorithm.");
""")
return code.ret("int")
for ndim in self.ndims: for ndim in self.ndims:
code.raw(f""" code.raw(f"""
if (ndim == {ndim}){{ if (ndim == {ndim}){{
...@@ -111,6 +172,15 @@ class SpconvOps(pccm.Class): ...@@ -111,6 +172,15 @@ class SpconvOps(pccm.Class):
code.arg("ksize, stride, padding, dilation", f"std::vector<int>") code.arg("ksize, stride, padding, dilation", f"std::vector<int>")
code.arg("transposed", f"bool", "false") code.arg("transposed", f"bool", "false")
code.arg("stream_int", f"std::uintptr_t", "0", pyanno="int") code.arg("stream_int", f"std::uintptr_t", "0", pyanno="int")
if CUMM_CPU_ONLY_BUILD:
return code.make_invalid()
if CUMM_CPU_ONLY_BUILD:
code.raw(f"""
TV_THROW_RT_ERR("CPU ONLY build, don't support cuda algorithm.");
""")
return code.ret("int")
code.raw(f""" code.raw(f"""
int ndim = indices.dim(1) - 1; int ndim = indices.dim(1) - 1;
TV_ASSERT_RT_ERR(output_dims.size() == ndim && input_dims.size() == ndim && TV_ASSERT_RT_ERR(output_dims.size() == ndim && input_dims.size() == ndim &&
...@@ -141,10 +211,113 @@ class SpconvOps(pccm.Class): ...@@ -141,10 +211,113 @@ class SpconvOps(pccm.Class):
return code.ret("int") return code.ret("int")
@pccm.pybind.mark
@pccm.cuda.static_function
def generate_conv_inds_mask_stage1(self):
code = pccm.FunctionCode()
if CUMM_CPU_ONLY_BUILD:
return code.make_invalid()
code.arg("indices", "tv::Tensor")
code.arg("indice_pairs_bwd, indice_pairs_uniq, indice_num_per_loc", "tv::Tensor")
code.arg("batch_size", "int")
code.arg("output_dims, input_dims", f"std::vector<int>")
code.arg("ksize, stride, padding, dilation", f"std::vector<int>")
code.arg("transposed", f"bool", "false")
code.arg("stream_int", f"std::uintptr_t", "0", pyanno="int")
if CUMM_CPU_ONLY_BUILD:
code.raw(f"""
TV_THROW_RT_ERR("CPU ONLY build, don't support cuda algorithm.");
""")
return code
code.raw(f"""
int ndim = indices.dim(1) - 1;
TV_ASSERT_RT_ERR(output_dims.size() == ndim && input_dims.size() == ndim &&
ksize.size() == ndim && stride.size() == ndim && dilation.size() == ndim &&
padding.size() == ndim, "your params size not equal to ndim", ndim);
""")
for ndim in self.ndims:
code.raw(f"""
if (ndim == {ndim}){{
tv::array<int, {ndim}> output_dims_, input_dims_;
tv::array<int, {ndim}> ksize_, stride_, padding_, dilation_;
for (int i = 0; i < {ndim}; ++i){{
output_dims_[i] = output_dims[i];
input_dims_[i] = input_dims[i];
ksize_[i] = ksize[i];
stride_[i] = stride[i];
padding_[i] = padding[i];
dilation_[i] = dilation[i];
}}
return SpconvIndices{ndim}D::generate_conv_inds_mask_stage1(indices,
indice_pairs_bwd, indice_pairs_uniq, indice_num_per_loc,
batch_size, output_dims_, input_dims_,
ksize_, stride_, padding_, dilation_, transposed, stream_int);
}}
""")
code.raw(f"""TV_THROW_RT_ERR("unknown ndim", ndim);""")
return code# .ret("int")
@pccm.pybind.mark
@pccm.cuda.static_function
def generate_conv_inds_mask_stage2(self):
code = pccm.FunctionCode()
if CUMM_CPU_ONLY_BUILD:
return code.make_invalid()
code.arg("indices, hashdata", "tv::Tensor")
code.arg("indice_pairs_fwd, indice_pairs_bwd, indice_pairs_uniq, out_inds", "tv::Tensor")
code.arg("mask_fwd, mask_bwd", "tv::Tensor")
code.arg("num_out_act", "int")
code.arg("batch_size", "int")
code.arg("output_dims, input_dims", f"std::vector<int>")
code.arg("ksize, stride, padding, dilation", f"std::vector<int>")
code.arg("transposed", f"bool", "false")
code.arg("stream_int", f"std::uintptr_t", "0", pyanno="int")
if CUMM_CPU_ONLY_BUILD:
code.raw(f"""
TV_THROW_RT_ERR("CPU ONLY build, don't support cuda algorithm.");
""")
return code.ret("int")
code.raw(f"""
int ndim = indices.dim(1) - 1;
TV_ASSERT_RT_ERR(output_dims.size() == ndim && input_dims.size() == ndim &&
ksize.size() == ndim && stride.size() == ndim && dilation.size() == ndim &&
padding.size() == ndim, "your params size not equal to ndim", ndim);
""")
for ndim in self.ndims:
code.raw(f"""
if (ndim == {ndim}){{
tv::array<int, {ndim}> output_dims_, input_dims_;
tv::array<int, {ndim}> ksize_, stride_, padding_, dilation_;
for (int i = 0; i < {ndim}; ++i){{
output_dims_[i] = output_dims[i];
input_dims_[i] = input_dims[i];
ksize_[i] = ksize[i];
stride_[i] = stride[i];
padding_[i] = padding[i];
dilation_[i] = dilation[i];
}}
return SpconvIndices{ndim}D::generate_conv_inds_stage2_mask(indices, hashdata,
indice_pairs_fwd, indice_pairs_bwd, indice_pairs_uniq, out_inds, mask_fwd, mask_bwd,
num_out_act, batch_size, output_dims_, input_dims_,
ksize_, stride_, padding_, dilation_, transposed, stream_int);
}}
""")
code.raw(f"""TV_THROW_RT_ERR("unknown ndim", ndim);""")
return code.ret("int")
@pccm.pybind.mark @pccm.pybind.mark
@pccm.cuda.static_function @pccm.cuda.static_function
def generate_subm_conv_inds(self): def generate_subm_conv_inds(self):
code = pccm.FunctionCode() code = pccm.FunctionCode()
if CUMM_CPU_ONLY_BUILD:
return code.make_invalid()
code.arg("indices, hashdata", "tv::Tensor") code.arg("indices, hashdata", "tv::Tensor")
code.arg("indice_pairs, out_inds, indice_num_per_loc", "tv::Tensor") code.arg("indice_pairs, out_inds, indice_num_per_loc", "tv::Tensor")
code.arg("batch_size", "int") code.arg("batch_size", "int")
...@@ -153,6 +326,12 @@ class SpconvOps(pccm.Class): ...@@ -153,6 +326,12 @@ class SpconvOps(pccm.Class):
code.arg("indice_pair_mask", "tv::Tensor", "tv::Tensor()", "cumm.tensorview.Tensor = Tensor()") code.arg("indice_pair_mask", "tv::Tensor", "tv::Tensor()", "cumm.tensorview.Tensor = Tensor()")
code.arg("backward", "bool", "false") code.arg("backward", "bool", "false")
code.arg("stream_int", f"std::uintptr_t", "0", pyanno="int = 0") code.arg("stream_int", f"std::uintptr_t", "0", pyanno="int = 0")
if CUMM_CPU_ONLY_BUILD:
code.raw(f"""
TV_THROW_RT_ERR("CPU ONLY build, don't support cuda algorithm.");
""")
return code.ret("int")
code.raw(f""" code.raw(f"""
int ndim = indices.dim(1) - 1; int ndim = indices.dim(1) - 1;
TV_ASSERT_RT_ERR(input_dims.size() == ndim && TV_ASSERT_RT_ERR(input_dims.size() == ndim &&
...@@ -178,15 +357,97 @@ class SpconvOps(pccm.Class): ...@@ -178,15 +357,97 @@ class SpconvOps(pccm.Class):
code.raw(f"""TV_THROW_RT_ERR("unknown ndim", ndim);""") code.raw(f"""TV_THROW_RT_ERR("unknown ndim", ndim);""")
return code.ret("int") return code.ret("int")
@pccm.pybind.mark
@pccm.static_function
def generate_conv_inds_cpu(self):
code = pccm.FunctionCode()
code.arg("indices", "tv::Tensor")
code.arg("indice_pairs, out_inds, indice_num_per_loc", "tv::Tensor")
code.arg("batch_size", "int")
code.arg("output_dims, input_dims", f"std::vector<int>")
code.arg("ksize, stride, padding, dilation", f"std::vector<int>")
code.arg("transposed", f"bool", "false")
code.raw(f"""
int ndim = indices.dim(1) - 1;
TV_ASSERT_RT_ERR(output_dims.size() == ndim && input_dims.size() == ndim &&
ksize.size() == ndim && stride.size() == ndim && dilation.size() == ndim &&
padding.size() == ndim, "your params size not equal to ndim", ndim);
""")
for ndim in self.ndims:
code.raw(f"""
if (ndim == {ndim}){{
tv::array<int, {ndim}> output_dims_, input_dims_;
tv::array<int, {ndim}> ksize_, stride_, padding_, dilation_;
for (int i = 0; i < {ndim}; ++i){{
output_dims_[i] = output_dims[i];
input_dims_[i] = input_dims[i];
ksize_[i] = ksize[i];
stride_[i] = stride[i];
padding_[i] = padding[i];
dilation_[i] = dilation[i];
}}
return SpconvIndicesCPU{ndim}D::generate_conv_inds(indices,
indice_pairs, out_inds, indice_num_per_loc,
batch_size, output_dims_, input_dims_,
ksize_, stride_, padding_, dilation_, transposed);
}}
""")
code.raw(f"""TV_THROW_RT_ERR("unknown ndim", ndim);""")
return code.ret("int")
@pccm.pybind.mark
@pccm.static_function
def generate_subm_conv_inds_cpu(self):
code = pccm.FunctionCode()
code.arg("indices", "tv::Tensor")
code.arg("indice_pairs, out_inds, indice_num_per_loc", "tv::Tensor")
code.arg("batch_size", "int")
code.arg("input_dims", f"std::vector<int>")
code.arg("ksize, dilation", f"std::vector<int>")
code.raw(f"""
int ndim = indices.dim(1) - 1;
TV_ASSERT_RT_ERR(input_dims.size() == ndim &&
ksize.size() == ndim && dilation.size() == ndim, "your params size not equal to ndim", ndim);
""")
for ndim in self.ndims:
code.raw(f"""
if (ndim == {ndim}){{
tv::array<int, {ndim}> input_dims_;
tv::array<int, {ndim}> ksize_, dilation_;
for (int i = 0; i < {ndim}; ++i){{
input_dims_[i] = input_dims[i];
ksize_[i] = ksize[i];
dilation_[i] = dilation[i];
}}
return SpconvIndicesCPU{ndim}D::generate_subm_conv_inds(indices,
indice_pairs, out_inds, indice_num_per_loc,
batch_size, input_dims_,
ksize_, dilation_);
}}
""")
code.raw(f"""TV_THROW_RT_ERR("unknown ndim", ndim);""")
return code.ret("int")
@pccm.pybind.mark @pccm.pybind.mark
@pccm.cuda.static_function @pccm.cuda.static_function
def maxpool_forward(self): def maxpool_forward(self):
code = pccm.FunctionCode() code = pccm.FunctionCode()
if CUMM_CPU_ONLY_BUILD:
return code.make_invalid()
code.arg("out", "tv::Tensor") code.arg("out", "tv::Tensor")
code.arg("inp", "tv::Tensor") code.arg("inp", "tv::Tensor")
code.arg("out_inds", "tv::Tensor") code.arg("out_inds", "tv::Tensor")
code.arg("in_inds", "tv::Tensor") code.arg("in_inds", "tv::Tensor")
code.arg("stream", "std::uintptr_t", "0", pyanno="int") code.arg("stream", "std::uintptr_t", "0", pyanno="int")
if CUMM_CPU_ONLY_BUILD:
code.raw(f"""
TV_THROW_RT_ERR("CPU ONLY build, don't support cuda algorithm.");
""")
return code
code.add_dependency(IndiceMaxPool) code.add_dependency(IndiceMaxPool)
code.raw(f""" code.raw(f"""
return IndiceMaxPool::forward(out, inp, out_inds, in_inds, stream); return IndiceMaxPool::forward(out, inp, out_inds, in_inds, stream);
...@@ -197,6 +458,8 @@ class SpconvOps(pccm.Class): ...@@ -197,6 +458,8 @@ class SpconvOps(pccm.Class):
@pccm.cuda.static_function @pccm.cuda.static_function
def maxpool_backward(self): def maxpool_backward(self):
code = pccm.FunctionCode() code = pccm.FunctionCode()
if CUMM_CPU_ONLY_BUILD:
return code.make_invalid()
code.arg("out", "tv::Tensor") code.arg("out", "tv::Tensor")
code.arg("inp", "tv::Tensor") code.arg("inp", "tv::Tensor")
code.arg("dout", "tv::Tensor") code.arg("dout", "tv::Tensor")
...@@ -204,30 +467,529 @@ class SpconvOps(pccm.Class): ...@@ -204,30 +467,529 @@ class SpconvOps(pccm.Class):
code.arg("out_inds", "tv::Tensor") code.arg("out_inds", "tv::Tensor")
code.arg("in_inds", "tv::Tensor") code.arg("in_inds", "tv::Tensor")
code.arg("stream", "std::uintptr_t", "0", pyanno="int") code.arg("stream", "std::uintptr_t", "0", pyanno="int")
if CUMM_CPU_ONLY_BUILD:
code.raw(f"""
TV_THROW_RT_ERR("CPU ONLY build, don't support cuda algorithm.");
""")
return code
code.add_dependency(IndiceMaxPool) code.add_dependency(IndiceMaxPool)
code.raw(f""" code.raw(f"""
return IndiceMaxPool::backward(out, inp, dout, dinp, out_inds, in_inds, stream); return IndiceMaxPool::backward(out, inp, dout, dinp, out_inds, in_inds, stream);
""") """)
return code return code
@pccm.pybind.mark
@pccm.cuda.static_function
def maxpool_implicit_gemm_forward(self):
code = pccm.FunctionCode()
if CUMM_CPU_ONLY_BUILD:
return code.make_invalid()
code.arg("out", "tv::Tensor")
code.arg("inp", "tv::Tensor")
code.arg("inds", "tv::Tensor")
code.arg("stream", "std::uintptr_t", "0", pyanno="int")
if CUMM_CPU_ONLY_BUILD:
code.raw(f"""
TV_THROW_RT_ERR("CPU ONLY build, don't support cuda algorithm.");
""")
return code
code.add_dependency(IndiceMaxPool)
code.raw(f"""
return IndiceMaxPool::forward_implicit_gemm(out, inp, inds, stream);
""")
return code
@pccm.pybind.mark
@pccm.cuda.static_function
def maxpool_implicit_gemm_backward(self):
code = pccm.FunctionCode()
if CUMM_CPU_ONLY_BUILD:
return code.make_invalid()
code.arg("out", "tv::Tensor")
code.arg("inp", "tv::Tensor")
code.arg("dout", "tv::Tensor")
code.arg("dinp", "tv::Tensor")
code.arg("inds", "tv::Tensor")
code.arg("stream", "std::uintptr_t", "0", pyanno="int")
if CUMM_CPU_ONLY_BUILD:
code.raw(f"""
TV_THROW_RT_ERR("CPU ONLY build, don't support cuda algorithm.");
""")
return code
code.add_dependency(IndiceMaxPool)
code.raw(f"""
return IndiceMaxPool::backward_implicit_gemm(out, inp, dout, dinp, inds, stream);
""")
return code
@pccm.pybind.mark
@pccm.static_function
def maxpool_forward_cpu(self):
code = pccm.FunctionCode()
code.arg("out", "tv::Tensor")
code.arg("inp", "tv::Tensor")
code.arg("out_inds", "tv::Tensor")
code.arg("in_inds", "tv::Tensor")
code.add_dependency(IndiceMaxPoolCPU)
code.raw(f"""
return IndiceMaxPoolCPU::forward(out, inp, out_inds, in_inds);
""")
return code
@pccm.pybind.mark
@pccm.static_function
def maxpool_backward_cpu(self):
code = pccm.FunctionCode()
code.arg("out", "tv::Tensor")
code.arg("inp", "tv::Tensor")
code.arg("dout", "tv::Tensor")
code.arg("dinp", "tv::Tensor")
code.arg("out_inds", "tv::Tensor")
code.arg("in_inds", "tv::Tensor")
code.add_dependency(IndiceMaxPoolCPU)
code.raw(f"""
return IndiceMaxPoolCPU::backward(out, inp, dout, dinp, out_inds, in_inds);
""")
return code
@pccm.pybind.mark
@pccm.static_function
def gather_cpu(self):
code = pccm.FunctionCode()
code.arg("out", "tv::Tensor")
code.arg("inp", "tv::Tensor")
code.arg("inds", "tv::Tensor")
code.add_dependency(GatherCPU)
code.raw(f"""
return GatherCPU::gather(out, inp, inds);
""")
return code
@pccm.pybind.mark
@pccm.static_function
def scatter_add_cpu(self):
code = pccm.FunctionCode()
code.arg("out", "tv::Tensor")
code.arg("inp", "tv::Tensor")
code.arg("inds", "tv::Tensor")
code.add_dependency(GatherCPU)
code.raw(f"""
return GatherCPU::scatter_add(out, inp, inds);
""")
return code
@pccm.pybind.mark @pccm.pybind.mark
@pccm.cuda.static_function @pccm.cuda.static_function
def sort_1d_by_key(self): def sort_1d_by_key(self):
code = pccm.FunctionCode() code = pccm.FunctionCode()
if CUMM_CPU_ONLY_BUILD:
return code.make_invalid()
code.arg("data", "tv::Tensor")
code.arg("indices", "tv::Tensor", "tv::Tensor()", pyanno="cumm.tensorview.Tensor = Tensor()")
code.arg("stream", "std::uintptr_t", "0", pyanno="int")
code.code_after_include = f"""
template <typename T> struct SmallOrEqualTo {{
TV_HOST_DEVICE_INLINE T operator()(const T &x, const T &y) const {{
return x < y;
}}
}};
template <typename T> __global__ void mask_input(T* inp, T mask, int size){{
for (int i : tv::KernelLoopX<int>(size)){{
inp[i] &= mask;
}}
}}
"""
if CUMM_CPU_ONLY_BUILD:
code.raw(f"""
TV_THROW_RT_ERR("CPU ONLY build, don't support cuda algorithm.");
""")
return code.ret("tv::Tensor")
code.add_dependency(ThrustLib, TensorViewKernel) code.add_dependency(ThrustLib, TensorViewKernel)
code.add_param_class("cudakers", CudaCommonKernel()) code.add_param_class("cudakers", CudaCommonKernel())
code.raw(f"""
cudaStream_t stream_cu = reinterpret_cast<cudaStream_t>(stream);
if (indices.empty()){{
indices = tv::empty({{data.dim(0)}}, tv::int32, 0);
}}
tv::cuda::Launch launcher(data.dim(0), stream_cu);
launcher(cudakers::arange_kernel<int32_t>, indices.data_ptr<int32_t>(), indices.dim(0));
auto timer = tv::CUDATimer();
tv::dispatch<int32_t, uint32_t, int64_t, uint64_t>(data.dtype(), [&](auto I){{
using T = TV_DECLTYPE(I);
thrust::device_ptr<T> ptr_tr(data.data_ptr<T>());
thrust::device_ptr<int32_t> ptr_k(indices.data_ptr<int32_t>());
auto thrust_ctx = thrust::cuda::par.on(stream_cu);
thrust::stable_sort_by_key(thrust_ctx, ptr_tr, ptr_tr + data.dim(0), ptr_k, SmallOrEqualTo<uint32_t>());
}});
tv::ssprint("SORT BY KEY TIME", data.dim(0), timer.report() / 1000.0);
return indices;
""")
return code.ret("tv::Tensor")
@pccm.pybind.mark
@pccm.cuda.static_function
def sort_1d_by_key_allocator(self):
code = pccm.FunctionCode()
if CUMM_CPU_ONLY_BUILD:
return code.make_invalid()
code.arg("data", "tv::Tensor") code.arg("data", "tv::Tensor")
code.arg("alloc_func", "std::function<std::uintptr_t(std::size_t)>")
code.arg("indices", "tv::Tensor", "tv::Tensor()", pyanno="cumm.tensorview.Tensor = Tensor()")
code.arg("stream", "std::uintptr_t", "0", pyanno="int")
code.code_after_include = f"""
template <typename T> struct SmallOrEqualTo {{
TV_HOST_DEVICE_INLINE T operator()(const T &x, const T &y) const {{
return x < y;
}}
}};
template <typename T> __global__ void mask_input(T* inp, T mask, int size){{
for (int i : tv::KernelLoopX<int>(size)){{
inp[i] &= mask;
}}
}}
"""
if CUMM_CPU_ONLY_BUILD:
code.raw(f"""
TV_THROW_RT_ERR("CPU ONLY build, don't support cuda algorithm.");
""")
return code.ret("tv::Tensor")
code.add_dependency(ThrustLib, TensorViewKernel)
code.add_param_class("cudakers", CudaCommonKernel())
code.raw(f""" code.raw(f"""
tv::Tensor indices({{data.dim(0)}}, tv::int32, 0); ThrustCustomAllocatorV2 allocator{{alloc_func}};
tv::cuda::Launch launcher(data.dim(0)); cudaStream_t stream_cu = reinterpret_cast<cudaStream_t>(stream);
if (indices.empty()){{
indices = tv::empty({{data.dim(0)}}, tv::int32, 0);
}}
tv::cuda::Launch launcher(data.dim(0), stream_cu);
launcher(cudakers::arange_kernel<int32_t>, indices.data_ptr<int32_t>(), indices.dim(0)); launcher(cudakers::arange_kernel<int32_t>, indices.data_ptr<int32_t>(), indices.dim(0));
// auto timer = tv::CUDATimer();
tv::dispatch<int32_t, uint32_t, int64_t, uint64_t>(data.dtype(), [&](auto I){{ tv::dispatch<int32_t, uint32_t, int64_t, uint64_t>(data.dtype(), [&](auto I){{
using T = TV_DECLTYPE(I); using T = TV_DECLTYPE(I);
thrust::device_ptr<T> ptr_tr(data.data_ptr<T>()); thrust::device_ptr<T> ptr_tr(data.data_ptr<T>());
thrust::device_ptr<int32_t> ptr_k(indices.data_ptr<int32_t>()); thrust::device_ptr<int32_t> ptr_k(indices.data_ptr<int32_t>());
auto thrust_ctx = thrust::cuda::par.on(0); auto thrust_ctx = thrust::cuda::par.on(stream_cu);
thrust::sort_by_key(thrust_ctx, ptr_tr, ptr_tr + data.dim(0), ptr_k); auto ctx2 = thrust::cuda::par(allocator).on(stream_cu);
thrust::sort_by_key(ctx2, ptr_tr, ptr_tr + data.dim(0), ptr_k);
}}); }});
// tv::ssprint("SORT BY KEY TIME", data.dim(0), timer.report() / 1000.0);
return indices; return indices;
""") """)
return code.ret("tv::Tensor") return code.ret("tv::Tensor")
@pccm.pybind.mark
@pccm.cuda.static_function
def sort_1d_by_key_split(self):
code = pccm.FunctionCode()
if CUMM_CPU_ONLY_BUILD:
return code.make_invalid()
code.arg("data", "tv::Tensor")
code.arg("mask", "tv::Tensor")
code.arg("indices", "tv::Tensor", "tv::Tensor()", pyanno="cumm.tensorview.Tensor = Tensor()")
code.arg("stream", "std::uintptr_t", "0", pyanno="int")
code.arg("mask_output", "bool", "false")
code.code_after_include = f"""
template <typename T> struct MaskedElementComp {{
T mask_;
TV_HOST_DEVICE_INLINE T operator()(const T &x, const T &y) const {{
return (x & mask_) < (y & mask_);
}}
}};
template <typename T> __global__ void mask_input(T* inp, T mask, int size){{
for (int i : tv::KernelLoopX<int>(size)){{
inp[i] &= mask;
}}
}}
"""
if CUMM_CPU_ONLY_BUILD:
code.raw(f"""
TV_THROW_RT_ERR("CPU ONLY build, don't support cuda algorithm.");
""")
return code.ret("tv::Tensor")
code.add_dependency(CustomThrustLib, TensorViewKernel)
code.add_param_class("cudakers", CudaCommonKernel())
code.raw(f"""
cudaStream_t stream_cu = reinterpret_cast<cudaStream_t>(stream);
// auto timer = tv::CudaContextTimer<>();
if (indices.empty()){{
indices = tv::empty({{data.dim(0)}}, tv::int32, 0);
}}
tv::cuda::Launch launcher(data.dim(0), stream_cu);
launcher(cudakers::arange_kernel<int32_t>, indices.data_ptr<int32_t>(), indices.dim(0));
tv::dispatch<int32_t, uint32_t, int64_t, uint64_t>(data.dtype(), [&](auto I){{
using T = TV_DECLTYPE(I);
auto masks_ptr = mask.data_ptr<T>();
MaskedElementComp<T> op_comp{{masks_ptr[0]}};
thrust::device_ptr<T> ptr_tr(data.data_ptr<T>());
thrust::device_ptr<int32_t> ptr_k(indices.data_ptr<int32_t>());
auto thrust_ctx = thrust::cuda::par.on(stream_cu);
thrust::sort_by_key(thrust_ctx, ptr_tr, ptr_tr + data.dim(0), ptr_k, op_comp);
if (mask_output){{
launcher(mask_input<T>, data.data_ptr<T>(), masks_ptr[0], data.dim(0));
}}
}});
// tv::ssprint("SORT BY KEY MASKED TIME", timer.report() / 1000.0);
return indices;
""")
return code.ret("tv::Tensor")
@pccm.pybind.mark
@pccm.cuda.static_function
def sort_1d_by_key_split_allocator(self):
code = pccm.FunctionCode()
if CUMM_CPU_ONLY_BUILD:
return code.make_invalid()
code.arg("data", "tv::Tensor")
code.arg("alloc_func", "std::function<std::uintptr_t(std::size_t)>")
code.arg("mask", "tv::Tensor")
code.arg("indices", "tv::Tensor", "tv::Tensor()", pyanno="cumm.tensorview.Tensor = Tensor()")
code.arg("stream", "std::uintptr_t", "0", pyanno="int")
code.arg("mask_output", "bool", "false")
code.code_after_include = f"""
template <typename T> struct MaskedElementComp {{
T mask_;
TV_HOST_DEVICE_INLINE T operator()(const T &x, const T &y) const {{
return (x & mask_) < (y & mask_);
}}
}};
template <typename T> __global__ void mask_input(T* inp, T mask, int size){{
for (int i : tv::KernelLoopX<int>(size)){{
inp[i] &= mask;
}}
}}
"""
if CUMM_CPU_ONLY_BUILD:
code.raw(f"""
TV_THROW_RT_ERR("CPU ONLY build, don't support cuda algorithm.");
""")
return code.ret("tv::Tensor")
code.add_dependency(CustomThrustLib, TensorViewKernel)
code.add_param_class("cudakers", CudaCommonKernel())
code.raw(f"""
ThrustCustomAllocatorV2 allocator{{alloc_func}};
cudaStream_t stream_cu = reinterpret_cast<cudaStream_t>(stream);
// auto timer = tv::CudaContextTimer<>();
if (indices.empty()){{
indices = tv::empty({{data.dim(0)}}, tv::int32, 0);
}}
tv::cuda::Launch launcher(data.dim(0), stream_cu);
launcher(cudakers::arange_kernel<int32_t>, indices.data_ptr<int32_t>(), indices.dim(0));
tv::dispatch<int32_t, uint32_t, int64_t, uint64_t>(data.dtype(), [&](auto I){{
using T = TV_DECLTYPE(I);
auto masks_ptr = mask.data_ptr<T>();
MaskedElementComp<T> op_comp{{masks_ptr[0]}};
thrust::device_ptr<T> ptr_tr(data.data_ptr<T>());
thrust::device_ptr<int32_t> ptr_k(indices.data_ptr<int32_t>());
// auto thrust_ctx = thrust::cuda::par.on(stream_cu);
auto ctx2 = thrust::cuda::par(allocator).on(stream_cu);
thrust::sort_by_key(ctx2, ptr_tr, ptr_tr + data.dim(0), ptr_k, op_comp);
if (mask_output){{
launcher(mask_input<T>, data.data_ptr<T>(), masks_ptr[0], data.dim(0));
}}
}});
// tv::ssprint("SORT_BY_KEY_MASKED", timer.report() / 1000.0);
return indices;
""")
return code.ret("tv::Tensor")
@pccm.pybind.mark
@pccm.cuda.static_function
def count_bits(self):
code = pccm.FunctionCode()
if CUMM_CPU_ONLY_BUILD:
return code.make_invalid()
code.add_dependency(TensorViewKernel)
code.arg("a", "tv::Tensor")
code.code_after_include = f"""
__global__ void count_bits_kernel_64(const uint64_t* data, int32_t* out, int size){{
for (int i : tv::KernelLoopX<int>(size)){{
out[i] = __popcll(reinterpret_cast<const unsigned long long*>(data)[i]);
}}
}}
__global__ void count_bits_kernel(const uint32_t* data, int32_t* out, int size){{
for (int i : tv::KernelLoopX<int>(size)){{
out[i] = __popc(data[i]);
}}
}}
int numberOfSetBits(uint32_t i)
{{
// https://stackoverflow.com/questions/109023/how-to-count-the-number-of-set-bits-in-a-32-bit-integer
// Java: use int, and use >>> instead of >>. Or use Integer.bitCount()
// C or C++: use uint32_t
i = i - ((i >> 1) & 0x55555555); // add pairs of bits
i = (i & 0x33333333) + ((i >> 2) & 0x33333333); // quads
i = (i + (i >> 4)) & 0x0F0F0F0F; // groups of 8
return (i * 0x01010101) >> 24; // horizontal sum of bytes
}}
int numberOfSetBits(uint64_t i)
{{
return numberOfSetBits(uint32_t(i)) + numberOfSetBits(uint32_t(i >> 32));
}}
"""
code.raw(f"""
tv::Tensor res(a.shape(), tv::int32, a.device());
tv::dispatch<uint32_t, uint64_t>(a.dtype(), [&](auto I){{
auto res_ptr = res.data_ptr<int>();
using T = TV_DECLTYPE(I);
auto a_ptr = a.data_ptr<const T>();
if (a.device() == -1){{
for (int i = 0; i < a.size(); ++i){{
res_ptr[i] = numberOfSetBits(a_ptr[i]);
}}
}}else{{
tv::cuda::Launch launcher(a.size());
tv::if_constexpr<std::is_same<T, uint64_t>::value>([=](auto _)mutable{{
launcher(_(count_bits_kernel_64), a_ptr, res_ptr, int(a.size()));
}}, [=](auto _)mutable{{
launcher(_(count_bits_kernel), a_ptr, res_ptr, int(a.size()));
}});
}}
}});
return res;
""")
return code.ret("tv::Tensor")
@pccm.pybind.mark
@pccm.static_function
def calc_point2voxel_meta_data(self):
code = pccm.FunctionCode()
code.arg("vsize_xyz", f"std::vector<float>")
code.arg("coors_range_xyz", f"std::vector<float>")
code.raw(f"""
int ndim = vsize_xyz.size();
TV_ASSERT_RT_ERR(vsize_xyz.size() == ndim &&
coors_range_xyz.size() == ndim * 2, "your params size not equal to ndim", ndim);
""")
for ndim in self.ndims:
code.raw(f"""
if (ndim == {ndim}){{
std::array<float, {ndim}> vsize_xyz_;
std::array<float, {ndim * 2}> coors_range_xyz_;
for (int i = 0; i < {ndim}; ++i){{
vsize_xyz_[i] = vsize_xyz[i];
coors_range_xyz_[i] = coors_range_xyz[i];
coors_range_xyz_[i + {ndim}] = coors_range_xyz[i + {ndim}];
}}
auto res = Point2Voxel{ndim}DCPU::calc_meta_data(vsize_xyz_, coors_range_xyz_);
std::vector<float> vsize({ndim}), coors_range({ndim * 2});
std::vector<int> grid_size({ndim}), grid_stride({ndim});
for (int i = 0; i < {ndim}; ++i){{
vsize[i] = std::get<0>(res)[i];
grid_size[i] = std::get<1>(res)[i];
grid_stride[i] = std::get<2>(res)[i];
coors_range[i] = std::get<3>(res)[i];
coors_range[i + {ndim}] = std::get<3>(res)[i + {ndim}];
}}
return std::make_tuple(vsize, grid_size, grid_stride, coors_range);
}}
""")
code.raw(f"""TV_THROW_RT_ERR("unknown ndim", ndim);""")
return code.ret("std::tuple<std::vector<float>, std::vector<int>, std::vector<int>, std::vector<float>>")
@pccm.pybind.mark
@pccm.static_function
def point2voxel_cpu(self):
code = pccm.FunctionCode()
code.arg("points", "tv::Tensor")
code.arg("voxels, indices, num_per_voxel, densehashdata", "tv::Tensor")
code.arg("vsize", f"std::vector<float>")
code.arg("grid_size, grid_stride", f"std::vector<int>")
code.arg("coors_range", f"std::vector<float>")
code.arg("empty_mean", "bool", "false")
code.arg("clear_voxels", "bool", "true")
code.raw(f"""
int ndim = vsize.size();
TV_ASSERT_RT_ERR(vsize.size() == ndim && grid_stride.size() == ndim &&
coors_range.size() == ndim * 2 && grid_size.size() == ndim,
"your params size not equal to ndim", ndim);
// voxels: []
""")
for ndim in self.ndims:
code.raw(f"""
if (ndim == {ndim}){{
std::array<float, {ndim}> vsize_;
std::array<int, {ndim}> grid_size_, grid_stride_;
std::array<float, {ndim * 2}> coors_range_;
for (int i = 0; i < {ndim}; ++i){{
vsize_[i] = vsize[i];
grid_size_[i] = grid_size[i];
grid_stride_[i] = grid_stride[i];
coors_range_[i] = coors_range[i];
coors_range_[i + {ndim}] = coors_range[i + {ndim}];
}}
if (empty_mean){{
return Point2Voxel{ndim}DCPU::point_to_voxel_empty_mean_static(points, voxels, indices,
num_per_voxel, densehashdata,
vsize_, grid_size_, grid_stride_, coors_range_, clear_voxels);
}} else{{
return Point2Voxel{ndim}DCPU::point_to_voxel_static(points, voxels, indices,
num_per_voxel, densehashdata,
vsize_, grid_size_, grid_stride_, coors_range_, clear_voxels);
}}
}}
""")
code.raw(f"""TV_THROW_RT_ERR("unknown ndim", ndim);""")
return code.ret("std::tuple<tv::Tensor, tv::Tensor, tv::Tensor>")
@pccm.pybind.mark
@pccm.static_function
def point2voxel_cuda(self):
code = pccm.FunctionCode()
code.arg("points", "tv::Tensor")
code.arg("voxels, indices, num_per_voxel, hashdata, point_indice_data", "tv::Tensor")
code.arg("vsize", f"std::vector<float>")
code.arg("grid_size, grid_stride", f"std::vector<int>")
code.arg("coors_range", f"std::vector<float>")
code.arg("empty_mean", "bool", "false")
code.arg("clear_voxels", "bool", "true")
code.arg("stream_int", f"std::uintptr_t", "0")
code.raw(f"""
int ndim = vsize.size();
TV_ASSERT_RT_ERR(vsize.size() == ndim && grid_stride.size() == ndim &&
coors_range.size() == ndim * 2 && grid_size.size() == ndim,
"your params size not equal to ndim", ndim);
// voxels: []
""")
for ndim in self.ndims:
code.raw(f"""
if (ndim == {ndim}){{
std::array<float, {ndim}> vsize_;
std::array<int, {ndim}> grid_size_, grid_stride_;
std::array<float, {ndim * 2}> coors_range_;
for (int i = 0; i < {ndim}; ++i){{
vsize_[i] = vsize[i];
grid_size_[i] = grid_size[i];
grid_stride_[i] = grid_stride[i];
coors_range_[i] = coors_range[i];
coors_range_[i + {ndim}] = coors_range[i + {ndim}];
}}
return Point2Voxel{ndim}D::point_to_voxel_hash_static(points, voxels, indices,
num_per_voxel, hashdata, point_indice_data,
vsize_, grid_size_, grid_stride_, coors_range_, clear_voxels,
empty_mean, stream_int);
}}
""")
code.raw(f"""TV_THROW_RT_ERR("unknown ndim", ndim);""")
return code.ret("std::tuple<tv::Tensor, tv::Tensor, tv::Tensor>")
\ No newline at end of file
#!/home/yy/library/anaconda3/bin/python
import sys
from pathlib import Path
import ctypes
# _cudart = ctypes.CDLL('libcudart.so')
print(str(Path(__file__).parent.parent.parent.parent))
sys.path.append(str(Path(__file__).parent.parent.parent.parent))
from spconv import tensorview as tv
from spconv.sparse import build
import numpy as np
from pathlib import Path
from spconv.spconv_ops_cc.sparse.all.ops import Point2Voxel
from spconv.spconv_ops_cc.sparse.all import SpconvOps
import time
def main():
data = np.load("/home/yy/OneDrive/dev/spconv/test/data/benchmark-pc.npz")["pc"].astype(np.float32)
print(data.shape, data.dtype)
p2v = Point2Voxel([0.1, 0.1, 0.1], [-80, -80, -2, 80, 80, 6], 3, 150000, 1)
gs = p2v.grid_size # zyx
print(gs)
# return
data_tv = tv.from_numpy(data).cuda()
for i in range(6):
t = time.time()
voxels, indices, num_per_voxel = p2v.point_to_voxel_hash(data_tv)
print(time.time() - t)
voxels, indices, num_per_voxel = p2v.point_to_voxel_hash(data_tv)
print(voxels.shape, gs)
gs_xyz = gs
indices_np = indices.cpu().numpy()
# indices_offset = indices_np[:, 0] * gs_xyz[1] * gs_xyz[2] + indices_np[:, 1] * gs_xyz[2] + indices_np[:, 2]
# uq = np.unique(indices_offset)
# print(uq.shape, indices_offset.shape, gs_xyz)
# return
ksize = [3] * 3
kv = int(np.prod(ksize))
indices_with_bs = np.zeros((indices_np.shape[0], 4), dtype=np.int32)
indices_with_bs[:, 1:] = indices_np
print(indices_with_bs.mean(), indices_with_bs.max(), indices_with_bs.min())
indices = tv.from_numpy(indices_with_bs).cuda()
out_indices = tv.zeros([indices.dim(0) * kv, 4], tv.int32, 0)
indice_num_per_loc = tv.zeros([kv], tv.int32, 0)
points = voxels.view([-1, 3])
hashdata = tv.zeros([points.dim(0) * kv * 2], tv.custom64, 0)
hashdata_subm = tv.zeros([points.dim(0) * 2], tv.custom64, 0)
indice_pairs = tv.full([2, kv, indices.dim(0)], -1, tv.int32, 0)
indice_pairs_uniq = tv.zeros([indice_pairs.size // 2 + 1], tv.int32, 0)
# for i in range(10):
# indice_pairs.fill_int_(-1)
# np.random.shuffle(indices_with_bs)
# indices = tv.from_numpy(indices_with_bs).cuda()
# indice_num_per_loc.zero_()
# out_act = SpconvOps.generate_conv_inds(indices, hashdata, indice_pairs,
# indice_pairs_uniq, out_indices, indice_num_per_loc,
# 1, gs, gs, [3, 3, 3], [1, 1, 1], [1, 1, 1], [1, 1, 1])
# indice_num_per_loc.zero_()
# out_act = SpconvOps.generate_subm_conv_inds(indices, hashdata_subm, indice_pairs,
# out_indices, indice_num_per_loc,
# 1, gs, ksize, [1, 1, 1])
# indice_num_per_loc_cpu = indice_num_per_loc.cpu().numpy()
# indice_pairs_cpu = indice_pairs.cpu().numpy()
# indice_pairs_cpu_flat = indice_pairs_cpu.reshape(-1)
# uq, count = np.unique(indice_pairs_cpu_flat, return_counts=True)
# print(out_act, indice_pairs_cpu.shape, indice_pairs_cpu.mean(), indice_num_per_loc_cpu.tolist())
# print(indice_pairs_cpu[:, 13, :2])
# print(uq, count)
if __name__ == "__main__":
main()
\ No newline at end of file
# Copyright 2021 Yan Yan
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pccm
from cumm.common import TensorView
from typing import List
class GatherCPU(pccm.Class):
def __init__(self):
super().__init__()
self.add_dependency(TensorView)
@pccm.static_function
def gather(self):
code = pccm.FunctionCode()
code.arg("out", "tv::Tensor")
code.arg("in", "tv::Tensor")
code.arg("inds", "tv::Tensor")
code.raw(f"""
// tv::check_shape(inds, {{out.dim(0)}});
auto nhot = inds.dim(0);
int channel = in.dim(1);
tv::dispatch<float, double>(out.dtype(), [&](auto I){{
auto indices_data = inds.data_ptr<const int>();
using T = TV_DECLTYPE(I);
T *buffer_data = out.data_ptr<T>();
const T *features_data = in.data_ptr<const T>();
for (int i = 0; i < nhot; ++i) {{
std::memcpy(buffer_data + i * channel,
features_data + indices_data[i] * channel,
sizeof(T) * channel);
}}
}});
""")
return code
@pccm.static_function
def scatter_add(self):
code = pccm.FunctionCode()
code.arg("out", "tv::Tensor")
code.arg("in", "tv::Tensor")
code.arg("inds", "tv::Tensor")
code.raw(f"""
// tv::check_shape(inds, {{in.dim(0)}});
auto nhot = inds.dim(0);
int channel = in.dim(1);
tv::dispatch<float, double>(out.dtype(), [&](auto I){{
using T = TV_DECLTYPE(I);
auto indices_data = inds.data_ptr<const int>();
const T *buffer_data = in.data_ptr<const T>();
T *features_data = out.data_ptr<T>();
const T *buf = in.data_ptr<const T>();
T *out_ptr = out.data_ptr<T>();
for (int i = 0; i < nhot; ++i) {{
buf = buffer_data + i * channel;
out_ptr = features_data + indices_data[i] * channel;
for (int j = 0; j < channel; ++j) {{
out_ptr[j] = out_ptr[j] + buf[j];
}}
}}
}});
""")
return code
...@@ -56,7 +56,6 @@ class CudaCommonKernel(pccm.ParameterizedClass): ...@@ -56,7 +56,6 @@ class CudaCommonKernel(pccm.ParameterizedClass):
class ConvOutLocIter(pccm.ParameterizedClass): class ConvOutLocIter(pccm.ParameterizedClass):
# TODO add conv transpose
def __init__(self, problem: ConvProblem): def __init__(self, problem: ConvProblem):
super().__init__() super().__init__()
self.add_dependency(TensorView) self.add_dependency(TensorView)
...@@ -73,7 +72,7 @@ class ConvOutLocIter(pccm.ParameterizedClass): ...@@ -73,7 +72,7 @@ class ConvOutLocIter(pccm.ParameterizedClass):
self.add_member("layout_npq", f"LayoutNPQ") self.add_member("layout_npq", f"LayoutNPQ")
self.add_member("layout_rs", f"LayoutRS") self.add_member("layout_rs", f"LayoutRS")
@pccm.cuda.constructor(host=True, device=True, forceinline=True) @pccm.constructor(header_only=True, attrs=["TV_HOST_DEVICE_INLINE"])
def ctor(self): def ctor(self):
code = pccm.FunctionCode() code = pccm.FunctionCode()
code.arg("problem", f"ConvProblem const&") code.arg("problem", f"ConvProblem const&")
...@@ -88,9 +87,7 @@ class ConvOutLocIter(pccm.ParameterizedClass): ...@@ -88,9 +87,7 @@ class ConvOutLocIter(pccm.ParameterizedClass):
return code return code
@pccm.cuda.member_function(host=True, @pccm.member_function(header_only=True, attrs=["TV_HOST_DEVICE_INLINE"],
device=True,
forceinline=True,
name="operator++") name="operator++")
def increment(self): def increment(self):
code = pccm.FunctionCode() code = pccm.FunctionCode()
...@@ -104,9 +101,7 @@ class ConvOutLocIter(pccm.ParameterizedClass): ...@@ -104,9 +101,7 @@ class ConvOutLocIter(pccm.ParameterizedClass):
code.raw("return *this;") code.raw("return *this;")
return code.ret(f"{self.class_name}&") return code.ret(f"{self.class_name}&")
@pccm.cuda.member_function(host=True, @pccm.member_function(header_only=True, attrs=["TV_HOST_DEVICE_INLINE"])
device=True,
forceinline=True)
def set_filter_offset(self): def set_filter_offset(self):
code = pccm.FunctionCode() code = pccm.FunctionCode()
code.arg("filter_offset", "int") code.arg("filter_offset", "int")
...@@ -115,9 +110,7 @@ class ConvOutLocIter(pccm.ParameterizedClass): ...@@ -115,9 +110,7 @@ class ConvOutLocIter(pccm.ParameterizedClass):
""") """)
return code return code
@pccm.cuda.member_function(host=True, @pccm.member_function(header_only=True, attrs=["TV_HOST_DEVICE_INLINE"],
device=True,
forceinline=True,
const=True) const=True)
def nhw_to_npq(self): def nhw_to_npq(self):
code = pccm.FunctionCode() code = pccm.FunctionCode()
...@@ -135,9 +128,7 @@ class ConvOutLocIter(pccm.ParameterizedClass): ...@@ -135,9 +128,7 @@ class ConvOutLocIter(pccm.ParameterizedClass):
""") """)
return code.ret(f"tv::array<int, {self.ndim + 1}>") return code.ret(f"tv::array<int, {self.ndim + 1}>")
@pccm.cuda.member_function(host=True, @pccm.member_function(header_only=True, attrs=["TV_HOST_DEVICE_INLINE"],
device=True,
forceinline=True,
const=True) const=True)
def npq_to_nhw(self): def npq_to_nhw(self):
code = pccm.FunctionCode() code = pccm.FunctionCode()
...@@ -154,9 +145,7 @@ class ConvOutLocIter(pccm.ParameterizedClass): ...@@ -154,9 +145,7 @@ class ConvOutLocIter(pccm.ParameterizedClass):
return code.ret(f"tv::array<int, {self.ndim + 1}>") return code.ret(f"tv::array<int, {self.ndim + 1}>")
@pccm.cuda.member_function(host=True, @pccm.member_function(header_only=True, attrs=["TV_HOST_DEVICE_INLINE"],
device=True,
forceinline=True,
const=True) const=True)
def query_npq(self): def query_npq(self):
code = pccm.FunctionCode() code = pccm.FunctionCode()
...@@ -181,9 +170,7 @@ class ConvOutLocIter(pccm.ParameterizedClass): ...@@ -181,9 +170,7 @@ class ConvOutLocIter(pccm.ParameterizedClass):
""") """)
return code return code
@pccm.cuda.member_function(host=True, @pccm.member_function(header_only=True, attrs=["TV_HOST_DEVICE_INLINE"],
device=True,
forceinline=True,
const=True) const=True)
def query_npq_no_stride(self): def query_npq_no_stride(self):
code = pccm.FunctionCode() code = pccm.FunctionCode()
...@@ -203,9 +190,7 @@ class ConvOutLocIter(pccm.ParameterizedClass): ...@@ -203,9 +190,7 @@ class ConvOutLocIter(pccm.ParameterizedClass):
""") """)
return code return code
@pccm.cuda.member_function(host=True, @pccm.member_function(header_only=True, attrs=["TV_HOST_DEVICE_INLINE"],
device=True,
forceinline=True,
const=True) const=True)
def query_nhw(self): def query_nhw(self):
code = pccm.FunctionCode() code = pccm.FunctionCode()
...@@ -225,9 +210,7 @@ class ConvOutLocIter(pccm.ParameterizedClass): ...@@ -225,9 +210,7 @@ class ConvOutLocIter(pccm.ParameterizedClass):
""") """)
return code return code
@pccm.cuda.member_function(host=True, @pccm.member_function(header_only=True, attrs=["TV_HOST_DEVICE_INLINE"],
device=True,
forceinline=True,
const=True) const=True)
def query_nhw_out(self): def query_nhw_out(self):
code = pccm.FunctionCode() code = pccm.FunctionCode()
...@@ -305,6 +288,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -305,6 +288,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
""") """)
return code return code
@pccm.cuda.cuda_global_function @pccm.cuda.cuda_global_function
def build_conv_hash_table(self): def build_conv_hash_table(self):
code = pccm.FunctionCode() code = pccm.FunctionCode()
...@@ -319,10 +303,10 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -319,10 +303,10 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
code.arg("num_indices", "int") code.arg("num_indices", "int")
code.raw(f""" code.raw(f"""
for (int i : tv::KernelLoopX<int>(num_indices)) {{ for (int output_index : tv::KernelLoopX<int>(num_indices)) {{
{self.dtype_indices} index = indice_pairs_for_uniq[i]; {self.dtype_indices} output_coord_offset = indice_pairs_for_uniq[output_index];
layout_npq.inverse(index, indices_out + {self.ndim + 1} * i); layout_npq.inverse(output_coord_offset, indices_out + {self.ndim + 1} * output_index);
table.insert(index, i); table.insert(output_coord_offset, output_index);
}} }}
""") """)
return code return code
...@@ -340,9 +324,9 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -340,9 +324,9 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
int filter_offset = blockIdx.y; int filter_offset = blockIdx.y;
auto indice_pairs_out_part_filter = indice_pairs_out_part + filter_offset * indices_pair_size; auto indice_pairs_out_part_filter = indice_pairs_out_part + filter_offset * indices_pair_size;
for (int i : tv::KernelLoopX<int>(num_indices_in)) {{ for (int i : tv::KernelLoopX<int>(num_indices_in)) {{
{self.dtype_indices} index = indice_pairs_out_part_filter[i]; {self.dtype_indices} output_coord_offset = indice_pairs_out_part_filter[i];
if (index > -1){{ if (output_coord_offset > -1){{
auto ptr = table.lookup_ptr(index); auto ptr = table.lookup_ptr(output_coord_offset);
if (ptr){{ if (ptr){{
indice_pairs_out_part_filter[i] = ptr->second; indice_pairs_out_part_filter[i] = ptr->second;
}} }}
...@@ -351,6 +335,141 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -351,6 +335,141 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
""") """)
return code return code
@pccm.cuda.cuda_global_function
def calc_conv_indices_stage1_mask(self):
code = pccm.FunctionCode()
code.arg("loc_iter", f"ConvLocIter") # [N, ndim + 1]
code.arg("indices_in", f"const int*") # [N, ndim + 1]
code.arg("indice_pairs_bwd", f"{self.dtype_indices}*") # [2, kernelProd, MaxSize]
code.arg("indice_pairs_for_uniq", f"{self.dtype_indices}*") # [2, kernelProd, MaxSize]
code.arg("indice_num_per_loc", f"int*") # [kernelProd]
code.arg("num_indices_in", "int")
code.arg("RS", "int")
code.arg("transposed", "bool")
code.raw(f"""
int filter_offset = blockIdx.y;
loc_iter.set_filter_offset(filter_offset);
int indices_pair_size_mul_RS = num_indices_in * RS;
int filter_offset_mul_indices_pair_size = filter_offset * num_indices_in;
for (int input_index : tv::KernelLoopX<int>(num_indices_in)) {{
tv::array<int, {self.ndim + 1}> npq_offset;
bool valid;
if (transposed){{
valid = loc_iter.query_nhw_out(indices_in + input_index * {self.ndim + 1}, npq_offset);
}}else{{
valid = loc_iter.query_npq(indices_in + input_index * {self.ndim + 1}, npq_offset);
}}
if (valid){{
int old_num = tv::cuda::atomicAggInc(indice_num_per_loc + filter_offset);
{self.dtype_indices} output_coord_offset = loc_iter.layout_npq(npq_offset);
// if (old_num < indices_pair_size){{
// indice_pairs[filter_offset_mul_indices_pair_size + old_num] = i;
indice_pairs_bwd[filter_offset_mul_indices_pair_size + input_index] = output_coord_offset;
indice_pairs_for_uniq[filter_offset_mul_indices_pair_size + old_num] = output_coord_offset;
// }}
}}
}}
""")
return code
@pccm.cuda.cuda_global_function
def calc_conv_indices_stage2_mask(self):
code = pccm.FunctionCode()
code.targ("TTable")
code.arg("table", f"TTable") # [N, ndim + 1]
code.arg("indice_pairs_fwd", f"int*") # [kernelProd, MaxSize], inp -> out
code.arg("indice_pairs_bwd", f"int*") # [kernelProd, MaxSize], out -> inp
code.arg("mask_fwd", f"uint32_t*") # [kernelProd]
code.arg("mask_bwd", f"uint32_t*") # [kernelProd]
code.arg("num_indices_in", "int")
code.arg("num_indices_out", "int")
# TODO use block instead of filter_offset?
code.raw(f"""
int filter_offset = blockIdx.y;
uint32_t filter_mask_fwd = (1u << (filter_offset));
// TODO following rule for even kernel size is wrong.
// uint32_t filter_mask_bwd = (1u << (gridDim.y - 1 - filter_offset));
auto indice_pairs_fwd_filter = indice_pairs_fwd + filter_offset * num_indices_out;
auto indice_pairs_bwd_filter = indice_pairs_bwd + filter_offset * num_indices_in;
for (int input_index : tv::KernelLoopX<int>(num_indices_in)) {{
{self.dtype_indices} output_coord_offset = indice_pairs_bwd_filter[input_index];
if (output_coord_offset > -1){{
auto ptr = table.lookup_ptr(output_coord_offset);
if (ptr){{
auto output_index = ptr->second;
atomicOr(mask_fwd + output_index, filter_mask_fwd);
// atomicOr(mask_bwd + input_index, filter_mask_bwd);
indice_pairs_fwd_filter[output_index] = input_index;
indice_pairs_bwd_filter[input_index] = output_index;
}}
}}
}}
""")
return code
@pccm.cuda.cuda_global_function
def calc_conv_indices_stage2_mask_output(self):
code = pccm.FunctionCode()
code.arg("indice_pairs_bwd", f"int*") # [kernelProd, MaxSize], out -> inp
code.arg("mask_bwd", f"uint32_t*") # [kernelProd]
code.arg("num_indices_in", "int")
code.arg("kv", "int")
# TODO use block instead of filter_offset?
code.raw(f"""
for (int input_index : tv::KernelLoopX<int>(num_indices_in)) {{
uint32_t mask = 0;
for (int filter_offset = 0; filter_offset < kv; ++filter_offset){{
auto val = indice_pairs_bwd[filter_offset * num_indices_in + input_index];
mask |= (val != -1) << filter_offset;
}}
mask_bwd[input_index] = mask;
}}
""")
return code
@pccm.cuda.cuda_global_function
def calc_conv_indices_stage2_inference_mask(self):
code = pccm.FunctionCode()
code.targ("TTable")
code.arg("table", f"TTable") # [N, ndim + 1]
code.arg("indice_pairs_fwd", f"int*") # [kernelProd, MaxSize], inp -> out
code.arg("indice_pairs_bwd", f"int*") # [kernelProd, MaxSize], out -> inp
code.arg("mask_fwd", f"uint32_t*") # [kernelProd]
code.arg("num_indices_in", "int")
code.arg("num_indices_out", "int")
# TODO use block instead of filter_offset?
code.raw(f"""
int filter_offset = blockIdx.y;
uint32_t filter_mask_fwd = (1u << (filter_offset));
auto indice_pairs_fwd_filter = indice_pairs_fwd + filter_offset * num_indices_out;
auto indice_pairs_bwd_filter = indice_pairs_bwd + filter_offset * num_indices_in;
for (int input_index : tv::KernelLoopX<int>(num_indices_in)) {{
{self.dtype_indices} output_coord_offset = indice_pairs_bwd_filter[input_index];
if (output_coord_offset > -1){{
auto ptr = table.lookup_ptr(output_coord_offset);
if (ptr){{
auto output_index = ptr->second;
atomicOr(mask_fwd + output_index, filter_mask_fwd);
indice_pairs_fwd_filter[output_index] = input_index;
}}
}}
}}
""")
return code
@pccm.cuda.cuda_global_function @pccm.cuda.cuda_global_function
def build_subm_conv_hash_table(self): def build_subm_conv_hash_table(self):
code = pccm.FunctionCode() code = pccm.FunctionCode()
...@@ -475,9 +594,9 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -475,9 +594,9 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
atomicOr(mask + input_index, filter_mask_in); atomicOr(mask + input_index, filter_mask_in);
// for this output, we set correct input idx. // for this output, we set correct input idx.
indice_pairs[filter_offset_mul_indices_pair_size + output_index] = input_index; indice_pairs[filter_offset_mul_indices_pair_size + output_index] = input_index;
indice_pairs[indices_pair_size_mul_RS + filter_offset_mul_indices_pair_size + input_index] = output_index;
// the output in "input location" connect this output idx in another location. // the output in "input location" connect this output idx in another location.
indice_pairs[filter_offset_mul_indices_pair_size_1 + input_index] = output_index; indice_pairs[filter_offset_mul_indices_pair_size_1 + input_index] = output_index;
indice_pairs[indices_pair_size_mul_RS + filter_offset_mul_indices_pair_size + input_index] = output_index;
indice_pairs[indices_pair_size_mul_RS + filter_offset_mul_indices_pair_size_1 + output_index] = input_index; indice_pairs[indices_pair_size_mul_RS + filter_offset_mul_indices_pair_size_1 + output_index] = input_index;
}} }}
}} }}
...@@ -559,8 +678,12 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -559,8 +678,12 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
// TODO handle num input == 0 // TODO handle num input == 0
int kv = tv::arrayops::prod(ksize); int kv = tv::arrayops::prod(ksize);
TV_ASSERT_RT_ERR(kv == indice_pairs.dim(1), "error"); TV_ASSERT_RT_ERR(kv == indice_pairs.dim(1), "error");
TV_ASSERT_RT_ERR(tv::arrayops::prod(input_dims) <= std::numeric_limits<{self.dtype_indices}>::max(),
"kernel volume must smaller than max value of {self.dtype_indices}");
// indice_pairs: [2, kv, indices.dim(0)] // indice_pairs: [2, kv, indices.dim(0)]
// indice_pairs_uniq: [indice_pairs.size() / 2 + 1] // indice_pairs_uniq: [indice_pairs.size() / 2 + 1]
tv::check_shape(indice_pairs, {{2, kv, indices.dim(0)}});
tv::check_shape(indice_num_per_loc, {{kv}});
int64_t uniq_size = indice_pairs.size() / 2 + 1; int64_t uniq_size = indice_pairs.size() / 2 + 1;
TV_ASSERT_RT_ERR(indice_pairs_uniq.dim(0) >= uniq_size, "error"); TV_ASSERT_RT_ERR(indice_pairs_uniq.dim(0) >= uniq_size, "error");
TV_ASSERT_RT_ERR(indice_num_per_loc.dim(0) == kv, "error"); TV_ASSERT_RT_ERR(indice_num_per_loc.dim(0) == kv, "error");
...@@ -585,6 +708,8 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -585,6 +708,8 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
""") """)
return code# .ret("int") return code# .ret("int")
@pccm.cuda.static_function @pccm.cuda.static_function
def generate_conv_inds_stage1_5(self): def generate_conv_inds_stage1_5(self):
code = pccm.FunctionCode() code = pccm.FunctionCode()
...@@ -622,7 +747,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -622,7 +747,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
// indice_pairs: [2, kv, indices.dim(0)] // indice_pairs: [2, kv, indices.dim(0)]
// indice_pairs_uniq: [indice_pairs.size() / 2 + 1] // indice_pairs_uniq: [indice_pairs.size() / 2 + 1]
// out_inds: [MaxSize, {self.ndim + 1}] // out_inds: [MaxSize, {self.ndim + 1}]
auto timer = tv::CudaContextTimer<>(); // auto timer = tv::CudaContextTimer<>();
int64_t uniq_size = indice_pairs.size() / 2 + 1; int64_t uniq_size = indice_pairs.size() / 2 + 1;
TV_ASSERT_RT_ERR(indice_pairs_uniq.dim(0) >= num_out_act, "error"); TV_ASSERT_RT_ERR(indice_pairs_uniq.dim(0) >= num_out_act, "error");
TV_ASSERT_RT_ERR(out_inds.dim(0) >= num_out_act && out_inds.dim(1) == {self.ndim + 1}, "error"); TV_ASSERT_RT_ERR(out_inds.dim(0) >= num_out_act && out_inds.dim(1) == {self.ndim + 1}, "error");
...@@ -654,6 +779,130 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -654,6 +779,130 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
""") """)
return code.ret("int") return code.ret("int")
@pccm.cuda.static_function
def generate_conv_inds_mask_stage1(self):
code = pccm.FunctionCode()
code.arg("indices", "tv::Tensor")
code.arg("indice_pairs_bwd, indice_pairs_uniq, indice_num_per_loc", "tv::Tensor")
code.arg("batch_size", "int")
code.arg("output_dims, input_dims", f"tv::array<int, {self.ndim}>")
code.arg("ksize, stride, padding, dilation", f"tv::array<int, {self.ndim}>")
code.arg("transposed", f"bool", "false")
code.arg("stream_int", f"std::uintptr_t", "0")
code.raw(f"""
// TODO stream
// TODO handle num input == 0
int kv = tv::arrayops::prod(ksize);
TV_ASSERT_RT_ERR(tv::arrayops::prod(input_dims) <= std::numeric_limits<{self.dtype_indices}>::max(),
"kernel volume must smaller than max value of {self.dtype_indices}");
// indice_pairs_bwd: [kv, indices.dim(0)]
// indice_pairs_uniq: [indice_pairs_bwd.size() + 1]
tv::check_shape(indice_pairs_bwd, {{kv, indices.dim(0)}});
tv::check_shape(indice_num_per_loc, {{kv}});
int64_t uniq_size = indice_pairs_bwd.size() + 1;
TV_ASSERT_RT_ERR(indice_pairs_uniq.dim(0) >= uniq_size, "error");
int64_t expected_out_size = indices.dim(0) * kv;
tv::cuda::Launch launcher_num_act_in(indices.dim(0), reinterpret_cast<cudaStream_t>(stream_int));
// tv::cuda::Launch launcher_num_act_in_2(indices.dim(0));
launcher_num_act_in.blocks.y = kv;
ConvProblem problem(batch_size, 1, 1, input_dims, output_dims, ksize, padding, stride, dilation);
ConvLocIter loc_iter(problem);
tv::cuda::Launch launcher_clean_uniq(uniq_size, reinterpret_cast<cudaStream_t>(stream_int));
launcher_clean_uniq(clean_indices_uniq, indice_pairs_uniq.data_ptr<{self.dtype_indices}>(), uniq_size);
launcher_num_act_in(calc_conv_indices_stage1_mask, loc_iter, indices.data_ptr<const int>(),
indice_pairs_bwd.data_ptr<{self.dtype_indices}>(),
indice_pairs_uniq.data_ptr<{self.dtype_indices}>(), indice_num_per_loc.data_ptr<int>(), indices.dim(0),
kv, transposed);
auto timer = tv::CudaContextTimer<>();
""")
return code# .ret("int")
@pccm.cuda.static_function
def generate_conv_inds_stage2_mask(self):
code = pccm.FunctionCode()
code.arg("indices, hashdata", "tv::Tensor")
code.arg("indice_pairs_fwd, indice_pairs_bwd, indice_pairs_uniq, out_inds", "tv::Tensor")
code.arg("mask_fwd, mask_bwd", "tv::Tensor")
code.arg("num_out_act", "int")
code.arg("batch_size", "int")
code.arg("output_dims, input_dims", f"tv::array<int, {self.ndim}>")
code.arg("ksize, stride, padding, dilation", f"tv::array<int, {self.ndim}>")
code.arg("transposed", f"bool", "false")
code.arg("stream_int", f"std::uintptr_t", "0")
code.raw(f"""
auto custream = reinterpret_cast<cudaStream_t>(stream_int);
// TODO stream
// TODO handle num input == 0
int kv = tv::arrayops::prod(ksize);
// indice_pairs_bwd: [kv, indices.dim(0)]
// indice_pairs_fwd: [kv, out_inds.dim(0)]
auto ctx = tv::Context();
ctx.set_cuda_stream(custream);
// out_inds: [MaxSize, {self.ndim + 1}]
// auto timer = tv::CudaContextTimer<>();
tv::check_shape(indice_pairs_bwd, {{kv, indices.dim(0)}});
tv::check_shape(indice_pairs_fwd, {{kv, num_out_act}});
tv::check_shape(out_inds, {{num_out_act, {self.ndim + 1}}});
tv::cuda::Launch launcher_num_act_in(indices.dim(0), custream);
launcher_num_act_in.blocks.y = kv;
tv::cuda::Launch launcher_num_act_in_no_y(indices.dim(0), custream);
ConvProblem problem(batch_size, 1, 1, input_dims, output_dims, ksize, padding, stride, dilation);
ConvLocIter loc_iter(problem);
// TODO handle invalid num_out_act
indice_pairs_uniq = indice_pairs_uniq.slice_first_axis(0, num_out_act);
tv::cuda::Launch lanucher_build_hash(num_out_act, custream);
using V = {self.dtype_indices};
using KeyType = {self.dtype_indices};
constexpr KeyType kEmptyKey = std::numeric_limits<KeyType>::max();
using table_t =
tv::hash::LinearHashTable<KeyType, V, tv::hash::Murmur3Hash<KeyType>,
kEmptyKey, false>;
using pair_t = typename table_t::value_type;
TV_ASSERT_RT_ERR(hashdata.dim(0) >= num_out_act, "hash size not enough");
table_t hash = table_t(hashdata.data_ptr<pair_t>(), hashdata.dim(0));
hash.clear(custream);
lanucher_build_hash(build_conv_hash_table<table_t>, hash,
out_inds.data_ptr<int>(), indice_pairs_uniq.data_ptr<const {self.dtype_indices}>(),
loc_iter.layout_npq, num_out_act);
if (!mask_bwd.empty()){{
// auto timer = tv::CudaContextTimer<>();
launcher_num_act_in(calc_conv_indices_stage2_mask<table_t>, hash,
indice_pairs_fwd.data_ptr<int>(), indice_pairs_bwd.data_ptr<int>(),
mask_fwd.data_ptr<uint32_t>(), mask_bwd.data_ptr<uint32_t>(),
indice_pairs_bwd.dim(1), indice_pairs_fwd.dim(1));
// tv::ssprint("calc_conv_indices_stage2_mask", timer.report() / 1000.0);
launcher_num_act_in_no_y(calc_conv_indices_stage2_mask_output, indice_pairs_bwd.data_ptr<int>(),
mask_bwd.data_ptr<uint32_t>(),
indice_pairs_bwd.dim(1), kv);
// tv::ssprint("calc_conv_indices_stage2_mask_output", timer.report() / 1000.0);
if (mask_fwd.dim(0) == 2){{
mask_fwd[1].copy_(mask_fwd[0], ctx);
}}
if (mask_bwd.dim(0) == 2){{
mask_bwd[1].copy_(mask_bwd[0], ctx);
}}
}}else{{
launcher_num_act_in(calc_conv_indices_stage2_inference_mask<table_t>, hash,
indice_pairs_fwd.data_ptr<int>(), indice_pairs_bwd.data_ptr<int>(),
mask_fwd.data_ptr<uint32_t>(),
indice_pairs_bwd.dim(1), indice_pairs_fwd.dim(1));
if (mask_fwd.dim(0) == 2){{
mask_fwd[1].copy_(mask_fwd[0], ctx);
}}
}}
return num_out_act;
""")
return code.ret("int")
@pccm.cuda.static_function @pccm.cuda.static_function
def generate_subm_conv_inds(self): def generate_subm_conv_inds(self):
...@@ -691,7 +940,8 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -691,7 +940,8 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
tv::cuda::Launch launcher_num_act_in(indices.dim(0), custream); tv::cuda::Launch launcher_num_act_in(indices.dim(0), custream);
launcher_num_act_in.blocks.y = (kv / 2) + 1; launcher_num_act_in.blocks.y = (kv / 2) + 1;
// launcher_num_act_in.blocks.y = kv; // launcher_num_act_in.blocks.y = kv;
TV_ASSERT_RT_ERR(tv::arrayops::prod(input_dims) <= std::numeric_limits<{self.dtype_indices}>::max(),
"kernel volume must smaller than max value of {self.dtype_indices}");
ConvProblem problem(batch_size, 1, 1, input_dims, input_dims, ksize, padding, stride, dilation); ConvProblem problem(batch_size, 1, 1, input_dims, input_dims, ksize, padding, stride, dilation);
ConvLocIter loc_iter(problem); ConvLocIter loc_iter(problem);
...@@ -713,7 +963,8 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -713,7 +963,8 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
loc_iter.layout_npq, indices.dim(0)); loc_iter.layout_npq, indices.dim(0));
// tv::ssprint("build_hash time", timer.report() / 1000.0); // tv::ssprint("build_hash time", timer.report() / 1000.0);
if (!indice_pair_mask.empty()){{ if (!indice_pair_mask.empty()){{
if (indice_pair_mask.ndim() == 2 && indice_pair_mask.dim(0) == 2){{ TV_ASSERT_INVALID_ARG(indice_pair_mask.ndim() == 2, "error");
if (indice_pair_mask.dim(0) == 2){{
auto mask_0 = indice_pair_mask[0]; auto mask_0 = indice_pair_mask[0];
tv::cuda::Launch lanucher_fill(mask_0.size(), custream); tv::cuda::Launch lanucher_fill(mask_0.size(), custream);
lanucher_fill(cudakers::fill_kernel<uint32_t>, mask_0.data_ptr<uint32_t>(), (1 << (kv / 2)), mask_0.size()); lanucher_fill(cudakers::fill_kernel<uint32_t>, mask_0.data_ptr<uint32_t>(), (1 << (kv / 2)), mask_0.size());
...@@ -726,7 +977,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -726,7 +977,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
}}else{{ }}else{{
tv::cuda::Launch lanucher_fill(indice_pair_mask.size(), custream); tv::cuda::Launch lanucher_fill(indice_pair_mask.size(), custream);
lanucher_fill(cudakers::fill_kernel<uint32_t>, indice_pair_mask.data_ptr<uint32_t>(), (1 << (kv / 2)), indice_pair_mask.size()); lanucher_fill(cudakers::fill_kernel<uint32_t>, indice_pair_mask.data_ptr<uint32_t>(), (1 << (kv / 2)), indice_pair_mask.size());
TV_ASSERT_RT_ERR(indice_pair_mask.ndim() == 1, "error"); TV_ASSERT_RT_ERR(indice_pair_mask.dim(0) == 1, "error");
launcher_num_act_in(calc_subm_conv_indices_mask<table_t>, loc_iter, hash, launcher_num_act_in(calc_subm_conv_indices_mask<table_t>, loc_iter, hash,
indices.data_ptr<int>(), indice_pairs.data_ptr<int>(), indices.data_ptr<int>(), indice_pairs.data_ptr<int>(),
indice_pair_mask.data_ptr<uint32_t>(), indices.dim(0), indice_pairs.dim(2), kv); indice_pair_mask.data_ptr<uint32_t>(), indices.dim(0), indice_pairs.dim(2), kv);
...@@ -741,3 +992,141 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -741,3 +992,141 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
""") """)
return code.ret("int") return code.ret("int")
class SparseConvIndicesCPU(pccm.ParameterizedClass):
def __init__(self, problem: ConvProblem, dtype_indices: dtypes.DType):
super().__init__()
self.add_dependency(TensorView)
self.add_include("unordered_map")
self.loc_iter = ConvOutLocIter(problem)
self.add_param_class("spinds", self.loc_iter, "ConvLocIter")
self.add_param_class("spinds", problem, "ConvProblem")
self.ndim = problem.ndim
self.dtype_indices = dtype_indices
self.dtype_indices_uniq = dtype_indices
assert dtype_indices == dtypes.int32 or dtype_indices == dtypes.int64
@pccm.static_function
def generate_subm_conv_inds(self):
code = pccm.FunctionCode()
code.arg("indices", "tv::Tensor")
code.arg("indice_pairs, out_inds, indice_num_per_loc", "tv::Tensor")
code.arg("batch_size", "int")
code.arg("input_dims", f"tv::array<int, {self.ndim}>")
code.arg("ksize, dilation", f"tv::array<int, {self.ndim}>")
code.raw(f"""
tv::array<int, {self.ndim}> stride, padding;
for (int i = 0; i < {self.ndim}; ++i){{
TV_ASSERT_RT_ERR(ksize[i] % 2 == 1, "subm only support odd ksize");
stride[i] = 1;
padding[i] = (ksize[i] / 2) * dilation[i];
}}
int kv = tv::arrayops::prod(ksize);
ConvProblem problem(batch_size, 1, 1, input_dims, input_dims, ksize, padding, stride, dilation);
ConvLocIter loc_iter(problem);
int indices_pair_size = indice_pairs.dim(2);
int indices_pair_size_mul_RS = indices_pair_size * kv;
auto indice_pairs_ptr = indice_pairs.data_ptr<{self.dtype_indices}>();
std::unordered_map<{self.dtype_indices}, {self.dtype_indices}> hash;
auto indices_ptr = indices.data_ptr<{self.dtype_indices}>();
int indice_in_num = indices.dim(0);
for (int i = 0; i < indice_in_num; ++i){{
{self.dtype_indices} index = loc_iter.layout_npq(indices_ptr);
hash.insert({{index, i}});
indices_ptr += {self.ndim + 1};
}}
for (int filter_offset = 0; filter_offset < (kv / 2 + 1); ++filter_offset){{
int filter_offset_mul_indices_pair_size = filter_offset * indices_pair_size;
int filter_offset_mul_indices_pair_size_1 = (kv - 1 - filter_offset) * indices_pair_size;
if (filter_offset == kv / 2){{
for (int i = 0; i < indice_in_num; ++i){{
indice_pairs_ptr[filter_offset_mul_indices_pair_size + i] = i;
indice_pairs_ptr[indices_pair_size_mul_RS + filter_offset_mul_indices_pair_size + i] = i;
}}
}}else{{
indices_ptr = indices.data_ptr<{self.dtype_indices}>();
auto indice_num_per_loc_ptr = indice_num_per_loc.data_ptr<{self.dtype_indices}>() + filter_offset;
for (int i = 0; i < indice_in_num; ++i){{
tv::array<int, {self.ndim + 1}> npq_offset;
if (loc_iter.query_npq_no_stride(indices_ptr, npq_offset)){{
auto index = loc_iter.layout_npq(npq_offset);
auto iter = hash.find(index);
if (iter != hash.end()){{
auto old_num = indice_num_per_loc_ptr[0]++;
indice_pairs_ptr[filter_offset_mul_indices_pair_size + old_num] = i;
indice_pairs_ptr[indices_pair_size_mul_RS + filter_offset_mul_indices_pair_size + old_num] = iter->second;
indice_pairs_ptr[filter_offset_mul_indices_pair_size_1 + old_num] = iter->second;
indice_pairs_ptr[indices_pair_size_mul_RS + filter_offset_mul_indices_pair_size_1 + old_num] = i;
}}
}}
indices_ptr += {self.ndim + 1};
}}
}}
++loc_iter;
}}
return indices.dim(0);
""")
return code.ret("int")
@pccm.static_function
def generate_conv_inds(self):
code = pccm.FunctionCode()
code.arg("indices", "tv::Tensor")
code.arg("indice_pairs, out_inds, indice_num_per_loc", "tv::Tensor")
code.arg("batch_size", "int")
code.arg("output_dims, input_dims", f"tv::array<int, {self.ndim}>")
code.arg("ksize, stride, padding, dilation", f"tv::array<int, {self.ndim}>")
code.arg("transposed", f"bool", "false")
code.raw(f"""
int kv = tv::arrayops::prod(ksize);
ConvProblem problem(batch_size, 1, 1, input_dims, output_dims, ksize, padding, stride, dilation);
ConvLocIter loc_iter(problem);
int indices_pair_size = indice_pairs.dim(2);
int indices_pair_size_mul_RS = indices_pair_size * kv;
auto indice_pairs_ptr = indice_pairs.data_ptr<{self.dtype_indices}>();
std::unordered_map<{self.dtype_indices}, {self.dtype_indices}> hash;
auto indices_ptr = indices.data_ptr<{self.dtype_indices}>();
auto out_inds_ptr = out_inds.data_ptr<{self.dtype_indices}>();
int indice_in_num = indices.dim(0);
int num_act = 0;
{self.dtype_indices} hashval;
for (int filter_offset = 0; filter_offset < kv; ++filter_offset){{
int filter_offset_mul_indices_pair_size = filter_offset * indices_pair_size;
indices_ptr = indices.data_ptr<{self.dtype_indices}>();
auto indice_num_per_loc_ptr = indice_num_per_loc.data_ptr<{self.dtype_indices}>() + filter_offset;
for (int i = 0; i < indice_in_num; ++i){{
tv::array<int, {self.ndim + 1}> npq_offset;
bool valid;
if (transposed){{
valid = loc_iter.query_nhw_out(indices_ptr, npq_offset);
}}else{{
valid = loc_iter.query_npq(indices_ptr, npq_offset);
}}
if (valid){{
auto index = loc_iter.layout_npq(npq_offset);
auto iter = hash.find(index);
if (iter == hash.end()){{
hashval = num_act++;
hash.insert({{index, hashval}});
for (int k = 0; k < {self.ndim + 1}; ++k){{
out_inds_ptr[k] = npq_offset[k];
}}
out_inds_ptr += {self.ndim + 1};
}}else{{
hashval = iter->second;
}}
indice_pairs_ptr[filter_offset_mul_indices_pair_size + indice_num_per_loc_ptr[0]] = i;
indice_pairs_ptr[indices_pair_size_mul_RS + filter_offset_mul_indices_pair_size + indice_num_per_loc_ptr[0]++] = hashval;
}}
indices_ptr += {self.ndim + 1};
}}
++loc_iter;
}}
return num_act;
""")
return code.ret("int")
...@@ -30,6 +30,7 @@ class IndiceMaxPool(pccm.Class): ...@@ -30,6 +30,7 @@ class IndiceMaxPool(pccm.Class):
# TODO optimize this function # TODO optimize this function
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.add_include("limits")
self.add_dependency(TensorViewKernel, TensorView, GemmBasic) self.add_dependency(TensorViewKernel, TensorView, GemmBasic)
@pccm.cuda.cuda_global_function @pccm.cuda.cuda_global_function
...@@ -61,6 +62,44 @@ class IndiceMaxPool(pccm.Class): ...@@ -61,6 +62,44 @@ class IndiceMaxPool(pccm.Class):
""") """)
return code return code
@pccm.cuda.cuda_global_function
def forward_implicit_gemm_kernel(self):
code = pccm.FunctionCode()
code.targ("T")
code.arg("out_features", f"T*")
code.arg("in_features", f"const T*")
code.arg("indices", "const int*")
code.arg("num_features", "int")
code.arg("RS", "int")
code.arg("num_indices", "int")
code.arg("lowest", "T")
code.raw(f"""
for (int i : tv::KernelLoopY<int>(num_indices)) {{
auto out_ptr = out_features + i * num_features;
for (int j : tv::KernelLoopX<int>(num_features)) {{
auto indices_ptr = indices + i;
int in_idx = indices_ptr[0];
T in, in_temp;
in = lowest;
bool valid = in_idx != -1;
in_temp = valid ? in_features[in_idx * num_features + j] : lowest;
in = (in < in_temp && valid) ? in_temp: in;
indices_ptr += num_indices;
for (int k = 1; k < RS; ++k){{
in_idx = indices_ptr[0];
valid = in_idx != -1;
in_temp = valid ? in_features[in_idx * num_features + j] : lowest;
in = (in < in_temp && valid) ? in_temp: in;
indices_ptr += num_indices;
}}
out_ptr[j] = in;
}}
}}
""")
return code
@pccm.cuda.cuda_global_function @pccm.cuda.cuda_global_function
def backward_kernel(self): def backward_kernel(self):
code = pccm.FunctionCode() code = pccm.FunctionCode()
...@@ -93,6 +132,52 @@ class IndiceMaxPool(pccm.Class): ...@@ -93,6 +132,52 @@ class IndiceMaxPool(pccm.Class):
""") """)
return code return code
@pccm.cuda.cuda_global_function
def backward_implicit_gemm_kernel(self):
code = pccm.FunctionCode()
code.targ("T")
code.arg("out_features", f"const T*")
code.arg("in_features", f"const T*")
code.arg("dout_features", f"const T*")
code.arg("din_features", f"T*")
code.arg("indices_bwd", "const int*")
code.arg("num_features", "int")
code.arg("RS", "int")
code.arg("num_indices", "int")
code.raw(f"""
for (int i : tv::KernelLoopY<int>(num_indices)) {{
auto in_ptr = in_features + i * num_features;
auto din_ptr = din_features + i * num_features;
for (int j : tv::KernelLoopX<int>(num_features)) {{
auto indices_ptr = indices_bwd + i;
int out_idx = indices_ptr[0];
T in = in_ptr[j];
T sum_val = T(0);
// if idx invalid, we only need to ensure in not equal to out.
T out = out_idx != -1 ? out_features[out_idx * num_features + j] : T(0);
T dout = out_idx != -1 ? dout_features[out_idx * num_features + j] : T(0);
bool valid = in == out && out_idx != -1;
sum_val = valid ? sum_val + dout : sum_val;
indices_ptr += num_indices;
for (int k = 1; k < RS; ++k){{
out_idx = indices_ptr[0];
out = out_idx != -1 ? out_features[out_idx * num_features + j] : T(0);
dout = out_idx != -1 ? dout_features[out_idx * num_features + j] : T(0);
valid = in == out && out_idx != -1;
sum_val = valid ? sum_val + dout : sum_val;
indices_ptr += num_indices;
}}
din_ptr[j] = sum_val;
}}
}}
""")
return code
@pccm.cuda.static_function @pccm.cuda.static_function
def forward(self): def forward(self):
code = pccm.FunctionCode() code = pccm.FunctionCode()
...@@ -132,6 +217,49 @@ class IndiceMaxPool(pccm.Class): ...@@ -132,6 +217,49 @@ class IndiceMaxPool(pccm.Class):
""") """)
return code return code
@pccm.cuda.static_function
def forward_implicit_gemm(self):
code = pccm.FunctionCode()
code.arg("out", "tv::Tensor")
code.arg("in", "tv::Tensor")
code.arg("inds", "tv::Tensor")
code.arg("stream", "std::uintptr_t", "0")
code.raw(f"""
auto nhot = out.dim(0);
tv::check_shape(inds, {{-1, nhot}});
tv::check_shape(in, {{-1, out.dim(1)}});
auto cudastream = reinterpret_cast<cudaStream_t>(stream);
tv::dispatch<float, double, tv::half_t, tv::bfloat16_t>(out.dtype(), [&](auto I){{
using T = TV_DECLTYPE(I);
constexpr int MaxThreads = 512;
tv::cuda::Launch launcher(1);
bool found = tv::dispatch_int_noexcept<512, 256, 128, 64, 32, 16>(out.dim(1), [](int my, int expect){{return my >= expect;}}, [&](auto V){{
// if out.dim(1) > value in list above, run this function.
// if a value is found, other value won't be executed.
int NumFeatures = TV_DECLTYPE(V)::value;
int Num0 = MaxThreads / NumFeatures;
dim3 blocks(tv::div_up(out.dim(1), NumFeatures), tv::div_up(nhot, Num0));
dim3 threads(NumFeatures, Num0);
launcher = tv::cuda::Launch(blocks, threads, cudastream);
}});
if (!found){{
int NumFeatures = 16;
int Num0 = MaxThreads / NumFeatures;
dim3 blocks(tv::div_up(out.dim(1), NumFeatures), tv::div_up(nhot, Num0));
dim3 threads(NumFeatures, Num0);
launcher = tv::cuda::Launch(blocks, threads, cudastream);
}}
T lowest = std::numeric_limits<T>::lowest();
lowest = T(0);
launcher(forward_implicit_gemm_kernel<T>, out.data_ptr<T>(), in.data_ptr<const T>(),
inds.data_ptr<const int>(), out.dim(1), inds.dim(0), inds.dim(1), lowest);
}});
""")
return code
@pccm.cuda.static_function @pccm.cuda.static_function
def backward(self): def backward(self):
code = pccm.FunctionCode() code = pccm.FunctionCode()
...@@ -173,3 +301,133 @@ class IndiceMaxPool(pccm.Class): ...@@ -173,3 +301,133 @@ class IndiceMaxPool(pccm.Class):
}}); }});
""") """)
return code return code
@pccm.cuda.static_function
def backward_implicit_gemm(self):
code = pccm.FunctionCode()
code.arg("out", "tv::Tensor")
code.arg("in", "tv::Tensor")
code.arg("dout", "tv::Tensor")
code.arg("din", "tv::Tensor")
code.arg("inds", "tv::Tensor")
code.arg("stream", "std::uintptr_t", "0")
code.raw(f"""
auto nhot = in.dim(0);
tv::check_shape(inds, {{-1, nhot}});
tv::check_shape(in, {{-1, out.dim(1)}});
auto cudastream = reinterpret_cast<cudaStream_t>(stream);
tv::dispatch<float, double, tv::half_t, tv::bfloat16_t>(out.dtype(), [&](auto I){{
using T = TV_DECLTYPE(I);
constexpr int MaxThreads = 512;
tv::cuda::Launch launcher(1);
bool found = tv::dispatch_int_noexcept<512, 256, 128, 64, 32, 16>(out.dim(1), [](int my, int expect){{return my >= expect;}}, [&](auto V){{
// if out.dim(1) > value in list above, run this function.
// if a value is found, other value won't be executed.
int NumFeatures = TV_DECLTYPE(V)::value;
int Num0 = MaxThreads / NumFeatures;
dim3 blocks(tv::div_up(out.dim(1), NumFeatures), tv::div_up(nhot, Num0));
dim3 threads(NumFeatures, Num0);
launcher = tv::cuda::Launch(blocks, threads, cudastream);
}});
if (!found){{
int NumFeatures = 16;
int Num0 = MaxThreads / NumFeatures;
dim3 blocks(tv::div_up(out.dim(1), NumFeatures), tv::div_up(nhot, Num0));
dim3 threads(NumFeatures, Num0);
launcher = tv::cuda::Launch(blocks, threads, cudastream);
}}
launcher(backward_implicit_gemm_kernel<T>, out.data_ptr<const T>(), in.data_ptr<const T>(),
dout.data_ptr<const T>(), din.data_ptr<T>(),
inds.data_ptr<const int>(), out.dim(1), inds.dim(0), inds.dim(1));
}});
""")
return code
class IndiceMaxPoolCPU(pccm.Class):
def __init__(self):
super().__init__()
self.add_dependency(TensorView)
@pccm.static_function
def forward(self):
code = pccm.FunctionCode()
code.arg("out", "tv::Tensor")
code.arg("in", "tv::Tensor")
code.arg("out_inds", "tv::Tensor")
code.arg("in_inds", "tv::Tensor")
code.arg("stream", "std::uintptr_t", "0")
code.raw(f"""
int nhot = out_inds.dim(0);
int num_features = in.dim(1);
tv::dispatch<float, double>(out.dtype(), [&](auto I){{
using T = TV_DECLTYPE(I);
auto out_features = out.data_ptr<T>();
auto in_features = in.data_ptr<const T>();
auto in_indices = in_inds.data_ptr<const int>();
auto out_indices = out_inds.data_ptr<const int>();
for (int i = 0; i < nhot; ++i) {{
int in_idx = in_indices[i];
int out_idx = out_indices[i];
auto in_ptr = in_features + in_idx * num_features;
auto out_ptr = out_features + out_idx * num_features;
for (int j = 0; j < num_features; ++j) {{
auto in = in_ptr[j];
auto out = out_ptr[j];
if (in > out){{
out_ptr[j] = in;
}}
}}
}}
}});
""")
return code
@pccm.static_function
def backward(self):
code = pccm.FunctionCode()
code.arg("out", "tv::Tensor")
code.arg("in", "tv::Tensor")
code.arg("dout", "tv::Tensor")
code.arg("din", "tv::Tensor")
code.arg("out_inds", "tv::Tensor")
code.arg("in_inds", "tv::Tensor")
code.arg("stream", "std::uintptr_t", "0")
code.raw(f"""
int nhot = out_inds.dim(0);
int num_features = in.dim(1);
tv::dispatch<float, double>(out.dtype(), [&](auto I){{
using T = TV_DECLTYPE(I);
auto out_features = out.data_ptr<const T>();
auto in_features = in.data_ptr<const T>();
auto dout_features = dout.data_ptr<const T>();
auto din_features = din.data_ptr<T>();
auto in_indices = in_inds.data_ptr<const int>();
auto out_indices = out_inds.data_ptr<const int>();
for (int i = 0; i < nhot; ++i) {{
int in_idx_offset = in_indices[i] * num_features;
int out_idx_offset = out_indices[i] * num_features;
auto in_ptr = in_features + in_idx_offset;
auto out_ptr = out_features + out_idx_offset;
auto din_ptr = din_features + in_idx_offset;
auto dout_ptr = dout_features + out_idx_offset;
for (int j = 0; j < num_features; ++j) {{
auto in = in_ptr[j];
auto out = out_ptr[j];
if (in == out){{
din_ptr[j] = din_ptr[j] + dout_ptr[j];
}}
}}
}}
}});
""")
return code
...@@ -23,6 +23,94 @@ from typing import List ...@@ -23,6 +23,94 @@ from typing import List
from cumm.conv.params import ConvProblem from cumm.conv.params import ConvProblem
import numpy as np import numpy as np
class Point2VoxelCommon(pccm.ParameterizedClass):
def __init__(self, dtype: dtypes.DType, ndim: int, zyx: bool = True):
super().__init__()
self.add_dependency(TensorView)
self.dtype = dtype
self.ndim = ndim
self.zyx = zyx
ret_str = f"std::array<int, {self.ndim}>"
retf_str = f"std::array<float, {self.ndim}>"
retf2_str = f"std::array<float, {self.ndim * 2}>"
self.calc_meta_ret = f"std::tuple<{retf_str}, {ret_str}, {ret_str}, {retf2_str}>"
@pccm.pybind.mark
@pccm.static_function
def calc_meta_data(self):
code = pccm.FunctionCode()
code.arg("vsize_xyz", f"std::array<float, {self.ndim}>")
code.arg("coors_range_xyz", f"std::array<float, {self.ndim * 2}>")
code.raw(f"""
std::array<float, {self.ndim}> vsize;
std::array<int, {self.ndim}> grid_size, grid_stride;
std::array<float, {self.ndim * 2}> coors_range;
""")
if self.zyx:
code.raw(f"""
for (int i = 0; i < {self.ndim}; ++i){{
vsize[{self.ndim - 1} - i] = vsize_xyz[i];
coors_range[{self.ndim - 1} - i] = coors_range_xyz[i];
coors_range[{2 * self.ndim - 1} - i] = coors_range_xyz[i + {self.ndim}];
}}
""")
else:
code.raw(f"""
for (int i = 0; i < {self.ndim}; ++i){{
vsize[i] = vsize_xyz[i];
coors_range[i] = coors_range_xyz[i];
coors_range[i + {self.ndim}] = coors_range_xyz[i + {self.ndim}];
}}
""")
code.raw(f"""
int64_t prod = 1;
for (size_t i = 0; i < {self.ndim}; ++i) {{
grid_size[i] =
std::round((coors_range[{self.ndim} + i] - coors_range[i]) / vsize[i]);
}}
for (int i = {self.ndim} - 1; i >= 0; --i) {{
grid_stride[i] = prod;
prod *= grid_size[i];
}}
return std::make_tuple(vsize, grid_size, grid_stride, coors_range);
""")
ret_str = f"std::array<int, {self.ndim}>"
retf_str = f"std::array<float, {self.ndim}>"
retf2_str = f"std::array<float, {self.ndim * 2}>"
return code.ret(f"std::tuple<{retf_str}, {ret_str}, {ret_str}, {retf2_str}>")
@pccm.static_function
def array2tvarray(self):
code = pccm.FunctionCode()
code.targ("T")
code.nontype_targ("N", "size_t")
code.arg("arr", "std::array<T, N>")
code.raw(f"""
tv::array<T, N> tarr;
for (int i = 0; i < N; ++i){{
tarr[i] = arr[i];
}}
return tarr;
""")
return code.ret("tv::array<T, N>")
@pccm.static_function
def tvarray2array(self):
code = pccm.FunctionCode()
code.targ("T")
code.nontype_targ("N", "size_t")
code.arg("arr", "tv::array<T, N>")
code.raw(f"""
std::array<T, N> tarr;
for (int i = 0; i < N; ++i){{
tarr[i] = arr[i];
}}
return tarr;
""")
return code.ret("std::array<T, N>")
class Point2VoxelKernel(pccm.ParameterizedClass, pccm.pybind.PybindClassMixin): class Point2VoxelKernel(pccm.ParameterizedClass, pccm.pybind.PybindClassMixin):
"""this class don't support multi-thread. """this class don't support multi-thread.
...@@ -145,16 +233,63 @@ class Point2VoxelKernel(pccm.ParameterizedClass, pccm.pybind.PybindClassMixin): ...@@ -145,16 +233,63 @@ class Point2VoxelKernel(pccm.ParameterizedClass, pccm.pybind.PybindClassMixin):
""") """)
return code return code
@pccm.cuda.cuda_global_function
def voxel_empty_fill_mean(self):
code = pccm.FunctionCode()
code.arg("voxels", f"{self.dtype} *")
code.arg("num_per_voxel", f"int *")
code.arg("num_voxels", f"int")
code.arg("num_points_per_voxel", f"int")
code.arg("num_voxel_features", f"int")
code.raw(f"""
int voxel_stride = num_points_per_voxel * num_voxel_features;
for (int i : tv::KernelLoopX<int>(num_voxels)){{
int count = min(num_points_per_voxel, num_per_voxel[i]);
num_per_voxel[i] = count;
for (int j = 0; j < num_voxel_features; ++j){{
auto voxel_ptr = voxels + i * voxel_stride + j;
{self.dtype} sum_val = 0;
for (int k = 0; k < count; ++k){{
sum_val += voxel_ptr[0];
voxel_ptr += num_voxel_features;
}}
sum_val = count == 0 ? 0 : sum_val / count;
for (int k = count; k < num_points_per_voxel; ++k){{
voxel_ptr[0] = sum_val;
voxel_ptr += num_voxel_features;
}}
}}
}}
""")
return code
@pccm.cuda.cuda_global_function
def limit_num_per_voxel_value(self):
code = pccm.FunctionCode()
code.arg("num_per_voxel", f"int *")
code.arg("num_voxels, num_points_per_voxel", f"int")
code.raw(f"""
for (int i : tv::KernelLoopX<int>(num_voxels)){{
int count = min(num_points_per_voxel, num_per_voxel[i]);
num_per_voxel[i] = count;
}}
""")
return code
class Point2Voxel(pccm.ParameterizedClass, pccm.pybind.PybindClassMixin): class Point2Voxel(pccm.ParameterizedClass, pccm.pybind.PybindClassMixin):
def __init__(self, dtype: dtypes.DType, ndim: int, zyx: bool = True): def __init__(self, dtype: dtypes.DType, ndim: int, zyx: bool = True):
super().__init__() super().__init__()
self.add_dependency(TensorView) self.add_dependency(TensorView)
self.p2v_c = Point2VoxelCommon(dtype, ndim, zyx)
self.add_param_class("p2v_c", self.p2v_c, "Point2VoxelCommon")
layout = TensorGeneric(ndim, True) layout = TensorGeneric(ndim, True)
self.add_param_class("layout_ns", layout, "Layout") self.add_param_class("layout_ns", layout, "Layout")
self.dtype = dtype self.dtype = dtype
self.ndim = ndim self.ndim = ndim
self.zyx = zyx self.zyx = zyx
cuda_funcs = [self.point_to_voxel_hash] cuda_funcs = [self.point_to_voxel_hash, self.point_to_voxel_hash_static]
self.add_impl_only_param_class(cuda_funcs, "kernel", Point2VoxelKernel(dtype, ndim, layout, zyx)) self.add_impl_only_param_class(cuda_funcs, "kernel", Point2VoxelKernel(dtype, ndim, layout, zyx))
self.add_pybind_member("hashdata", "tv::Tensor", readwrite=False, pyanno="cumm.tensorview.Tensor") self.add_pybind_member("hashdata", "tv::Tensor", readwrite=False, pyanno="cumm.tensorview.Tensor")
...@@ -230,8 +365,25 @@ class Point2Voxel(pccm.ParameterizedClass, pccm.pybind.PybindClassMixin): ...@@ -230,8 +365,25 @@ class Point2Voxel(pccm.ParameterizedClass, pccm.pybind.PybindClassMixin):
code = pccm.FunctionCode() code = pccm.FunctionCode()
code.arg("points", "tv::Tensor") code.arg("points", "tv::Tensor")
code.arg("clear_voxels", "bool", "true") code.arg("clear_voxels", "bool", "true")
code.arg("empty_mean", "bool", "false")
code.arg("stream_int", f"std::uintptr_t", "0")
code.raw(f""" code.raw(f"""
int64_t expected_hash_data_num = points.dim(0) * 2;
if (hashdata.dim(0) < expected_hash_data_num){{
hashdata = tv::zeros({{expected_hash_data_num}}, tv::custom128, 0);
}}
if (point_indice_data.dim(0) < points.dim(0)){{
point_indice_data = tv::zeros({{points.dim(0)}}, tv::int64, 0);
}}
return point_to_voxel_hash_static(points, voxels, indices, num_per_voxel,
hashdata, point_indice_data, Point2VoxelCommon::tvarray2array(vsize),
Point2VoxelCommon::tvarray2array(grid_size), Point2VoxelCommon::tvarray2array(grid_stride),
Point2VoxelCommon::tvarray2array(coors_range), clear_voxels, empty_mean, stream_int);
""")
return code.ret("std::tuple<tv::Tensor, tv::Tensor, tv::Tensor>")
code.raw(f"""
TV_ASSERT_INVALID_ARG(points.ndim() == 2 && points.dim(1) >= {self.ndim}, "error"); TV_ASSERT_INVALID_ARG(points.ndim() == 2 && points.dim(1) >= {self.ndim}, "error");
using V = int64_t; using V = int64_t;
using KeyType = int64_t; using KeyType = int64_t;
...@@ -288,6 +440,86 @@ class Point2Voxel(pccm.ParameterizedClass, pccm.pybind.PybindClassMixin): ...@@ -288,6 +440,86 @@ class Point2Voxel(pccm.ParameterizedClass, pccm.pybind.PybindClassMixin):
return code.ret("std::tuple<tv::Tensor, tv::Tensor, tv::Tensor>") return code.ret("std::tuple<tv::Tensor, tv::Tensor, tv::Tensor>")
@pccm.pybind.mark
@pccm.cuda.static_function
def point_to_voxel_hash_static(self):
code = pccm.FunctionCode()
code.arg("points", "tv::Tensor")
code.arg("voxels, indices, num_per_voxel, hashdata, point_indice_data", "tv::Tensor")
code.arg("vsize", f"std::array<float, {self.ndim}>")
code.arg("grid_size, grid_stride", f"std::array<int, {self.ndim}>")
code.arg("coors_range", f"std::array<float, {self.ndim * 2}>")
code.arg("clear_voxels", "bool", "true")
code.arg("empty_mean", "bool", "false")
code.arg("stream_int", f"std::uintptr_t", "0")
code.raw(f"""
auto vsize_tv = Point2VoxelCommon::array2tvarray(vsize);
auto grid_size_tv = Point2VoxelCommon::array2tvarray(grid_size);
auto grid_stride_tv = Point2VoxelCommon::array2tvarray(grid_stride);
auto coors_range_tv = Point2VoxelCommon::array2tvarray(coors_range);
auto custream = reinterpret_cast<cudaStream_t>(stream_int);
auto ctx = tv::Context();
ctx.set_cuda_stream(custream);
TV_ASSERT_INVALID_ARG(points.ndim() == 2 && points.dim(1) >= {self.ndim}, "error");
using V = int64_t;
using KeyType = int64_t;
constexpr KeyType kEmptyKey = std::numeric_limits<KeyType>::max();
if (clear_voxels){{
voxels.zero_(ctx);
}}
using table_t =
tv::hash::LinearHashTable<KeyType, V, tv::hash::Murmur3Hash<KeyType>,
kEmptyKey, false>;
using pair_t = typename table_t::value_type;
// int64_t expected_hash_data_num = int64_t(tv::hash::align_to_power2(points.dim(0) * 2));
int64_t expected_hash_data_num = points.dim(0) * 2;
TV_ASSERT_RT_ERR(hashdata.dim(0) >= expected_hash_data_num, "hash table too small")
TV_ASSERT_RT_ERR(point_indice_data.dim(0) >= points.dim(0), "point_indice_data too small")
// auto timer = tv::CudaContextTimer<>();
num_per_voxel.zero_(ctx);
table_t hash = table_t(hashdata.data_ptr<pair_t>(), expected_hash_data_num);
hash.clear(custream);
auto launcher = tv::cuda::Launch(points.dim(0), custream);
launcher(kernel::build_hash_table<table_t>, hash, points.data_ptr<const {self.dtype}>(),
point_indice_data.data_ptr<int64_t>(),
points.dim(1), vsize_tv, coors_range_tv, grid_size_tv, grid_stride_tv, points.dim(0));
auto table_launcher = tv::cuda::Launch(hash.size(), custream);
tv::Tensor count = tv::zeros({{1}}, tv::int32, 0);
Layout layout = Layout::from_shape(grid_size_tv);
table_launcher(kernel::assign_table<table_t>, hash, indices.data_ptr<int>(),
count.data_ptr<int>(),
layout, voxels.dim(0));
auto count_cpu = count.cpu();
int count_val = count_cpu.item<int32_t>();
// tv::ssprint("assign_table", timer.report());
launcher(kernel::generate_voxel<table_t>, hash, points.data_ptr<const {self.dtype}>(),
point_indice_data.data_ptr<const int64_t>(), voxels.data_ptr<{self.dtype}>(),
num_per_voxel.data_ptr<int>(), points.dim(1), voxels.dim(1),
voxels.dim(0), vsize_tv, coors_range_tv,
grid_size_tv, grid_stride_tv, points.dim(0));
// tv::ssprint("generate_voxel", timer.report());
auto voxel_launcher = tv::cuda::Launch(count_val, custream);
if (empty_mean){{
launcher(kernel::voxel_empty_fill_mean, voxels.data_ptr<{self.dtype}>(),
num_per_voxel.data_ptr<int>(), count_val,
voxels.dim(1), voxels.dim(2));
}}else{{
launcher(kernel::limit_num_per_voxel_value, num_per_voxel.data_ptr<int>(), count_val,
voxels.dim(1));
}}
return std::make_tuple(voxels.slice_first_axis(0, count_val),
indices.slice_first_axis(0, count_val),
num_per_voxel.slice_first_axis(0, count_val));
""")
return code.ret("std::tuple<tv::Tensor, tv::Tensor, tv::Tensor>")
class Point2VoxelCPU(pccm.ParameterizedClass, pccm.pybind.PybindClassMixin): class Point2VoxelCPU(pccm.ParameterizedClass, pccm.pybind.PybindClassMixin):
def __init__(self, dtype: dtypes.DType, ndim: int, zyx: bool = True): def __init__(self, dtype: dtypes.DType, ndim: int, zyx: bool = True):
...@@ -298,13 +530,14 @@ class Point2VoxelCPU(pccm.ParameterizedClass, pccm.pybind.PybindClassMixin): ...@@ -298,13 +530,14 @@ class Point2VoxelCPU(pccm.ParameterizedClass, pccm.pybind.PybindClassMixin):
self.dtype = dtype self.dtype = dtype
self.ndim = ndim self.ndim = ndim
self.zyx = zyx self.zyx = zyx
self.p2v_c = Point2VoxelCommon(dtype, ndim, zyx)
self.add_param_class("p2v_c", self.p2v_c, "Point2VoxelCommon")
self.add_pybind_member("densehashdata", "tv::Tensor", readwrite=False, pyanno="cumm.tensorview.Tensor") self.add_pybind_member("densehashdata", "tv::Tensor", readwrite=False, pyanno="cumm.tensorview.Tensor")
self.add_pybind_member("voxels", "tv::Tensor", readwrite=False) self.add_pybind_member("voxels", "tv::Tensor", readwrite=False)
self.add_pybind_member("indices", "tv::Tensor", readwrite=False) self.add_pybind_member("indices", "tv::Tensor", readwrite=False)
self.add_pybind_member("num_per_voxel", "tv::Tensor", readwrite=False) self.add_pybind_member("num_per_voxel", "tv::Tensor", readwrite=False)
self.add_member("mean_per_voxel", "tv::Tensor")
self.add_member("vsize", f"tv::array<float, {self.ndim}>") self.add_member("vsize", f"tv::array<float, {self.ndim}>")
self.add_member("coors_range", f"tv::array<float, {self.ndim * 2}>") self.add_member("coors_range", f"tv::array<float, {self.ndim * 2}>")
...@@ -324,6 +557,18 @@ class Point2VoxelCPU(pccm.ParameterizedClass, pccm.pybind.PybindClassMixin): ...@@ -324,6 +557,18 @@ class Point2VoxelCPU(pccm.ParameterizedClass, pccm.pybind.PybindClassMixin):
""") """)
return code.ret(f"std::array<int, {self.ndim}>") return code.ret(f"std::array<int, {self.ndim}>")
@pccm.pybind.mark
@pccm.static_function
def calc_meta_data(self):
code = pccm.FunctionCode()
code.arg("vsize_xyz", f"std::array<float, {self.ndim}>")
code.arg("coors_range_xyz", f"std::array<float, {self.ndim * 2}>")
code.raw(f"""
return Point2VoxelCommon::calc_meta_data(vsize_xyz, coors_range_xyz);
""")
return code.ret(self.p2v_c.calc_meta_ret)
@pccm.pybind.mark @pccm.pybind.mark
@pccm.constructor @pccm.constructor
def ctor(self): def ctor(self):
...@@ -361,7 +606,6 @@ class Point2VoxelCPU(pccm.ParameterizedClass, pccm.pybind.PybindClassMixin): ...@@ -361,7 +606,6 @@ class Point2VoxelCPU(pccm.ParameterizedClass, pccm.pybind.PybindClassMixin):
voxels = tv::zeros({{max_num_voxels, max_num_points_per_voxel, num_point_features}}, tv::type_v<{self.dtype}>, -1); voxels = tv::zeros({{max_num_voxels, max_num_points_per_voxel, num_point_features}}, tv::type_v<{self.dtype}>, -1);
indices = tv::zeros({{max_num_voxels, {self.ndim}}}, tv::int32, -1); indices = tv::zeros({{max_num_voxels, {self.ndim}}}, tv::int32, -1);
num_per_voxel = tv::zeros({{max_num_voxels}}, tv::int32, -1); num_per_voxel = tv::zeros({{max_num_voxels}}, tv::int32, -1);
mean_per_voxel = tv::zeros({{max_num_voxels, num_point_features}}, tv::DType({self.dtype.tv_dtype}), -1);
tv::TensorShape grid_shape(grid_size.data(), grid_size.data() + {self.ndim}); tv::TensorShape grid_shape(grid_size.data(), grid_size.data() + {self.ndim});
densehashdata = tv::zeros(grid_shape, tv::int32, -1); densehashdata = tv::zeros(grid_shape, tv::int32, -1);
auto densehashdata_ptr = densehashdata.data_ptr<int>(); auto densehashdata_ptr = densehashdata.data_ptr<int>();
...@@ -371,9 +615,14 @@ class Point2VoxelCPU(pccm.ParameterizedClass, pccm.pybind.PybindClassMixin): ...@@ -371,9 +615,14 @@ class Point2VoxelCPU(pccm.ParameterizedClass, pccm.pybind.PybindClassMixin):
""") """)
return code return code
def point_to_voxel_template(self, mean: bool = False): def point_to_voxel_static_template(self, mean: bool = False):
code = pccm.FunctionCode() code = pccm.FunctionCode()
code.arg("points", "tv::Tensor") code.arg("points", "tv::Tensor")
code.arg("voxels, indices, num_per_voxel, densehashdata", "tv::Tensor")
code.arg("vsize", f"std::array<float, {self.ndim}>")
code.arg("grid_size, grid_stride", f"std::array<int, {self.ndim}>")
code.arg("coors_range", f"std::array<float, {self.ndim * 2}>")
code.arg("clear_voxels", "bool", "true") code.arg("clear_voxels", "bool", "true")
point_xyz = f"{self.ndim - 1} - j" point_xyz = f"{self.ndim - 1} - j"
...@@ -386,14 +635,6 @@ class Point2VoxelCPU(pccm.ParameterizedClass, pccm.pybind.PybindClassMixin): ...@@ -386,14 +635,6 @@ class Point2VoxelCPU(pccm.ParameterizedClass, pccm.pybind.PybindClassMixin):
if (clear_voxels){{ if (clear_voxels){{
voxels.zero_(); voxels.zero_();
}} }}
""")
if mean:
code.raw(f"mean_per_voxel.zero_();")
code.raw(f"auto means_rw = mean_per_voxel.tview<{self.dtype}, 2>();")
else:
code.raw(f"auto means_rw = mean_per_voxel.tview<{self.dtype}, 2>();")
code.raw(f"""
int res_voxel_num = 0; int res_voxel_num = 0;
int num_features = points.dim(1); int num_features = points.dim(1);
auto N = points.dim(0); auto N = points.dim(0);
...@@ -442,21 +683,27 @@ class Point2VoxelCPU(pccm.ParameterizedClass, pccm.pybind.PybindClassMixin): ...@@ -442,21 +683,27 @@ class Point2VoxelCPU(pccm.ParameterizedClass, pccm.pybind.PybindClassMixin):
voxels_rw(voxelidx, num, k) = points_rw(i, k); voxels_rw(voxelidx, num, k) = points_rw(i, k);
}} }}
num_points_per_voxel_rw(voxelidx) += 1; num_points_per_voxel_rw(voxelidx) += 1;
if TV_IF_CONSTEXPR ({pccm.boolean(mean)}){{
for (int k = 0; k < num_features; ++k) {{
means_rw(voxelidx, k) +=
(points_rw(i, k) - means_rw(voxelidx, k)) / {self.dtype}(num + 1);
}}
}}
}} }}
}} }}
std::vector<{self.dtype}> mean_value(num_features);
for (int i = 0; i < voxel_num; ++i) {{ for (int i = 0; i < voxel_num; ++i) {{
coor_to_voxelidx_rw({codeops.unpack("coors_rw", range(self.ndim), left="(i, ", right=")")}) = -1; coor_to_voxelidx_rw({codeops.unpack("coors_rw", range(self.ndim), left="(i, ", right=")")}) = -1;
if TV_IF_CONSTEXPR ({pccm.boolean(mean)}){{ if TV_IF_CONSTEXPR ({pccm.boolean(mean)}){{
num = num_points_per_voxel_rw(i); num = num_points_per_voxel_rw(i);
for (int j = num; j < max_num_points_per_voxel; ++j) {{ if (num > 0){{
for (int k = 0; k < num_features; ++k) {{ mean_value.clear();
voxels_rw(i, j, k) = means_rw(i, k); for (int j = 0; j < num; ++j) {{
for (int k = 0; k < num_features; ++k) {{
mean_value[k] += voxels_rw(i, j, k);
}}
}}
for (int k = 0; k < num_features; ++k){{
mean_value[k] /= num;
}}
for (int j = num; j < max_num_points_per_voxel; ++j) {{
for (int k = 0; k < num_features; ++k) {{
voxels_rw(i, j, k) = mean_value[k];
}}
}} }}
}} }}
}} }}
...@@ -469,13 +716,70 @@ class Point2VoxelCPU(pccm.ParameterizedClass, pccm.pybind.PybindClassMixin): ...@@ -469,13 +716,70 @@ class Point2VoxelCPU(pccm.ParameterizedClass, pccm.pybind.PybindClassMixin):
""") """)
return code.ret("std::tuple<tv::Tensor, tv::Tensor, tv::Tensor>") return code.ret("std::tuple<tv::Tensor, tv::Tensor, tv::Tensor>")
@pccm.static_function
def array2tvarray(self):
code = pccm.FunctionCode()
code.targ("T")
code.nontype_targ("N", "size_t")
code.arg("arr", "std::array<T, N>")
code.raw(f"""
tv::array<T, N> tarr;
for (int i = 0; i < N; ++i){{
tarr[i] = arr[i];
}}
return tarr;
""")
return code.ret("tv::array<T, N>")
@pccm.static_function
def tvarray2array(self):
code = pccm.FunctionCode()
code.targ("T")
code.nontype_targ("N", "size_t")
code.arg("arr", "tv::array<T, N>")
code.raw(f"""
std::array<T, N> tarr;
for (int i = 0; i < N; ++i){{
tarr[i] = arr[i];
}}
return tarr;
""")
return code.ret("std::array<T, N>")
@pccm.pybind.mark
@pccm.static_function
def point_to_voxel_static(self):
return self.point_to_voxel_static_template(False)
@pccm.pybind.mark
@pccm.static_function
def point_to_voxel_empty_mean_static(self):
return self.point_to_voxel_static_template(True)
@pccm.pybind.mark @pccm.pybind.mark
@pccm.member_function @pccm.member_function
def point_to_voxel(self): def point_to_voxel(self):
return self.point_to_voxel_template(False) code = pccm.FunctionCode()
code.arg("points", "tv::Tensor")
code.arg("clear_voxels", "bool", "true")
code.raw(f"""
return point_to_voxel_static(points, voxels, indices, num_per_voxel, densehashdata,
tvarray2array(vsize),
tvarray2array(grid_size), tvarray2array(grid_stride),
tvarray2array(coors_range), clear_voxels);
""")
return code.ret("std::tuple<tv::Tensor, tv::Tensor, tv::Tensor>")
@pccm.pybind.mark @pccm.pybind.mark
@pccm.member_function @pccm.member_function
def point_to_voxel_empty_mean(self): def point_to_voxel_empty_mean(self):
return self.point_to_voxel_template(True) code = pccm.FunctionCode()
code.arg("points", "tv::Tensor")
code.arg("clear_voxels", "bool", "true")
code.raw(f"""
return point_to_voxel_empty_mean_static(points, voxels, indices, num_per_voxel,
densehashdata, tvarray2array(vsize),
tvarray2array(grid_size), tvarray2array(grid_stride),
tvarray2array(coors_range), clear_voxels);
""")
return code.ret("std::tuple<tv::Tensor, tv::Tensor, tv::Tensor>")
# Copyright 2021 Yan Yan # Copyright 2021 Yan Yan
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...@@ -23,14 +23,15 @@ from torch.nn import init ...@@ -23,14 +23,15 @@ from torch.nn import init
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from spconv import pytorch as spconv from spconv import pytorch as spconv
from spconv.algo import ConvAlgo from spconv.core import ConvAlgo
import spconv.pytorch.functional as Fsp import spconv.pytorch.functional as Fsp
from spconv.pytorch import ops from spconv.pytorch import ops
from spconv.pytorch.core import IndiceData, SparseConvTensor from spconv.pytorch.core import IndiceData, SparseConvTensor, ImplicitGemmIndiceData
from spconv.pytorch.modules import SparseModule from spconv.pytorch.modules import SparseModule
from spconv.constants import FILTER_HWIO from spconv.constants import FILTER_HWIO
def _calculate_fan_in_and_fan_out_hwio(tensor):
def _calculate_fan_in_and_fan_out_hwio(tensor, algo: ConvAlgo):
dimensions = tensor.ndimension() dimensions = tensor.ndimension()
if dimensions < 2: if dimensions < 2:
raise ValueError( raise ValueError(
...@@ -41,15 +42,24 @@ def _calculate_fan_in_and_fan_out_hwio(tensor): ...@@ -41,15 +42,24 @@ def _calculate_fan_in_and_fan_out_hwio(tensor):
fan_in = tensor.size(-2) fan_in = tensor.size(-2)
fan_out = tensor.size(-1) fan_out = tensor.size(-1)
else: else:
if FILTER_HWIO: if algo == ConvAlgo.Native:
num_input_fmaps = tensor.size(-2) if FILTER_HWIO:
num_output_fmaps = tensor.size(-1) num_input_fmaps = tensor.size(-2)
num_output_fmaps = tensor.size(-1)
else:
num_input_fmaps = tensor.size(-1)
num_output_fmaps = tensor.size(-2)
receptive_field_size = 1
if tensor.dim() > 2:
receptive_field_size = tensor[..., 0, 0].numel()
else: else:
num_input_fmaps = tensor.size(-1) num_input_fmaps = tensor.size(-1)
num_output_fmaps = tensor.size(-2) num_output_fmaps = tensor.size(0)
receptive_field_size = 1 receptive_field_size = 1
if tensor.dim() > 2: if tensor.dim() > 2:
receptive_field_size = tensor[..., 0, 0].numel() receptive_field_size = int(np.prod(tensor.shape[1:-1]))
fan_in = num_input_fmaps * receptive_field_size fan_in = num_input_fmaps * receptive_field_size
fan_out = num_output_fmaps * receptive_field_size fan_out = num_output_fmaps * receptive_field_size
...@@ -59,29 +69,28 @@ def _calculate_fan_in_and_fan_out_hwio(tensor): ...@@ -59,29 +69,28 @@ def _calculate_fan_in_and_fan_out_hwio(tensor):
class SparseConvolution(SparseModule): class SparseConvolution(SparseModule):
__constants__ = [ __constants__ = [
'stride', 'padding', 'dilation', 'groups', 'bias', 'subm', 'inverse', 'stride', 'padding', 'dilation', 'groups', 'bias', 'subm', 'inverse',
'transposed', 'output_padding', 'fused_bn' 'transposed', 'output_padding'
] ]
def __init__(self, def __init__(self,
ndim: int, ndim: int,
in_channels: int, in_channels: int,
out_channels: int, out_channels: int,
kernel_size: Union[int, List[int], Tuple[int, ...]]=3, kernel_size: Union[int, List[int], Tuple[int, ...]] = 3,
stride: Union[int, List[int], Tuple[int, ...]]=1, stride: Union[int, List[int], Tuple[int, ...]] = 1,
padding: Union[int, List[int], Tuple[int, ...]]=0, padding: Union[int, List[int], Tuple[int, ...]] = 0,
dilation: Union[int, List[int], Tuple[int, ...]]=1, dilation: Union[int, List[int], Tuple[int, ...]] = 1,
groups: Union[int, List[int], Tuple[int, ...]]=1, groups: Union[int, List[int], Tuple[int, ...]] = 1,
bias: bool=True, bias: bool = True,
subm: bool=False, subm: bool = False,
output_padding: Union[int, List[int], Tuple[int, ...]]=0, output_padding: Union[int, List[int], Tuple[int, ...]] = 0,
transposed: bool=False, transposed: bool = False,
inverse: bool=False, inverse: bool = False,
indice_key: Optional[str]=None, indice_key: Optional[str] = None,
fused_bn: bool=False, algo: Optional[ConvAlgo] = None,
algo: ops.ConvAlgo=ops.ConvAlgo.Native,
name=None): name=None):
super(SparseConvolution, self).__init__(name=name) super(SparseConvolution, self).__init__(name=name)
assert groups == 1 assert groups == 1, "don't support groups for now"
if not isinstance(kernel_size, (list, tuple)): if not isinstance(kernel_size, (list, tuple)):
kernel_size = [kernel_size] * ndim kernel_size = [kernel_size] * ndim
if not isinstance(stride, (list, tuple)): if not isinstance(stride, (list, tuple)):
...@@ -96,7 +105,8 @@ class SparseConvolution(SparseModule): ...@@ -96,7 +105,8 @@ class SparseConvolution(SparseModule):
self.in_channels = in_channels self.in_channels = in_channels
self.out_channels = out_channels self.out_channels = out_channels
self.kernel_size = kernel_size self.kernel_size = kernel_size
self.conv1x1 = np.prod(kernel_size) == 1 kv = int(np.prod(kernel_size))
self.conv1x1 = kv == 1
self.stride = stride self.stride = stride
self.padding = padding self.padding = padding
self.dilation = dilation self.dilation = dilation
...@@ -106,31 +116,77 @@ class SparseConvolution(SparseModule): ...@@ -106,31 +116,77 @@ class SparseConvolution(SparseModule):
self.groups = groups self.groups = groups
self.subm = subm self.subm = subm
self.indice_key = indice_key self.indice_key = indice_key
self.fused_bn = fused_bn if algo is None:
if kv <= 32:
if kv < 8:
algo = ConvAlgo.MaskImplicitGemm
else:
algo = ConvAlgo.MaskImplicitGemm
else:
algo = ConvAlgo.Native
if kv > 32:
assert algo == ConvAlgo.Native, "implicit gemm don't support kv >= 32 for now"
self.algo = algo self.algo = algo
if FILTER_HWIO: # self.algo = ConvAlgo.Native
self.weight = Parameter( if self.algo == ConvAlgo.Native:
torch.Tensor(*kernel_size, in_channels, out_channels)) if FILTER_HWIO:
# RSCK
self.weight = Parameter(
torch.Tensor(*kernel_size, in_channels, out_channels))
else:
# RSKC
self.weight = Parameter(
torch.Tensor(*kernel_size, out_channels, in_channels))
else: else:
# KRSC
self.weight = Parameter( self.weight = Parameter(
torch.Tensor(*kernel_size, out_channels, in_channels)) torch.Tensor(out_channels, *kernel_size, in_channels))
if bias: if bias:
self.bias = Parameter(torch.Tensor(out_channels)) self.bias = Parameter(torch.Tensor(out_channels))
else: else:
self.register_parameter('bias', None) self.register_parameter('bias', None)
self.reset_parameters() self.reset_parameters()
def extra_repr(self):
s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}'
', stride={stride}')
if self.padding != (0, ) * len(self.padding):
s += ', padding={padding}'
if self.dilation != (1, ) * len(self.dilation):
s += ', dilation={dilation}'
if self.output_padding != (0, ) * len(self.output_padding):
s += ', output_padding={output_padding}'
if self.groups != 1:
s += ', groups={groups}'
if self.bias is None:
s += ', bias=False'
if self.algo is not None:
s += f', algo={self.algo}'
return s.format(**self.__dict__)
def reset_parameters(self): def reset_parameters(self):
n = self.in_channels n = self.in_channels
# init.uniform_(self.weight, 0, 0.001) # following commented code is used to make weight different layout have same value
# if self.algo != ConvAlgo.Native:
# weight2 = self.weight.data.permute(1, 2, 3, 0,
# 4).contiguous().clone()
# init.uniform_(weight2, 0, 0.001)
# self.weight.data[:] = weight2.permute(3, 0, 1, 2, 4)
# else:
# init.uniform_(self.weight, 0, 0.001)
init.kaiming_uniform_(self.weight, a=math.sqrt(0.005)) init.kaiming_uniform_(self.weight, a=math.sqrt(0.005))
if self.bias is not None: if self.bias is not None:
fan_in, _ = _calculate_fan_in_and_fan_out_hwio(self.weight) fan_in, _ = _calculate_fan_in_and_fan_out_hwio(
self.weight, self.algo)
bound = 1 / math.sqrt(fan_in) bound = 1 / math.sqrt(fan_in)
init.uniform_(self.bias, -bound, bound) init.uniform_(self.bias, -bound, bound)
def forward(self, input: SparseConvTensor): def forward(self, input: SparseConvTensor):
assert isinstance(input, SparseConvTensor) assert isinstance(input, SparseConvTensor)
assert input.features.shape[
1] == self.in_channels, "channel size mismatch"
features = input.features features = input.features
device = features.device device = features.device
indices = input.indices indices = input.indices
...@@ -188,79 +244,161 @@ class SparseConvolution(SparseModule): ...@@ -188,79 +244,161 @@ class SparseConvolution(SparseModule):
features += self.bias features += self.bias
out_tensor = out_tensor.replace_feature(features) out_tensor = out_tensor.replace_feature(features)
return out_tensor return out_tensor
datas = input.find_indice_pair(self.indice_key) indice_dict = input.indice_dict.copy()
if self.inverse:
assert datas is not None and self.indice_key is not None algo = self.algo
assert datas.is_subm is False, "inverse conv can only be used with standard conv and pool ops." if self.indice_key is not None :
outids = datas.indices datas = input.find_indice_pair(self.indice_key)
indice_pairs = datas.indice_pairs if datas is not None:
indice_pair_num = datas.indice_pair_num msg = "due to limitation of pytorch, you must provide same algo to layers share same indice key."
out_spatial_shape = datas.out_spatial_shape assert algo == datas.algo, msg
assert indice_pair_num.shape[0] == np.prod( # algo = datas.algo
self.kernel_size if algo == ConvAlgo.Native:
), "inverse conv must have same kernel size as its couple conv" datas = input.find_indice_pair(self.indice_key)
else: if datas is not None:
if self.indice_key is not None and datas is not None: assert isinstance(datas, IndiceData)
outids = datas.out_indices if self.inverse:
assert datas is not None and self.indice_key is not None
assert datas.is_subm is False, "inverse conv can only be used with standard conv and pool ops."
outids = datas.indices
indice_pairs = datas.indice_pairs indice_pairs = datas.indice_pairs
indice_pair_num = datas.indice_pair_num indice_pair_num = datas.indice_pair_num
out_spatial_shape = datas.out_spatial_shape
assert indice_pair_num.shape[0] == np.prod(
self.kernel_size
), "inverse conv must have same kernel size as its couple conv"
else: else:
if input.benchmark: if self.indice_key is not None and datas is not None:
torch.cuda.synchronize() outids = datas.out_indices
t = time.time() indice_pairs = datas.indice_pairs
outids, indice_pairs, indice_pair_num = ops.get_indice_pairs( indice_pair_num = datas.indice_pair_num
indices, else:
batch_size, if input.benchmark:
spatial_shape, torch.cuda.synchronize()
self.algo, t = time.time()
self.kernel_size, outids, indice_pairs, indice_pair_num = ops.get_indice_pairs(
self.stride, indices, batch_size, spatial_shape, algo,
self.padding, self.kernel_size, self.stride, self.padding,
self.dilation, self.dilation, self.output_padding, self.subm,
self.output_padding, self.transposed)
self.subm, if input.benchmark:
self.transposed) torch.cuda.synchronize()
if input.benchmark: interval = time.time() - t
torch.cuda.synchronize() out_tensor.benchmark_record[
interval = time.time() - t self.name]["indice_gen_time"].append(interval)
out_tensor.benchmark_record[
self.name]["indice_gen_time"].append(interval) indice_data = IndiceData(outids,
indices,
indice_data = IndiceData(outids, indices, indice_pairs, indice_pairs,
indice_pair_num, spatial_shape, is_subm=self.subm) indice_pair_num,
input.indice_dict[self.indice_key] = indice_data spatial_shape,
if input.benchmark: is_subm=self.subm,
torch.cuda.synchronize() algo=algo)
t = time.time() if self.indice_key is not None:
msg = f"your indice key {self.indice_key} already exists in this sparse tensor."
if self.fused_bn: assert self.indice_key not in indice_dict, msg
raise NotImplementedError indice_dict[self.indice_key] = indice_data
assert self.bias is not None if input.benchmark:
out_features = ops.fused_indice_conv(features, self.weight, torch.cuda.synchronize()
self.bias, t = time.time()
indice_pairs.to(device), indice_pairs_calc = indice_pairs
indice_pair_num, if indice_pairs.device != features.device:
outids.shape[0], self.inverse, indice_pairs_calc = indice_pairs.to(features.device)
self.subm)
else:
if self.subm: if self.subm:
out_features = Fsp.indice_subm_conv(features, self.weight, out_features = Fsp.indice_subm_conv(features, self.weight,
indice_pairs.to(device), indice_pairs_calc,
indice_pair_num, indice_pair_num,
outids.shape[0], self.algo) outids.shape[0], algo)
else: else:
if self.inverse: if self.inverse:
out_features = Fsp.indice_inverse_conv( out_features = Fsp.indice_inverse_conv(
features, self.weight, indice_pairs.to(device), features, self.weight, indice_pairs_calc,
indice_pair_num, outids.shape[0], self.algo) indice_pair_num, outids.shape[0], algo)
else: else:
out_features = Fsp.indice_conv(features, self.weight, out_features = Fsp.indice_conv(features, self.weight,
indice_pairs.to(device), indice_pairs_calc,
indice_pair_num, indice_pair_num,
outids.shape[0], self.algo) outids.shape[0], algo)
if self.bias is not None: else:
out_features += self.bias datas = input.find_indice_pair(self.indice_key)
if datas is not None:
assert isinstance(datas, ImplicitGemmIndiceData)
if self.inverse:
assert datas is not None and self.indice_key is not None
assert datas.is_subm is False, "inverse conv can only be used with standard conv and pool ops."
outids = datas.indices
pair_fwd = datas.pair_bwd
pair_bwd = datas.pair_fwd
pair_mask_fwd_splits = datas.pair_mask_bwd_splits
pair_mask_bwd_splits = datas.pair_mask_fwd_splits
mask_argsort_fwd_splits = datas.mask_argsort_bwd_splits
mask_argsort_bwd_splits = datas.mask_argsort_fwd_splits
masks = datas.masks
else:
if self.indice_key is not None and datas is not None:
outids = datas.out_indices
pair_fwd = datas.pair_fwd
pair_bwd = datas.pair_bwd
pair_mask_fwd_splits = datas.pair_mask_fwd_splits
pair_mask_bwd_splits = datas.pair_mask_bwd_splits
mask_argsort_fwd_splits = datas.mask_argsort_fwd_splits
mask_argsort_bwd_splits = datas.mask_argsort_bwd_splits
masks = datas.masks
else:
res = ops.get_indice_pairs_implicit_gemm(
indices,
batch_size,
spatial_shape,
algo,
ksize=self.kernel_size,
stride=self.stride,
padding=self.padding,
dilation=self.dilation,
out_padding=self.output_padding,
subm=self.subm,
transpose=self.transposed,
is_train=self.training,
alloc=input.thrust_allocator)
outids = res[0]
num_inds_per_loc = res[1]
pair_fwd = res[2]
pair_bwd = res[3]
pair_mask_fwd_splits = res[4]
pair_mask_bwd_splits = res[5]
mask_argsort_fwd_splits = res[6]
mask_argsort_bwd_splits = res[7]
masks = res[8]
if self.indice_key is not None:
indice_data = ImplicitGemmIndiceData(
outids,
indices,
pair_fwd,
pair_bwd,
pair_mask_fwd_splits=pair_mask_fwd_splits,
pair_mask_bwd_splits=pair_mask_bwd_splits,
mask_argsort_fwd_splits=mask_argsort_fwd_splits,
mask_argsort_bwd_splits=mask_argsort_bwd_splits,
masks=masks,
is_subm=self.subm,
out_spatial_shape=out_spatial_shape,
algo=algo)
msg = f"your indice key {self.indice_key} already exists in this sparse tensor."
assert self.indice_key not in indice_dict, msg
indice_dict[self.indice_key] = indice_data
if input.benchmark:
torch.cuda.synchronize()
t = time.time()
num_activate_out = outids.shape[0]
out_features = Fsp.implicit_gemm(
features, self.weight, pair_fwd, pair_bwd,
pair_mask_fwd_splits, pair_mask_bwd_splits,
mask_argsort_fwd_splits, mask_argsort_bwd_splits,
num_activate_out, masks, self.training, self.subm)
if self.bias is not None:
out_features += self.bias
if input.benchmark: if input.benchmark:
torch.cuda.synchronize() torch.cuda.synchronize()
interval = time.time() - t interval = time.time() - t
...@@ -271,9 +409,11 @@ class SparseConvolution(SparseModule): ...@@ -271,9 +409,11 @@ class SparseConvolution(SparseModule):
out_features.shape[0]) out_features.shape[0])
out_tensor = out_tensor.replace_feature(out_features) out_tensor = out_tensor.replace_feature(out_features)
out_tensor.indices = outids out_tensor.indices = outids
out_tensor.indice_dict = indice_dict
out_tensor.spatial_shape = out_spatial_shape out_tensor.spatial_shape = out_spatial_shape
return out_tensor return out_tensor
class SparseConv1d(SparseConvolution): class SparseConv1d(SparseConvolution):
def __init__(self, def __init__(self,
in_channels, in_channels,
...@@ -285,7 +425,7 @@ class SparseConv1d(SparseConvolution): ...@@ -285,7 +425,7 @@ class SparseConv1d(SparseConvolution):
groups=1, groups=1,
bias=True, bias=True,
indice_key=None, indice_key=None,
algo=ops.ConvAlgo.Native, algo: Optional[ConvAlgo] = None,
name=None): name=None):
super(SparseConv1d, self).__init__(1, super(SparseConv1d, self).__init__(1,
in_channels, in_channels,
...@@ -312,7 +452,7 @@ class SparseConv2d(SparseConvolution): ...@@ -312,7 +452,7 @@ class SparseConv2d(SparseConvolution):
groups=1, groups=1,
bias=True, bias=True,
indice_key=None, indice_key=None,
algo=ops.ConvAlgo.Native, algo: Optional[ConvAlgo] = None,
name=None): name=None):
super(SparseConv2d, self).__init__(2, super(SparseConv2d, self).__init__(2,
in_channels, in_channels,
...@@ -339,7 +479,7 @@ class SparseConv3d(SparseConvolution): ...@@ -339,7 +479,7 @@ class SparseConv3d(SparseConvolution):
groups=1, groups=1,
bias=True, bias=True,
indice_key=None, indice_key=None,
algo=ops.ConvAlgo.Native, algo: Optional[ConvAlgo] = None,
name=None): name=None):
super(SparseConv3d, self).__init__(3, super(SparseConv3d, self).__init__(3,
in_channels, in_channels,
...@@ -366,7 +506,7 @@ class SparseConv4d(SparseConvolution): ...@@ -366,7 +506,7 @@ class SparseConv4d(SparseConvolution):
groups=1, groups=1,
bias=True, bias=True,
indice_key=None, indice_key=None,
algo=ops.ConvAlgo.Native, algo: Optional[ConvAlgo] = None,
name=None): name=None):
super(SparseConv4d, self).__init__(4, super(SparseConv4d, self).__init__(4,
in_channels, in_channels,
...@@ -393,7 +533,7 @@ class SparseConvTranspose1d(SparseConvolution): ...@@ -393,7 +533,7 @@ class SparseConvTranspose1d(SparseConvolution):
groups=1, groups=1,
bias=True, bias=True,
indice_key=None, indice_key=None,
algo=ops.ConvAlgo.Native, algo: Optional[ConvAlgo] = None,
name=None): name=None):
super(SparseConvTranspose1d, self).__init__(1, super(SparseConvTranspose1d, self).__init__(1,
in_channels, in_channels,
...@@ -421,7 +561,7 @@ class SparseConvTranspose2d(SparseConvolution): ...@@ -421,7 +561,7 @@ class SparseConvTranspose2d(SparseConvolution):
groups=1, groups=1,
bias=True, bias=True,
indice_key=None, indice_key=None,
algo=ops.ConvAlgo.Native, algo: Optional[ConvAlgo] = None,
name=None): name=None):
super(SparseConvTranspose2d, self).__init__(2, super(SparseConvTranspose2d, self).__init__(2,
in_channels, in_channels,
...@@ -449,7 +589,7 @@ class SparseConvTranspose3d(SparseConvolution): ...@@ -449,7 +589,7 @@ class SparseConvTranspose3d(SparseConvolution):
groups=1, groups=1,
bias=True, bias=True,
indice_key=None, indice_key=None,
algo=ops.ConvAlgo.Native, algo: Optional[ConvAlgo] = None,
name=None): name=None):
super(SparseConvTranspose3d, self).__init__(3, super(SparseConvTranspose3d, self).__init__(3,
in_channels, in_channels,
...@@ -465,6 +605,7 @@ class SparseConvTranspose3d(SparseConvolution): ...@@ -465,6 +605,7 @@ class SparseConvTranspose3d(SparseConvolution):
algo=algo, algo=algo,
name=name) name=name)
class SparseConvTranspose4d(SparseConvolution): class SparseConvTranspose4d(SparseConvolution):
def __init__(self, def __init__(self,
in_channels, in_channels,
...@@ -476,7 +617,7 @@ class SparseConvTranspose4d(SparseConvolution): ...@@ -476,7 +617,7 @@ class SparseConvTranspose4d(SparseConvolution):
groups=1, groups=1,
bias=True, bias=True,
indice_key=None, indice_key=None,
algo=ops.ConvAlgo.Native, algo: Optional[ConvAlgo] = None,
name=None): name=None):
super(SparseConvTranspose4d, self).__init__(4, super(SparseConvTranspose4d, self).__init__(4,
in_channels, in_channels,
...@@ -500,7 +641,7 @@ class SparseInverseConv1d(SparseConvolution): ...@@ -500,7 +641,7 @@ class SparseInverseConv1d(SparseConvolution):
kernel_size, kernel_size,
indice_key, indice_key,
bias=True, bias=True,
algo=ops.ConvAlgo.Native, algo: Optional[ConvAlgo] = None,
name=None): name=None):
super(SparseInverseConv1d, self).__init__(1, super(SparseInverseConv1d, self).__init__(1,
in_channels, in_channels,
...@@ -520,7 +661,7 @@ class SparseInverseConv2d(SparseConvolution): ...@@ -520,7 +661,7 @@ class SparseInverseConv2d(SparseConvolution):
kernel_size, kernel_size,
indice_key, indice_key,
bias=True, bias=True,
algo=ops.ConvAlgo.Native, algo: Optional[ConvAlgo] = None,
name=None): name=None):
super(SparseInverseConv2d, self).__init__(2, super(SparseInverseConv2d, self).__init__(2,
in_channels, in_channels,
...@@ -540,7 +681,7 @@ class SparseInverseConv3d(SparseConvolution): ...@@ -540,7 +681,7 @@ class SparseInverseConv3d(SparseConvolution):
kernel_size, kernel_size,
indice_key, indice_key,
bias=True, bias=True,
algo=ops.ConvAlgo.Native, algo: Optional[ConvAlgo] = None,
name=None): name=None):
super(SparseInverseConv3d, self).__init__(3, super(SparseInverseConv3d, self).__init__(3,
in_channels, in_channels,
...@@ -552,6 +693,7 @@ class SparseInverseConv3d(SparseConvolution): ...@@ -552,6 +693,7 @@ class SparseInverseConv3d(SparseConvolution):
algo=algo, algo=algo,
name=name) name=name)
class SparseInverseConv4d(SparseConvolution): class SparseInverseConv4d(SparseConvolution):
def __init__(self, def __init__(self,
in_channels, in_channels,
...@@ -559,7 +701,7 @@ class SparseInverseConv4d(SparseConvolution): ...@@ -559,7 +701,7 @@ class SparseInverseConv4d(SparseConvolution):
kernel_size, kernel_size,
indice_key, indice_key,
bias=True, bias=True,
algo=ops.ConvAlgo.Native, algo: Optional[ConvAlgo] = None,
name=None): name=None):
super(SparseInverseConv4d, self).__init__(4, super(SparseInverseConv4d, self).__init__(4,
in_channels, in_channels,
...@@ -571,6 +713,7 @@ class SparseInverseConv4d(SparseConvolution): ...@@ -571,6 +713,7 @@ class SparseInverseConv4d(SparseConvolution):
algo=algo, algo=algo,
name=name) name=name)
class SubMConv1d(SparseConvolution): class SubMConv1d(SparseConvolution):
def __init__(self, def __init__(self,
in_channels, in_channels,
...@@ -582,7 +725,7 @@ class SubMConv1d(SparseConvolution): ...@@ -582,7 +725,7 @@ class SubMConv1d(SparseConvolution):
groups=1, groups=1,
bias=True, bias=True,
indice_key=None, indice_key=None,
algo=ops.ConvAlgo.Native, algo: Optional[ConvAlgo] = None,
name=None): name=None):
super(SubMConv1d, self).__init__(1, super(SubMConv1d, self).__init__(1,
in_channels, in_channels,
...@@ -610,7 +753,7 @@ class SubMConv2d(SparseConvolution): ...@@ -610,7 +753,7 @@ class SubMConv2d(SparseConvolution):
groups=1, groups=1,
bias=True, bias=True,
indice_key=None, indice_key=None,
algo=ops.ConvAlgo.Native, algo: Optional[ConvAlgo] = None,
name=None): name=None):
super(SubMConv2d, self).__init__(2, super(SubMConv2d, self).__init__(2,
in_channels, in_channels,
...@@ -638,7 +781,7 @@ class SubMConv3d(SparseConvolution): ...@@ -638,7 +781,7 @@ class SubMConv3d(SparseConvolution):
groups=1, groups=1,
bias=True, bias=True,
indice_key=None, indice_key=None,
algo=ops.ConvAlgo.Native, algo: Optional[ConvAlgo] = None,
name=None): name=None):
super(SubMConv3d, self).__init__(3, super(SubMConv3d, self).__init__(3,
in_channels, in_channels,
...@@ -666,7 +809,7 @@ class SubMConv4d(SparseConvolution): ...@@ -666,7 +809,7 @@ class SubMConv4d(SparseConvolution):
groups=1, groups=1,
bias=True, bias=True,
indice_key=None, indice_key=None,
algo=ops.ConvAlgo.Native, algo: Optional[ConvAlgo] = None,
name=None): name=None):
super(SubMConv4d, self).__init__(4, super(SubMConv4d, self).__init__(4,
in_channels, in_channels,
......
# Copyright 2021 Yan Yan # Copyright 2021 Yan Yan
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import List, Optional from typing import List, Optional, Union
import numpy as np import numpy as np
import torch import torch
from spconv.core import ConvAlgo
from spconv.pytorch.constants import PYTORCH_VERSION from spconv.pytorch.constants import PYTORCH_VERSION
from spconv.pytorch.ops import ThrustSortAllocator
if PYTORCH_VERSION >= [1, 8, 0]: if PYTORCH_VERSION >= [1, 8, 0]:
try: try:
...@@ -27,21 +29,48 @@ if PYTORCH_VERSION >= [1, 8, 0]: ...@@ -27,21 +29,48 @@ if PYTORCH_VERSION >= [1, 8, 0]:
from torch.fx.symbolic_trace import ProxyableClassMeta from torch.fx.symbolic_trace import ProxyableClassMeta
SpConvTensorMeta = ProxyableClassMeta SpConvTensorMeta = ProxyableClassMeta
except: except:
class SpConvTensorMeta(type): class SpConvTensorMeta(type):
pass pass
else: else:
class SpConvTensorMeta(type): class SpConvTensorMeta(type):
pass pass
class IndiceData(object): class IndiceData(object):
def __init__(self, out_indices, indices, indice_pairs, indice_pair_num, def __init__(self, out_indices, indices, indice_pairs, indice_pair_num,
out_spatial_shape, is_subm: bool): out_spatial_shape, is_subm: bool, algo: ConvAlgo):
self.out_indices = out_indices self.out_indices = out_indices
self.indices = indices self.indices = indices
self.indice_pairs = indice_pairs self.indice_pairs = indice_pairs
self.indice_pair_num = indice_pair_num self.indice_pair_num = indice_pair_num
self.out_spatial_shape = out_spatial_shape self.out_spatial_shape = out_spatial_shape
self.is_subm = is_subm self.is_subm = is_subm
self.algo = algo
class ImplicitGemmIndiceData(object):
def __init__(self, out_indices: torch.Tensor, indices: torch.Tensor, pair_fwd: torch.Tensor,
pair_bwd: torch.Tensor,
pair_mask_fwd_splits: List[torch.Tensor],
pair_mask_bwd_splits: List[torch.Tensor],
mask_argsort_fwd_splits: List[torch.Tensor],
mask_argsort_bwd_splits: List[torch.Tensor],
masks: List[np.ndarray], out_spatial_shape, is_subm: bool, algo: ConvAlgo):
self.out_indices = out_indices
self.indices = indices
self.pair_fwd = pair_fwd
self.pair_bwd = pair_bwd
self.pair_mask_fwd_splits = pair_mask_fwd_splits
self.pair_mask_bwd_splits = pair_mask_bwd_splits
self.mask_argsort_fwd_splits = mask_argsort_fwd_splits
self.mask_argsort_bwd_splits = mask_argsort_bwd_splits
self.masks = masks
self.out_spatial_shape = out_spatial_shape
self.is_subm = is_subm
self.algo = algo
def scatter_nd(indices, updates, shape): def scatter_nd(indices, updates, shape):
"""pytorch edition of tensorflow scatter_nd. """pytorch edition of tensorflow scatter_nd.
...@@ -58,6 +87,7 @@ def scatter_nd(indices, updates, shape): ...@@ -58,6 +87,7 @@ def scatter_nd(indices, updates, shape):
ret[slices] = updates.view(*output_shape) ret[slices] = updates.view(*output_shape)
return ret return ret
# ProxyableClassMeta is used for TensorRT conversion in future. # ProxyableClassMeta is used for TensorRT conversion in future.
class SparseConvTensor(metaclass=SpConvTensorMeta): class SparseConvTensor(metaclass=SpConvTensorMeta):
def __init__(self, def __init__(self,
...@@ -65,10 +95,11 @@ class SparseConvTensor(metaclass=SpConvTensorMeta): ...@@ -65,10 +95,11 @@ class SparseConvTensor(metaclass=SpConvTensorMeta):
indices: torch.Tensor, indices: torch.Tensor,
spatial_shape: List[int], spatial_shape: List[int],
batch_size: int, batch_size: int,
grid: Optional[torch.Tensor]=None, grid: Optional[torch.Tensor] = None,
voxel_num: Optional[torch.Tensor]=None, voxel_num: Optional[torch.Tensor] = None,
indice_dict: Optional[dict] = None, indice_dict: Optional[dict] = None,
benchmark: bool=False): benchmark: bool = False,
permanent_thrust_allocator: bool = False):
""" """
Args: Args:
features: [num_points, num_features] feature tensor features: [num_points, num_features] feature tensor
...@@ -80,6 +111,12 @@ class SparseConvTensor(metaclass=SpConvTensorMeta): ...@@ -80,6 +111,12 @@ class SparseConvTensor(metaclass=SpConvTensorMeta):
benchmark: whether to enable benchmark. if enabled, all sparse operators will be record to benchmark: whether to enable benchmark. if enabled, all sparse operators will be record to
SparseConvTensor. SparseConvTensor.
""" """
ndim = indices.shape[1] - 1
assert features.ndim == 2
assert indices.ndim == 2
assert len(spatial_shape) == ndim, "spatial shape must equal to ndim"
assert indices.dtype == torch.int32, "only support int32"
assert batch_size > 0
self._features = features self._features = features
self.indices = indices self.indices = indices
self.spatial_shape = spatial_shape self.spatial_shape = spatial_shape
...@@ -90,17 +127,24 @@ class SparseConvTensor(metaclass=SpConvTensorMeta): ...@@ -90,17 +127,24 @@ class SparseConvTensor(metaclass=SpConvTensorMeta):
if grid is None: if grid is None:
grid = torch.Tensor() # empty tensor grid = torch.Tensor() # empty tensor
self.grid = grid self.grid = grid
self.voxel_num = voxel_num # for tensorrt self.voxel_num = voxel_num # for tensorrt
self.benchmark = benchmark self.benchmark = benchmark
self.benchmark_record = {} self.benchmark_record = {}
self.thrust_allocator: Optional[ThrustSortAllocator] = None
if permanent_thrust_allocator:
self.thrust_allocator = ThrustSortAllocator(features.device)
def replace_feature(self, feature): def replace_feature(self, feature):
"""we need to replace x.features = F.relu(x) with x = x.replace_feature(F.relu(x.features)) """we need to replace x.features = F.relu(x.features) with x = x.replace_feature(F.relu(x.features))
due to limit of torch.fx due to limit of torch.fx
""" """
new_spt = SparseConvTensor(feature, self.indices, self.spatial_shape, self.batch_size, self.grid, self.voxel_num, self.indice_dict) new_spt = SparseConvTensor(feature, self.indices, self.spatial_shape,
self.batch_size, self.grid, self.voxel_num,
self.indice_dict)
new_spt.benchmark = self.benchmark new_spt.benchmark = self.benchmark
new_spt.benchmark_record = self.benchmark_record new_spt.benchmark_record = self.benchmark_record
new_spt.thrust_allocator = self.thrust_allocator
return new_spt return new_spt
@property @property
...@@ -109,8 +153,9 @@ class SparseConvTensor(metaclass=SpConvTensorMeta): ...@@ -109,8 +153,9 @@ class SparseConvTensor(metaclass=SpConvTensorMeta):
@features.setter @features.setter
def features(self, val): def features(self, val):
msg = ("you can't set feature directly, use 'x = x.replace_feature(your_new_feature)'" msg = (
" to generate new SparseConvTensor instead.") "you can't set feature directly, use 'x = x.replace_feature(your_new_feature)'"
" to generate new SparseConvTensor instead.")
raise ValueError(msg) raise ValueError(msg)
@classmethod @classmethod
...@@ -129,14 +174,14 @@ class SparseConvTensor(metaclass=SpConvTensorMeta): ...@@ -129,14 +174,14 @@ class SparseConvTensor(metaclass=SpConvTensorMeta):
def spatial_size(self): def spatial_size(self):
return np.prod(self.spatial_shape) return np.prod(self.spatial_shape)
def find_indice_pair(self, key) -> Optional[IndiceData]: def find_indice_pair(self, key) -> Optional[Union[IndiceData, ImplicitGemmIndiceData]]:
if key is None: if key is None:
return None return None
if key in self.indice_dict: if key in self.indice_dict:
return self.indice_dict[key] return self.indice_dict[key]
return None return None
def dense(self, channels_first: bool=True): def dense(self, channels_first: bool = True):
output_shape = [self.batch_size] + list( output_shape = [self.batch_size] + list(
self.spatial_shape) + [self.features.shape[1]] self.spatial_shape) + [self.features.shape[1]]
res = scatter_nd( res = scatter_nd(
...@@ -159,6 +204,8 @@ class SparseConvTensor(metaclass=SpConvTensorMeta): ...@@ -159,6 +204,8 @@ class SparseConvTensor(metaclass=SpConvTensorMeta):
"""create a new spconv tensor with all member unchanged""" """create a new spconv tensor with all member unchanged"""
tensor = SparseConvTensor(self.features, self.indices, tensor = SparseConvTensor(self.features, self.indices,
self.spatial_shape, self.batch_size, self.spatial_shape, self.batch_size,
self.grid, self.voxel_num, self.indice_dict, self.benchmark) self.grid, self.voxel_num, self.indice_dict,
self.benchmark)
tensor.benchmark_record = self.benchmark_record tensor.benchmark_record = self.benchmark_record
tensor.thrust_allocator = self.thrust_allocator
return tensor return tensor
# Copyright 2021 Yan Yan # Copyright 2021 Yan Yan
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...@@ -17,10 +17,16 @@ from torch import nn ...@@ -17,10 +17,16 @@ from torch import nn
from torch.autograd import Function from torch.autograd import Function
import spconv.pytorch.ops as ops import spconv.pytorch.ops as ops
import torch.cuda.amp as amp
from torch.autograd.function import once_differentiable
import numpy as np
from typing import List
class SparseConvFunction(Function): class SparseConvFunction(Function):
@staticmethod @staticmethod
@amp.custom_fwd(cast_inputs=torch.float16)
def forward(ctx, features, filters, indice_pairs, indice_pair_num, def forward(ctx, features, filters, indice_pairs, indice_pair_num,
num_activate_out, algo): num_activate_out, algo):
ctx.save_for_backward(indice_pairs, indice_pair_num, features, filters) ctx.save_for_backward(indice_pairs, indice_pair_num, features, filters)
...@@ -34,6 +40,8 @@ class SparseConvFunction(Function): ...@@ -34,6 +40,8 @@ class SparseConvFunction(Function):
algo=algo) algo=algo)
@staticmethod @staticmethod
@once_differentiable
@amp.custom_bwd
def backward(ctx, grad_output): def backward(ctx, grad_output):
indice_pairs, indice_pair_num, features, filters = ctx.saved_tensors indice_pairs, indice_pair_num, features, filters = ctx.saved_tensors
...@@ -50,6 +58,7 @@ class SparseConvFunction(Function): ...@@ -50,6 +58,7 @@ class SparseConvFunction(Function):
class SparseInverseConvFunction(Function): class SparseInverseConvFunction(Function):
@staticmethod @staticmethod
@amp.custom_fwd(cast_inputs=torch.float16)
def forward(ctx, features, filters, indice_pairs, indice_pair_num, def forward(ctx, features, filters, indice_pairs, indice_pair_num,
num_activate_out, algo): num_activate_out, algo):
ctx.save_for_backward(indice_pairs, indice_pair_num, features, filters) ctx.save_for_backward(indice_pairs, indice_pair_num, features, filters)
...@@ -64,6 +73,8 @@ class SparseInverseConvFunction(Function): ...@@ -64,6 +73,8 @@ class SparseInverseConvFunction(Function):
algo=algo) algo=algo)
@staticmethod @staticmethod
@once_differentiable
@amp.custom_bwd
def backward(ctx, grad_output): def backward(ctx, grad_output):
indice_pairs, indice_pair_num, features, filters = ctx.saved_tensors indice_pairs, indice_pair_num, features, filters = ctx.saved_tensors
input_bp, filters_bp = ops.indice_conv_backward(features, input_bp, filters_bp = ops.indice_conv_backward(features,
...@@ -78,8 +89,67 @@ class SparseInverseConvFunction(Function): ...@@ -78,8 +89,67 @@ class SparseInverseConvFunction(Function):
return input_bp, filters_bp, None, None, None, None return input_bp, filters_bp, None, None, None, None
class SparseImplicitGemmFunction(Function):
@staticmethod
@amp.custom_fwd(cast_inputs=torch.float16)
def forward(ctx, features: torch.Tensor, filters: torch.Tensor,
pair_fwd: torch.Tensor, pair_bwd: torch.Tensor,
pair_mask_fwd_splits: List[torch.Tensor],
pair_mask_bwd_splits: List[torch.Tensor],
mask_argsort_fwd_splits: List[torch.Tensor],
mask_argsort_bwd_splits: List[torch.Tensor],
num_activate_out: int, masks: List[np.ndarray], is_train: bool,
is_subm: bool):
out, mask_out, mask_width = ops.implicit_gemm(
features, filters, pair_fwd, pair_mask_fwd_splits,
mask_argsort_fwd_splits, num_activate_out, masks, is_train, is_subm)
ctx.save_for_backward(features, filters, pair_fwd, pair_bwd)
ctx.mask_width = mask_width
ctx.mask_out = mask_out
ctx.pair_mask_fwd_splits = pair_mask_fwd_splits
ctx.mask_argsort_fwd_splits = mask_argsort_fwd_splits
ctx.pair_mask_bwd_splits = pair_mask_bwd_splits
ctx.mask_argsort_bwd_splits = mask_argsort_bwd_splits
# ctx.num_activate_out = num_activate_out
ctx.masks = masks
ctx.is_subm = is_subm
return out
@staticmethod
@once_differentiable
@amp.custom_bwd
def backward(ctx, grad_output):
features, filters, pair_fwd, pair_bwd = ctx.saved_tensors
mask_width = ctx.mask_width
mask_out = ctx.mask_out
pair_mask_fwd_splits = ctx.pair_mask_fwd_splits
mask_argsort_fwd_splits = ctx.mask_argsort_fwd_splits
pair_mask_bwd_splits = ctx.pair_mask_bwd_splits
mask_argsort_bwd_splits = ctx.mask_argsort_bwd_splits
# num_activate_out = ctx.num_activate_out
masks = ctx.masks
is_subm = ctx.is_subm
input_bp, filters_bp = ops.implicit_gemm_backward(features,
filters,
grad_output,
pair_fwd,
pair_bwd,
pair_mask_fwd_splits,
pair_mask_bwd_splits,
mask_argsort_fwd_splits,
mask_argsort_bwd_splits,
mask_output_fwd=mask_out,
masks=masks,
mask_width=mask_width,
is_subm=is_subm)
None_9 = [None] * 10
return input_bp, filters_bp, *None_9
class SubMConvFunction(Function): class SubMConvFunction(Function):
@staticmethod @staticmethod
@amp.custom_fwd(cast_inputs=torch.float16)
def forward(ctx, features, filters, indice_pairs, indice_pair_num, def forward(ctx, features, filters, indice_pairs, indice_pair_num,
num_activate_out, algo): num_activate_out, algo):
ctx.save_for_backward(indice_pairs, indice_pair_num, features, filters) ctx.save_for_backward(indice_pairs, indice_pair_num, features, filters)
...@@ -94,6 +164,8 @@ class SubMConvFunction(Function): ...@@ -94,6 +164,8 @@ class SubMConvFunction(Function):
algo=algo) algo=algo)
@staticmethod @staticmethod
@once_differentiable
@amp.custom_bwd
def backward(ctx, grad_output): def backward(ctx, grad_output):
indice_pairs, indice_pair_num, features, filters = ctx.saved_tensors indice_pairs, indice_pair_num, features, filters = ctx.saved_tensors
input_bp, filters_bp = ops.indice_conv_backward(features, input_bp, filters_bp = ops.indice_conv_backward(features,
...@@ -110,6 +182,7 @@ class SubMConvFunction(Function): ...@@ -110,6 +182,7 @@ class SubMConvFunction(Function):
class SparseMaxPoolFunction(Function): class SparseMaxPoolFunction(Function):
@staticmethod @staticmethod
@amp.custom_fwd(cast_inputs=torch.float16)
def forward(ctx, features, indice_pairs, indice_pair_num, def forward(ctx, features, indice_pairs, indice_pair_num,
num_activate_out): num_activate_out):
out = ops.indice_maxpool(features, indice_pairs, indice_pair_num, out = ops.indice_maxpool(features, indice_pairs, indice_pair_num,
...@@ -118,14 +191,35 @@ class SparseMaxPoolFunction(Function): ...@@ -118,14 +191,35 @@ class SparseMaxPoolFunction(Function):
return out return out
@staticmethod @staticmethod
@once_differentiable
@amp.custom_bwd
def backward(ctx, grad_output): def backward(ctx, grad_output):
indice_pairs, indice_pair_num, features, out = ctx.saved_tensors indice_pairs, indice_pair_num, features, out = ctx.saved_tensors
input_bp = ops.indice_maxpool_backward(features, out, grad_output, input_bp = ops.indice_maxpool_backward(features, out, grad_output,
indice_pairs, indice_pair_num) indice_pairs, indice_pair_num)
return input_bp, None, None, None return input_bp, None, None, None
class SparseMaxPoolImplicitGemmFunction(Function):
@staticmethod
@amp.custom_fwd(cast_inputs=torch.float16)
def forward(ctx, features: torch.Tensor, indice_pairs_fwd: torch.Tensor, indice_pairs_bwd: torch.Tensor,
num_activate_out: int):
out = ops.indice_maxpool_implicit_gemm(features, indice_pairs_fwd, num_activate_out)
ctx.save_for_backward(indice_pairs_bwd, features, out)
return out
@staticmethod
@once_differentiable
@amp.custom_bwd
def backward(ctx, grad_output):
indice_pairs_bwd, features, out = ctx.saved_tensors
input_bp = ops.indice_maxpool_implicit_gemm_backward(features, out, grad_output,
indice_pairs_bwd)
return input_bp, None, None, None
indice_conv = SparseConvFunction.apply indice_conv = SparseConvFunction.apply
implicit_gemm = SparseImplicitGemmFunction.apply
indice_inverse_conv = SparseInverseConvFunction.apply indice_inverse_conv = SparseInverseConvFunction.apply
indice_subm_conv = SubMConvFunction.apply indice_subm_conv = SubMConvFunction.apply
indice_maxpool = SparseMaxPoolFunction.apply indice_maxpool = SparseMaxPoolFunction.apply
indice_maxpool_implicit_gemm = SparseMaxPoolImplicitGemmFunction.apply
...@@ -14,20 +14,47 @@ ...@@ -14,20 +14,47 @@
import functools import functools
from enum import Enum from enum import Enum
from cumm import tensorview as tv from cumm import tensorview as tv
from cumm.conv.bases import KRSC, NHWC, ConvOpType
from cumm.gemm.algospec.core import ShuffleStrideType from cumm.gemm.algospec.core import ShuffleStrideType
import torch import torch
import numpy as np import numpy as np
import spconv import spconv
from spconv.algo import AlgoHint, ConvAlgo from spconv.core import AlgoHint, ConvAlgo
from typing import List, Union from typing import List, Optional, Union
from spconv.pytorch.cppcore import torch_tensor_to_tv, get_current_stream from spconv.pytorch.cppcore import torch_tensor_to_tv, get_current_stream
from spconv.core_cc.csrc.sparse.all import SpconvOps from spconv.core_cc.csrc.sparse.all import SpconvOps
from spconv.algo import GEMM # , GATHER, SCATTER import spconv.core_cc as _ext
if hasattr(_ext, "cumm"):
from spconv.algo import GEMM, CONV # , GATHER, SCATTER
else:
GEMM = None
CONV = None
import time import time
from spconv.constants import FILTER_HWIO from spconv.constants import FILTER_HWIO
import pickle from cumm.gemm import codeops
from pathlib import Path
DEBUG = False
class ThrustSortAllocator:
def __init__(self, device: torch.device) -> None:
super().__init__()
self.alloced_objs = {}
self.device = device
def alloc(self, n: int):
if n in self.alloced_objs:
return self.alloced_objs[n].data_ptr()
for n_cur, ten in self.alloced_objs.items():
if n < n_cur:
return ten.data_ptr()
ten = torch.empty([n], dtype=torch.uint8, device=self.device)
self.alloced_objs[n] = ten
return ten.data_ptr()
def get_conv_output_size(input_size, kernel_size, stride, padding, dilation): def get_conv_output_size(input_size, kernel_size, stride, padding, dilation):
ndim = len(input_size) ndim = len(input_size)
...@@ -68,6 +95,11 @@ def get_indice_pairs(indices: torch.Tensor, ...@@ -68,6 +95,11 @@ def get_indice_pairs(indices: torch.Tensor,
transpose: bool = False): transpose: bool = False):
# torch.cuda.synchronize() # torch.cuda.synchronize()
# t = time.time() # t = time.time()
# stream = get_current_stream()
# CONV.stream_synchronize(stream)
# t = time.time()
ndim = indices.shape[1] - 1 ndim = indices.shape[1] - 1
kv: int = functools.reduce(lambda x, y: x * y, ksize, 1) kv: int = functools.reduce(lambda x, y: x * y, ksize, 1)
if not subm: if not subm:
...@@ -80,9 +112,11 @@ def get_indice_pairs(indices: torch.Tensor, ...@@ -80,9 +112,11 @@ def get_indice_pairs(indices: torch.Tensor,
else: else:
out_shape = spatial_shape out_shape = spatial_shape
if any([x == 0 for x in out_shape]): if any([x == 0 for x in out_shape]):
raise ValueError(f"your out spatial shape {out_shape} reach zero!!! input shape: {spatial_shape}") raise ValueError(
f"your out spatial shape {out_shape} reach zero!!! input shape: {spatial_shape}"
)
assert algo == ConvAlgo.Native, "TODO" assert algo == ConvAlgo.Native, "TODO"
stream = get_current_stream() # indices = indices.cpu()
pair = torch.full((2, kv, indices.shape[0]), pair = torch.full((2, kv, indices.shape[0]),
-1, -1,
...@@ -95,15 +129,213 @@ def get_indice_pairs(indices: torch.Tensor, ...@@ -95,15 +129,213 @@ def get_indice_pairs(indices: torch.Tensor,
inds_tv = torch_tensor_to_tv(indices) inds_tv = torch_tensor_to_tv(indices)
pair_tv = torch_tensor_to_tv(pair) pair_tv = torch_tensor_to_tv(pair)
indice_num_per_loc_tv = torch_tensor_to_tv(indice_num_per_loc) indice_num_per_loc_tv = torch_tensor_to_tv(indice_num_per_loc)
if subm: if subm:
out_inds = indices out_inds = indices
if indices.is_cuda:
stream = get_current_stream()
hashdata = torch.empty((out_inds.shape[0] * 2, ),
dtype=torch.int64,
device=indices.device)
out_inds_tv = torch_tensor_to_tv(out_inds)
hashdata_tv = torch_tensor_to_tv(hashdata, dtype=tv.custom64)
SpconvOps.generate_subm_conv_inds(inds_tv,
hashdata_tv,
pair_tv,
out_inds_tv,
indice_num_per_loc_tv,
batch_size=batch_size,
input_dims=spatial_shape,
ksize=ksize,
dilation=dilation,
stream_int=stream)
else:
out_inds_tv = torch_tensor_to_tv(out_inds)
SpconvOps.generate_subm_conv_inds_cpu(inds_tv,
pair_tv,
out_inds_tv,
indice_num_per_loc_tv,
batch_size=batch_size,
input_dims=spatial_shape,
ksize=ksize,
dilation=dilation)
# CONV.stream_synchronize(stream)
# print("SUBM", time.time() - t)
else:
if indices.is_cuda:
stream = get_current_stream()
indice_pairs_uniq = torch.empty((pair.numel() // 2 + 1, ),
dtype=indices.dtype,
device=indices.device)
indice_pairs_uniq_tv = torch_tensor_to_tv(indice_pairs_uniq)
SpconvOps.generate_conv_inds_stage1(inds_tv,
pair_tv,
indice_pairs_uniq_tv,
indice_num_per_loc_tv,
batch_size=batch_size,
output_dims=out_shape,
input_dims=spatial_shape,
ksize=ksize,
stride=stride,
padding=padding,
dilation=dilation,
transposed=transpose,
stream_int=stream)
uniq_res = indice_pairs_uniq.unique()
num_act_out = uniq_res.shape[0] - 1
uniq_res_tv = torch_tensor_to_tv(uniq_res)
# num_act_out = SpconvOps.generate_conv_inds_stage1_5(
# indice_pairs_uniq_tv,
# ndim,
# uniq_size=indice_pairs_uniq_tv.size,
# stream_int=stream)
# uniq_res_tv = indice_pairs_uniq_tv.slice_first_axis(0, num_act_out)
out_inds = torch.empty((num_act_out, indices.shape[1]),
dtype=indices.dtype,
device=indices.device)
hashdata = torch.empty((out_inds.shape[0] * 2, ),
dtype=torch.int64,
device=indices.device)
out_inds_tv = torch_tensor_to_tv(out_inds)
hashdata_tv = torch_tensor_to_tv(hashdata, dtype=tv.custom64)
SpconvOps.generate_conv_inds_stage2(inds_tv,
hashdata_tv,
pair_tv,
uniq_res_tv,
out_inds_tv,
num_out_act=num_act_out,
batch_size=batch_size,
output_dims=out_shape,
input_dims=spatial_shape,
ksize=ksize,
stride=stride,
padding=padding,
dilation=dilation,
transposed=transpose,
stream_int=stream)
else:
out_inds = torch.empty((kv * indices.shape[0], indices.shape[1]),
dtype=indices.dtype,
device=indices.device)
out_inds_tv = torch_tensor_to_tv(out_inds)
num_act_out = SpconvOps.generate_conv_inds_cpu(
inds_tv,
pair_tv,
out_inds_tv,
indice_num_per_loc_tv,
batch_size=batch_size,
output_dims=out_shape,
input_dims=spatial_shape,
ksize=ksize,
stride=stride,
padding=padding,
dilation=dilation,
transposed=transpose)
out_inds = out_inds[:num_act_out]
# CONV.stream_synchronize(stream)
# print("REGU", time.time() - t)
return out_inds, pair, indice_num_per_loc
def get_indice_pairs_implicit_gemm(indices: torch.Tensor,
batch_size: int,
spatial_shape: List[int],
algo: ConvAlgo,
ksize: List[int],
stride: List[int],
padding: List[int],
dilation: List[int],
out_padding: List[int],
subm: bool = False,
transpose: bool = False,
is_train: bool = True,
alloc: Optional[ThrustSortAllocator] = None):
"""
Why return tuple? because pytorch seems don't support custom object in autograd.
return: (
out_inds,
num_inds_per_loc,
pair_fwd,
pair_bwd, # None if subm or inference mode
pair_mask_fwd_splits,
pair_mask_bwd_splits, # None if subm or inference mode
mask_argsort_fwd_splits,
mask_argsort_bwd_splits, # None if subm or inference mode
masks,
)
"""
stream = get_current_stream()
t = 0
if DEBUG:
CONV.stream_synchronize(stream)
t = time.time()
assert indices.is_cuda, "implicit gemm only support cuda"
ndim = indices.shape[1] - 1
kv: int = functools.reduce(lambda x, y: x * y, ksize, 1)
# TODO in future we will support up to 128 kernel volume.
assert kv <= 32, "currently only support kernel volume <= 32 to use implicit gemm"
if not subm:
if transpose:
out_shape = get_deconv_output_size(spatial_shape, ksize, stride,
padding, dilation, out_padding)
else:
out_shape = get_conv_output_size(spatial_shape, ksize, stride,
padding, dilation)
else:
out_shape = spatial_shape
if any([x == 0 for x in out_shape]):
raise ValueError(
f"your out spatial shape {out_shape} reach zero!!! input shape: {spatial_shape}"
)
assert algo == ConvAlgo.MaskImplicitGemm or algo == ConvAlgo.MaskSplitImplicitGemm, "TODO"
is_mask_split = algo == ConvAlgo.MaskSplitImplicitGemm
mask_split_count = 2 if is_mask_split else 1
if subm:
pair = torch.full((2, kv, indices.shape[0]),
-1,
dtype=indices.dtype,
device=indices.device)
else:
# for regular conv, pair-in not equal to pair-out
pair = torch.full((kv, indices.shape[0]),
-1,
dtype=indices.dtype,
device=indices.device)
indice_num_per_loc = torch.zeros((kv, ),
dtype=indices.dtype,
device=indices.device)
inds_tv = torch_tensor_to_tv(indices)
pair_tv = torch_tensor_to_tv(pair)
indice_num_per_loc_tv = torch_tensor_to_tv(indice_num_per_loc)
if is_mask_split:
kv_div_2 = kv // 2
remain = kv - kv_div_2
mask_np_1 = np.array([1], dtype=np.uint64)
first = ((mask_np_1 << (remain)) - 1)
second = ((mask_np_1 << (kv_div_2)) - 1) << remain
masks = [first.astype(np.uint32), second.astype(np.uint32)]
else:
masks = [np.array([0xffffffff], dtype=np.uint32)]
# torch.cuda.synchronize()
# print("SUBM0", time.time() - t)
if subm:
out_inds = indices
hashdata = torch.empty((out_inds.shape[0] * 2, ), hashdata = torch.empty((out_inds.shape[0] * 2, ),
dtype=torch.int64, dtype=torch.int64,
device=indices.device) device=indices.device)
pair_mask = torch.empty((mask_split_count, indices.shape[0]),
dtype=torch.int32,
device=indices.device)
out_inds_tv = torch_tensor_to_tv(out_inds) out_inds_tv = torch_tensor_to_tv(out_inds)
hashdata_tv = torch_tensor_to_tv(hashdata, dtype=tv.custom64) hashdata_tv = torch_tensor_to_tv(hashdata, dtype=tv.custom64)
pair_mask_tv = torch_tensor_to_tv(pair_mask, dtype=tv.uint32)
SpconvOps.generate_subm_conv_inds(inds_tv, SpconvOps.generate_subm_conv_inds(inds_tv,
hashdata_tv, hashdata_tv,
...@@ -114,64 +346,238 @@ def get_indice_pairs(indices: torch.Tensor, ...@@ -114,64 +346,238 @@ def get_indice_pairs(indices: torch.Tensor,
input_dims=spatial_shape, input_dims=spatial_shape,
ksize=ksize, ksize=ksize,
dilation=dilation, dilation=dilation,
indice_pair_mask=pair_mask_tv,
stream_int=stream) stream_int=stream)
# torch.cuda.synchronize() # torch.cuda.synchronize()
# print("SUBM", time.time() - t) # print("SUBM0", time.time() - t)
# CONV.stream_synchronize(stream)
mask_argsort = torch.empty((mask_split_count, out_inds.shape[0]),
dtype=torch.int32,
device=indices.device)
mask_argsort_tv = torch_tensor_to_tv(mask_argsort)
if alloc is None:
alloc = ThrustSortAllocator(indices.device)
for j in range(mask_split_count):
# thrust don't provide two-step sort (first step return workspace size)
# so I use this stupid hack to use torch allocator without touch
# pytorch binary (c++).
# f**k thrust
SpconvOps.sort_1d_by_key_allocator(pair_mask_tv[j], alloc.alloc,
mask_argsort_tv[j], stream)
# CONV.stream_synchronize(stream)
pair_mask_in_splits = [pair_mask[i] for i in range(mask_split_count)]
mask_argsort_in_splits = [
mask_argsort[i] for i in range(mask_split_count)
]
if DEBUG:
CONV.stream_synchronize(stream)
print("SUBM", time.time() - t)
return (out_inds, indice_num_per_loc, pair[0], pair[1],
pair_mask_in_splits, [], mask_argsort_in_splits, [], masks)
else: else:
indice_pairs_uniq = torch.empty((pair.numel() // 2 + 1, ), if DEBUG:
CONV.stream_synchronize(stream)
print("REGU_PREPARE", time.time() - t)
t = time.time()
pair_bwd = pair
pair_bwd_tv = pair_tv
indice_pairs_uniq = torch.empty((pair.numel() + 1, ),
dtype=indices.dtype, dtype=indices.dtype,
device=indices.device) device=indices.device)
indice_pairs_uniq_tv = torch_tensor_to_tv(indice_pairs_uniq) indice_pairs_uniq_tv = torch_tensor_to_tv(indice_pairs_uniq)
SpconvOps.generate_conv_inds_stage1(inds_tv, SpconvOps.generate_conv_inds_mask_stage1(inds_tv,
pair_tv, pair_bwd_tv,
indice_pairs_uniq_tv, indice_pairs_uniq_tv,
indice_num_per_loc_tv, indice_num_per_loc_tv,
batch_size=batch_size, batch_size=batch_size,
output_dims=out_shape, output_dims=out_shape,
input_dims=spatial_shape, input_dims=spatial_shape,
ksize=ksize, ksize=ksize,
stride=stride, stride=stride,
padding=padding, padding=padding,
dilation=dilation, dilation=dilation,
transposed=transpose, transposed=transpose,
stream_int=stream) stream_int=stream)
if DEBUG:
CONV.stream_synchronize(stream)
print("REGU_S1", time.time() - t)
t = time.time()
uniq_res = indice_pairs_uniq.unique() uniq_res = indice_pairs_uniq.unique()
num_act_out = uniq_res.shape[0] - 1 num_act_out = uniq_res.shape[0] - 1
if DEBUG:
CONV.stream_synchronize(stream)
print("REGU_UNIQ", time.time() - t)
t = time.time()
uniq_res_tv = torch_tensor_to_tv(uniq_res) uniq_res_tv = torch_tensor_to_tv(uniq_res)
# num_act_out = SpconvOps.generate_conv_inds_stage1_5(
# indice_pairs_uniq_tv,
# ndim,
# uniq_size=indice_pairs_uniq_tv.size,
# stream_int=stream)
# uniq_res_tv = indice_pairs_uniq_tv.slice_first_axis(0, num_act_out)
out_inds = torch.empty((num_act_out, indices.shape[1]), out_inds = torch.empty((num_act_out, indices.shape[1]),
dtype=indices.dtype, dtype=indices.dtype,
device=indices.device) device=indices.device)
pair_fwd = torch.full((kv, num_act_out),
-1,
dtype=indices.dtype,
device=indices.device)
pair_mask_fwd = torch.zeros((mask_split_count, num_act_out),
dtype=torch.int32,
device=indices.device)
pair_fwd_tv = torch_tensor_to_tv(pair_fwd)
pair_mask_fwd_tv = torch_tensor_to_tv(pair_mask_fwd, dtype=tv.uint32)
pair_mask_bwd = torch.Tensor()
pair_mask_bwd_tv = tv.Tensor()
if is_train:
pair_mask_bwd = torch.zeros((mask_split_count, indices.shape[0]),
dtype=torch.int32,
device=indices.device)
pair_mask_bwd_tv = torch_tensor_to_tv(pair_mask_bwd,
dtype=tv.uint32)
hashdata = torch.empty((out_inds.shape[0] * 2, ), hashdata = torch.empty((out_inds.shape[0] * 2, ),
dtype=torch.int64, dtype=torch.int64,
device=indices.device) device=indices.device)
out_inds_tv = torch_tensor_to_tv(out_inds) out_inds_tv = torch_tensor_to_tv(out_inds)
hashdata_tv = torch_tensor_to_tv(hashdata, dtype=tv.custom64) hashdata_tv = torch_tensor_to_tv(hashdata, dtype=tv.custom64)
SpconvOps.generate_conv_inds_stage2(inds_tv, if DEBUG:
hashdata_tv,
pair_tv, CONV.stream_synchronize(stream)
uniq_res_tv, print("REGU_S2_PREPARE", time.time() - t)
out_inds_tv, t = time.time()
num_out_act=num_act_out,
batch_size=batch_size, SpconvOps.generate_conv_inds_mask_stage2(inds_tv,
output_dims=out_shape, hashdata_tv,
input_dims=spatial_shape, pair_fwd_tv,
ksize=ksize, pair_bwd_tv,
stride=stride, uniq_res_tv,
padding=padding, out_inds_tv,
dilation=dilation, pair_mask_fwd_tv,
transposed=transpose, pair_mask_bwd_tv,
stream_int=stream) num_out_act=num_act_out,
# torch.cuda.synchronize() batch_size=batch_size,
# print("REGU", time.time() - t) output_dims=out_shape,
return out_inds, pair, indice_num_per_loc input_dims=spatial_shape,
ksize=ksize,
stride=stride,
padding=padding,
dilation=dilation,
transposed=transpose,
stream_int=stream)
if DEBUG:
CONV.stream_synchronize(stream)
print("REGU_S2", time.time() - t)
t = time.time()
mask_argsort_fwd = torch.empty((mask_split_count, out_inds.shape[0]),
dtype=torch.int32,
device=indices.device)
mask_argsort_fwd_tv = torch_tensor_to_tv(mask_argsort_fwd)
mask_argsort_bwd_tv = tv.Tensor()
mask_argsort_bwd = torch.Tensor()
if is_train:
mask_argsort_bwd = torch.empty(
(mask_split_count, indices.shape[0]),
dtype=torch.int32,
device=indices.device)
mask_argsort_bwd_tv = torch_tensor_to_tv(mask_argsort_bwd)
if alloc is None:
alloc = ThrustSortAllocator(indices.device)
if is_mask_split:
for j in range(mask_split_count):
mask_tv = tv.from_numpy(masks[j])
# here we try to ensure only call allocator once.
if not is_train:
SpconvOps.sort_1d_by_key_split_allocator(
pair_mask_fwd_tv[j], alloc.alloc, mask_tv,
mask_argsort_fwd_tv[j], stream)
else:
if pair_mask_bwd_tv.dim(1) > pair_mask_fwd_tv.dim(1):
SpconvOps.sort_1d_by_key_split_allocator(
pair_mask_bwd_tv[j], alloc.alloc, mask_tv,
mask_argsort_bwd_tv[j], stream)
SpconvOps.sort_1d_by_key_split_allocator(
pair_mask_fwd_tv[j], alloc.alloc, mask_tv,
mask_argsort_fwd_tv[j], stream)
else:
SpconvOps.sort_1d_by_key_split_allocator(
pair_mask_fwd_tv[j], alloc.alloc, mask_tv,
mask_argsort_fwd_tv[j], stream)
SpconvOps.sort_1d_by_key_split_allocator(
pair_mask_bwd_tv[j], alloc.alloc, mask_tv,
mask_argsort_bwd_tv[j], stream)
# SpconvOps.sort_1d_by_key_split(pair_mask_fwd_tv[j], mask_tv,
# mask_argsort_fwd_tv[j], stream)
# if is_train:
# SpconvOps.sort_1d_by_key_split(pair_mask_bwd_tv[j],
# mask_tv,
# mask_argsort_bwd_tv[j],
# stream)
else:
# if pair_mask_bwd_tv.dim(1) > pair_mask_fwd_tv.dim(1):
if not is_train:
SpconvOps.sort_1d_by_key_allocator(pair_mask_fwd_tv[0],
alloc.alloc,
mask_argsort_fwd_tv[0], stream)
else:
if pair_mask_bwd_tv.dim(1) > pair_mask_fwd_tv.dim(1):
SpconvOps.sort_1d_by_key_allocator(pair_mask_bwd_tv[0],
alloc.alloc,
mask_argsort_bwd_tv[0],
stream)
SpconvOps.sort_1d_by_key_allocator(pair_mask_fwd_tv[0],
alloc.alloc,
mask_argsort_fwd_tv[0], stream)
else:
SpconvOps.sort_1d_by_key_allocator(pair_mask_fwd_tv[0],
alloc.alloc,
mask_argsort_fwd_tv[0], stream)
SpconvOps.sort_1d_by_key_allocator(pair_mask_bwd_tv[0],
alloc.alloc,
mask_argsort_bwd_tv[0],
stream)
if DEBUG:
CONV.stream_synchronize(stream)
print("REGU_S2_FINISH", time.time() - t)
t = time.time()
# CONV.stream_synchronize(stream)
if not is_train:
pair_bwd = torch.Tensor()
pair_mask_bwd_splits: List[torch.Tensor] = []
mask_argsort_bwd_splits: List[torch.Tensor] = []
else:
pair_mask_bwd_splits = [
pair_mask_bwd[i] for i in range(mask_split_count)
]
mask_argsort_bwd_splits = [
mask_argsort_bwd[i] for i in range(mask_split_count)
]
pair_mask_fwd_splits = [
pair_mask_fwd[i] for i in range(mask_split_count)
]
mask_argsort_fwd_splits = [
mask_argsort_fwd[i] for i in range(mask_split_count)
]
if DEBUG:
CONV.stream_synchronize(stream)
print("REGU", time.time() - t)
return (out_inds, indice_num_per_loc, pair_fwd, pair_bwd,
pair_mask_fwd_splits, pair_mask_bwd_splits,
mask_argsort_fwd_splits, mask_argsort_bwd_splits, masks)
def indice_conv(features: torch.Tensor, def indice_conv(features: torch.Tensor,
...@@ -183,8 +589,12 @@ def indice_conv(features: torch.Tensor, ...@@ -183,8 +589,12 @@ def indice_conv(features: torch.Tensor,
subm: bool = False, subm: bool = False,
algo: ConvAlgo = ConvAlgo.Native): algo: ConvAlgo = ConvAlgo.Native):
# filters: RSKC # filters: RSKC
# torch.cuda.synchronize() # stream = get_current_stream()
# CONV.stream_synchronize(stream)
# t = time.time() # t = time.time()
if not features.is_contiguous():
features = features.contiguous()
if features.dtype == torch.int8 or features.dtype == torch.qint8: if features.dtype == torch.int8 or features.dtype == torch.qint8:
raise NotImplementedError("work in progress") raise NotImplementedError("work in progress")
if FILTER_HWIO: if FILTER_HWIO:
...@@ -195,6 +605,9 @@ def indice_conv(features: torch.Tensor, ...@@ -195,6 +605,9 @@ def indice_conv(features: torch.Tensor,
kv = filters.shape[0] kv = filters.shape[0]
kv_center = kv // 2 kv_center = kv // 2
if subm: if subm:
# out_features = torch.zeros((num_activate_out, out_channel),
# dtype=features.dtype,
# device=features.device)
if FILTER_HWIO: if FILTER_HWIO:
out_features = torch.mm(features, filters[kv_center]) out_features = torch.mm(features, filters[kv_center])
else: else:
...@@ -206,15 +619,47 @@ def indice_conv(features: torch.Tensor, ...@@ -206,15 +619,47 @@ def indice_conv(features: torch.Tensor,
if kv == 1 and subm: if kv == 1 and subm:
return out_features return out_features
stream = get_current_stream()
indice_pair_num_cpu = indice_pair_num.cpu().tolist() indice_pair_num_cpu = indice_pair_num.cpu().tolist()
if subm and all(x == 0 for x in indice_pair_num_cpu): if subm and all(x == 0 for x in indice_pair_num_cpu):
return out_features return out_features
maxnhot = max(indice_pair_num_cpu)
arch = torch.cuda.get_device_capability()
inited: bool = subm inited: bool = subm
a = torch_tensor_to_tv(features) a = torch_tensor_to_tv(features)
c = torch_tensor_to_tv(out_features) c = torch_tensor_to_tv(out_features)
indice_pairs_tv = torch_tensor_to_tv(indice_pairs)
pair_in = indice_pairs_tv[int(inverse)]
pair_out = indice_pairs_tv[int(not inverse)]
filters_tv = torch_tensor_to_tv(filters)
if not features.is_cuda:
# perform gather-mm-scatter_add for cpu data
assert not filters.is_cuda
assert not indice_pairs.is_cuda
inp_buffer = torch.empty([maxnhot, features.shape[1]],
dtype=features.dtype)
out_buffer = torch.empty([maxnhot, out_features.shape[1]],
dtype=out_features.dtype)
inp_buffer_tv = torch_tensor_to_tv(inp_buffer)
out_buffer_tv = torch_tensor_to_tv(out_buffer)
for i, nhot in enumerate(indice_pair_num_cpu):
if subm and i == kv_center:
continue
if subm and i > kv_center:
nhot = indice_pair_num_cpu[kv - i - 1]
if nhot <= 0:
continue
inp_indices = pair_in[i].slice_first_axis(0, nhot)
out_indices = pair_out[i].slice_first_axis(0, nhot)
SpconvOps.gather_cpu(inp_buffer_tv, a, inp_indices)
filters_cur = filters[i] if FILTER_HWIO else filters[i].T
torch.mm(inp_buffer[:nhot], filters_cur, out=out_buffer[:nhot])
SpconvOps.scatter_add_cpu(c, out_buffer_tv, out_indices)
return out_features
stream = get_current_stream()
profile_idx = kv_center profile_idx = kv_center
if subm: if subm:
profile_idx = kv_center - 1 profile_idx = kv_center - 1
...@@ -229,22 +674,24 @@ def indice_conv(features: torch.Tensor, ...@@ -229,22 +674,24 @@ def indice_conv(features: torch.Tensor,
profile_idx = i profile_idx = i
assert nhot_profile > 0, "this shouldn't happen" assert nhot_profile > 0, "this shouldn't happen"
# print(nhot_profile, indice_pair_num_cpu) # print(nhot_profile, indice_pair_num_cpu)
profile_res = GEMM.get_profiled_algo( arch = torch.cuda.get_device_capability()
a.shape,
filters.shape[-2:],
c.shape,
False,
False if FILTER_HWIO else True,
False,
arch=arch,
shuffle_type=ShuffleStrideType.ShuffleAC,
a_inds_shape=[nhot_profile],
c_inds_shape=[nhot_profile],
hint=AlgoHint.Fowrard.value)
tuned_res = GEMM.get_tuned_algo(a.dtype,
filters_tv.dtype,
c.dtype,
a.shape,
filters.shape[-2:],
c.shape,
False,
False if FILTER_HWIO else True,
False,
arch=arch,
shuffle_type=ShuffleStrideType.ShuffleAC,
a_inds_shape=[nhot_profile],
c_inds_shape=[nhot_profile],
hint=AlgoHint.Fowrard.value)
maxnhot = max(indice_pair_num_cpu) if tuned_res is None:
if profile_res is None:
# run profile on center # run profile on center
inp_indices_th = indice_pairs[int(inverse)][profile_idx, :nhot_profile] inp_indices_th = indice_pairs[int(inverse)][profile_idx, :nhot_profile]
out_indices_th = indice_pairs[int(not inverse)][ out_indices_th = indice_pairs[int(not inverse)][
...@@ -253,7 +700,7 @@ def indice_conv(features: torch.Tensor, ...@@ -253,7 +700,7 @@ def indice_conv(features: torch.Tensor,
out_indices = torch_tensor_to_tv(out_indices_th) out_indices = torch_tensor_to_tv(out_indices_th)
filter_tv = torch_tensor_to_tv(filters)[profile_idx] filter_tv = torch_tensor_to_tv(filters)[profile_idx]
profile_res, min_time = GEMM.profile_and_cache( tuned_res, min_time = GEMM.tune_and_cache(
a, a,
filter_tv, filter_tv,
c, c,
...@@ -268,11 +715,9 @@ def indice_conv(features: torch.Tensor, ...@@ -268,11 +715,9 @@ def indice_conv(features: torch.Tensor,
beta=0.0, beta=0.0,
hint=AlgoHint.Fowrard.value, hint=AlgoHint.Fowrard.value,
stream=stream) stream=stream)
# CONV.stream_synchronize(stream)
# t = time.time()
indice_pairs_tv = torch_tensor_to_tv(indice_pairs)
pair_in = indice_pairs_tv[int(inverse)]
pair_out = indice_pairs_tv[int(not inverse)]
filters_tv = torch_tensor_to_tv(filters)
for i, nhot in enumerate(indice_pair_num_cpu): for i, nhot in enumerate(indice_pair_num_cpu):
if subm and i == kv_center: if subm and i == kv_center:
continue continue
...@@ -285,28 +730,31 @@ def indice_conv(features: torch.Tensor, ...@@ -285,28 +730,31 @@ def indice_conv(features: torch.Tensor,
b = filters_tv[i] b = filters_tv[i]
# inp @ filter.T, NC @ KC # inp @ filter.T, NC @ KC
beta = 1.0 if inited else 0.0 beta = 1.0 if inited else 0.0
algo_desp = GEMM.run_profile(profile_res, algo_desp = GEMM.run_with_tuned_result(
a, tuned_res,
b, a,
c, b,
False, c,
False if FILTER_HWIO else True, False,
False, False if FILTER_HWIO else True,
arch=arch, False,
stream=stream, arch=arch,
shuffle_type=ShuffleStrideType.ShuffleAC, stream=stream,
a_inds=inp_indices, shuffle_type=ShuffleStrideType.ShuffleAC,
c_inds=out_indices, a_inds=inp_indices,
hint=AlgoHint.Fowrard.value, c_inds=out_indices,
alpha=1.0, hint=AlgoHint.Fowrard.value,
beta=beta) alpha=1.0,
beta=beta)
# gather_times += gather_time # gather_times += gather_time
inited = True inited = True
# torch.cuda.synchronize() # CONV.stream_synchronize(stream)
# # print(stream, valid_count, maxnhot, features.shape[0], features.shape[1], out_channel, time.time() - t, total_times, txt) # print(out_features.mean(), out_features.max(), out_features.min())
# # print(algo_desp, profile_res.external_gather, profile_res.splitk, features.shape[0], features.shape[1], out_channel, time.time() - t)
# # print(stream, valid_count, maxnhot, features.shape[0], features.shape[1], out_channel, time.time() - t, total_times, txt)
# # print(algo_desp, tuned_res.external_gather, tuned_res.splitk, features.shape[0], features.shape[1], out_channel, time.time() - t)
# print("F", time.time() - t)
return out_features return out_features
...@@ -323,6 +771,8 @@ def indice_conv_backward(features: torch.Tensor, ...@@ -323,6 +771,8 @@ def indice_conv_backward(features: torch.Tensor,
inverse: bool = False, inverse: bool = False,
subm: bool = False, subm: bool = False,
algo: ConvAlgo = ConvAlgo.Native): algo: ConvAlgo = ConvAlgo.Native):
# print(out_bp.mean(), out_bp.max(), out_bp.min())
num_activate_out = out_bp.shape[0] num_activate_out = out_bp.shape[0]
out_channel = out_bp.shape[-1] out_channel = out_bp.shape[-1]
filters_shape = filters.shape filters_shape = filters.shape
...@@ -361,6 +811,7 @@ def indice_conv_backward(features: torch.Tensor, ...@@ -361,6 +811,7 @@ def indice_conv_backward(features: torch.Tensor,
indice_pair_num_cpu = indice_pair_num.cpu().tolist() indice_pair_num_cpu = indice_pair_num.cpu().tolist()
if subm and all(x == 0 for x in indice_pair_num_cpu): if subm and all(x == 0 for x in indice_pair_num_cpu):
return (din, dfilters.reshape(filters_shape)) return (din, dfilters.reshape(filters_shape))
maxnhot = max(indice_pair_num_cpu)
arch = torch.cuda.get_device_capability() arch = torch.cuda.get_device_capability()
filters_tv = torch_tensor_to_tv(filters) filters_tv = torch_tensor_to_tv(filters)
...@@ -371,6 +822,37 @@ def indice_conv_backward(features: torch.Tensor, ...@@ -371,6 +822,37 @@ def indice_conv_backward(features: torch.Tensor,
din_tv = torch_tensor_to_tv(din) din_tv = torch_tensor_to_tv(din)
if not features.is_cuda:
# perform gather-mm-scatter_add for cpu data
assert not filters.is_cuda
assert not indice_pairs.is_cuda
inp_buffer = torch.empty([maxnhot, features.shape[1]],
dtype=features.dtype)
out_buffer = torch.empty([maxnhot, out_bp.shape[1]],
dtype=out_bp.dtype)
inp_buffer_tv = torch_tensor_to_tv(inp_buffer)
out_buffer_tv = torch_tensor_to_tv(out_buffer)
for i, nhot in enumerate(indice_pair_num_cpu):
if subm and i == kv_center:
continue
if subm and i > kv_center:
nhot = indice_pair_num_cpu[kv - i - 1]
if nhot <= 0:
continue
inp_indices = pair_in[i].slice_first_axis(0, nhot)
out_indices = pair_out[i].slice_first_axis(0, nhot)
SpconvOps.gather_cpu(inp_buffer_tv, features_tv, inp_indices)
SpconvOps.gather_cpu(out_buffer_tv, out_bp_tv, out_indices)
filters_T_cur = filters[i].T if FILTER_HWIO else filters[i]
dfilters_cur = dfilters[i] if FILTER_HWIO else dfilters[i].T
torch.mm(inp_buffer[:nhot].T, out_buffer[:nhot], out=dfilters_cur)
torch.mm(out_buffer[:nhot], filters_T_cur, out=inp_buffer[:nhot])
SpconvOps.scatter_add_cpu(din_tv, inp_buffer_tv, inp_indices)
return (din, dfilters.reshape(filters_shape))
profile_idx = kv_center profile_idx = kv_center
if subm or indice_pair_num_cpu[profile_idx] == 0: if subm or indice_pair_num_cpu[profile_idx] == 0:
profile_idx = kv_center - 1 profile_idx = kv_center - 1
...@@ -386,7 +868,10 @@ def indice_conv_backward(features: torch.Tensor, ...@@ -386,7 +868,10 @@ def indice_conv_backward(features: torch.Tensor,
assert nhot_profile > 0, "this shouldn't happen" assert nhot_profile > 0, "this shouldn't happen"
# print(nhot_profile, indice_pair_num_cpu) # print(nhot_profile, indice_pair_num_cpu)
profile_res_dgrad = GEMM.get_profiled_algo( tuned_res_dgrad = GEMM.get_tuned_algo(
out_bp_tv.dtype,
filters_tv.dtype,
din_tv.dtype,
out_bp_tv.shape, out_bp_tv.shape,
filters.shape[-2:], filters.shape[-2:],
din_tv.shape, din_tv.shape,
...@@ -398,11 +883,11 @@ def indice_conv_backward(features: torch.Tensor, ...@@ -398,11 +883,11 @@ def indice_conv_backward(features: torch.Tensor,
a_inds_shape=[nhot_profile], a_inds_shape=[nhot_profile],
c_inds_shape=[nhot_profile], c_inds_shape=[nhot_profile],
hint=AlgoHint.BackwardInput.value) hint=AlgoHint.BackwardInput.value)
if profile_res_dgrad is None: if tuned_res_dgrad is None:
inp_indices = pair_in[profile_idx].slice_first_axis(0, nhot_profile) inp_indices = pair_in[profile_idx].slice_first_axis(0, nhot_profile)
out_indices = pair_out[profile_idx].slice_first_axis(0, nhot_profile) out_indices = pair_out[profile_idx].slice_first_axis(0, nhot_profile)
filter_tv = filters_tv[profile_idx] filter_tv = filters_tv[profile_idx]
profile_res_dgrad, min_time = GEMM.profile_and_cache( tuned_res_dgrad, min_time = GEMM.tune_and_cache(
out_bp_tv, out_bp_tv,
filter_tv, filter_tv,
din_tv, din_tv,
...@@ -423,7 +908,10 @@ def indice_conv_backward(features: torch.Tensor, ...@@ -423,7 +908,10 @@ def indice_conv_backward(features: torch.Tensor,
else: else:
a_wgrad = features_tv a_wgrad = features_tv
b_wgrad = out_bp_tv b_wgrad = out_bp_tv
profile_res_wgrad = GEMM.get_profiled_algo( tuned_res_wgrad = GEMM.get_tuned_algo(
a_wgrad.dtype,
b_wgrad.dtype,
filters_tv.dtype,
a_wgrad.shape, a_wgrad.shape,
b_wgrad.shape, b_wgrad.shape,
filters.shape[-2:], filters.shape[-2:],
...@@ -436,7 +924,7 @@ def indice_conv_backward(features: torch.Tensor, ...@@ -436,7 +924,7 @@ def indice_conv_backward(features: torch.Tensor,
b_inds_shape=[nhot_profile], b_inds_shape=[nhot_profile],
hint=AlgoHint.BackwardWeight.value) hint=AlgoHint.BackwardWeight.value)
if profile_res_wgrad is None: if tuned_res_wgrad is None:
inp_indices = pair_in[profile_idx].slice_first_axis(0, nhot_profile) inp_indices = pair_in[profile_idx].slice_first_axis(0, nhot_profile)
out_indices = pair_out[profile_idx].slice_first_axis(0, nhot_profile) out_indices = pair_out[profile_idx].slice_first_axis(0, nhot_profile)
dfilter_tv = dfilters_tv[profile_idx] dfilter_tv = dfilters_tv[profile_idx]
...@@ -446,7 +934,7 @@ def indice_conv_backward(features: torch.Tensor, ...@@ -446,7 +934,7 @@ def indice_conv_backward(features: torch.Tensor,
else: else:
a_inds_wgrad = inp_indices a_inds_wgrad = inp_indices
b_inds_wgrad = out_indices b_inds_wgrad = out_indices
profile_res_wgrad, min_time = GEMM.profile_and_cache( tuned_res_wgrad, min_time = GEMM.tune_and_cache(
a_wgrad, a_wgrad,
b_wgrad, b_wgrad,
dfilter_tv, dfilter_tv,
...@@ -461,8 +949,7 @@ def indice_conv_backward(features: torch.Tensor, ...@@ -461,8 +949,7 @@ def indice_conv_backward(features: torch.Tensor,
beta=0.0, beta=0.0,
hint=AlgoHint.BackwardWeight.value, hint=AlgoHint.BackwardWeight.value,
stream=stream) stream=stream)
# print(profile_res_wgrad.algo_desp, profile_res_wgrad.splitk, min_time) # print(tuned_res_wgrad.algo_desp, tuned_res_wgrad.splitk, min_time)
maxnhot = max(indice_pair_num_cpu)
# get workspace size for wgrad # get workspace size for wgrad
if not FILTER_HWIO: if not FILTER_HWIO:
a_shape = [maxnhot, out_bp_tv.dim(1)] a_shape = [maxnhot, out_bp_tv.dim(1)]
...@@ -472,16 +959,16 @@ def indice_conv_backward(features: torch.Tensor, ...@@ -472,16 +959,16 @@ def indice_conv_backward(features: torch.Tensor,
a_shape = [maxnhot, features_tv.dim(1)] a_shape = [maxnhot, features_tv.dim(1)]
m, n, k = GEMM.extract_mnk(a_shape, m, n, k = GEMM.extract_mnk(a_shape,
b_shape, b_shape,
profile_res_wgrad.algo_desp.trans_a, tuned_res_wgrad.algo_desp.trans_a,
profile_res_wgrad.algo_desp.trans_b, tuned_res_wgrad.algo_desp.trans_b,
profile_res_wgrad.algo_desp.trans_c, tuned_res_wgrad.algo_desp.trans_c,
arch=arch, arch=arch,
shuffle_type=ShuffleStrideType.ShuffleAB, shuffle_type=ShuffleStrideType.ShuffleAB,
a_inds_shape=[maxnhot], a_inds_shape=[maxnhot],
b_inds_shape=[maxnhot], b_inds_shape=[maxnhot],
hint=AlgoHint.BackwardWeight.value) hint=AlgoHint.BackwardWeight.value)
workspace_size = profile_res_wgrad.algo_desp.query_workspace_size( workspace_size = tuned_res_wgrad.algo_desp.query_workspace_size(
m, n, k, profile_res_wgrad.splitk) m, n, k, tuned_res_wgrad.splitk)
workspace = torch.Tensor() workspace = torch.Tensor()
workspace_tv = tv.Tensor() workspace_tv = tv.Tensor()
...@@ -490,7 +977,7 @@ def indice_conv_backward(features: torch.Tensor, ...@@ -490,7 +977,7 @@ def indice_conv_backward(features: torch.Tensor,
dtype=torch.int8, dtype=torch.int8,
device=features.device) device=features.device)
workspace_tv = torch_tensor_to_tv(workspace) workspace_tv = torch_tensor_to_tv(workspace)
# print(workspace_size, m, n, k, profile_res_wgrad.splitk) # print(workspace_size, m, n, k, tuned_res_wgrad.splitk)
# torch.cuda.synchronize() # torch.cuda.synchronize()
# di_time = time.time() - t # di_time = time.time() - t
# t = time.time() # t = time.time()
...@@ -507,21 +994,21 @@ def indice_conv_backward(features: torch.Tensor, ...@@ -507,21 +994,21 @@ def indice_conv_backward(features: torch.Tensor,
out_indices = pair_out[i].slice_first_axis(0, nhot) out_indices = pair_out[i].slice_first_axis(0, nhot)
# out.T @ inp, NK @ NC # out.T @ inp, NK @ NC
# print(features_tv.shape, out_bp_tv.shape) # print(features_tv.shape, out_bp_tv.shape)
GEMM.run_profile(profile_res_dgrad, GEMM.run_with_tuned_result(tuned_res_dgrad,
out_bp_tv, out_bp_tv,
filters_tv[i], filters_tv[i],
din_tv, din_tv,
False, False,
True if FILTER_HWIO else False, True if FILTER_HWIO else False,
False, False,
arch=arch, arch=arch,
stream=stream, stream=stream,
shuffle_type=ShuffleStrideType.ShuffleAC, shuffle_type=ShuffleStrideType.ShuffleAC,
a_inds=out_indices, a_inds=out_indices,
c_inds=inp_indices, c_inds=inp_indices,
hint=AlgoHint.BackwardInput.value, hint=AlgoHint.BackwardInput.value,
alpha=1.0, alpha=1.0,
beta=beta) beta=beta)
if not FILTER_HWIO: if not FILTER_HWIO:
a = out_bp_tv a = out_bp_tv
...@@ -533,40 +1020,331 @@ def indice_conv_backward(features: torch.Tensor, ...@@ -533,40 +1020,331 @@ def indice_conv_backward(features: torch.Tensor,
b = out_bp_tv b = out_bp_tv
a_inds = inp_indices a_inds = inp_indices
b_inds = out_indices b_inds = out_indices
GEMM.run_profile(profile_res_wgrad, GEMM.run_with_tuned_result(tuned_res_wgrad,
a, a,
b, b,
dfilters_tv[i], dfilters_tv[i],
True, True,
False, False,
False, False,
arch=arch, arch=arch,
stream=stream, stream=stream,
shuffle_type=ShuffleStrideType.ShuffleAB, shuffle_type=ShuffleStrideType.ShuffleAB,
a_inds=a_inds, a_inds=a_inds,
b_inds=b_inds, b_inds=b_inds,
hint=AlgoHint.BackwardWeight.value, hint=AlgoHint.BackwardWeight.value,
alpha=1.0, alpha=1.0,
beta=beta, beta=beta,
workspace=workspace_tv) workspace=workspace_tv)
inited = True inited = True
# torch.cuda.synchronize() # torch.cuda.synchronize()
# dw_time = time.time() - t # dw_time = time.time() - t
# # print(dw_time + di_time, di_time, dw_time, profile_res_wgrad.splitk, profile_res_wgrad.algo_desp, dfilters.shape) # # print(dw_time + di_time, di_time, dw_time, tuned_res_wgrad.splitk, tuned_res_wgrad.algo_desp, dfilters.shape)
# # print(dw_time + di_time) # # print(dw_time + di_time)
# print("BWG", time.time() - t) # print("BWG", time.time() - t)
return (din, dfilters.reshape(filters_shape)) return (din, dfilters.reshape(filters_shape))
def indice_maxpool(features, indice_pairs, indice_pair_num, num_activate_out): def implicit_gemm(features: torch.Tensor, filters: torch.Tensor,
pair_fwd: torch.Tensor,
pair_mask_fwd_splits: List[torch.Tensor],
mask_argsort_fwd_splits: List[torch.Tensor],
num_activate_out: int, masks: List[np.ndarray],
is_train: bool, is_subm: bool):
stream = get_current_stream()
# if DEBUG:
# CONV.stream_synchronize(stream)
# t = time.time()
if not features.is_contiguous():
features = features.contiguous()
assert features.is_contiguous()
assert filters.is_contiguous()
if features.dtype == torch.int8 or features.dtype == torch.qint8:
raise NotImplementedError("work in progress")
# here filters is KRSC
masks_ints = [m.item() for m in masks]
out_channel = filters.shape[0]
in_channel = filters.shape[-1]
num_split = len(pair_mask_fwd_splits)
filters = filters.reshape(out_channel, -1, filters.shape[-1])
kv = filters.shape[1]
if is_subm:
out_features = torch.empty((num_activate_out, out_channel),
dtype=features.dtype,
device=features.device)
else:
out_features = torch.zeros((num_activate_out, out_channel),
dtype=features.dtype,
device=features.device)
pair_fwd_tv = torch_tensor_to_tv(pair_fwd)
features_tv = torch_tensor_to_tv(features)
filters_tv = torch_tensor_to_tv(filters)
out_features_tv = torch_tensor_to_tv(out_features)
arch = torch.cuda.get_device_capability()
pair_mask_fwd_split_tvs = [
torch_tensor_to_tv(x, dtype=tv.uint32) for x in pair_mask_fwd_splits
]
mask_argsort_fwd_split_tvs = [
torch_tensor_to_tv(x) for x in mask_argsort_fwd_splits
]
# CONV.stream_synchronize(stream)
# t = time.time()
tune_res = CONV.get_tuned_algo(ConvOpType.kForward, features_tv.dtype,
filters_tv.dtype, out_features_tv.dtype,
out_channel, in_channel, arch)
if tune_res is None:
tune_res, _ = CONV.tune_and_cache(
ConvOpType.kForward,
features_tv,
filters_tv,
out_features_tv,
NHWC,
KRSC,
NHWC,
arch,
mask=pair_mask_fwd_split_tvs[0],
mask_argsort=mask_argsort_fwd_split_tvs[0],
indices=pair_fwd_tv,
reverse_mask=False,
mask_filter=masks[0].item(),
stream=stream)
mask_width = tune_res.algo_desp.tile_shape[0]
if is_train:
mask_output_fwd = torch.empty(
[num_split,
codeops.div_up(num_activate_out, mask_width)],
dtype=torch.int32,
device=features.device)
# pytorch don't support uint32.
mask_output_fwd_tv = torch_tensor_to_tv(mask_output_fwd,
dtype=tv.uint32)
mask_output_fwd_tvs = [mask_output_fwd_tv[j] for j in range(num_split)]
else:
mask_output_fwd = None
mask_output_fwd_tv = tv.Tensor()
mask_output_fwd_tvs = [tv.Tensor() for _ in range(num_split)]
# CONV.stream_synchronize(stream)
# print("FPREPARE", time.time() - t)
# # t = time.time()
# CONV.stream_synchronize(stream)
# t = time.time()
for j in range(num_split):
beta = 0 if j == 0 else 1
CONV.run_with_tuned_result(tune_res,
ConvOpType.kForward,
features_tv,
filters_tv,
out_features_tv,
mask=pair_mask_fwd_split_tvs[j],
mask_argsort=mask_argsort_fwd_split_tvs[j],
mask_output=mask_output_fwd_tvs[j],
indices=pair_fwd_tv,
reverse_mask=False,
mask_filter=masks_ints[j],
mask_width=-1,
beta=beta,
stream=stream,
verbose=False)
# torch.cuda.synchronize() # torch.cuda.synchronize()
# if DEBUG:
# CONV.stream_synchronize(stream)
# dura = time.time() - t
# print("F", tune_res.algo_desp, dura)
# print(out_features.mean(), out_features.max(), out_features.min())
return out_features, mask_output_fwd, mask_width
def implicit_gemm_backward(features: torch.Tensor, filters: torch.Tensor,
out_bp: torch.Tensor, pair_fwd: torch.Tensor,
pair_bwd: torch.Tensor,
pair_mask_fwd_splits: List[torch.Tensor],
pair_mask_bwd_splits: List[torch.Tensor],
mask_argsort_fwd_splits: List[torch.Tensor],
mask_argsort_bwd_splits: List[torch.Tensor],
mask_output_fwd: torch.Tensor,
masks: List[np.ndarray], mask_width: int,
is_subm: bool):
# print(out_bp.mean(), out_bp.max(), out_bp.min())
if features.dtype == torch.int8 or features.dtype == torch.qint8:
raise NotImplementedError("work in progress")
if not out_bp.is_contiguous():
out_bp = out_bp.contiguous()
assert out_bp.is_contiguous()
assert filters.is_contiguous()
assert features.is_contiguous()
# here filters is KRSC
filters_shape = filters.shape
out_channel = filters.shape[0]
in_channel = filters.shape[-1]
num_split = len(pair_mask_fwd_splits)
if is_subm:
din = torch.empty_like(features)
else:
din = torch.zeros_like(features)
dfilters = torch.zeros_like(filters)
filters = filters.reshape(out_channel, -1, filters.shape[-1])
kv = filters.shape[1]
stream = get_current_stream()
pair_fwd_tv = torch_tensor_to_tv(pair_fwd)
pair_bwd_tv = torch_tensor_to_tv(pair_bwd)
features_tv = torch_tensor_to_tv(features)
filters_tv = torch_tensor_to_tv(filters)
dfilters_tv = torch_tensor_to_tv(dfilters)
dout_tv = torch_tensor_to_tv(out_bp)
din_tv = torch_tensor_to_tv(din)
mask_output_fwd_tv = torch_tensor_to_tv(mask_output_fwd, dtype=tv.uint32)
arch = torch.cuda.get_device_capability()
pair_mask_fwd_split_tvs = [
torch_tensor_to_tv(x, dtype=tv.uint32) for x in pair_mask_fwd_splits
]
pair_mask_bwd_split_tvs = [
torch_tensor_to_tv(x, dtype=tv.uint32) for x in pair_mask_bwd_splits
]
mask_argsort_fwd_split_tvs = [
torch_tensor_to_tv(x) for x in mask_argsort_fwd_splits
]
mask_argsort_bwd_split_tvs = [
torch_tensor_to_tv(x) for x in mask_argsort_bwd_splits
]
dgrad_tune_res = CONV.get_tuned_algo(ConvOpType.kBackwardInput,
din_tv.dtype, filters_tv.dtype,
dout_tv.dtype, out_channel,
in_channel, arch)
wgrad_tune_res = CONV.get_tuned_algo(ConvOpType.kBackwardWeight,
features_tv.dtype, dfilters_tv.dtype,
dout_tv.dtype, out_channel,
in_channel, arch, mask_width)
if dgrad_tune_res is None:
# TODO split mask maybe completely invalid
if is_subm:
mask = pair_mask_fwd_split_tvs[0]
mask_argsort = mask_argsort_fwd_split_tvs[0]
else:
mask = pair_mask_bwd_split_tvs[0]
mask_argsort = mask_argsort_bwd_split_tvs[0]
dgrad_tune_res, _ = CONV.tune_and_cache(ConvOpType.kBackwardInput,
din_tv,
filters_tv,
dout_tv,
NHWC,
KRSC,
NHWC,
arch,
mask=mask,
mask_argsort=mask_argsort,
indices=pair_bwd_tv,
reverse_mask=is_subm,
mask_filter=masks[0].item(),
stream=stream)
if wgrad_tune_res is None:
wgrad_tune_res, _ = CONV.tune_and_cache(
ConvOpType.kBackwardWeight,
features_tv,
dfilters_tv,
dout_tv,
NHWC,
KRSC,
NHWC,
arch,
mask=pair_mask_fwd_split_tvs[0],
mask_argsort=mask_argsort_fwd_split_tvs[0],
indices=pair_fwd_tv,
reverse_mask=False,
mask_filter=masks[0].item(),
mask_output=mask_output_fwd_tv[0],
mask_width=mask_width,
stream=stream)
workspace_size = CONV.query_workspace_size(wgrad_tune_res.algo_desp,
wgrad_tune_res.splitk,
ConvOpType.kBackwardWeight,
pair_fwd_tv.dim(1), in_channel,
out_channel, kv)
workspace = torch.Tensor()
workspace_tv = tv.Tensor()
if workspace_size > 0:
workspace = torch.empty((workspace_size, ),
dtype=torch.int8,
device=features.device)
workspace_tv = torch_tensor_to_tv(workspace)
for j in range(num_split):
beta = 0 if j == 0 else 1
if is_subm:
mask = pair_mask_fwd_split_tvs[j]
mask_argsort = mask_argsort_fwd_split_tvs[j]
else:
mask = pair_mask_bwd_split_tvs[j]
mask_argsort = mask_argsort_bwd_split_tvs[j]
CONV.run_with_tuned_result(dgrad_tune_res,
ConvOpType.kBackwardInput,
din_tv,
filters_tv,
dout_tv,
mask=mask,
mask_argsort=mask_argsort,
mask_output=tv.Tensor(),
indices=pair_bwd_tv,
reverse_mask=is_subm,
mask_filter=masks[j].item(),
mask_width=-1,
beta=beta,
stream=stream)
CONV.run_with_tuned_result(wgrad_tune_res,
ConvOpType.kBackwardWeight,
features_tv,
dfilters_tv,
dout_tv,
mask=mask_output_fwd_tv[j],
mask_argsort=mask_argsort_fwd_split_tvs[j],
mask_output=tv.Tensor(),
indices=pair_fwd_tv,
reverse_mask=False,
mask_filter=masks[j].item(),
mask_width=mask_width,
beta=beta,
workspace=workspace_tv,
stream=stream)
return (din, dfilters.reshape(filters_shape))
def indice_maxpool(features: torch.Tensor, indice_pairs: torch.Tensor,
indice_pair_num: torch.Tensor, num_activate_out):
# torch.cuda.synchronize()
# t = time.time()
# stream = get_current_stream()
# CONV.stream_synchronize(stream)
# t = time.time() # t = time.time()
out_channel = features.shape[-1] out_channel = features.shape[-1]
out_features = torch.zeros((num_activate_out, out_channel), out_features = torch.zeros((num_activate_out, out_channel),
dtype=features.dtype, dtype=features.dtype,
device=features.device) device=features.device)
stream = get_current_stream() stream = 0
is_cpu = not features.is_cuda
if not is_cpu:
stream = get_current_stream()
indice_pair_num_cpu = indice_pair_num.cpu().tolist() indice_pair_num_cpu = indice_pair_num.cpu().tolist()
out_features_tv = torch_tensor_to_tv(out_features) out_features_tv = torch_tensor_to_tv(out_features)
features_tv = torch_tensor_to_tv(features) features_tv = torch_tensor_to_tv(features)
...@@ -576,9 +1354,14 @@ def indice_maxpool(features, indice_pairs, indice_pair_num, num_activate_out): ...@@ -576,9 +1354,14 @@ def indice_maxpool(features, indice_pairs, indice_pair_num, num_activate_out):
continue continue
inp_indices = indice_pairs_tv[0][i].slice_first_axis(0, nhot) inp_indices = indice_pairs_tv[0][i].slice_first_axis(0, nhot)
out_indices = indice_pairs_tv[1][i].slice_first_axis(0, nhot) out_indices = indice_pairs_tv[1][i].slice_first_axis(0, nhot)
SpconvOps.maxpool_forward(out_features_tv, features_tv, out_indices, if is_cpu:
inp_indices, stream) SpconvOps.maxpool_forward_cpu(out_features_tv, features_tv,
# torch.cuda.synchronize() out_indices, inp_indices)
else:
SpconvOps.maxpool_forward(out_features_tv, features_tv,
out_indices, inp_indices, stream)
# CONV.stream_synchronize(stream)
# print("M", time.time() - t) # print("M", time.time() - t)
return out_features return out_features
...@@ -588,7 +1371,10 @@ def indice_maxpool_backward(features, out_features, out_bp, indice_pairs, ...@@ -588,7 +1371,10 @@ def indice_maxpool_backward(features, out_features, out_bp, indice_pairs,
indice_pair_num): indice_pair_num):
out_channel = features.shape[-1] out_channel = features.shape[-1]
din = torch.zeros_like(features) din = torch.zeros_like(features)
stream = get_current_stream() is_cpu = not features.is_cuda
stream = 0
if not is_cpu:
stream = get_current_stream()
indice_pair_num_cpu = indice_pair_num.cpu().tolist() indice_pair_num_cpu = indice_pair_num.cpu().tolist()
if not out_bp.is_contiguous(): if not out_bp.is_contiguous():
out_bp = out_bp.contiguous() out_bp = out_bp.contiguous()
...@@ -602,15 +1388,61 @@ def indice_maxpool_backward(features, out_features, out_bp, indice_pairs, ...@@ -602,15 +1388,61 @@ def indice_maxpool_backward(features, out_features, out_bp, indice_pairs,
continue continue
inp_indices = indice_pairs_tv[0][i].slice_first_axis(0, nhot) inp_indices = indice_pairs_tv[0][i].slice_first_axis(0, nhot)
out_indices = indice_pairs_tv[1][i].slice_first_axis(0, nhot) out_indices = indice_pairs_tv[1][i].slice_first_axis(0, nhot)
SpconvOps.maxpool_backward(out_features_tv, features_tv, out_bp_tv, if is_cpu:
din_tv, out_indices, inp_indices, stream) SpconvOps.maxpool_backward_cpu(out_features_tv, features_tv,
out_bp_tv, din_tv, out_indices,
inp_indices)
else:
SpconvOps.maxpool_backward(out_features_tv, features_tv, out_bp_tv,
din_tv, out_indices, inp_indices,
stream)
return din return din
def nms(boxes, scores, pre_max_size, post_max_size, thresh, eps): def indice_maxpool_implicit_gemm(features: torch.Tensor,
raise NotImplementedError indice_pairs: torch.Tensor, num_activate_out):
# torch.cuda.synchronize()
# t = time.time()
stream = get_current_stream()
# CONV.stream_synchronize(stream)
# t = time.time()
out_channel = features.shape[-1]
out_features = torch.empty((num_activate_out, out_channel),
dtype=features.dtype,
device=features.device)
assert features.is_cuda
stream = get_current_stream()
out_features_tv = torch_tensor_to_tv(out_features)
features_tv = torch_tensor_to_tv(features)
indice_pairs_tv = torch_tensor_to_tv(indice_pairs)
SpconvOps.maxpool_implicit_gemm_forward(out_features_tv, features_tv,
indice_pairs_tv, stream)
# CONV.stream_synchronize(stream)
# print("M", time.time() - t)
return out_features
def indice_maxpool_implicit_gemm_backward(features, out_features, out_bp,
indice_pairs):
# torch.cuda.synchronize()
# t = time.time()
out_channel = features.shape[-1]
din = torch.zeros_like(features)
assert features.is_cuda
if not out_bp.is_contiguous():
out_bp = out_bp.contiguous()
stream = get_current_stream()
out_features_tv = torch_tensor_to_tv(out_features)
features_tv = torch_tensor_to_tv(features)
out_bp_tv = torch_tensor_to_tv(out_bp)
din_tv = torch_tensor_to_tv(din)
indice_pairs_tv = torch_tensor_to_tv(indice_pairs)
SpconvOps.maxpool_implicit_gemm_backward(out_features_tv, features_tv,
out_bp_tv, din_tv,
indice_pairs_tv, stream)
return din
def pillar_scatter(features, coors, shape):
raise NotImplementedError
# Copyright 2021 Yan Yan # Copyright 2021 Yan Yan
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...@@ -20,24 +20,26 @@ import torch ...@@ -20,24 +20,26 @@ import torch
from torch import nn from torch import nn
from torch.nn import init from torch.nn import init
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from typing import List, Optional, Tuple, Union
from spconv import pytorch as spconv from spconv import pytorch as spconv
from spconv.algo import ConvAlgo from spconv.core import ConvAlgo
import spconv.pytorch.functional as Fsp import spconv.pytorch.functional as Fsp
from spconv.pytorch import ops from spconv.pytorch import ops
from spconv.pytorch.core import IndiceData from spconv.pytorch.core import IndiceData, ImplicitGemmIndiceData
from spconv.pytorch.modules import SparseModule from spconv.pytorch.modules import SparseModule
class SparseMaxPool(SparseModule): class SparseMaxPool(SparseModule):
def __init__(self, def __init__(self,
ndim, ndim,
kernel_size, kernel_size: Union[int, List[int], Tuple[int, ...]] = 3,
stride=None, stride: Union[int, List[int], Tuple[int, ...]] = 1,
padding=0, padding: Union[int, List[int], Tuple[int, ...]] = 0,
dilation=1, dilation: Union[int, List[int], Tuple[int, ...]] = 1,
indice_key=None, indice_key: Optional[str] = None,
subm=False, subm: bool = False,
algo: Optional[ConvAlgo] = None,
name=None): name=None):
super(SparseMaxPool, self).__init__(name=name) super(SparseMaxPool, self).__init__(name=name)
if not isinstance(kernel_size, (list, tuple)): if not isinstance(kernel_size, (list, tuple)):
...@@ -57,6 +59,31 @@ class SparseMaxPool(SparseModule): ...@@ -57,6 +59,31 @@ class SparseMaxPool(SparseModule):
self.subm = subm self.subm = subm
self.dilation = dilation self.dilation = dilation
self.indice_key = indice_key self.indice_key = indice_key
kv = int(np.prod(kernel_size))
if algo is None:
# keep in mind that this algorithm is set for Inverse Sparse Conv
# maxpool itself don't need mask.
if kv <= 32:
if kv < 8:
algo = ConvAlgo.MaskImplicitGemm
else:
algo = ConvAlgo.MaskImplicitGemm
else:
algo = ConvAlgo.Native
if kv > 32:
assert algo == ConvAlgo.Native, "implicit gemm don't support kv >= 32 for now"
self.algo = algo
def extra_repr(self):
s = ('kernel_size={kernel_size}' ', stride={stride}')
if self.padding != (0, ) * len(self.padding):
s += ', padding={padding}'
if self.dilation != (1, ) * len(self.dilation):
s += ', dilation={dilation}'
if self.algo is not None:
s += f', algo={self.algo}'
return s.format(**self.__dict__)
def forward(self, input): def forward(self, input):
assert isinstance(input, spconv.SparseConvTensor) assert isinstance(input, spconv.SparseConvTensor)
...@@ -96,37 +123,80 @@ class SparseMaxPool(SparseModule): ...@@ -96,37 +123,80 @@ class SparseMaxPool(SparseModule):
if input.benchmark: if input.benchmark:
torch.cuda.synchronize() torch.cuda.synchronize()
t = time.time() t = time.time()
out_padding = [0] * self.ndim
indice_dict = input.indice_dict.copy()
if self.algo == ConvAlgo.Native:
outids, indice_pairs, indice_pairs_num = ops.get_indice_pairs(
indices, batch_size, spatial_shape, ConvAlgo.Native,
self.kernel_size, self.stride, self.padding, self.dilation, out_padding,
False)
if input.benchmark:
torch.cuda.synchronize()
interval = time.time() - t
out_tensor.benchmark_record[
self.name]["indice_gen_time"].append(interval)
t = time.time()
outids, indice_pairs, indice_pairs_num = ops.get_indice_pairs( if self.indice_key is not None:
indices, datas = input.find_indice_pair(self.indice_key)
batch_size, if datas is None:
spatial_shape, indice_data = IndiceData(outids,
ConvAlgo.Native, indices,
self.kernel_size, indice_pairs,
self.stride, indice_pairs_num,
self.padding, spatial_shape,
self.dilation, is_subm=False,
0, algo=self.algo)
False) indice_dict[self.indice_key] = indice_data
if input.benchmark: else:
torch.cuda.synchronize() raise ValueError(f"indice key {self.indice_key} exists")
interval = time.time() - t
out_tensor.benchmark_record[self.name]["indice_gen_time"].append(
interval)
t = time.time()
if self.indice_key is not None: out_features = Fsp.indice_maxpool(features,
datas = input.find_indice_pair(self.indice_key) indice_pairs.to(device),
if datas is None: indice_pairs_num.to(device),
indice_data = IndiceData(outids, indices, indice_pairs, outids.shape[0])
indice_pairs_num, spatial_shape, is_subm=False) else:
input.indice_dict[self.indice_key] = indice_data res = ops.get_indice_pairs_implicit_gemm(indices,
else: batch_size,
raise ValueError("indice data exists") spatial_shape,
self.algo,
ksize=self.kernel_size,
stride=self.stride,
padding=self.padding,
dilation=self.dilation,
out_padding=out_padding,
subm=self.subm,
is_train=self.training,
alloc=input.thrust_allocator)
outids = res[0]
num_inds_per_loc = res[1]
pair_fwd = res[2]
pair_bwd = res[3]
pair_mask_fwd_splits = res[4]
pair_mask_bwd_splits = res[5]
mask_argsort_fwd_splits = res[6]
mask_argsort_bwd_splits = res[7]
masks = res[8]
if self.indice_key is not None:
indice_data = ImplicitGemmIndiceData(
outids,
indices,
pair_fwd,
pair_bwd,
pair_mask_fwd_splits=pair_mask_fwd_splits,
pair_mask_bwd_splits=pair_mask_bwd_splits,
mask_argsort_fwd_splits=mask_argsort_fwd_splits,
mask_argsort_bwd_splits=mask_argsort_bwd_splits,
masks=masks,
is_subm=self.subm,
out_spatial_shape=out_spatial_shape,
algo=self.algo)
msg = f"your indice key {self.indice_key} already exists in this sparse tensor."
assert self.indice_key not in indice_dict, msg
indice_dict[self.indice_key] = indice_data
out_features = Fsp.indice_maxpool_implicit_gemm(
features, pair_fwd, pair_bwd, outids.shape[0])
out_features = Fsp.indice_maxpool(features, indice_pairs.to(device),
indice_pairs_num.to(device),
outids.shape[0])
if input.benchmark: if input.benchmark:
torch.cuda.synchronize() torch.cuda.synchronize()
interval = time.time() - t interval = time.time() - t
...@@ -137,6 +207,7 @@ class SparseMaxPool(SparseModule): ...@@ -137,6 +207,7 @@ class SparseMaxPool(SparseModule):
out_features.shape[0]) out_features.shape[0])
out_tensor = out_tensor.replace_feature(out_features) out_tensor = out_tensor.replace_feature(out_features)
out_tensor.indices = outids out_tensor.indices = outids
out_tensor.indice_dict = indice_dict
out_tensor.spatial_shape = out_spatial_shape out_tensor.spatial_shape = out_spatial_shape
return out_tensor return out_tensor
...@@ -148,6 +219,7 @@ class SparseMaxPool1d(SparseMaxPool): ...@@ -148,6 +219,7 @@ class SparseMaxPool1d(SparseMaxPool):
padding=0, padding=0,
dilation=1, dilation=1,
indice_key=None, indice_key=None,
algo: Optional[ConvAlgo] = None,
name=None): name=None):
super(SparseMaxPool1d, self).__init__(1, super(SparseMaxPool1d, self).__init__(1,
kernel_size, kernel_size,
...@@ -155,8 +227,10 @@ class SparseMaxPool1d(SparseMaxPool): ...@@ -155,8 +227,10 @@ class SparseMaxPool1d(SparseMaxPool):
padding, padding,
dilation, dilation,
indice_key=indice_key, indice_key=indice_key,
algo=algo,
name=name) name=name)
class SparseMaxPool2d(SparseMaxPool): class SparseMaxPool2d(SparseMaxPool):
def __init__(self, def __init__(self,
kernel_size, kernel_size,
...@@ -164,6 +238,7 @@ class SparseMaxPool2d(SparseMaxPool): ...@@ -164,6 +238,7 @@ class SparseMaxPool2d(SparseMaxPool):
padding=0, padding=0,
dilation=1, dilation=1,
indice_key=None, indice_key=None,
algo: Optional[ConvAlgo] = None,
name=None): name=None):
super(SparseMaxPool2d, self).__init__(2, super(SparseMaxPool2d, self).__init__(2,
kernel_size, kernel_size,
...@@ -171,6 +246,7 @@ class SparseMaxPool2d(SparseMaxPool): ...@@ -171,6 +246,7 @@ class SparseMaxPool2d(SparseMaxPool):
padding, padding,
dilation, dilation,
indice_key=indice_key, indice_key=indice_key,
algo=algo,
name=name) name=name)
...@@ -181,6 +257,7 @@ class SparseMaxPool3d(SparseMaxPool): ...@@ -181,6 +257,7 @@ class SparseMaxPool3d(SparseMaxPool):
padding=0, padding=0,
dilation=1, dilation=1,
indice_key=None, indice_key=None,
algo: Optional[ConvAlgo] = None,
name=None): name=None):
super(SparseMaxPool3d, self).__init__(3, super(SparseMaxPool3d, self).__init__(3,
kernel_size, kernel_size,
...@@ -188,8 +265,10 @@ class SparseMaxPool3d(SparseMaxPool): ...@@ -188,8 +265,10 @@ class SparseMaxPool3d(SparseMaxPool):
padding, padding,
dilation, dilation,
indice_key=indice_key, indice_key=indice_key,
algo=algo,
name=name) name=name)
class SparseMaxPool4d(SparseMaxPool): class SparseMaxPool4d(SparseMaxPool):
def __init__(self, def __init__(self,
kernel_size, kernel_size,
...@@ -197,6 +276,7 @@ class SparseMaxPool4d(SparseMaxPool): ...@@ -197,6 +276,7 @@ class SparseMaxPool4d(SparseMaxPool):
padding=0, padding=0,
dilation=1, dilation=1,
indice_key=None, indice_key=None,
algo: Optional[ConvAlgo] = None,
name=None): name=None):
super(SparseMaxPool4d, self).__init__(4, super(SparseMaxPool4d, self).__init__(4,
kernel_size, kernel_size,
...@@ -204,4 +284,5 @@ class SparseMaxPool4d(SparseMaxPool): ...@@ -204,4 +284,5 @@ class SparseMaxPool4d(SparseMaxPool):
padding, padding,
dilation, dilation,
indice_key=indice_key, indice_key=indice_key,
algo=algo,
name=name) name=name)
...@@ -18,15 +18,17 @@ from torch.autograd import Function ...@@ -18,15 +18,17 @@ from torch.autograd import Function
import spconv.pytorch as spconv import spconv.pytorch as spconv
#from torch.nn import Module #from torch.nn import Module
from spconv.pytorch.modules import SparseModule from spconv.pytorch.modules import SparseModule
from spconv.pytorch.core import SparseConvTensor
from typing import List
class JoinTable(SparseModule): # Module): class JoinTable(SparseModule): # Module):
def forward(self, input): def forward(self, input: List[SparseConvTensor]):
output = spconv.SparseConvTensor( output = spconv.SparseConvTensor(
torch.cat([i.features for i in input], 1), input[1].indices, torch.cat([i.features for i in input], 1), input[0].indices,
input[1].spatial_shape, input[0].batch_size) input[0].spatial_shape, input[0].batch_size, input[0].grid, input[0].voxel_num,
output.indice_dict = input[1].indice_dict input[0].indice_dict)
output.grid = input[1].grid output.benchmark_record = input[1].benchmark_record
output.thrust_allocator = input[1].thrust_allocator
return output return output
def input_spatial_size(self, out_size): def input_spatial_size(self, out_size):
...@@ -34,14 +36,13 @@ class JoinTable(SparseModule): # Module): ...@@ -34,14 +36,13 @@ class JoinTable(SparseModule): # Module):
class AddTable(SparseModule): # Module): class AddTable(SparseModule): # Module):
def forward(self, input): def forward(self, input: List[SparseConvTensor]):
output = spconv.SparseConvTensor(sum([i.features for i in input]), output = spconv.SparseConvTensor(
input[1].indices, sum([i.features for i in input]), input[0].indices,
input[1].spatial_shape, input[0].spatial_shape, input[0].batch_size, input[0].grid, input[0].voxel_num,
input[1].batch_size) input[0].indice_dict)
output.indice_dict = input[1].indice_dict output.benchmark_record = input[1].benchmark_record
output.grid = input[1].grid output.thrust_allocator = input[1].thrust_allocator
return output return output
def input_spatial_size(self, out_size): def input_spatial_size(self, out_size):
......
# Copyright 2021 Yan Yan
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
import torch
from cumm import tensorview as tv
from spconv.core_cc.csrc.sparse.all import SpconvOps
from spconv.pytorch.cppcore import torch_tensor_to_tv, get_current_stream
class PointToVoxel(object):
"""WARNING: you MUST construct PointToVoxel AFTER set device.
"""
def __init__(self,
vsize_xyz: List[float],
coors_range_xyz: List[float],
num_point_features: int,
max_num_voxels: int,
max_num_points_per_voxel: int,
device: torch.device = torch.device("cpu:0")):
self.ndim = len(vsize_xyz)
self.device = device
vsize, grid_size, grid_stride, coors_range = SpconvOps.calc_point2voxel_meta_data(
vsize_xyz, coors_range_xyz)
self.num_point_features = num_point_features
self.max_num_voxels = max_num_voxels
self.max_num_points_per_voxel = max_num_points_per_voxel
self.vsize = vsize
self.grid_size = grid_size
self.grid_stride = grid_stride
self.coors_range = coors_range
self.voxels = torch.zeros(
[max_num_voxels, max_num_points_per_voxel, num_point_features],
dtype=torch.float32,
device=device)
self.indices = torch.zeros([max_num_voxels, self.ndim],
dtype=torch.int32,
device=device)
self.num_per_voxel = torch.zeros([max_num_voxels],
dtype=torch.int32,
device=device)
if device.type == "cpu":
self.hashdata = torch.full(grid_size,
-1,
dtype=torch.int32,
device=device)
self.point_indice_data = torch.Tensor()
else:
self.hashdata = torch.empty([1, 2],
dtype=torch.int64,
device=device)
self.point_indice_data = torch.empty([1],
dtype=torch.int64,
device=device)
def __call__(self,
pc: torch.Tensor,
clear_voxels: bool = True,
empty_mean: bool = False):
assert pc.device.type == self.device.type, "your pc device is wrong"
expected_hash_data_num = pc.shape[0] * 2
with torch.no_grad():
if self.device.type != "cpu":
if self.hashdata.shape[0] < expected_hash_data_num:
self.hashdata = torch.empty([expected_hash_data_num, 2],
dtype=torch.int64,
device=self.device)
if self.point_indice_data.shape[0] < pc.shape[0]:
self.point_indice_data = torch.empty([pc.shape[0]],
dtype=torch.int64,
device=self.device)
pc_tv = torch_tensor_to_tv(pc)
stream = get_current_stream()
voxels_tv = torch_tensor_to_tv(self.voxels)
indices_tv = torch_tensor_to_tv(self.indices)
num_per_voxel_tv = torch_tensor_to_tv(self.num_per_voxel)
hashdata_tv = torch_tensor_to_tv(self.hashdata,
dtype=tv.custom128,
shape=[self.hashdata.shape[0]])
point_indice_data_tv = torch_tensor_to_tv(self.point_indice_data)
res = SpconvOps.point2voxel_cuda(pc_tv, voxels_tv, indices_tv,
num_per_voxel_tv, hashdata_tv,
point_indice_data_tv, self.vsize,
self.grid_size, self.grid_stride,
self.coors_range, empty_mean,
clear_voxels, stream)
num_voxels = res[0].shape[0]
else:
pc_tv = torch_tensor_to_tv(pc)
stream = get_current_stream()
voxels_tv = torch_tensor_to_tv(self.voxels)
indices_tv = torch_tensor_to_tv(self.indices)
num_per_voxel_tv = torch_tensor_to_tv(self.num_per_voxel)
hashdata_tv = torch_tensor_to_tv(self.hashdata, dtype=tv.int32)
res = SpconvOps.point2voxel_cpu(pc_tv, voxels_tv, indices_tv,
num_per_voxel_tv, hashdata_tv,
self.vsize, self.grid_size,
self.grid_stride, self.coors_range,
empty_mean, clear_voxels)
num_voxels = res[0].shape[0]
return (self.voxels[:num_voxels], self.indices[:num_voxels],
self.num_per_voxel[:num_voxels])
...@@ -14,12 +14,17 @@ ...@@ -14,12 +14,17 @@
import numpy as np import numpy as np
from cumm import tensorview as tv from cumm import tensorview as tv
from spconv.core_cc.csrc.sparse.all.ops1d import Point2Voxel as Point2VoxelGPU1d
from spconv.core_cc.csrc.sparse.all.ops2d import Point2Voxel as Point2VoxelGPU2d
from spconv.core_cc.csrc.sparse.all.ops3d import Point2Voxel as Point2VoxelGPU3d
from spconv.core_cc.csrc.sparse.all.ops4d import Point2Voxel as Point2VoxelGPU4d
from spconv.core_cc.csrc.sparse.all.ops_cpu1d import Point2VoxelCPU as Point2VoxelCPU1d from spconv.core_cc.csrc.sparse.all.ops_cpu1d import Point2VoxelCPU as Point2VoxelCPU1d
from spconv.core_cc.csrc.sparse.all.ops_cpu2d import Point2VoxelCPU as Point2VoxelCPU2d from spconv.core_cc.csrc.sparse.all.ops_cpu2d import Point2VoxelCPU as Point2VoxelCPU2d
from spconv.core_cc.csrc.sparse.all.ops_cpu3d import Point2VoxelCPU as Point2VoxelCPU3d from spconv.core_cc.csrc.sparse.all.ops_cpu3d import Point2VoxelCPU as Point2VoxelCPU3d
from spconv.core_cc.csrc.sparse.all.ops_cpu4d import Point2VoxelCPU as Point2VoxelCPU4d from spconv.core_cc.csrc.sparse.all.ops_cpu4d import Point2VoxelCPU as Point2VoxelCPU4d
\ No newline at end of file import spconv.core_cc.csrc.sparse.all as __all
IS_CPU_ONLY_BUILD = hasattr(__all, "ops1d")
if IS_CPU_ONLY_BUILD:
from spconv.core_cc.csrc.sparse.all.ops1d import Point2Voxel as Point2VoxelGPU1d
from spconv.core_cc.csrc.sparse.all.ops2d import Point2Voxel as Point2VoxelGPU2d
from spconv.core_cc.csrc.sparse.all.ops3d import Point2Voxel as Point2VoxelGPU3d
from spconv.core_cc.csrc.sparse.all.ops4d import Point2Voxel as Point2VoxelGPU4d
...@@ -47,74 +47,55 @@ BWG 0.003300189971923828 ...@@ -47,74 +47,55 @@ BWG 0.003300189971923828
""" """
STR1 = """ STR1 = """
SUBM 0.00036716461181640625 SUBM 0.0005137920379638672
G 0.0010955333709716797 F 0.0012662410736083984
G 0.0010745525360107422 F 0.0016875267028808594
REGU 0.0006923675537109375 REGU 0.0009055137634277344
M 0.0005242824554443359 M 0.0009114742279052734
SUBM 0.0003108978271484375 SUBM 0.00037789344787597656
G 0.0010905265808105469 F 0.0020329952239990234
G 0.0011067390441894531 F 0.001947641372680664
REGU 0.00058746337890625 REGU 0.0009374618530273438
M 0.0005304813385009766 M 0.00045609474182128906
SUBM 0.0002682209014892578 SUBM 0.0009856224060058594
G 0.0010945796966552734 F 0.0009992122650146484
G 0.0011165142059326172 F 0.0010600090026855469
REGU 0.0005419254302978516 REGU 0.0006346702575683594
M 0.0005164146423339844 M 0.0004057884216308594
SUBM 0.00021505355834960938 SUBM 0.0006394386291503906
G 0.0010805130004882812 F 0.0008478164672851562
G 0.0010516643524169922 F 0.0008838176727294922
REGU 0.00052642822265625 REGU 0.0007183551788330078
M 0.0004677772521972656 M 0.00025177001953125
SUBM 0.0002262592315673828 SUBM 0.0009539127349853516
G 0.0010986328125 F 0.0009481906890869141
G 0.0010256767272949219 F 0.0010502338409423828
REGU 0.0005693435668945312 REGU 0.0007147789001464844
M 0.00048661231994628906 M 0.000274658203125
SUBM 0.0002319812774658203 SUBM 0.0007004737854003906
G 0.0011110305786132812 F 0.0009715557098388672
G 0.0011196136474609375 F 0.0012331008911132812
REGU 0.0005295276641845703 REGU 0.0008800029754638672
M 0.0005729198455810547 M 0.0002167224884033203
SUBM 0.00023889541625976562 SUBM 0.00045108795166015625
G 0.0005326271057128906 F 0.0006735324859619141
G 0.0005140304565429688 F 0.0008375644683837891
""" """
STR2 = """ STR2 = """
SUBM 0.0003352165222167969 F Turing_f16f16f16f16f16tnt_m32n64k32m32n32k16A0T1688_NS00_C3_01LLL_1 0.0007038116455078125
G 0.001149892807006836 F Turing_f16f16f16f16f16tnt_m32n64k32m32n32k16A1T1688_NS00_C3_01LLL_1 0.0007627010345458984
G 0.0017066001892089844 F Turing_f16f16f16f16f16tnt_m64n128k32m32n64k32A1T1688_NS00_C3_01LLL_1 0.0007650852203369141
REGU 0.0006349086761474609 F Turing_f16f16f16f16f16tnt_m64n128k32m32n64k32A1T1688_NS00_C3_01LLL_1 0.0008864402770996094
M 0.00048804283142089844 F Turing_f16f16f16f16f16tnt_m64n128k32m32n64k32A1T1688_NS00_C3_01LLL_1 0.0004017353057861328
SUBM 0.00029850006103515625 F Turing_f16f16f16f16f16tnt_m32n128k64m32n32k32A1T1688_NS00_C3_01LLL_1 0.0006165504455566406
G 0.001767873764038086 F Turing_f16f16f16f16f16tnt_m64n64k32m32n32k32A1T1688_NS00_C3_01LLL_1 0.0005872249603271484
G 0.0020656585693359375 F Turing_f16f16f16f16f16tnt_m64n64k32m32n32k32A1T1688_NS00_C3_01LLL_1 0.0006289482116699219
REGU 0.0005462169647216797 F Turing_f16f16f16f16f16tnt_m32n64k32m32n32k16A1T1688_NS00_C3_01LLL_1 0.0002968311309814453
M 0.0005753040313720703 F Turing_f16f16f16f16f16tnt_m64n64k32m32n32k32A1T1688_NS00_C3_01LLL_1 0.0003299713134765625
SUBM 0.0002789497375488281 F Turing_f16f16f16f16f16tnt_m64n128k64m32n64k32A1T1688_NS00_C3_01LLL_1 0.0002288818359375
G 0.0012230873107910156 F Turing_f16f16f16f16f16tnt_m32n64k32m32n32k16A1T1688_NS00_C3_01LLL_1 0.0002830028533935547
G 0.0014438629150390625 F Turing_f16f16f16f16f16tnt_m32n64k32m32n32k16A1T1688_NS00_C3_01LLL_1 0.0001780986785888672
REGU 0.0005102157592773438 F Turing_f16f16f16f16f16tnt_m32n64k32m32n32k16A1T1688_NS00_C3_01LLL_1 0.0003058910369873047
M 0.0005676746368408203
SUBM 0.00020241737365722656
G 0.00102996826171875
G 0.0011174678802490234
REGU 0.0005424022674560547
M 0.0005102157592773438
SUBM 0.0001976490020751953
G 0.0010385513305664062
G 0.0010204315185546875
REGU 0.0005321502685546875
M 0.00047278404235839844
SUBM 0.00021529197692871094
G 0.0010280609130859375
G 0.0010151863098144531
REGU 0.0004942417144775391
M 0.0004811286926269531
SUBM 0.00020694732666015625
G 0.0005142688751220703
G 0.0005171298980712891
""" """
def _handle_lines(s: str): def _handle_lines(s: str):
arr = s.split(" ") arr = s.split(" ")
......
...@@ -18,7 +18,8 @@ from pathlib import Path ...@@ -18,7 +18,8 @@ from pathlib import Path
import numpy as np import numpy as np
import torch import torch
from torch import nn from torch import nn
from cumm import tensorview as tv from cumm import tensorview as tv
from spconv.core import ConvAlgo
import spconv.pytorch as spconv import spconv.pytorch as spconv
from spconv.utils import Point2VoxelCPU3d from spconv.utils import Point2VoxelCPU3d
...@@ -41,6 +42,8 @@ def waymo_data(batch_size=1): ...@@ -41,6 +42,8 @@ def waymo_data(batch_size=1):
class Net(nn.Module): class Net(nn.Module):
def __init__(self, shape, algo): def __init__(self, shape, algo):
super().__init__() super().__init__()
pool_algo = algo
# pool_algo = ConvAlgo.Native
self.net = spconv.SparseSequential( self.net = spconv.SparseSequential(
spconv.SubMConv3d(3, 64, 3, bias=False, indice_key="c0", spconv.SubMConv3d(3, 64, 3, bias=False, indice_key="c0",
algo=algo), algo=algo),
...@@ -64,9 +67,9 @@ class Net(nn.Module): ...@@ -64,9 +67,9 @@ class Net(nn.Module):
algo=algo), algo=algo),
# nn.BatchNorm1d(32), # nn.BatchNorm1d(32),
# nn.ReLU(), # nn.ReLU(),
spconv.SparseConv3d(64, 64, 2, 2, bias=False, indice_key="m0"), # spconv.SparseConv3d(64, 64, 2, 2, bias=False, indice_key="m0"),
# spconv.SparseMaxPool3d(2, 2), spconv.SparseMaxPool3d(2, 2, algo=pool_algo),
spconv.SubMConv3d(64, spconv.SubMConv3d(64,
96, 96,
3, 3,
...@@ -81,9 +84,8 @@ class Net(nn.Module): ...@@ -81,9 +84,8 @@ class Net(nn.Module):
algo=algo), algo=algo),
# nn.BatchNorm1d(64), # nn.BatchNorm1d(64),
# nn.ReLU(), # nn.ReLU(),
spconv.SparseConv3d(96, 96, 2, 2, bias=False, indice_key="m1"), # spconv.SparseConv3d(96, 96, 2, 2, bias=False, indice_key="m1"),
spconv.SparseMaxPool3d(2, 2, algo=pool_algo),
# spconv.SparseMaxPool3d(2, 2),
spconv.SubMConv3d(96, spconv.SubMConv3d(96,
128, 128,
3, 3,
...@@ -98,9 +100,9 @@ class Net(nn.Module): ...@@ -98,9 +100,9 @@ class Net(nn.Module):
algo=algo), algo=algo),
# nn.BatchNorm1d(128), # nn.BatchNorm1d(128),
# nn.ReLU(), # nn.ReLU(),
spconv.SparseConv3d(128, 128, 2, 2, bias=False, indice_key="m2"), # spconv.SparseConv3d(128, 128, 2, 2, bias=False, indice_key="m2"),
# spconv.SparseMaxPool3d(2, 2), spconv.SparseMaxPool3d(2, 2, algo=pool_algo),
spconv.SubMConv3d(128, spconv.SubMConv3d(128,
160, 160,
3, 3,
...@@ -115,9 +117,9 @@ class Net(nn.Module): ...@@ -115,9 +117,9 @@ class Net(nn.Module):
algo=algo), algo=algo),
# nn.BatchNorm1d(128), # nn.BatchNorm1d(128),
# nn.ReLU(), # nn.ReLU(),
spconv.SparseConv3d(160, 160, 2, 2, bias=False, indice_key="m3"), # spconv.SparseConv3d(160, 160, 2, 2, bias=False, indice_key="m3"),
# spconv.SparseMaxPool3d(2, 2), spconv.SparseMaxPool3d(2, 2, algo=pool_algo),
spconv.SubMConv3d(160, spconv.SubMConv3d(160,
192, 192,
3, 3,
...@@ -132,8 +134,8 @@ class Net(nn.Module): ...@@ -132,8 +134,8 @@ class Net(nn.Module):
algo=algo), algo=algo),
# nn.BatchNorm1d(128), # nn.BatchNorm1d(128),
# nn.ReLU(), # nn.ReLU(),
# spconv.SparseMaxPool3d(2, 2, indice_key="m4"), spconv.SparseMaxPool3d(2, 2, indice_key="m4", algo=pool_algo),
spconv.SparseConv3d(192, 192, 2, 2, bias=False, indice_key="m4"), # spconv.SparseConv3d(192, 192, 2, 2, bias=False, indice_key="m4"),
spconv.SubMConv3d(192, spconv.SubMConv3d(192,
224, 224,
...@@ -147,10 +149,10 @@ class Net(nn.Module): ...@@ -147,10 +149,10 @@ class Net(nn.Module):
bias=False, bias=False,
indice_key="c5", indice_key="c5",
algo=algo), algo=algo),
nn.BatchNorm1d(224), # nn.BatchNorm1d(224),
nn.ReLU(), # nn.ReLU(),
spconv.SparseConv3d(224, 224, 2, 2, bias=False, indice_key="m5"), # spconv.SparseConv3d(224, 224, 2, 2, bias=False, indice_key="m5"),
# spconv.SparseMaxPool3d(2, 2, indice_key="m5"), spconv.SparseMaxPool3d(2, 2, indice_key="m5", algo=pool_algo),
spconv.SubMConv3d(224, spconv.SubMConv3d(224,
256, 256,
3, 3,
...@@ -164,14 +166,14 @@ class Net(nn.Module): ...@@ -164,14 +166,14 @@ class Net(nn.Module):
indice_key="c6", indice_key="c6",
algo=algo), algo=algo),
nn.BatchNorm1d(256), # nn.BatchNorm1d(256),
nn.ReLU(), # nn.ReLU(),
spconv.SparseInverseConv3d(256, 128, 2, indice_key="m5", bias=False), # spconv.SparseInverseConv3d(256, 128, 2, indice_key="m5", bias=False, algo=algo),
nn.BatchNorm1d(128), # # nn.BatchNorm1d(128),
nn.ReLU(), # # nn.ReLU(),
spconv.SparseInverseConv3d(128, 64, 2, indice_key="m4", bias=False), # spconv.SparseInverseConv3d(128, 64, 2, indice_key="m4", bias=False, algo=algo),
) )
max_batch_size = 1 max_batch_size = 1
...@@ -238,6 +240,27 @@ class Net2(nn.Module): ...@@ -238,6 +240,27 @@ class Net2(nn.Module):
self.grid) self.grid)
return self.net(x) return self.net(x)
import numpy as np
from cumm import tensorview as tv
from spconv.core_cc.csrc.sparse.all import SpconvOps
import pickle
import torch
from spconv.pytorch.cppcore import torch_tensor_to_tv
def sort_bench():
with open("/home/yy/asd.pkl", "rb") as f:
a_th = pickle.load(f)
mask_argsort = torch.empty((1, a_th.shape[1]),
dtype=torch.int32,
device=a_th.device)
a = a_th.cpu().numpy()[0]
a_tv = torch_tensor_to_tv(a_th)
mask_argsort_tv = torch_tensor_to_tv(mask_argsort)
for i in range(10):
a_tv_1 = a_tv.clone()
SpconvOps.sort_1d_by_key(a_tv_1[0], mask_argsort_tv[0])
def main(): def main():
import pickle import pickle
...@@ -252,45 +275,46 @@ def main(): ...@@ -252,45 +275,46 @@ def main():
print(voxels.shape) print(voxels.shape)
# voxels = voxels[:100] # voxels = voxels[:100]
# coors = coors[:100] # coors = coors[:100]
dtype = torch.float32 dtype = torch.float16
device = torch.device("cuda:0")
voxels_th = torch.from_numpy(voxels).cuda().to(dtype) voxels_th = torch.from_numpy(voxels).to(device).to(dtype)
coors_th = torch.from_numpy(coors).cuda().int() coors_th = torch.from_numpy(coors).to(device).int()
voxels_th.requires_grad = True voxels_th.requires_grad = True
algo = spconv.ConvAlgo.Native algo = spconv.ConvAlgo.MaskImplicitGemm
net = Net(spatial_shape, algo).cuda().eval().to(dtype) net = Net(spatial_shape, algo).to(device).eval().to(dtype).train()
print(coors_th.shape) print(coors_th.shape)
out = net(voxels_th, coors_th, 1) out = net(voxels_th, coors_th, 1)
print(out.spatial_shape) print(out.spatial_shape)
print(voxels.mean(), voxels.max(), voxels.min()) print(voxels.mean(), voxels.max(), voxels.min())
dout = np.random.uniform(-0.2, 0.2, dout = np.random.uniform(-0.2, 0.2,
out.features.shape).astype(np.float32) out.features.shape).astype(np.float32)
dout_t = torch.from_numpy(dout).cuda().to(dtype) dout_t = torch.from_numpy(dout).to(device).to(dtype)
print(out.spatial_shape, out.features.mean(), out.features.max(), out.features.min()) print(out.spatial_shape, out.features.mean(), out.features.max(), out.features.min())
# times = []
# with torch.no_grad():
# for i in range(20):
# print("------------")
# torch.cuda.synchronize()
# t = time.time()
# out_nograd = net(voxels_th, coors_th, 1)
# torch.cuda.synchronize()
# times.append(time.time() - t)
# print("spconv time", np.mean(times[10:]))
times = [] times = []
with torch.no_grad():
for i in range(20):
print("------------")
torch.cuda.synchronize()
t = time.time()
out_nograd = net(voxels_th, coors_th, 1)
torch.cuda.synchronize()
# sort_bench()
times.append(time.time() - t)
print("spconv time", np.mean(times[10:]))
# times = []
for i in range(1): # for i in range(10):
out = net(voxels_th, coors_th, 1) # out = net(voxels_th, coors_th, 1)
print("------------") # print("------------")
torch.cuda.synchronize() # torch.cuda.synchronize()
t = time.time() # t = time.time()
out.features.backward(dout_t) # out.features.backward(dout_t)
torch.cuda.synchronize() # torch.cuda.synchronize()
times.append(time.time() - t) # times.append(time.time() - t)
# # print((net.grid == -1).float().sum(), net.grid.numel()) # print((net.grid == -1).float().sum(), net.grid.numel())
# # print("spconv time", time.time() - t) # print("spconv time", time.time() - t)
# print("spconv bw time", np.mean(times[5:])) # print("spconv bw time", np.mean(times[5:]))
......
...@@ -19,12 +19,16 @@ from pathlib import Path ...@@ -19,12 +19,16 @@ from pathlib import Path
import numpy as np import numpy as np
import torch import torch
from torch import nn from torch import nn
from spconv.core import ConvAlgo
import spconv.pytorch as spconv import spconv.pytorch as spconv
from spconv.test_utils import TestCase, generate_sparse_data, params_grid from spconv.test_utils import TestCase, generate_sparse_data, params_grid
from spconv.constants import FILTER_HWIO from spconv.constants import FILTER_HWIO
# import sparseconvnet as scn # import sparseconvnet as scn
# we must disable tf32 to increase reference precision.
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
class SparseConv3dTestTorch(nn.Module): class SparseConv3dTestTorch(nn.Module):
def __init__(self, def __init__(self,
...@@ -37,8 +41,9 @@ class SparseConv3dTestTorch(nn.Module): ...@@ -37,8 +41,9 @@ class SparseConv3dTestTorch(nn.Module):
stride, stride,
padding, padding,
dilation, dilation,
algo=spconv.ConvAlgo.Native): algo=spconv.ConvAlgo.MaskSplitImplicitGemm):
super().__init__() super().__init__()
self.algo = algo
layers = [ layers = [
spconv.SparseConv3d(in_channels, spconv.SparseConv3d(in_channels,
out_channels, out_channels,
...@@ -347,6 +352,7 @@ def scatter_nd(indices, updates, shape): ...@@ -347,6 +352,7 @@ def scatter_nd(indices, updates, shape):
class TestSpConv(TestCase): class TestSpConv(TestCase):
def testSpConv3d(self): def testSpConv3d(self):
np.random.seed(484) np.random.seed(484)
torch.manual_seed(48848)
devices = ["cuda:0"] devices = ["cuda:0"]
shapes = [[19, 18, 17]] shapes = [[19, 18, 17]]
batchsizes = [1, 2] batchsizes = [1, 2]
...@@ -357,17 +363,23 @@ class TestSpConv(TestCase): ...@@ -357,17 +363,23 @@ class TestSpConv(TestCase):
strides = [1, 2, 3] strides = [1, 2, 3]
paddings = [0, 1, 2] paddings = [0, 1, 2]
dilations = [1, 2, 3] dilations = [1, 2, 3]
# strides = [1] algos = [ConvAlgo.Native, ConvAlgo.MaskImplicitGemm, ConvAlgo.MaskSplitImplicitGemm]
# paddings = [0] algos = [ConvAlgo.MaskSplitImplicitGemm]
# dilations = [1]
for dev, shape, bs, IC, OC, k, s, p, d in params_grid( for dev, shape, bs, IC, OC, k, s, p, d, al in params_grid(
devices, shapes, batchsizes, in_channels, out_channels, ksizes, devices, shapes, batchsizes, in_channels, out_channels, ksizes,
strides, paddings, dilations): strides, paddings, dilations, algos):
if all([s > 1, d > 1]): if all([s > 1, d > 1]):
continue # don't support this. continue # don't support this.
print(k, s, p, d)
device = torch.device(dev) device = torch.device(dev)
num_points = [1000] * bs num_points = [1000] * bs
dtype = torch.float32
net = SparseConv3dTestTorch(1, 3, shape, IC, OC, k, s, p,
d, algo=al).to(device).to(dtype)
net_ref = Conv3dTestTorch(1, 3, shape, IC, OC, k, s, p,
d).to(device).to(dtype)
sparse_dict = generate_sparse_data(shape, num_points, IC) sparse_dict = generate_sparse_data(shape, num_points, IC)
features = np.ascontiguousarray(sparse_dict["features"]).astype( features = np.ascontiguousarray(sparse_dict["features"]).astype(
...@@ -375,29 +387,31 @@ class TestSpConv(TestCase): ...@@ -375,29 +387,31 @@ class TestSpConv(TestCase):
indices = np.ascontiguousarray( indices = np.ascontiguousarray(
sparse_dict["indices"][:, [3, 0, 1, 2]]).astype(np.int32) sparse_dict["indices"][:, [3, 0, 1, 2]]).astype(np.int32)
features_dense = sparse_dict["features_dense"].astype(np.float32) features_dense = sparse_dict["features_dense"].astype(np.float32)
if FILTER_HWIO:
filters = np.random.uniform(0, 1, size=[k, k, k, IC,
OC]).astype(np.float32)
else:
filters = np.random.uniform(0, 1, size=[k, k, k, OC,
IC]).astype(np.float32)
dtype = torch.float16
indices_t = torch.from_numpy(indices).int().to(device) indices_t = torch.from_numpy(indices).int().to(device)
features_t = torch.from_numpy(features).to(device).to(dtype) features_t = torch.from_numpy(features).to(device).to(dtype)
features_t.requires_grad = True features_t.requires_grad = True
features_dense_t = torch.from_numpy(features_dense).to(device).to(dtype) features_dense_t = torch.from_numpy(features_dense).to(device).to(dtype)
features_dense_t.requires_grad = True features_dense_t.requires_grad = True
net = SparseConv3dTestTorch(1, 3, shape, IC, OC, k, s, p, if net.algo == ConvAlgo.Native:
d).to(device).to(dtype) if FILTER_HWIO:
net_ref = Conv3dTestTorch(1, 3, shape, IC, OC, k, s, p, filters = np.random.uniform(-1, 1, size=[k, k, k, IC,
d).to(device).to(dtype) OC]).astype(np.float32)
filters_t = torch.from_numpy(filters).to(device).to(dtype) else:
if FILTER_HWIO: filters = np.random.uniform(-1, 1, size=[k, k, k, OC,
net_ref.net[0].weight.data[:] = filters_t.permute(4, 3, 0, 1, IC]).astype(np.float32)
2).contiguous() filters_t = torch.from_numpy(filters).to(device).to(dtype)
if FILTER_HWIO:
net_ref.net[0].weight.data[:] = filters_t.permute(4, 3, 0, 1,
2).contiguous()
else:
net_ref.net[0].weight.data[:] = filters_t.permute(3, 4, 0, 1,
2).contiguous()
else: else:
net_ref.net[0].weight.data[:] = filters_t.permute(3, 4, 0, 1, filters = np.random.uniform(-1, 1, size=[OC, k, k, k, IC]).astype(np.float32)
2).contiguous() filters_t = torch.from_numpy(filters).to(device).to(dtype)
net_ref.net[0].weight.data[:] = filters_t.permute(0, 4, 1, 2,
3).contiguous()
net.net[0].weight.data[:] = filters_t net.net[0].weight.data[:] = filters_t
out_ref = net_ref(features_dense_t) out_ref = net_ref(features_dense_t)
out = net(features_t, indices_t, bs).dense() out = net(features_t, indices_t, bs).dense()
...@@ -420,11 +434,14 @@ class TestSpConv(TestCase): ...@@ -420,11 +434,14 @@ class TestSpConv(TestCase):
for layer, layer_ref in zip(net.net, net_ref.net): for layer, layer_ref in zip(net.net, net_ref.net):
dw = layer.weight.grad.detach().cpu().numpy() dw = layer.weight.grad.detach().cpu().numpy()
dw_ref = layer_ref.weight.grad.detach().cpu().numpy() dw_ref = layer_ref.weight.grad.detach().cpu().numpy()
if FILTER_HWIO: if net.algo == ConvAlgo.Native:
if FILTER_HWIO:
dw = dw.transpose(4, 3, 0, 1, 2) dw = dw.transpose(4, 3, 0, 1, 2)
else:
dw = dw.transpose(3, 4, 0, 1, 2)
else: else:
dw = dw.transpose(3, 4, 0, 1, 2) # OHWI -> OIHW
dw = dw.transpose(0, 4, 1, 2, 3)
self.assertAllClose(dw, dw_ref, atol=1e-4) self.assertAllClose(dw, dw_ref, atol=1e-4)
self.assertAllClose(din_np, din_sparse_np, atol=1e-4) self.assertAllClose(din_np, din_sparse_np, atol=1e-4)
...@@ -592,10 +609,10 @@ class TestSpConv(TestCase): ...@@ -592,10 +609,10 @@ class TestSpConv(TestCase):
strides = [1, 2, 3] strides = [1, 2, 3]
paddings = [0, 1] paddings = [0, 1]
dilations = [1, 2, 3] dilations = [1, 2, 3]
ksizes = [2] # ksizes = [2]
strides = [2] # strides = [2]
paddings = [0] # paddings = [0]
dilations = [1] # dilations = [1]
for dev, shape, bs, IC, OC, k, s, p, d in params_grid( for dev, shape, bs, IC, OC, k, s, p, d in params_grid(
devices, shapes, batchsizes, in_channels, out_channels, ksizes, devices, shapes, batchsizes, in_channels, out_channels, ksizes,
...@@ -797,4 +814,4 @@ if __name__ == '__main__': ...@@ -797,4 +814,4 @@ if __name__ == '__main__':
# main(algo=spconv.ConvAlgo.SparseConvNet, dtype=torch.float32) # main(algo=spconv.ConvAlgo.SparseConvNet, dtype=torch.float32)
# TestCase().assertAllClose(out_my, out_ref) # TestCase().assertAllClose(out_my, out_ref)
# unittest.main() # unittest.main()
TestSpConv().testSpConv3d() TestSpConv().testSpMaxPool3d()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment