Commit 21bb00ae authored by Yan Yan's avatar Yan Yan
Browse files

still working on c++ only

parent 899008fa
<!--
Copyright 2022 Yan Yan
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
TODO
\ No newline at end of file
...@@ -175,7 +175,7 @@ if disable_jit is not None and disable_jit == "1": ...@@ -175,7 +175,7 @@ if disable_jit is not None and disable_jit == "1":
std = "c++14" std = "c++14"
else: else:
std = "c++17" std = "c++17"
if CUMM_CPU_ONLY_BUILD: if not CUMM_CPU_ONLY_BUILD:
gemmtuner = GemmTunerSimple(cu) gemmtuner = GemmTunerSimple(cu)
gemmtuner.namespace = "csrc.sparse.convops.gemmops" gemmtuner.namespace = "csrc.sparse.convops.gemmops"
convtuner = ConvTunerSimple(convcu) convtuner = ConvTunerSimple(convcu)
......
...@@ -62,8 +62,7 @@ if project_is_installed(PACKAGE_NAME) and project_is_editable( ...@@ -62,8 +62,7 @@ if project_is_installed(PACKAGE_NAME) and project_is_editable(
CompileInfo(), CompileInfo(),
ExternalAllocator(), ExternalAllocator(),
ExternalSpconvMatmul(), ExternalSpconvMatmul(),
SimpleExternalSpconvMatmul(), SimpleExternalSpconvMatmul(), # for debug, won't be included in release
] ]
pccm.builder.build_pybind(cus, pccm.builder.build_pybind(cus,
PACKAGE_ROOT / "core_cc", PACKAGE_ROOT / "core_cc",
......
...@@ -64,7 +64,7 @@ SPCONV_DEBUG_CPP_ONLY = project_is_editable(PACKAGE_NAME) ...@@ -64,7 +64,7 @@ SPCONV_DEBUG_CPP_ONLY = project_is_editable(PACKAGE_NAME)
class AllocKeys: class AllocKeys:
Pair = "Pair" PairBwd = "PairBwd"
IndiceNumPerLoc = "IndiceNumPerLoc" IndiceNumPerLoc = "IndiceNumPerLoc"
PairMask = "PairMask" PairMask = "PairMask"
MaskArgSort = "MaskArgSort" MaskArgSort = "MaskArgSort"
...@@ -102,4 +102,6 @@ SPCONV_DEBUG_WEIGHT = False ...@@ -102,4 +102,6 @@ SPCONV_DEBUG_WEIGHT = False
SPCONV_CPP_INDICE_PAIRS = True SPCONV_CPP_INDICE_PAIRS = True
SPCONV_CPP_INDICE_PAIRS_IGEMM = True SPCONV_CPP_INDICE_PAIRS_IGEMM = True
SPCONV_CPP_GEMM = True SPCONV_CPP_GEMM = True
\ No newline at end of file
SPCONV_FX_TRACE_MODE = os.getenv("SPCONV_FX_TRACE_MODE", "0") == "1"
\ No newline at end of file
...@@ -240,6 +240,28 @@ class SpconvOps: ...@@ -240,6 +240,28 @@ class SpconvOps:
""" """
... ...
@staticmethod @staticmethod
def avgpool_implicit_gemm_forward(out: Tensor, inp: Tensor, inds: Tensor, count_out: Tensor, stream: int = 0) -> None:
"""
Args:
out:
inp:
inds:
count_out:
stream:
"""
...
@staticmethod
def avgpool_implicit_gemm_backward(dout: Tensor, dinp: Tensor, inds: Tensor, count_out: Tensor, stream: int = 0) -> None:
"""
Args:
dout:
dinp:
inds:
count_out:
stream:
"""
...
@staticmethod
def maxpool_forward_cpu(out: Tensor, inp: Tensor, out_inds: Tensor, in_inds: Tensor) -> None: def maxpool_forward_cpu(out: Tensor, inp: Tensor, out_inds: Tensor, in_inds: Tensor) -> None:
""" """
Args: Args:
...@@ -280,15 +302,6 @@ class SpconvOps: ...@@ -280,15 +302,6 @@ class SpconvOps:
""" """
... ...
@staticmethod @staticmethod
def sort_1d_by_key(data: Tensor, indices: Tensor = Tensor(), stream: int = 0) -> Tensor:
"""
Args:
data:
indices:
stream:
"""
...
@staticmethod
def sort_1d_by_key_allocator(data: Tensor, alloc_func, indices: Tensor = Tensor(), stream: int = 0) -> Tensor: def sort_1d_by_key_allocator(data: Tensor, alloc_func, indices: Tensor = Tensor(), stream: int = 0) -> Tensor:
""" """
Args: Args:
...@@ -348,6 +361,24 @@ class SpconvOps: ...@@ -348,6 +361,24 @@ class SpconvOps:
""" """
... ...
@staticmethod @staticmethod
def maximum_value_int(data: Tensor, value: int, stream_int: int) -> None:
"""
Args:
data:
value:
stream_int:
"""
...
@staticmethod
def sort_1d_by_key(data: Tensor, indices: Tensor = Tensor(), stream: int = 0) -> Tensor:
"""
Args:
data:
indices:
stream:
"""
...
@staticmethod
def calc_point2voxel_meta_data(vsize_xyz: List[float], coors_range_xyz: List[float]) -> Tuple[List[float], List[int], List[int], List[float]]: def calc_point2voxel_meta_data(vsize_xyz: List[float], coors_range_xyz: List[float]) -> Tuple[List[float], List[int], List[int], List[float]]:
""" """
Args: Args:
...@@ -407,6 +438,18 @@ class SpconvOps: ...@@ -407,6 +438,18 @@ class SpconvOps:
""" """
... ...
@staticmethod @staticmethod
def get_indice_gen_tensors_from_workspace(workspace, kv: int, num_act_in: int, num_act_out_bound: int, subm: bool, use_int64_hash_k: bool) -> Dict[str, Tensor]:
"""
Args:
workspace:
kv:
num_act_in:
num_act_out_bound:
subm:
use_int64_hash_k:
"""
...
@staticmethod
def get_indice_pairs_implicit_gemm(allocator, indices: Tensor, batch_size: int, input_dims: List[int], algo: int, ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], out_padding: List[int], subm: bool, transposed: bool, is_train: bool, stream_int: int = 0, num_out_act_bound: int = -1) -> Tuple[Tensor, int]: def get_indice_pairs_implicit_gemm(allocator, indices: Tensor, batch_size: int, input_dims: List[int], algo: int, ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], out_padding: List[int], subm: bool, transposed: bool, is_train: bool, stream_int: int = 0, num_out_act_bound: int = -1) -> Tuple[Tensor, int]:
""" """
Args: Args:
...@@ -428,7 +471,7 @@ class SpconvOps: ...@@ -428,7 +471,7 @@ class SpconvOps:
""" """
... ...
@staticmethod @staticmethod
def get_indice_pairs(allocator, indices: Tensor, batch_size: int, input_dims: List[int], algo: int, ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], out_padding: List[int], subm: bool, transposed: bool, stream_int: int = 0, num_out_act_bound: int = -1) -> int: def get_indice_pairs(allocator, indices: Tensor, batch_size: int, input_dims: List[int], algo: int, ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], out_padding: List[int], subm: bool, transposed: bool, stream_int: int = 0, num_out_act_bound: int = -1, num_input_act_bound: int = -1) -> int:
""" """
Args: Args:
allocator: allocator:
...@@ -445,5 +488,6 @@ class SpconvOps: ...@@ -445,5 +488,6 @@ class SpconvOps:
transposed: transposed:
stream_int: stream_int:
num_out_act_bound: num_out_act_bound:
num_input_act_bound:
""" """
... ...
...@@ -2,29 +2,29 @@ from typing import overload, Any, Callable, Dict, List, Optional, Set, Tuple, Ty ...@@ -2,29 +2,29 @@ from typing import overload, Any, Callable, Dict, List, Optional, Set, Tuple, Ty
from pccm.stubs import EnumValue, EnumClassValue from pccm.stubs import EnumValue, EnumClassValue
from cumm.tensorview import Tensor from cumm.tensorview import Tensor
class ExternalAllocator: class ExternalAllocator:
def zeros(self, name: str, shape: List[int], dtype: int, device: int, is_temp_memory: bool = False, stream: int = 0) -> Tensor: def zeros(self, name: str, shape: List[int], dtype: int, device: int, stream: int = 0, is_temp_memory: bool = False) -> Tensor:
""" """
Args: Args:
name: name:
shape: shape:
dtype: dtype:
device: device:
is_temp_memory:
stream: stream:
is_temp_memory:
""" """
... ...
def empty(self, name: str, shape: List[int], dtype: int, device: int, is_temp_memory: bool = False, stream: int = 0) -> Tensor: def empty(self, name: str, shape: List[int], dtype: int, device: int, stream: int = 0, is_temp_memory: bool = False) -> Tensor:
""" """
Args: Args:
name: name:
shape: shape:
dtype: dtype:
device: device:
is_temp_memory:
stream: stream:
is_temp_memory:
""" """
... ...
def full_int(self, name: str, shape: List[int], value: int, dtype: int, device: int, is_temp_memory: bool = False, stream: int = 0) -> Tensor: def full_int(self, name: str, shape: List[int], value: int, dtype: int, device: int, stream: int = 0, is_temp_memory: bool = False) -> Tensor:
""" """
Args: Args:
name: name:
...@@ -32,11 +32,11 @@ class ExternalAllocator: ...@@ -32,11 +32,11 @@ class ExternalAllocator:
value: value:
dtype: dtype:
device: device:
is_temp_memory:
stream: stream:
is_temp_memory:
""" """
... ...
def full_float(self, name: str, shape: List[int], value: float, dtype: int, device: int, is_temp_memory: bool = False, stream: int = 0) -> Tensor: def full_float(self, name: str, shape: List[int], value: float, dtype: int, device: int, stream: int = 0, is_temp_memory: bool = False) -> Tensor:
""" """
Args: Args:
name: name:
...@@ -44,8 +44,8 @@ class ExternalAllocator: ...@@ -44,8 +44,8 @@ class ExternalAllocator:
value: value:
dtype: dtype:
device: device:
is_temp_memory:
stream: stream:
is_temp_memory:
""" """
... ...
def get_tensor_by_name(self, name: str) -> Tensor: def get_tensor_by_name(self, name: str) -> Tensor:
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import List
from cumm.common import TensorView, TensorViewCPU, TensorViewKernel, ThrustLib, GemmBasicHost from cumm.common import TensorView, TensorViewCPU, TensorViewKernel, ThrustLib, GemmBasicHost
import cumm import cumm
from cumm.conv.bases import ConvOpType, NHWC from cumm.conv.bases import ConvOpType, NHWC
...@@ -27,6 +28,7 @@ from .maxpool import IndiceMaxPool, IndiceMaxPoolCPU ...@@ -27,6 +28,7 @@ from .maxpool import IndiceMaxPool, IndiceMaxPoolCPU
from .gather import GatherCPU from .gather import GatherCPU
from .alloc import ExternalAllocator, ThrustAllocator from .alloc import ExternalAllocator, ThrustAllocator
from spconv.constants import AllocKeys from spconv.constants import AllocKeys
import re
class CustomThrustLib(pccm.Class): class CustomThrustLib(pccm.Class):
def __init__(self): def __init__(self):
...@@ -70,6 +72,11 @@ class ThrustCustomAllocatorV2(pccm.Class, pccm.pybind.PybindClassMixin): ...@@ -70,6 +72,11 @@ class ThrustCustomAllocatorV2(pccm.Class, pccm.pybind.PybindClassMixin):
code.arg("num_bytes", "size_t") code.arg("num_bytes", "size_t")
return code return code
def to_snake_case(name):
name = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name)
name = re.sub('__([A-Z])', r'_\1', name)
name = re.sub('([a-z0-9])([A-Z])', r'\1_\2', name)
return name.lower()
class SpconvOps(pccm.Class): class SpconvOps(pccm.Class):
def __init__(self): def __init__(self):
...@@ -102,11 +109,19 @@ class SpconvOps(pccm.Class): ...@@ -102,11 +109,19 @@ class SpconvOps(pccm.Class):
self.add_impl_only_param_class(cuda_funcs, f"ops{ndim}d", self.add_impl_only_param_class(cuda_funcs, f"ops{ndim}d",
indices, indices,
f"SpconvIndices{ndim}D") f"SpconvIndices{ndim}D")
defines: List[str] = []
# static constexpr in c++ < 17 may cause
# undefined symbol. use macro instead.
for name in dir(AllocKeys): for name in dir(AllocKeys):
if not name.startswith("__"): if not name.startswith("__"):
v = getattr(AllocKeys, name) v = getattr(AllocKeys, name)
self.add_static_const("k" + name, "auto", f"tv::make_const_string({pccm.literal(v)})") defines.append(f"#define SPCONV_ALLOC_{to_snake_case(name).upper()} {pccm.literal(v)}")
define_str = "\n".join(defines)
self.add_global_code(define_str)
# for name in dir(AllocKeys):
# if not name.startswith("__"):
# v = getattr(AllocKeys, name)
# self.add_static_const("k" + name, "auto", f"tv::make_const_string({pccm.literal(v)})")
@pccm.pybind.mark @pccm.pybind.mark
@pccm.static_function @pccm.static_function
...@@ -613,6 +628,40 @@ class SpconvOps(pccm.Class): ...@@ -613,6 +628,40 @@ class SpconvOps(pccm.Class):
""") """)
return code return code
@pccm.pybind.mark
@pccm.cuda.static_function
def avgpool_implicit_gemm_forward(self):
code = pccm.FunctionCode()
if CUMM_CPU_ONLY_BUILD:
return code.make_invalid()
code.arg("out", "tv::Tensor")
code.arg("inp", "tv::Tensor")
code.arg("inds", "tv::Tensor")
code.arg("count_out", "tv::Tensor")
code.arg("stream", "std::uintptr_t", "0", pyanno="int")
code.add_dependency(IndiceMaxPool)
code.raw(f"""
return IndiceMaxPool::forward_avgpool_implicit_gemm(out, inp, inds, count_out, stream);
""")
return code
@pccm.pybind.mark
@pccm.cuda.static_function
def avgpool_implicit_gemm_backward(self):
code = pccm.FunctionCode()
if CUMM_CPU_ONLY_BUILD:
return code.make_invalid()
code.arg("dout", "tv::Tensor")
code.arg("dinp", "tv::Tensor")
code.arg("inds", "tv::Tensor")
code.arg("count_out", "tv::Tensor")
code.arg("stream", "std::uintptr_t", "0", pyanno="int")
code.add_dependency(IndiceMaxPool)
code.raw(f"""
return IndiceMaxPool::backward_avgpool_implicit_gemm(dout, dinp, inds, count_out, stream);
""")
return code
@pccm.pybind.mark @pccm.pybind.mark
@pccm.static_function @pccm.static_function
def maxpool_forward_cpu(self): def maxpool_forward_cpu(self):
...@@ -1035,6 +1084,97 @@ class SpconvOps(pccm.Class): ...@@ -1035,6 +1084,97 @@ class SpconvOps(pccm.Class):
""") """)
return code.ret("tv::Tensor") return code.ret("tv::Tensor")
# cpu only build can't use pccm.cuda
__CUDA_DECORATOR = pccm.static_function
if not CUMM_CPU_ONLY_BUILD:
__CUDA_DECORATOR = pccm.cuda.static_function
@pccm.pybind.mark
@__CUDA_DECORATOR
def maximum_value_int(self):
code = pccm.FunctionCode()
if not CUMM_CPU_ONLY_BUILD:
code.add_param_class("cudakers", self.cuda_common_kernel)
code.arg("data", "tv::Tensor")
code.arg("value", "int")
code.arg("stream_int", "std::uintptr_t")
code.raw(f"""
auto size = data.size();
using ints_t = std::tuple<int32_t, int16_t, int8_t, int64_t, uint32_t, uint64_t, uint16_t, uint8_t>;
""")
with code.block("", start="tv::Dispatch<ints_t>()(data.dtype(), [&](auto I){", end="});"):
code.raw(f"""
using T = TV_DECLTYPE(I);
auto ptr = data.data_ptr<T>();
""")
with code.if_("data.is_cpu()"):
code.raw(f"""
for (int i = 0; i < size; ++i){{
ptr[i] = std::max(ptr[i], T(value));
}}
""")
with code.else_():
if not CUMM_CPU_ONLY_BUILD:
code.raw(f"""
tv::cuda::Launch lanucher(size, reinterpret_cast<cudaStream_t>(stream_int));
lanucher(cudakers::maximum_value_kernel<T>, ptr, value, size);
""")
else:
code.raw(f"""
TV_THROW_RT_ERR("only support cpu.");
""")
return code
@pccm.pybind.mark
@pccm.cuda.static_function
def sort_1d_by_key(self):
code = pccm.FunctionCode()
if CUMM_CPU_ONLY_BUILD:
return code.make_invalid()
code.arg("data", "tv::Tensor")
code.arg("indices",
"tv::Tensor",
"tv::Tensor()",
pyanno="cumm.tensorview.Tensor = Tensor()")
code.arg("stream", "std::uintptr_t", "0", pyanno="int")
code.code_after_include = f"""
template <typename T> struct SmallOrEqualTo {{
TV_HOST_DEVICE_INLINE T operator()(const T &x, const T &y) const {{
return x < y;
}}
}};
template <typename T> __global__ void mask_input(T* inp, T mask, int size){{
for (int i : tv::KernelLoopX<int>(size)){{
inp[i] &= mask;
}}
}}
"""
code.add_dependency(CustomThrustLib, TensorViewKernel)
code.add_param_class("cudakers", self.cuda_common_kernel)
code.raw(f"""
cudaStream_t stream_cu = reinterpret_cast<cudaStream_t>(stream);
if (indices.empty()){{
indices = tv::empty({{data.dim(0)}}, tv::int32, 0);
}}
tv::cuda::Launch launcher(data.dim(0), stream_cu);
launcher(cudakers::arange_kernel<int32_t>, indices.data_ptr<int32_t>(), indices.dim(0));
auto timer = tv::CUDATimer();
tv::dispatch<int32_t, uint32_t, int64_t, uint64_t>(data.dtype(), [&](auto I){{
using T = TV_DECLTYPE(I);
thrust::device_ptr<T> ptr_tr(data.data_ptr<T>());
thrust::device_ptr<int32_t> ptr_k(indices.data_ptr<int32_t>());
auto thrust_ctx = thrust::cuda::par.on(stream_cu);
thrust::stable_sort_by_key(thrust_ctx, ptr_tr, ptr_tr + data.dim(0), ptr_k, SmallOrEqualTo<uint32_t>());
}});
// tv::ssprint("SORT BY KEY TIME", data.dim(0), timer.report() / 1000.0);
return indices;
""")
return code.ret("tv::Tensor")
@pccm.pybind.mark @pccm.pybind.mark
@pccm.static_function @pccm.static_function
def calc_point2voxel_meta_data(self): def calc_point2voxel_meta_data(self):
...@@ -1249,16 +1389,55 @@ class SpconvOps(pccm.Class): ...@@ -1249,16 +1389,55 @@ class SpconvOps(pccm.Class):
code.arg("subm, use_int64_hash_k", "bool") code.arg("subm, use_int64_hash_k", "bool")
code.raw(f""" code.raw(f"""
if (subm){{ if (subm){{
return 2 * num_act_in * (use_int64_hash_k ? 2 : 3) * sizeof(int); return 2 * num_act_out_bound * (use_int64_hash_k ? 3 : 2) * sizeof(int);
}}else{{ }}else{{
size_t pair_single_size = kv * num_act_in; size_t pair_single_size = kv * num_act_in; // 40000
size_t ind_uniq_and_bkp_size = (pair_single_size + 1) * 2 * (use_int64_hash_k ? sizeof(int64_t) : sizeof(int32_t)); size_t ind_uniq_and_bkp_size = (pair_single_size + 1) * 2 * (use_int64_hash_k ? sizeof(int64_t) : sizeof(int32_t));
size_t hash_size = 2 * num_act_out_bound * (use_int64_hash_k ? 2 : 3) * sizeof(int); size_t hash_size = 2 * num_act_out_bound * (use_int64_hash_k ? 3 : 2) * sizeof(int);
return ind_uniq_and_bkp_size + hash_size; return ind_uniq_and_bkp_size + hash_size;
}} }}
""") """)
return code.ret("std::size_t") return code.ret("std::size_t")
@pccm.pybind.mark
@pccm.static_function
def get_indice_gen_tensors_from_workspace(self):
code = pccm.code()
code.arg("workspace", "uint8_t*")
code.arg("kv", "size_t")
code.arg("num_act_in", "size_t")
code.arg("num_act_out_bound", "size_t")
code.arg("subm, use_int64_hash_k", "bool")
code.raw(f"""
std::unordered_map<std::string, tv::Tensor> res;
auto ws_prev = workspace;
auto expected_size = get_indice_gen_workspace_size(kv, num_act_in, num_act_out_bound, subm, use_int64_hash_k);
if (use_int64_hash_k){{
auto ten = tv::from_blob(workspace, {{int64_t(num_act_out_bound) * 2}}, tv::int64, 0);
res.insert({{{pccm.literal(AllocKeys.HashKOrKV)}, ten}});
workspace += ten.nbytes();
auto ten2 = tv::from_blob(workspace, {{int64_t(num_act_out_bound) * 2}}, tv::int32, 0);
res.insert({{{pccm.literal(AllocKeys.HashV)}, ten2}});
workspace += ten2.nbytes();
}}else{{
auto ten = tv::from_blob(workspace, {{2, int64_t(num_act_out_bound) * 2}}, tv::int32, 0);
res.insert({{{pccm.literal(AllocKeys.HashKOrKV)}, ten}});
workspace += ten.nbytes();
}}
if (!subm){{
size_t pair_single_size = kv * int64_t(num_act_in);
auto ten = tv::from_blob(workspace, {{pair_single_size + 1}}, use_int64_hash_k ? tv::int64 : tv::int32, 0);
res.insert({{{pccm.literal(AllocKeys.IndicePairsUniq)}, ten}});
workspace += ten.nbytes();
auto ten2 = tv::from_blob(workspace, {{pair_single_size + 1}}, use_int64_hash_k ? tv::int64 : tv::int32, 0);
res.insert({{{pccm.literal(AllocKeys.IndicePairsUniqBackup)}, ten2}});
workspace += ten2.nbytes();
}}
TV_ASSERT_RT_ERR(workspace - ws_prev == expected_size, "this shouldn't happen");
return res;
""")
return code.ret("std::unordered_map<std::string, tv::Tensor>")
@pccm.pybind.mark @pccm.pybind.mark
@pccm.static_function @pccm.static_function
def get_indice_pairs_implicit_gemm(self): def get_indice_pairs_implicit_gemm(self):
...@@ -1282,15 +1461,9 @@ class SpconvOps(pccm.Class): ...@@ -1282,15 +1461,9 @@ class SpconvOps(pccm.Class):
code.raw(f""" code.raw(f"""
auto tvctx = tv::Context(); auto tvctx = tv::Context();
tvctx.set_cuda_stream(reinterpret_cast<cudaStream_t>(stream_int)); tvctx.set_cuda_stream(reinterpret_cast<cudaStream_t>(stream_int));
auto conv_algo = static_cast<tv::gemm::SparseConvAlgo>(algo); auto conv_algo = static_cast<tv::gemm::SparseConvAlgo>(algo);
int kv = std::accumulate(ksize.begin(), ksize.end(), 1, std::multiplies<int>()); int kv = std::accumulate(ksize.begin(), ksize.end(), 1, std::multiplies<int>());
TV_ASSERT_RT_ERR(kv <= 32, "currently only support ksize < 32"); TV_ASSERT_RT_ERR(kv <= 32, "currently only support ksize < 32");
std::vector<int64_t> input_dims_i64(input_dims.begin(), input_dims.end());
int64_t spatial_volume = std::accumulate(input_dims_i64.begin(),
input_dims_i64.end(), int64_t(1), std::multiplies<int64_t>());
bool use_int64_hash_k = spatial_volume >= int64_t(std::numeric_limits<int>::max());
tv::DType indice_uniq_dtype = use_int64_hash_k ? tv::int64 : tv::int32;
std::vector<int> out_shape; std::vector<int> out_shape;
if (!subm){{ if (!subm){{
if (transposed){{ if (transposed){{
...@@ -1306,20 +1479,42 @@ class SpconvOps(pccm.Class): ...@@ -1306,20 +1479,42 @@ class SpconvOps(pccm.Class):
TV_THROW_RT_ERR("your out spatial shape", out_shape, "ratch zero!, input shape:", input_dims); TV_THROW_RT_ERR("your out spatial shape", out_shape, "ratch zero!, input shape:", input_dims);
}} }}
}} }}
std::vector<int64_t> output_dims_i64(out_shape.begin(), out_shape.end());
int64_t out_spatial_volume = std::accumulate(output_dims_i64.begin(),
output_dims_i64.end(), int64_t(1), std::multiplies<int64_t>());
bool use_int64_hash_k = out_spatial_volume >= int64_t(std::numeric_limits<int>::max());
tv::DType indice_uniq_dtype = use_int64_hash_k ? tv::int64 : tv::int32;
TV_ASSERT_RT_ERR(conv_algo == tv::gemm::SparseConvAlgo::kMaskImplicitGemm || TV_ASSERT_RT_ERR(conv_algo == tv::gemm::SparseConvAlgo::kMaskImplicitGemm ||
conv_algo == tv::gemm::SparseConvAlgo::kMaskSplitImplicitGemm, "only support implicit gemm"); conv_algo == tv::gemm::SparseConvAlgo::kMaskSplitImplicitGemm, "only support implicit gemm");
bool is_mask_split = conv_algo == tv::gemm::SparseConvAlgo::kMaskSplitImplicitGemm; bool is_mask_split = conv_algo == tv::gemm::SparseConvAlgo::kMaskSplitImplicitGemm;
int mask_split_count = is_mask_split ? 2 : 1; int mask_split_count = is_mask_split ? 2 : 1;
tv::Tensor pair; tv::Tensor pair;
int64_t num_act_in = indices.dim(0);
if (subm){{ if (subm){{
pair = allocator.full_int({pccm.literal(AllocKeys.Pair)}, if (is_train){{
{{2, kv, indices.dim(0)}}, -1, indices.dtype(), indices.device()); // query pair for fwd and bwd
pair = allocator.full_int({pccm.literal(AllocKeys.PairFwd)},
{{2, kv, num_act_in}}, -1, indices.dtype(), indices.device(), stream_int);
}}else{{
// query pair fwd only
pair = allocator.full_int({pccm.literal(AllocKeys.PairFwd)},
{{1, kv, num_act_in}}, -1, indices.dtype(), indices.device(), stream_int);
}}
}}else{{ }}else{{
pair = allocator.full_int({pccm.literal(AllocKeys.Pair)}, if (is_train){{
{{kv, indices.dim(0)}}, -1, indices.dtype(), indices.device()); // query pair bwd
pair = allocator.full_int({pccm.literal(AllocKeys.PairBwd)},
{{kv, num_act_in}}, -1, indices.dtype(), indices.device(), stream_int);
}}else{{
// don't need pair bwd, empty
pair = tv::Tensor();
}}
}} }}
auto indice_num_per_loc = allocator.zeros({pccm.literal(AllocKeys.IndiceNumPerLoc)}, auto indice_num_per_loc = allocator.zeros({pccm.literal(AllocKeys.IndiceNumPerLoc)},
{{kv}}, indices.dtype(), indices.device()); {{kv}}, indices.dtype(), indices.device(), stream_int);
tv::Tensor mask_tensor = tv::zeros({{mask_split_count}}, tv::uint32, -1); tv::Tensor mask_tensor = tv::zeros({{mask_split_count}}, tv::uint32, -1);
auto mask_tensor_ptr = mask_tensor.data_ptr<uint32_t>(); auto mask_tensor_ptr = mask_tensor.data_ptr<uint32_t>();
...@@ -1360,20 +1555,23 @@ class SpconvOps(pccm.Class): ...@@ -1360,20 +1555,23 @@ class SpconvOps(pccm.Class):
}} }}
auto pair_mask = allocator.empty({pccm.literal(AllocKeys.PairMask)}, auto pair_mask = allocator.empty({pccm.literal(AllocKeys.PairMask)},
{{mask_split_count, indices.dim(0)}}, tv::uint32, 0); {{mask_split_count, num_act_in}}, tv::uint32, 0, stream_int);
generate_subm_conv_inds(indices, hash_k, hash_v, pair, out_inds, indice_num_per_loc, generate_subm_conv_inds(indices, hash_k, hash_v, pair, out_inds, indice_num_per_loc,
batch_size, input_dims, ksize, dilation, pair_mask, false, stream_int); batch_size, input_dims, ksize, dilation, pair_mask, is_train, stream_int);
auto mask_argsort = allocator.empty({pccm.literal(AllocKeys.MaskArgSort)}, auto mask_argsort = allocator.empty({pccm.literal(AllocKeys.MaskArgSort)},
{{mask_split_count, out_inds.dim(0)}}, tv::int32, 0); {{mask_split_count, num_act_in}}, tv::int32, 0, stream_int);
for (int j = 0; j < mask_split_count; ++j){{ for (int j = 0; j < mask_split_count; ++j){{
sort_1d_by_key_allocator_v2(pair_mask[j], thrustalloc, mask_argsort[j], stream_int); sort_1d_by_key_allocator_v2(pair_mask[j], thrustalloc, mask_argsort[j], stream_int);
}} }}
}}else{{ }}else{{
auto pair_bwd = pair; auto pair_bwd = pair;
auto indice_pairs_uniq_guard = allocator.empty_guard({{int64_t(pair.numel() + 1)}}, auto pair_size = kv * num_act_in;
auto indice_pairs_uniq_guard = allocator.empty_guard({{int64_t(pair_size + 1)}},
indice_uniq_dtype, 0, {pccm.literal(AllocKeys.IndicePairsUniq)}); indice_uniq_dtype, 0, {pccm.literal(AllocKeys.IndicePairsUniq)});
auto indice_pairs_uniq_bkp_guard = allocator.empty_guard({{int64_t(pair.numel() + 1)}}, auto indice_pairs_uniq_bkp_guard = allocator.empty_guard({{int64_t(pair_size + 1)}},
indice_uniq_dtype, 0, {pccm.literal(AllocKeys.IndicePairsUniqBackup)}); indice_uniq_dtype, 0, {pccm.literal(AllocKeys.IndicePairsUniqBackup)});
auto indice_pairs_uniq = indice_pairs_uniq_guard->tensor; auto indice_pairs_uniq = indice_pairs_uniq_guard->tensor;
...@@ -1386,21 +1584,26 @@ class SpconvOps(pccm.Class): ...@@ -1386,21 +1584,26 @@ class SpconvOps(pccm.Class):
if (num_out_act_bound > 0 && num_act_out > num_out_act_bound){{ if (num_out_act_bound > 0 && num_act_out > num_out_act_bound){{
num_act_out = num_out_act_bound; num_act_out = num_out_act_bound;
}} }}
indice_pairs_uniq = indice_pairs_uniq.slice_first_axis(0, num_act_out); indice_pairs_uniq = indice_pairs_uniq.slice_first_axis(0, num_act_out);
// for fixed size allocator, all memory alloc size must be fixed.
out_inds = allocator.empty({pccm.literal(AllocKeys.OutIndices)}, out_inds = allocator.empty({pccm.literal(AllocKeys.OutIndices)},
{{num_act_out, indices.dim(1)}}, indices.dtype(), 0); {{num_act_out, indices.dim(1)}}, indices.dtype(), 0, stream_int);
auto pair_fwd = allocator.full_int({pccm.literal(AllocKeys.PairFwd)}, auto pair_fwd = allocator.full_int({pccm.literal(AllocKeys.PairFwd)},
{{kv, num_act_out}}, -1, indices.dtype(), indices.device()); {{kv, num_act_out}}, -1, indices.dtype(), indices.device(), stream_int);
auto pair_mask_fwd = allocator.zeros({pccm.literal(AllocKeys.PairMask)}, auto pair_mask_fwd = allocator.zeros({pccm.literal(AllocKeys.PairMask)},
{{mask_split_count, num_act_out}}, tv::uint32, 0); {{mask_split_count, num_act_out}}, tv::uint32, 0, stream_int);
auto pair_mask_bwd = tv::Tensor(); auto pair_mask_bwd = tv::Tensor();
if (is_train){{ if (is_train){{
pair_mask_bwd = allocator.zeros({pccm.literal(AllocKeys.PairMaskBwd)}, pair_mask_bwd = allocator.zeros({pccm.literal(AllocKeys.PairMaskBwd)},
{{mask_split_count, indices.dim(0)}}, tv::uint32, 0); {{mask_split_count, indices.dim(0)}}, tv::uint32, 0, stream_int);
}} }}
ExternalAllocator::guard_t hash_k_guard, hash_v_gurad, hash_kv_gurad; ExternalAllocator::guard_t hash_k_guard, hash_v_gurad, hash_kv_gurad;
tv::Tensor hash_k, hash_v; tv::Tensor hash_k, hash_v;
if (use_int64_hash_k){{ if (use_int64_hash_k){{
// temp memory don't need to be fixed, static alloc will check
// that tensor is large enough.
hash_k_guard = allocator.empty_guard({{num_act_out * 2}}, hash_k_guard = allocator.empty_guard({{num_act_out * 2}},
tv::int64, 0, {pccm.literal(AllocKeys.HashKOrKV)}); tv::int64, 0, {pccm.literal(AllocKeys.HashKOrKV)});
hash_v_gurad = allocator.empty_guard({{num_act_out * 2}}, hash_v_gurad = allocator.empty_guard({{num_act_out * 2}},
...@@ -1419,12 +1622,13 @@ class SpconvOps(pccm.Class): ...@@ -1419,12 +1622,13 @@ class SpconvOps(pccm.Class):
batch_size, out_shape, input_dims, ksize, stride, padding, dilation, batch_size, out_shape, input_dims, ksize, stride, padding, dilation,
transposed, stream_int); transposed, stream_int);
auto mask_argsort_fwd = allocator.empty({pccm.literal(AllocKeys.MaskArgSort)}, auto mask_argsort_fwd = allocator.empty({pccm.literal(AllocKeys.MaskArgSort)},
{{mask_split_count, out_inds.dim(0)}}, tv::int32, 0); {{mask_split_count, num_act_out}}, tv::int32, 0, stream_int);
tv::Tensor mask_argsort_bwd = tv::Tensor(); tv::Tensor mask_argsort_bwd = tv::Tensor();
if (is_train){{ if (is_train){{
mask_argsort_bwd = allocator.zeros({pccm.literal(AllocKeys.MaskArgSortBwd)}, mask_argsort_bwd = allocator.zeros({pccm.literal(AllocKeys.MaskArgSortBwd)},
{{mask_split_count, indices.dim(0)}}, tv::int32, 0); {{mask_split_count, num_act_in}}, tv::int32, 0, stream_int);
}} }}
if (is_mask_split){{ if (is_mask_split){{
for (int j = 0; j < mask_split_count; ++j){{ for (int j = 0; j < mask_split_count; ++j){{
auto mask_tensor_sub = mask_tensor.slice_first_axis(j, j + 1); auto mask_tensor_sub = mask_tensor.slice_first_axis(j, j + 1);
...@@ -1449,6 +1653,7 @@ class SpconvOps(pccm.Class): ...@@ -1449,6 +1653,7 @@ class SpconvOps(pccm.Class):
mask_argsort_bwd[0], stream_int); mask_argsort_bwd[0], stream_int);
}} }}
}} }}
}} }}
return std::make_tuple(mask_tensor, num_act_out); return std::make_tuple(mask_tensor, num_act_out);
""") """)
...@@ -1467,18 +1672,18 @@ class SpconvOps(pccm.Class): ...@@ -1467,18 +1672,18 @@ class SpconvOps(pccm.Class):
code.arg("subm, transposed", f"bool") code.arg("subm, transposed", f"bool")
code.arg("stream_int", f"std::uintptr_t", "0") code.arg("stream_int", f"std::uintptr_t", "0")
code.arg("num_out_act_bound", f"int", "-1") code.arg("num_out_act_bound", f"int", "-1")
code.arg("num_input_act_bound", f"int", "-1")
code.raw(f""" code.raw(f"""
int kv = std::accumulate(ksize.begin(), ksize.end(), 1, std::multiplies<int>()); int kv = std::accumulate(ksize.begin(), ksize.end(), 1, std::multiplies<int>());
auto conv_algo = static_cast<tv::gemm::SparseConvAlgo>(algo); auto conv_algo = static_cast<tv::gemm::SparseConvAlgo>(algo);
TV_ASSERT_RT_ERR(conv_algo == tv::gemm::SparseConvAlgo::kNative, "only support kNative"); TV_ASSERT_RT_ERR(conv_algo == tv::gemm::SparseConvAlgo::kNative, "only support kNative");
if (num_out_act_bound > 0){{
TV_ASSERT_RT_ERR(num_input_act_bound > 0 && indices.dim(0) <= num_input_act_bound,
"out bound and input bound must both larger than zero");
}}
std::vector<int64_t> input_dims_i64(input_dims.begin(), input_dims.end());
int64_t spatial_volume = std::accumulate(input_dims_i64.begin(),
input_dims_i64.end(), int64_t(1), std::multiplies<int64_t>());
bool use_int64_hash_k = spatial_volume >= int64_t(std::numeric_limits<int>::max());
tv::DType indice_uniq_dtype = use_int64_hash_k ? tv::int64 : tv::int32;
std::vector<int> out_shape; std::vector<int> out_shape;
if (!subm){{ if (!subm){{
if (transposed){{ if (transposed){{
...@@ -1494,12 +1699,24 @@ class SpconvOps(pccm.Class): ...@@ -1494,12 +1699,24 @@ class SpconvOps(pccm.Class):
TV_THROW_RT_ERR("your out spatial shape", out_shape, "ratch zero!, input shape:", input_dims); TV_THROW_RT_ERR("your out spatial shape", out_shape, "ratch zero!, input shape:", input_dims);
}} }}
}} }}
std::vector<int64_t> output_dims_i64(out_shape.begin(), out_shape.end());
int64_t out_spatial_volume = std::accumulate(output_dims_i64.begin(),
output_dims_i64.end(), int64_t(1), std::multiplies<int64_t>());
bool use_int64_hash_k = out_spatial_volume >= int64_t(std::numeric_limits<int>::max());
tv::DType indice_uniq_dtype = use_int64_hash_k ? tv::int64 : tv::int32;
tv::Tensor pair; tv::Tensor pair;
pair = allocator.full_int({pccm.literal(AllocKeys.Pair)}, int64_t num_act_in_bounded = indices.dim(0);
{{2, kv, indices.dim(0)}}, -1, indices.dtype(), indices.device());
if (num_out_act_bound > 0){{
// we need stable pair stride for bounded output
num_act_in_bounded = num_input_act_bound;
}}
pair = allocator.full_int({pccm.literal(AllocKeys.PairFwd)},
{{2, kv, num_act_in_bounded}}, -1, indices.dtype(), indices.device(), stream_int);
auto indice_num_per_loc = allocator.zeros({pccm.literal(AllocKeys.IndiceNumPerLoc)}, auto indice_num_per_loc = allocator.zeros({pccm.literal(AllocKeys.IndiceNumPerLoc)},
{{kv}}, indices.dtype(), indices.device()); {{kv}}, indices.dtype(), indices.device(), stream_int);
tv::Tensor out_inds; tv::Tensor out_inds;
int num_act_out = -1; int num_act_out = -1;
""") """)
...@@ -1576,13 +1793,18 @@ class SpconvOps(pccm.Class): ...@@ -1576,13 +1793,18 @@ class SpconvOps(pccm.Class):
// TODO pytorch unique may be faster? // TODO pytorch unique may be faster?
num_act_out = apply_thrust_unique_to_indice_pairs_uniq(indice_pairs_uniq, thrustalloc, stream_int); num_act_out = apply_thrust_unique_to_indice_pairs_uniq(indice_pairs_uniq, thrustalloc, stream_int);
bool use_bound_algo = false; bool use_bound_algo = false;
int64_t num_out_bounded = num_act_out;
if (num_out_act_bound > 0 && num_act_out > num_out_act_bound){{ if (num_out_act_bound > 0 && num_act_out > num_out_act_bound){{
num_act_out = num_out_act_bound; num_act_out = num_out_act_bound;
use_bound_algo = true; use_bound_algo = true;
}} }}
if (num_out_act_bound > 0 ){{
num_out_bounded = num_out_act_bound;
}}
indice_pairs_uniq = indice_pairs_uniq.slice_first_axis(0, num_act_out); indice_pairs_uniq = indice_pairs_uniq.slice_first_axis(0, num_act_out);
out_inds = allocator.empty({pccm.literal(AllocKeys.OutIndices)}, out_inds = allocator.empty({pccm.literal(AllocKeys.OutIndices)},
{{num_act_out, indices.dim(1)}}, indices.dtype(), 0); {{num_out_bounded, indices.dim(1)}}, indices.dtype(), 0, stream_int);
ExternalAllocator::guard_t hash_k_guard, hash_v_gurad, hash_kv_gurad; ExternalAllocator::guard_t hash_k_guard, hash_v_gurad, hash_kv_gurad;
tv::Tensor hash_k, hash_v; tv::Tensor hash_k, hash_v;
if (use_int64_hash_k){{ if (use_int64_hash_k){{
......
...@@ -2,7 +2,8 @@ import pccm ...@@ -2,7 +2,8 @@ import pccm
from cumm.common import TensorView, TensorViewCPU, TensorViewKernel, ThrustLib from cumm.common import TensorView, TensorViewCPU, TensorViewKernel, ThrustLib
from spconv.constants import AllocKeys from spconv.constants import AllocKeys
from cumm.constants import CUMM_CPU_ONLY_BUILD
from .indices import CudaCommonKernel
class ExternalAllocatorGuard(pccm.Class): class ExternalAllocatorGuard(pccm.Class):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
...@@ -53,8 +54,8 @@ class ExternalAllocator(pccm.Class): ...@@ -53,8 +54,8 @@ class ExternalAllocator(pccm.Class):
code.arg("shape", "std::vector<int64_t>") code.arg("shape", "std::vector<int64_t>")
code.arg("dtype", "int") code.arg("dtype", "int")
code.arg("device", "int") code.arg("device", "int")
code.arg("is_temp_memory", "bool", "false")
code.arg("stream", "std::uintptr_t", "0") code.arg("stream", "std::uintptr_t", "0")
code.arg("is_temp_memory", "bool", "false")
return code.ret("tv::Tensor") return code.ret("tv::Tensor")
...@@ -66,8 +67,8 @@ class ExternalAllocator(pccm.Class): ...@@ -66,8 +67,8 @@ class ExternalAllocator(pccm.Class):
code.arg("shape", "std::vector<int64_t>") code.arg("shape", "std::vector<int64_t>")
code.arg("dtype", "int") code.arg("dtype", "int")
code.arg("device", "int") code.arg("device", "int")
code.arg("is_temp_memory", "bool", "false")
code.arg("stream", "std::uintptr_t", "0") code.arg("stream", "std::uintptr_t", "0")
code.arg("is_temp_memory", "bool", "false")
return code.ret("tv::Tensor") return code.ret("tv::Tensor")
...@@ -80,8 +81,8 @@ class ExternalAllocator(pccm.Class): ...@@ -80,8 +81,8 @@ class ExternalAllocator(pccm.Class):
code.arg("value", "int") code.arg("value", "int")
code.arg("dtype", "int") code.arg("dtype", "int")
code.arg("device", "int") code.arg("device", "int")
code.arg("is_temp_memory", "bool", "false")
code.arg("stream", "std::uintptr_t", "0") code.arg("stream", "std::uintptr_t", "0")
code.arg("is_temp_memory", "bool", "false")
return code.ret("tv::Tensor") return code.ret("tv::Tensor")
...@@ -94,8 +95,9 @@ class ExternalAllocator(pccm.Class): ...@@ -94,8 +95,9 @@ class ExternalAllocator(pccm.Class):
code.arg("value", "float") code.arg("value", "float")
code.arg("dtype", "int") code.arg("dtype", "int")
code.arg("device", "int") code.arg("device", "int")
code.arg("is_temp_memory", "bool", "false")
code.arg("stream", "std::uintptr_t", "0") code.arg("stream", "std::uintptr_t", "0")
code.arg("is_temp_memory", "bool", "false")
return code.ret("tv::Tensor") return code.ret("tv::Tensor")
@pccm.pybind.mark(virtual=True) @pccm.pybind.mark(virtual=True)
...@@ -129,7 +131,7 @@ class ExternalAllocator(pccm.Class): ...@@ -129,7 +131,7 @@ class ExternalAllocator(pccm.Class):
code.arg("stream", "std::uintptr_t", "0") code.arg("stream", "std::uintptr_t", "0")
code.raw(f""" code.raw(f"""
// "" means temp memory // "" means temp memory
auto ten = zeros(name, shape, dtype, device, true, stream); auto ten = zeros(name, shape, dtype, device, stream, true);
return std::make_{self.ptr_type}<ExternalAllocatorGuard>(ten, [this](tv::Tensor ten){{ return std::make_{self.ptr_type}<ExternalAllocatorGuard>(ten, [this](tv::Tensor ten){{
this->free(ten); this->free(ten);
}}); }});
...@@ -145,7 +147,7 @@ class ExternalAllocator(pccm.Class): ...@@ -145,7 +147,7 @@ class ExternalAllocator(pccm.Class):
code.arg("name", "std::string", "\"\"") code.arg("name", "std::string", "\"\"")
code.arg("stream", "std::uintptr_t", "0") code.arg("stream", "std::uintptr_t", "0")
code.raw(f""" code.raw(f"""
auto ten = empty(name, shape, dtype, device, true, stream); auto ten = empty(name, shape, dtype, device, stream, true);
return std::make_{self.ptr_type}<ExternalAllocatorGuard>(ten, [this](tv::Tensor ten){{ return std::make_{self.ptr_type}<ExternalAllocatorGuard>(ten, [this](tv::Tensor ten){{
this->free(ten); this->free(ten);
}}); }});
...@@ -162,7 +164,7 @@ class ExternalAllocator(pccm.Class): ...@@ -162,7 +164,7 @@ class ExternalAllocator(pccm.Class):
code.arg("name", "std::string", "\"\"") code.arg("name", "std::string", "\"\"")
code.arg("stream", "std::uintptr_t", "0") code.arg("stream", "std::uintptr_t", "0")
code.raw(f""" code.raw(f"""
auto ten = full_int(name, shape, value, dtype, device, true, stream); auto ten = full_int(name, shape, value, dtype, device, stream, true);
return std::make_{self.ptr_type}<ExternalAllocatorGuard>(ten, [this](tv::Tensor ten){{ return std::make_{self.ptr_type}<ExternalAllocatorGuard>(ten, [this](tv::Tensor ten){{
this->free(ten); this->free(ten);
}}); }});
...@@ -179,7 +181,7 @@ class ExternalAllocator(pccm.Class): ...@@ -179,7 +181,7 @@ class ExternalAllocator(pccm.Class):
code.arg("name", "std::string", "\"\"") code.arg("name", "std::string", "\"\"")
code.arg("stream", "std::uintptr_t", "0") code.arg("stream", "std::uintptr_t", "0")
code.raw(f""" code.raw(f"""
auto ten = full_float(name, shape, value, dtype, device, true, stream); auto ten = full_float(name, shape, value, dtype, device, stream, true);
return std::make_{self.ptr_type}<ExternalAllocatorGuard>(ten, [this](tv::Tensor t){{ return std::make_{self.ptr_type}<ExternalAllocatorGuard>(ten, [this](tv::Tensor t){{
this->free(t); this->free(t);
}}); }});
...@@ -222,8 +224,10 @@ class ThrustAllocator(pccm.Class): ...@@ -222,8 +224,10 @@ class ThrustAllocator(pccm.Class):
""") """)
return code return code
class StaticAllocator(ExternalAllocator): class StaticAllocator(ExternalAllocator):
"""a simple allocator for tensorrt plugin. """a static allocator for tensorrt plugin.
""" """
def __init__(self): def __init__(self):
super().__init__() super().__init__()
...@@ -232,6 +236,7 @@ class StaticAllocator(ExternalAllocator): ...@@ -232,6 +236,7 @@ class StaticAllocator(ExternalAllocator):
self.add_member("repr_", "std::string") self.add_member("repr_", "std::string")
self.add_member("thrust_tmp_tensor_", "tv::Tensor") self.add_member("thrust_tmp_tensor_", "tv::Tensor")
self.grow = 1.5 self.grow = 1.5
self.cuda_common_kernel = CudaCommonKernel()
@pccm.pybind.mark @pccm.pybind.mark
@pccm.constructor @pccm.constructor
...@@ -242,7 +247,22 @@ class StaticAllocator(ExternalAllocator): ...@@ -242,7 +247,22 @@ class StaticAllocator(ExternalAllocator):
code.raw(f""" code.raw(f"""
std::stringstream ss; std::stringstream ss;
for (auto& p : tensor_dict){{ for (auto& p : tensor_dict){{
tv::ssprint(ss, p.first, p.second.shape(), tv::dtype_str(p.second.dtype()), "\\n"); tv::sstream_print(ss, p.first, p.second.shape(), tv::dtype_str(p.second.dtype()), "\\n");
}}
repr_ = ss.str();
""")
return code
@pccm.pybind.mark
@pccm.member_function
def set_new_tensor_dict(self):
code = pccm.code()
code.arg("tensor_dict", "std::unordered_map<std::string, tv::Tensor>")
code.raw(f"""
tensor_dict_ = tensor_dict;
std::stringstream ss;
for (auto& p : tensor_dict){{
tv::sstream_print(ss, p.first, p.second.shape(), tv::dtype_str(p.second.dtype()), "\\n");
}} }}
repr_ = ss.str(); repr_ = ss.str();
""") """)
...@@ -255,12 +275,21 @@ class StaticAllocator(ExternalAllocator): ...@@ -255,12 +275,21 @@ class StaticAllocator(ExternalAllocator):
code.arg("shape", "std::vector<int64_t>") code.arg("shape", "std::vector<int64_t>")
code.arg("dtype", "int") code.arg("dtype", "int")
code.arg("device", "int") code.arg("device", "int")
code.arg("is_temp_memory", "bool", "false")
code.raw(f""" code.raw(f"""
auto res = get_tensor_by_name(name); auto res = get_tensor_by_name(name);
size_t total = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int64_t>()); size_t total = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int64_t>());
TV_ASSERT_RT_ERR(res.nbytes() >= total * tv::bit_size(tv::DType(dtype)) TV_ASSERT_RT_ERR(res.nbytes() >= total * tv::bit_size(tv::DType(dtype)) / 8
&& res.device() == device, "alloc failed", shape, res.shape()); && res.device() == device, "alloc failed, tensor size too small", shape, res.shape());
return tv::from_blob(res.raw_data(), shape, dtype, device);
// if (is_temp_memory){{
// }}else{{
// // size must exactly match
// TV_ASSERT_RT_ERR(res.nbytes() == total * tv::bit_size(tv::DType(dtype)) / 8
// && res.device() == device, "alloc failed, named memory size must match", shape, res.shape());
// }}
return tv::from_blob(res.raw_data(), shape, tv::DType(dtype), device);
""") """)
return code.ret("tv::Tensor") return code.ret("tv::Tensor")
...@@ -273,16 +302,22 @@ class StaticAllocator(ExternalAllocator): ...@@ -273,16 +302,22 @@ class StaticAllocator(ExternalAllocator):
code.arg("shape", "std::vector<int64_t>") code.arg("shape", "std::vector<int64_t>")
code.arg("dtype", "int") code.arg("dtype", "int")
code.arg("device", "int") code.arg("device", "int")
code.arg("is_temp_memory", "bool", "false")
code.arg("stream", "std::uintptr_t", "0") code.arg("stream", "std::uintptr_t", "0")
code.arg("is_temp_memory", "bool", "false")
code.raw(f""" code.raw(f"""
auto tvctx = tv::Context(); auto tvctx = tv::Context();
tvctx.set_cuda_stream(reinterpret_cast<cudaStream_t>(stream)); """)
auto blob = _get_raw_and_check(name, shape, dtype, device); if not CUMM_CPU_ONLY_BUILD:
code.raw(f"""
tvctx.set_cuda_stream(reinterpret_cast<cudaStream_t>(stream));
""")
code.raw(f"""
auto blob = _get_raw_and_check(name, shape, dtype, device, is_temp_memory);
return blob.zero_(tvctx); return blob.zero_(tvctx);
""") """)
return code.ret("tv::Tensor") return code.ret("tv::Tensor")
@pccm.pybind.mark @pccm.pybind.mark
@pccm.member_function(virtual=True) @pccm.member_function(virtual=True)
def empty(self): def empty(self):
...@@ -291,8 +326,8 @@ class StaticAllocator(ExternalAllocator): ...@@ -291,8 +326,8 @@ class StaticAllocator(ExternalAllocator):
code.arg("shape", "std::vector<int64_t>") code.arg("shape", "std::vector<int64_t>")
code.arg("dtype", "int") code.arg("dtype", "int")
code.arg("device", "int") code.arg("device", "int")
code.arg("is_temp_memory", "bool", "false")
code.arg("stream", "std::uintptr_t", "0") code.arg("stream", "std::uintptr_t", "0")
code.arg("is_temp_memory", "bool", "false")
code.raw(f""" code.raw(f"""
if (name == {pccm.literal(AllocKeys.ThrustTemp)}){{ if (name == {pccm.literal(AllocKeys.ThrustTemp)}){{
// thrust tmp shouldn't inside tensor_dict. use a simple method to allocate // thrust tmp shouldn't inside tensor_dict. use a simple method to allocate
...@@ -300,23 +335,28 @@ class StaticAllocator(ExternalAllocator): ...@@ -300,23 +335,28 @@ class StaticAllocator(ExternalAllocator):
// so we can just use one tensor // so we can just use one tensor
tv::Tensor res = thrust_tmp_tensor_; tv::Tensor res = thrust_tmp_tensor_;
if (res.empty()){{ if (res.empty()){{
res = tv::empty(shape, dtype, device); res = tv::empty(shape, tv::DType(dtype), device);
thrust_tmp_tensor_ = res; thrust_tmp_tensor_ = res;
}} }}
if (shape[0] > thrust_tmp_tensor_.dim(0)){{ if (shape[0] > thrust_tmp_tensor_.dim(0)){{
res = tv::empty({{int64_t(shape[0] * {self.grow})}}, dtype, device); res = tv::empty({{int64_t(shape[0] * {self.grow})}}, tv::DType(dtype), device);
thrust_tmp_tensor_ = res; thrust_tmp_tensor_ = res;
}} }}
return res; return res;
}}else{{ }}else{{
auto blob = _get_raw_and_check(name, shape, dtype, device); auto blob = _get_raw_and_check(name, shape, dtype, device, is_temp_memory);
return blob; return blob;
}} }}
""") """)
return code.ret("tv::Tensor") return code.ret("tv::Tensor")
# cpu only build can't use pccm.cuda
__CUDA_DECORATOR = pccm.member_function
if not CUMM_CPU_ONLY_BUILD:
__CUDA_DECORATOR = pccm.cuda.member_function
@pccm.pybind.mark @pccm.pybind.mark
@pccm.member_function(virtual=True) @__CUDA_DECORATOR
def full_int(self): def full_int(self):
code = pccm.code() code = pccm.code()
code.arg("name", "std::string") code.arg("name", "std::string")
...@@ -324,17 +364,36 @@ class StaticAllocator(ExternalAllocator): ...@@ -324,17 +364,36 @@ class StaticAllocator(ExternalAllocator):
code.arg("value", "int") code.arg("value", "int")
code.arg("dtype", "int") code.arg("dtype", "int")
code.arg("device", "int") code.arg("device", "int")
code.arg("is_temp_memory", "bool", "false")
code.arg("stream", "std::uintptr_t", "0") code.arg("stream", "std::uintptr_t", "0")
code.arg("is_temp_memory", "bool", "false")
code.raw(f""" code.raw(f"""
auto tvctx = tv::Context(); auto tvctx = tv::Context();
auto blob = _get_raw_and_check(name, shape, dtype, device); auto blob = _get_raw_and_check(name, shape, dtype, device, is_temp_memory);
return blob.fill_(tvctx, value);
""")
if not CUMM_CPU_ONLY_BUILD:
code.add_param_class("cudakers", self.cuda_common_kernel)
code.raw(f"""
tvctx.set_cuda_stream(reinterpret_cast<cudaStream_t>(stream));
using ints_t = std::tuple<int32_t, int16_t, int8_t, int64_t, uint32_t, uint64_t, uint16_t, uint8_t>;
tv::Dispatch<ints_t>()(blob.dtype(), [&](auto I){{
using T = TV_DECLTYPE(I);
tv::cuda::Launch lanucher_fill(blob.size(), reinterpret_cast<cudaStream_t>(stream));
lanucher_fill(cudakers::fill_kernel<T>, blob.data_ptr<T>(), value, blob.size());
}});
""")
else:
code.raw(f"""
blob.fill_(value);
""")
code.raw(f"""
return blob;
""") """)
return code.ret("tv::Tensor") return code.ret("tv::Tensor")
@pccm.pybind.mark @pccm.pybind.mark
@pccm.member_function(virtual=True) @__CUDA_DECORATOR
def full_float(self): def full_float(self):
code = pccm.code() code = pccm.code()
code.arg("name", "std::string") code.arg("name", "std::string")
...@@ -342,11 +401,29 @@ class StaticAllocator(ExternalAllocator): ...@@ -342,11 +401,29 @@ class StaticAllocator(ExternalAllocator):
code.arg("value", "float") code.arg("value", "float")
code.arg("dtype", "int") code.arg("dtype", "int")
code.arg("device", "int") code.arg("device", "int")
code.arg("is_temp_memory", "bool", "false")
code.arg("stream", "std::uintptr_t", "0") code.arg("stream", "std::uintptr_t", "0")
code.arg("is_temp_memory", "bool", "false")
code.raw(f"""
auto tvctx = tv::Context();
auto blob = _get_raw_and_check(name, shape, dtype, device, is_temp_memory);
""")
if not CUMM_CPU_ONLY_BUILD:
code.add_param_class("cudakers", self.cuda_common_kernel)
code.raw(f"""
tvctx.set_cuda_stream(reinterpret_cast<cudaStream_t>(stream));
using dtypes_t = std::tuple<float, double>;
tv::Dispatch<dtypes_t>()(blob.dtype(), [&](auto I){{
using T = TV_DECLTYPE(I);
tv::cuda::Launch lanucher_fill(blob.size(), reinterpret_cast<cudaStream_t>(stream));
lanucher_fill(cudakers::fill_kernel<T>, blob.data_ptr<T>(), value, blob.size());
}});
""")
else:
code.raw(f"""
blob.fill_(value);
""")
code.raw(f""" code.raw(f"""
auto blob = _get_raw_and_check(name, shape, dtype, device); return blob;
return blob.fill_(tvctx, value);
""") """)
return code.ret("tv::Tensor") return code.ret("tv::Tensor")
...@@ -364,6 +441,7 @@ class StaticAllocator(ExternalAllocator): ...@@ -364,6 +441,7 @@ class StaticAllocator(ExternalAllocator):
@pccm.pybind.mark @pccm.pybind.mark
@pccm.member_function(virtual=True) @pccm.member_function(virtual=True)
def free(self): def free(self):
# nothing here because this is a static allocator
code = pccm.code() code = pccm.code()
code.arg("ten", "tv::Tensor") code.arg("ten", "tv::Tensor")
return code return code
......
...@@ -78,11 +78,9 @@ class ExternalSpconvMatmul(pccm.Class): ...@@ -78,11 +78,9 @@ class ExternalSpconvMatmul(pccm.Class):
return code return code
class SimpleExternalSpconvMatmul(ExternalSpconvMatmul): class SimpleExternalSpconvMatmul(ExternalSpconvMatmul):
"""a helper class to warp matmul operations """implement gemm in cuda via cublasLt. (only support forward)
because we don't want to implement matmul should be used with tensorrt plugin.
(link to cublas/mkl/pytorch) in python package.
""" """
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.add_dependency(TensorView, ExternalAllocator) self.add_dependency(TensorView, ExternalAllocator)
...@@ -311,7 +309,7 @@ class SimpleExternalSpconvMatmul(ExternalSpconvMatmul): ...@@ -311,7 +309,7 @@ class SimpleExternalSpconvMatmul(ExternalSpconvMatmul):
TV_THROW_RT_ERR("unsupported"); TV_THROW_RT_ERR("unsupported");
}} }}
check_cublas_status(cublasLtMatmul( check_cublas_status(cublasLtMatmul(
handle, operationDesc, alpha_storage, a.raw_data(), Adesc, b.raw_data(), handle, operationDesc, alpha_storage, a.const_raw_data(), Adesc, b.const_raw_data(),
Bdesc, beta_storage, c.raw_data(), Cdesc, c.raw_data(), Cdesc, Bdesc, beta_storage, c.raw_data(), Cdesc, c.raw_data(), Cdesc,
&heuristicResult.algo, nullptr, 0, stream)); &heuristicResult.algo, nullptr, 0, stream));
if (preference) if (preference)
...@@ -1417,11 +1415,12 @@ class ConvGemmOps(pccm.ParameterizedClass): ...@@ -1417,11 +1415,12 @@ class ConvGemmOps(pccm.ParameterizedClass):
is_KC_not_CK, kv_center, out_channel); is_KC_not_CK, kv_center, out_channel);
}}else{{ }}else{{
out_features = allocator.zeros({pccm.literal(AllocKeys.OutFeatures)}, out_features = allocator.zeros({pccm.literal(AllocKeys.OutFeatures)},
{{num_activate_out, out_channel}}, features.dtype(), features.device()); {{num_activate_out, out_channel}}, features.dtype(), features.device(), stream_int);
}} }}
if (kv == 1 && subm){{ if (kv == 1 && subm){{
return; return;
}} }}
auto indice_pair_num_cpu = indice_pair_num.cpu(); auto indice_pair_num_cpu = indice_pair_num.cpu();
auto indice_pair_num_cpu_ptr = indice_pair_num_cpu.data_ptr<int>(); auto indice_pair_num_cpu_ptr = indice_pair_num_cpu.data_ptr<int>();
int maxnhot = 0; int maxnhot = 0;
...@@ -1618,7 +1617,7 @@ class ConvGemmOps(pccm.ParameterizedClass): ...@@ -1618,7 +1617,7 @@ class ConvGemmOps(pccm.ParameterizedClass):
int kv_center = kv / 2; int kv_center = kv / 2;
tv::Tensor din; tv::Tensor din;
auto dfilters = allocator.zeros({pccm.literal(AllocKeys.DFilters)}, auto dfilters = allocator.zeros({pccm.literal(AllocKeys.DFilters)},
prev_filter_shape_vec, features.dtype(), features.device()); prev_filter_shape_vec, features.dtype(), features.device(), stream_int);
dfilters = dfilters.view(filters.shape()); dfilters = dfilters.view(filters.shape());
if (subm){{ if (subm){{
din = ext_mm.indice_conv_bwd_init_gemm({pccm.literal(AllocKeys.Features)}, din = ext_mm.indice_conv_bwd_init_gemm({pccm.literal(AllocKeys.Features)},
...@@ -1628,7 +1627,7 @@ class ConvGemmOps(pccm.ParameterizedClass): ...@@ -1628,7 +1627,7 @@ class ConvGemmOps(pccm.ParameterizedClass):
is_KC_not_CK, kv_center); is_KC_not_CK, kv_center);
}}else{{ }}else{{
din = allocator.zeros({pccm.literal(AllocKeys.DIn)}, din = allocator.zeros({pccm.literal(AllocKeys.DIn)},
features.shape_vector(), features.dtype(), features.device()); features.shape_vector(), features.dtype(), features.device(), stream_int);
}} }}
if (kv == 1 && subm){{ if (kv == 1 && subm){{
return; return;
...@@ -1922,10 +1921,10 @@ class ConvGemmOps(pccm.ParameterizedClass): ...@@ -1922,10 +1921,10 @@ class ConvGemmOps(pccm.ParameterizedClass):
tv::Tensor out_features; tv::Tensor out_features;
if (is_subm){{ if (is_subm){{
out_features = allocator.empty({pccm.literal(AllocKeys.OutFeatures)}, out_features = allocator.empty({pccm.literal(AllocKeys.OutFeatures)},
{{num_activate_out, out_channel}}, features.dtype(), features.device()); {{num_activate_out, out_channel}}, features.dtype(), features.device(), stream_int);
}}else{{ }}else{{
out_features = allocator.zeros({pccm.literal(AllocKeys.OutFeatures)}, out_features = allocator.zeros({pccm.literal(AllocKeys.OutFeatures)},
{{num_activate_out, out_channel}}, features.dtype(), features.device()); {{num_activate_out, out_channel}}, features.dtype(), features.device(), stream_int);
}} }}
auto arch = get_compute_capability(); auto arch = get_compute_capability();
constexpr auto kForwardInt = static_cast<int>(tv::gemm::ConvOpType::kForward); constexpr auto kForwardInt = static_cast<int>(tv::gemm::ConvOpType::kForward);
...@@ -1966,7 +1965,7 @@ class ConvGemmOps(pccm.ParameterizedClass): ...@@ -1966,7 +1965,7 @@ class ConvGemmOps(pccm.ParameterizedClass):
if (is_train){{ if (is_train){{
mask_output_fwd = allocator.empty({pccm.literal(AllocKeys.MaskOutputFwd)}, mask_output_fwd = allocator.empty({pccm.literal(AllocKeys.MaskOutputFwd)},
{{num_split, tv::div_up(num_activate_out, mask_width)}}, {{num_split, tv::div_up(num_activate_out, mask_width)}},
tv::uint32, features.device()); tv::uint32, features.device(), stream_int);
for (int i = 0; i < num_split; ++i){{ for (int i = 0; i < num_split; ++i){{
mask_output_fwd_splits.push_back(mask_output_fwd[i]); mask_output_fwd_splits.push_back(mask_output_fwd[i]);
}} }}
...@@ -2042,13 +2041,13 @@ class ConvGemmOps(pccm.ParameterizedClass): ...@@ -2042,13 +2041,13 @@ class ConvGemmOps(pccm.ParameterizedClass):
tv::Tensor din; tv::Tensor din;
if (is_subm){{ if (is_subm){{
din = allocator.empty({pccm.literal(AllocKeys.DIn)}, din = allocator.empty({pccm.literal(AllocKeys.DIn)},
features.shape_vector(), features.dtype(), features.device()); features.shape_vector(), features.dtype(), features.device(), stream_int);
}}else{{ }}else{{
din = allocator.zeros({pccm.literal(AllocKeys.DIn)}, din = allocator.zeros({pccm.literal(AllocKeys.DIn)},
features.shape_vector(), features.dtype(), features.device()); features.shape_vector(), features.dtype(), features.device(), stream_int);
}} }}
tv::Tensor dfilters = allocator.zeros({pccm.literal(AllocKeys.DFilters)}, tv::Tensor dfilters = allocator.zeros({pccm.literal(AllocKeys.DFilters)},
filters_shape_vec, filters.dtype(), filters.device()); filters_shape_vec, filters.dtype(), filters.device(), stream_int);
dfilters = dfilters.view(out_channel, -1, in_channel); dfilters = dfilters.view(out_channel, -1, in_channel);
constexpr auto kForwardInt = static_cast<int>(tv::gemm::ConvOpType::kForward); constexpr auto kForwardInt = static_cast<int>(tv::gemm::ConvOpType::kForward);
......
...@@ -27,6 +27,11 @@ import numpy as np ...@@ -27,6 +27,11 @@ import numpy as np
class CudaCommonKernel(pccm.ParameterizedClass): class CudaCommonKernel(pccm.ParameterizedClass):
# we need to use PClass instead of Class # we need to use PClass instead of Class
# because cuda global function can't be put in class body. # because cuda global function can't be put in class body.
def __init__(self) -> None:
super().__init__()
self.add_include("tensorview/cuda/launch.h")
self.add_include("tensorview/cuda/kernel_utils.h")
@pccm.cuda.cuda_global_function @pccm.cuda.cuda_global_function
def arange_kernel(self): def arange_kernel(self):
code = pccm.FunctionCode() code = pccm.FunctionCode()
...@@ -54,6 +59,19 @@ class CudaCommonKernel(pccm.ParameterizedClass): ...@@ -54,6 +59,19 @@ class CudaCommonKernel(pccm.ParameterizedClass):
""") """)
return code return code
@pccm.cuda.cuda_global_function
def maximum_value_kernel(self):
code = pccm.FunctionCode()
code.targ("T")
code.arg("data", f"T*")
code.arg("val", f"T")
code.arg("size", f"int")
code.raw(f"""
for (int i : tv::KernelLoopX<int>(size)) {{
data[i] = max(data[i], val);
}}
""")
return code
class ConvOutLocIter(pccm.ParameterizedClass): class ConvOutLocIter(pccm.ParameterizedClass):
def __init__(self, problem: ConvProblem): def __init__(self, problem: ConvProblem):
...@@ -260,6 +278,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -260,6 +278,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
assert dtype_indices == dtypes.int32 or dtype_indices == dtypes.int64 assert dtype_indices == dtypes.int32 or dtype_indices == dtypes.int64
@pccm.cuda.cuda_global_function @pccm.cuda.cuda_global_function
def calc_conv_indices_stage1(self): def calc_conv_indices_stage1(self):
code = pccm.FunctionCode() code = pccm.FunctionCode()
...@@ -282,7 +301,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -282,7 +301,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
code.raw(f""" code.raw(f"""
int filter_offset = blockIdx.y; int filter_offset = blockIdx.y;
loc_iter.set_filter_offset(filter_offset); loc_iter.set_filter_offset(filter_offset);
int indices_pair_size_mul_RS = indices_pair_size * RS; // int indices_pair_size_mul_RS = indices_pair_size * RS;
int filter_offset_mul_indices_pair_size = filter_offset * indices_pair_size; int filter_offset_mul_indices_pair_size = filter_offset * indices_pair_size;
for (int i : tv::KernelLoopX<int>(num_indices_in)) {{ for (int i : tv::KernelLoopX<int>(num_indices_in)) {{
tv::array<int, {self.ndim + 1}> npq_offset; tv::array<int, {self.ndim + 1}> npq_offset;
...@@ -418,7 +437,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -418,7 +437,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
code.raw(f""" code.raw(f"""
int filter_offset = blockIdx.y; int filter_offset = blockIdx.y;
loc_iter.set_filter_offset(filter_offset); loc_iter.set_filter_offset(filter_offset);
int indices_pair_size_mul_RS = num_indices_in * RS; // int indices_pair_size_mul_RS = num_indices_in * RS;
int filter_offset_mul_indices_pair_size = filter_offset * num_indices_in; int filter_offset_mul_indices_pair_size = filter_offset * num_indices_in;
for (int input_index : tv::KernelLoopX<int>(num_indices_in)) {{ for (int input_index : tv::KernelLoopX<int>(num_indices_in)) {{
tv::array<int, {self.ndim + 1}> npq_offset; tv::array<int, {self.ndim + 1}> npq_offset;
...@@ -479,7 +498,9 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -479,7 +498,9 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
atomicOr(mask_fwd + output_index, filter_mask_fwd); atomicOr(mask_fwd + output_index, filter_mask_fwd);
// atomicOr(mask_bwd + input_index, filter_mask_bwd); // atomicOr(mask_bwd + input_index, filter_mask_bwd);
indice_pairs_fwd_filter[output_index] = input_index; indice_pairs_fwd_filter[output_index] = input_index;
indice_pairs_bwd_filter[input_index] = output_index; if (indice_pairs_bwd != nullptr){{
indice_pairs_bwd_filter[input_index] = output_index;
}}
}} }}
}} }}
}} }}
...@@ -530,7 +551,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -530,7 +551,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
uint32_t filter_mask_fwd = (1u << (filter_offset)); uint32_t filter_mask_fwd = (1u << (filter_offset));
auto indice_pairs_fwd_filter = indice_pairs_fwd + filter_offset * num_indices_out; auto indice_pairs_fwd_filter = indice_pairs_fwd + filter_offset * num_indices_out;
auto indice_pairs_bwd_filter = indice_pairs_bwd + filter_offset * num_indices_in; // auto indice_pairs_bwd_filter = indice_pairs_bwd + filter_offset * num_indices_in;
auto indice_pairs_uniq_before_sort_filter = indice_pairs_uniq_before_sort + filter_offset * num_indices_in; auto indice_pairs_uniq_before_sort_filter = indice_pairs_uniq_before_sort + filter_offset * num_indices_in;
for (int input_index : tv::KernelLoopX<int>(num_indices_in)) {{ for (int input_index : tv::KernelLoopX<int>(num_indices_in)) {{
auto output_coord_offset = indice_pairs_uniq_before_sort_filter[input_index]; auto output_coord_offset = indice_pairs_uniq_before_sort_filter[input_index];
...@@ -642,6 +663,8 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -642,6 +663,8 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
code.arg("indices_pair_size", "int") code.arg("indices_pair_size", "int")
code.arg("RS", "int") code.arg("RS", "int")
code.arg("is_train", "bool")
code.raw(f""" code.raw(f"""
int filter_offset = blockIdx.y; int filter_offset = blockIdx.y;
uint32_t filter_mask_out = (1u << (filter_offset)); uint32_t filter_mask_out = (1u << (filter_offset));
...@@ -657,7 +680,9 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -657,7 +680,9 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
for (int i : tv::KernelLoopX<int>(num_indices)) {{ for (int i : tv::KernelLoopX<int>(num_indices)) {{
// atomicOr(mask + i, filter_mask_center); // atomicOr(mask + i, filter_mask_center);
indice_pairs[filter_offset_mul_indices_pair_size + i] = i; indice_pairs[filter_offset_mul_indices_pair_size + i] = i;
indice_pairs[indices_pair_size_mul_RS + filter_offset_mul_indices_pair_size + i] = i; if (is_train){{
indice_pairs[indices_pair_size_mul_RS + filter_offset_mul_indices_pair_size + i] = i;
}}
}} }}
}} else {{ }} else {{
for (int output_index : tv::KernelLoopX<int>(num_indices)) {{ for (int output_index : tv::KernelLoopX<int>(num_indices)) {{
...@@ -674,10 +699,14 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -674,10 +699,14 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
atomicOr(mask + input_index, filter_mask_in); atomicOr(mask + input_index, filter_mask_in);
// for this output, we set correct input idx. // for this output, we set correct input idx.
indice_pairs[filter_offset_mul_indices_pair_size + output_index] = input_index; indice_pairs[filter_offset_mul_indices_pair_size + output_index] = input_index;
indice_pairs[indices_pair_size_mul_RS + filter_offset_mul_indices_pair_size + input_index] = output_index; if (is_train){{
indice_pairs[indices_pair_size_mul_RS + filter_offset_mul_indices_pair_size + input_index] = output_index;
}}
// the output in "input location" connect this output idx in another location. // the output in "input location" connect this output idx in another location.
indice_pairs[filter_offset_mul_indices_pair_size_1 + input_index] = output_index; indice_pairs[filter_offset_mul_indices_pair_size_1 + input_index] = output_index;
indice_pairs[indices_pair_size_mul_RS + filter_offset_mul_indices_pair_size_1 + output_index] = input_index; if (is_train){{
indice_pairs[indices_pair_size_mul_RS + filter_offset_mul_indices_pair_size_1 + output_index] = input_index;
}}
}} }}
}} }}
}} }}
...@@ -702,6 +731,8 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -702,6 +731,8 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
code.arg("indices_pair_size", "int") code.arg("indices_pair_size", "int")
code.arg("RS", "int") code.arg("RS", "int")
code.arg("is_train", "bool")
code.raw(f""" code.raw(f"""
int filter_offset = blockIdx.y; int filter_offset = blockIdx.y;
uint32_t filter_mask_out = (1u << (filter_offset)); uint32_t filter_mask_out = (1u << (filter_offset));
...@@ -715,7 +746,9 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -715,7 +746,9 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
if (filter_offset == (RS / 2)){{ if (filter_offset == (RS / 2)){{
for (int i : tv::KernelLoopX<int>(num_indices)) {{ for (int i : tv::KernelLoopX<int>(num_indices)) {{
indice_pairs[filter_offset_mul_indices_pair_size + i] = i; indice_pairs[filter_offset_mul_indices_pair_size + i] = i;
indice_ptr_inv[filter_offset_mul_indices_pair_size + i] = i; if (is_train){{
indice_ptr_inv[filter_offset_mul_indices_pair_size + i] = i;
}}
}} }}
}} else {{ }} else {{
for (int output_index : tv::KernelLoopX<int>(num_indices)) {{ for (int output_index : tv::KernelLoopX<int>(num_indices)) {{
...@@ -733,8 +766,10 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -733,8 +766,10 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
indice_pairs[filter_offset_mul_indices_pair_size + output_index] = input_index; indice_pairs[filter_offset_mul_indices_pair_size + output_index] = input_index;
// the output in "input location" connect this output idx in another location. // the output in "input location" connect this output idx in another location.
indice_pairs[filter_offset_mul_indices_pair_size_1 + input_index] = output_index; indice_pairs[filter_offset_mul_indices_pair_size_1 + input_index] = output_index;
indice_ptr_inv[filter_offset_mul_indices_pair_size + input_index] = output_index; if (is_train){{
indice_ptr_inv[filter_offset_mul_indices_pair_size_1 + output_index] = input_index; indice_ptr_inv[filter_offset_mul_indices_pair_size + input_index] = output_index;
indice_ptr_inv[filter_offset_mul_indices_pair_size_1 + output_index] = input_index;
}}
}} }}
}} }}
}} }}
...@@ -760,15 +795,20 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -760,15 +795,20 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
// TODO stream // TODO stream
// TODO handle num input == 0 // TODO handle num input == 0
int kv = ksize.op<tv::arrayops::prod>(); int kv = ksize.op<tv::arrayops::prod>();
TV_ASSERT_RT_ERR(kv == indice_pairs.dim(1), "error"); TV_ASSERT_RT_ERR(kv == indice_pairs.dim(1), "error");
// indice_pairs: [2, kv, indices.dim(0)]
// indice_pairs_uniq: [indice_pairs.size() / 2 + 1] // indice_pairs: [2, kv, num_act_in]
tv::check_shape(indice_pairs, {{2, kv, indices.dim(0)}}); // indice_pairs_uniq: [num_act_in * kv + 1]
tv::check_shape(indice_pairs, {{2, kv, -1}});
// TV_ASSERT_RT_ERR(indice_pairs.dim(-1) == indices.dim(0), "error");
tv::check_shape(indice_num_per_loc, {{kv}}); tv::check_shape(indice_num_per_loc, {{kv}});
int64_t uniq_size = indice_pairs.size() / 2 + 1; int64_t uniq_size = indice_pairs.size() / 2 + 1;
TV_ASSERT_RT_ERR(indice_pairs_uniq.dim(0) >= uniq_size, "error"); TV_ASSERT_RT_ERR(indice_pairs_uniq.dim(0) >= uniq_size, "error");
TV_ASSERT_RT_ERR(indice_num_per_loc.dim(0) == kv, "error"); TV_ASSERT_RT_ERR(indice_num_per_loc.dim(0) == kv, "error");
int64_t expected_out_size = indices.dim(0) * kv;
tv::cuda::Launch launcher_num_act_in(indices.dim(0), reinterpret_cast<cudaStream_t>(stream_int)); tv::cuda::Launch launcher_num_act_in(indices.dim(0), reinterpret_cast<cudaStream_t>(stream_int));
// tv::cuda::Launch launcher_num_act_in_2(indices.dim(0)); // tv::cuda::Launch launcher_num_act_in_2(indices.dim(0));
launcher_num_act_in.blocks.y = kv; launcher_num_act_in.blocks.y = kv;
...@@ -828,6 +868,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -828,6 +868,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
code.raw(f""" code.raw(f"""
auto custream = reinterpret_cast<cudaStream_t>(stream_int); auto custream = reinterpret_cast<cudaStream_t>(stream_int);
// use_bound_algo = true;
// TODO stream // TODO stream
// TODO handle num input == 0 // TODO handle num input == 0
int kv = ksize.op<tv::arrayops::prod>(); int kv = ksize.op<tv::arrayops::prod>();
...@@ -837,16 +878,12 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -837,16 +878,12 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
auto ctx = tv::Context(); auto ctx = tv::Context();
ctx.set_cuda_stream(custream); ctx.set_cuda_stream(custream);
// indice_pairs: [2, kv, indices.dim(0)] // indice_pairs: [2, kv, num_act_in_bounded]
// indice_pairs_uniq: [indice_pairs.size() / 2 + 1] // indice_pairs_uniq: [indice_pairs.size() / 2 + 1]
// out_inds: [MaxSize, {self.ndim + 1}] // out_inds: [MaxSize, {self.ndim + 1}]
// auto timer = tv::CudaContextTimer<>(); // auto timer = tv::CudaContextTimer<>();
int64_t uniq_size = indice_pairs.size() / 2 + 1; int64_t uniq_size = indice_pairs.size() / 2 + 1;
TV_ASSERT_RT_ERR(indice_pairs_uniq.dim(0) >= num_out_act, "error"); TV_ASSERT_RT_ERR(indice_pairs_uniq.dim(0) >= num_out_act, "error");
// int num_out_act_bounded = num_out_act;
// if (num_out_act_bound > 0){{
// num_out_act_bounded = std::min(num_out_act_bounded, num_out_act);
// }}
TV_ASSERT_RT_ERR(out_inds.dim(0) >= num_out_act && out_inds.dim(1) == {self.ndim + 1}, "error"); TV_ASSERT_RT_ERR(out_inds.dim(0) >= num_out_act && out_inds.dim(1) == {self.ndim + 1}, "error");
tv::cuda::Launch launcher_num_act_in(indices.dim(0), custream); tv::cuda::Launch launcher_num_act_in(indices.dim(0), custream);
launcher_num_act_in.blocks.y = kv; launcher_num_act_in.blocks.y = kv;
...@@ -915,15 +952,17 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -915,15 +952,17 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
// TODO stream // TODO stream
// TODO handle num input == 0 // TODO handle num input == 0
int kv = ksize.op<tv::arrayops::prod>(); int kv = ksize.op<tv::arrayops::prod>();
// indice_pairs_bwd: [kv, indices.dim(0)] int num_act_in = indices.dim(0);
// indice_pairs_uniq: [indice_pairs_bwd.size() + 1] // indice_pairs_bwd: [kv, num_act_in] or empty
tv::check_shape(indice_pairs_bwd, {{kv, indices.dim(0)}}); // indice_pairs_uniq: [kv * num_act_in + 1]
if (!indice_pairs_bwd.empty()){{
tv::check_shape(indice_pairs_bwd, {{kv, num_act_in}});
}}
tv::check_shape(indice_num_per_loc, {{kv}}); tv::check_shape(indice_num_per_loc, {{kv}});
int64_t uniq_size = indice_pairs_bwd.size() + 1; int64_t uniq_size = kv * num_act_in + 1;
TV_ASSERT_RT_ERR(indice_pairs_uniq.dim(0) >= uniq_size, "error"); TV_ASSERT_RT_ERR(indice_pairs_uniq.dim(0) == uniq_size, "error");
int64_t expected_out_size = indices.dim(0) * kv;
tv::cuda::Launch launcher_num_act_in(indices.dim(0), reinterpret_cast<cudaStream_t>(stream_int)); tv::cuda::Launch launcher_num_act_in(indices.dim(0), reinterpret_cast<cudaStream_t>(stream_int));
// tv::cuda::Launch launcher_num_act_in_2(indices.dim(0)); // tv::cuda::Launch launcher_num_act_in_2(indices.dim(0));
launcher_num_act_in.blocks.y = kv; launcher_num_act_in.blocks.y = kv;
...@@ -945,6 +984,9 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -945,6 +984,9 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
@pccm.cuda.static_function @pccm.cuda.static_function
def generate_conv_inds_stage2_mask(self): def generate_conv_inds_stage2_mask(self):
"""here indice_pairs_uniq may be bounded, some
points may be dropped.
"""
code = pccm.FunctionCode() code = pccm.FunctionCode()
code.arg("indices, hashdata_k, hashdata_v", "tv::Tensor") code.arg("indices, hashdata_k, hashdata_v", "tv::Tensor")
code.arg( code.arg(
...@@ -965,21 +1007,26 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -965,21 +1007,26 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
// TODO stream // TODO stream
// TODO handle num input == 0 // TODO handle num input == 0
int kv = ksize.op<tv::arrayops::prod>(); int kv = ksize.op<tv::arrayops::prod>();
// indice_pairs_bwd: [kv, indices.dim(0)] // indice_pairs_bwd: [kv, num_act_in] or empty
// indice_pairs_fwd: [kv, out_inds.dim(0)] // indice_pairs_fwd: [kv, num_act_out]
auto ctx = tv::Context(); auto ctx = tv::Context();
ctx.set_cuda_stream(custream); ctx.set_cuda_stream(custream);
int num_act_in = indices.dim(0);
int num_act_out = num_out_act;
TV_ASSERT_RT_ERR(hashdata_k.dtype() == indice_pairs_uniq.dtype(), "error"); TV_ASSERT_RT_ERR(hashdata_k.dtype() == indice_pairs_uniq.dtype(), "error");
TV_ASSERT_RT_ERR(hashdata_v.dtype() == tv::int32, "error"); TV_ASSERT_RT_ERR(hashdata_v.dtype() == tv::int32, "error");
// out_inds: [MaxSize, {self.ndim + 1}] // out_inds: [num_out_act, {self.ndim + 1}]
// auto timer = tv::CudaContextTimer<>(); // auto timer = tv::CudaContextTimer<>();
tv::check_shape(indice_pairs_bwd, {{kv, indices.dim(0)}}); if (!indice_pairs_bwd.empty()){{
tv::check_shape(indice_pairs_fwd, {{kv, num_out_act}}); tv::check_shape(indice_pairs_bwd, {{kv, num_act_in}});
}}
tv::check_shape(indice_pairs_fwd, {{kv, num_act_out}});
tv::check_shape(out_inds, {{num_out_act, {self.ndim + 1}}}); tv::check_shape(out_inds, {{num_out_act, {self.ndim + 1}}});
tv::cuda::Launch launcher_num_act_in(indices.dim(0), custream); tv::cuda::Launch launcher_num_act_in(num_act_in, custream);
launcher_num_act_in.blocks.y = kv; launcher_num_act_in.blocks.y = kv;
tv::cuda::Launch launcher_num_act_in_no_y(indices.dim(0), custream); tv::cuda::Launch launcher_num_act_in_no_y(num_act_in, custream);
ConvProblem problem(batch_size, 1, 1, input_dims, output_dims, ksize, padding, stride, dilation); ConvProblem problem(batch_size, 1, 1, input_dims, output_dims, ksize, padding, stride, dilation);
ConvLocIter loc_iter(problem); ConvLocIter loc_iter(problem);
...@@ -1001,17 +1048,15 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -1001,17 +1048,15 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
out_inds.data_ptr<int>(), indice_pairs_uniq.data_ptr<const K>(), out_inds.data_ptr<int>(), indice_pairs_uniq.data_ptr<const K>(),
loc_iter.layout_npq, num_out_act); loc_iter.layout_npq, num_out_act);
if (!mask_bwd.empty()){{ if (!mask_bwd.empty()){{
// auto timer = tv::CudaContextTimer<>();
launcher_num_act_in(calc_conv_indices_stage2_mask<table_t>, hash, launcher_num_act_in(calc_conv_indices_stage2_mask<table_t>, hash,
indice_pairs_fwd.data_ptr<int>(), indice_pairs_bwd.data_ptr<int>(), indice_pairs_fwd.data_ptr<int>(), indice_pairs_bwd.data_ptr<int>(),
indice_pairs_uniq_before_sort.data_ptr<K>(), indice_pairs_uniq_before_sort.data_ptr<K>(),
mask_fwd.data_ptr<uint32_t>(), mask_bwd.data_ptr<uint32_t>(), mask_fwd.data_ptr<uint32_t>(), mask_bwd.data_ptr<uint32_t>(),
indice_pairs_bwd.dim(1), indice_pairs_fwd.dim(1)); num_act_in, indice_pairs_fwd.dim(1));
// tv::ssprint("calc_conv_indices_stage2_mask", timer.report() / 1000.0); launcher_num_act_in_no_y(calc_conv_indices_stage2_mask_output,
launcher_num_act_in_no_y(calc_conv_indices_stage2_mask_output, indice_pairs_bwd.data_ptr<int>(), indice_pairs_bwd.data_ptr<int>(),
mask_bwd.data_ptr<uint32_t>(), mask_bwd.data_ptr<uint32_t>(),
indice_pairs_bwd.dim(1), kv); num_act_in, kv);
// tv::ssprint("calc_conv_indices_stage2_mask_output", timer.report() / 1000.0);
if (mask_fwd.dim(0) == 2){{ if (mask_fwd.dim(0) == 2){{
mask_fwd[1].copy_(mask_fwd[0], ctx); mask_fwd[1].copy_(mask_fwd[0], ctx);
}} }}
...@@ -1023,7 +1068,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -1023,7 +1068,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
indice_pairs_fwd.data_ptr<int>(), indice_pairs_bwd.data_ptr<int>(), indice_pairs_fwd.data_ptr<int>(), indice_pairs_bwd.data_ptr<int>(),
indice_pairs_uniq_before_sort.data_ptr<K>(), indice_pairs_uniq_before_sort.data_ptr<K>(),
mask_fwd.data_ptr<uint32_t>(), mask_fwd.data_ptr<uint32_t>(),
indice_pairs_bwd.dim(1), indice_pairs_fwd.dim(1)); num_act_in, indice_pairs_fwd.dim(1));
if (mask_fwd.dim(0) == 2){{ if (mask_fwd.dim(0) == 2){{
mask_fwd[1].copy_(mask_fwd[0], ctx); mask_fwd[1].copy_(mask_fwd[0], ctx);
}} }}
...@@ -1043,10 +1088,11 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -1043,10 +1088,11 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
code.arg("ksize, dilation", f"tv::array<int, {self.ndim}>") code.arg("ksize, dilation", f"tv::array<int, {self.ndim}>")
code.arg("indice_pair_mask", "tv::Tensor", "tv::Tensor()", code.arg("indice_pair_mask", "tv::Tensor", "tv::Tensor()",
"cumm.tensorview.Tensor = Tensor()") "cumm.tensorview.Tensor = Tensor()")
code.arg("backward", "bool", "false") code.arg("is_train", "bool", "true")
code.arg("stream_int", f"std::uintptr_t", "0") code.arg("stream_int", f"std::uintptr_t", "0")
code.raw(f""" code.raw(f"""
int num_act_in_real = indices.dim(0);
auto custream = reinterpret_cast<cudaStream_t>(stream_int); auto custream = reinterpret_cast<cudaStream_t>(stream_int);
auto ctx = tv::Context(); auto ctx = tv::Context();
ctx.set_cuda_stream(custream); ctx.set_cuda_stream(custream);
...@@ -1063,17 +1109,17 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -1063,17 +1109,17 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
}} }}
int kv = ksize.op<tv::arrayops::prod>(); int kv = ksize.op<tv::arrayops::prod>();
TV_ASSERT_RT_ERR(kv == indice_pairs.dim(1), "error"); TV_ASSERT_RT_ERR(kv == indice_pairs.dim(1), "error");
// indice_pairs: [2, kv, indices.dim(0)] // indice_pairs: [1 or 2, kv, num_act_in] if mask else [2, kv, num_act_in]
// out_inds: [MaxSize, {self.ndim + 1}] // out_inds: [MaxSize, {self.ndim + 1}]
// auto timer = tv::CudaContextTimer<>();
TV_ASSERT_RT_ERR(indice_num_per_loc.dim(0) == kv, "error"); TV_ASSERT_RT_ERR(indice_num_per_loc.dim(0) == kv, "error");
tv::cuda::Launch launcher_num_act_in(indices.dim(0), custream); tv::cuda::Launch launcher_num_act_in(num_act_in_real, custream);
launcher_num_act_in.blocks.y = (kv / 2) + 1; launcher_num_act_in.blocks.y = (kv / 2) + 1;
// launcher_num_act_in.blocks.y = kv; // launcher_num_act_in.blocks.y = kv;
ConvProblem problem(batch_size, 1, 1, input_dims, input_dims, ksize, padding, stride, dilation); ConvProblem problem(batch_size, 1, 1, input_dims, input_dims, ksize, padding, stride, dilation);
ConvLocIter loc_iter(problem); ConvLocIter loc_iter(problem);
tv::cuda::Launch lanucher_build_hash(indices.dim(0), custream); tv::cuda::Launch lanucher_build_hash(num_act_in_real, custream);
tv::dispatch<int32_t, int64_t>(hashdata_k.dtype(), [&](auto I){{ tv::dispatch<int32_t, int64_t>(hashdata_k.dtype(), [&](auto I){{
using V = {self.dtype_indices}; using V = {self.dtype_indices};
using K = TV_DECLTYPE(I); using K = TV_DECLTYPE(I);
...@@ -1083,43 +1129,45 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -1083,43 +1129,45 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
using table_t = using table_t =
tv::hash::LinearHashTableSplit<K, V, tv::hash::Murmur3Hash<K>, tv::hash::LinearHashTableSplit<K, V, tv::hash::Murmur3Hash<K>,
tv::hash::default_empty_key_v<K>, false>; tv::hash::default_empty_key_v<K>, false>;
TV_ASSERT_RT_ERR(hashdata_k.dim(0) >= indices.dim(0), "hash size not enough"); TV_ASSERT_RT_ERR(hashdata_k.dim(0) >= num_act_in_real, "hash size not enough");
table_t hash = table_t(hashdata_k.data_ptr<K>(), hashdata_v.data_ptr<V>(), hashdata_k.dim(0)); table_t hash = table_t(hashdata_k.data_ptr<K>(), hashdata_v.data_ptr<V>(), hashdata_k.dim(0));
tv::hash::clear_map_split(hash, custream); tv::hash::clear_map_split(hash, custream);
lanucher_build_hash(build_subm_conv_hash_table<table_t>, hash, indices.data_ptr<const int>(), lanucher_build_hash(build_subm_conv_hash_table<table_t>, hash, indices.data_ptr<const int>(),
loc_iter.layout_npq, indices.dim(0)); loc_iter.layout_npq, num_act_in_real);
// tv::ssprint("build_hash time", timer.report() / 1000.0);
if (!indice_pair_mask.empty()){{ if (!indice_pair_mask.empty()){{
TV_ASSERT_RT_ERR(indice_pairs.ndim() == 3, "error");
TV_ASSERT_RT_ERR(indice_pairs.dim(0) == (is_train ? 2 : 1), "error");
TV_ASSERT_INVALID_ARG(indice_pair_mask.ndim() == 2, "error"); TV_ASSERT_INVALID_ARG(indice_pair_mask.ndim() == 2, "error");
// indice_pair_mask: [mask_split_count, num_act_in]
if (indice_pair_mask.dim(0) == 2){{ if (indice_pair_mask.dim(0) == 2){{
auto mask_0 = indice_pair_mask[0]; auto mask_0 = indice_pair_mask[0].slice_first_axis(0, num_act_in_real);
tv::cuda::Launch lanucher_fill(mask_0.size(), custream); auto mask_1 = indice_pair_mask[1].slice_first_axis(0, num_act_in_real);
lanucher_fill(cudakers::fill_kernel<uint32_t>, mask_0.data_ptr<uint32_t>(), (1 << (kv / 2)), mask_0.size()); tv::cuda::Launch lanucher_fill(num_act_in_real, custream);
indice_pair_mask[1].zero_(ctx); lanucher_fill(cudakers::fill_kernel<uint32_t>, mask_0.data_ptr<uint32_t>(), (1 << (kv / 2)), indices.dim(0));
mask_1.zero_(ctx);
auto kernel = &calc_subm_conv_indices_split_mask<table_t>; auto kernel = &calc_subm_conv_indices_split_mask<table_t>;
launcher_num_act_in(kernel, loc_iter, hash, launcher_num_act_in(kernel, loc_iter, hash,
indices.data_ptr<int>(), indice_pairs.data_ptr<int>(), indices.data_ptr<const int>(), indice_pairs.data_ptr<int>(),
indice_pair_mask[0].data_ptr<uint32_t>(), indice_pair_mask[1].data_ptr<uint32_t>(), mask_0.data_ptr<uint32_t>(), mask_1.data_ptr<uint32_t>(),
indices.dim(0), indice_pairs.dim(2), kv); indices.dim(0), indice_pairs.dim(2), kv, is_train);
}}else{{ }}else{{
tv::cuda::Launch lanucher_fill(indice_pair_mask.size(), custream); // indice_pair_mask: [1, num_act_in]
lanucher_fill(cudakers::fill_kernel<uint32_t>, indice_pair_mask.data_ptr<uint32_t>(), (1 << (kv / 2)), indice_pair_mask.size()); tv::cuda::Launch lanucher_fill(num_act_in_real, custream);
lanucher_fill(cudakers::fill_kernel<uint32_t>, indice_pair_mask.data_ptr<uint32_t>(), (1 << (kv / 2)), indices.dim(0));
TV_ASSERT_RT_ERR(indice_pair_mask.dim(0) == 1, "error"); TV_ASSERT_RT_ERR(indice_pair_mask.dim(0) == 1, "error");
launcher_num_act_in(calc_subm_conv_indices_mask<table_t>, loc_iter, hash, launcher_num_act_in(calc_subm_conv_indices_mask<table_t>, loc_iter, hash,
indices.data_ptr<int>(), indice_pairs.data_ptr<int>(), indices.data_ptr<const int>(), indice_pairs.data_ptr<int>(),
indice_pair_mask.data_ptr<uint32_t>(), indices.dim(0), indice_pairs.dim(2), kv); indice_pair_mask.data_ptr<uint32_t>(), indices.dim(0), indice_pairs.dim(2), kv, is_train);
}} }}
}}else{{ }}else{{
launcher_num_act_in(calc_subm_conv_indices<table_t>, loc_iter, hash, indices.data_ptr<int>(), TV_ASSERT_RT_ERR(indice_pairs.ndim() == 3, "error");
TV_ASSERT_RT_ERR(indice_pairs.dim(0) == 2, "error");
launcher_num_act_in(calc_subm_conv_indices<table_t>, loc_iter, hash, indices.data_ptr<const int>(),
indice_pairs.data_ptr<int>(), indice_pairs.data_ptr<int>(),
indice_num_per_loc.data_ptr<int>(), indices.dim(0), indice_pairs.dim(2), kv); indice_num_per_loc.data_ptr<int>(), indices.dim(0), indice_pairs.dim(2), kv);
}} }}
}}); }});
// tv::ssprint("clear hash time", hashdata.dim(0), timer.report() / 1000.0);
// tv::ssprint("gem subm conv inds time", timer.report() / 1000.0);
return indices.dim(0); return indices.dim(0);
""") """)
...@@ -1166,7 +1214,7 @@ class SparseConvIndicesCPU(pccm.ParameterizedClass): ...@@ -1166,7 +1214,7 @@ class SparseConvIndicesCPU(pccm.ParameterizedClass):
int indices_pair_size_mul_RS = indices_pair_size * kv; int indices_pair_size_mul_RS = indices_pair_size * kv;
auto indice_pairs_ptr = indice_pairs.data_ptr<{self.dtype_indices}>(); auto indice_pairs_ptr = indice_pairs.data_ptr<{self.dtype_indices}>();
std::unordered_map<{self.dtype_indices}, {self.dtype_indices}> hash; std::unordered_map<{self.dtype_indices}, {self.dtype_indices}> hash;
auto indices_ptr = indices.data_ptr<{self.dtype_indices}>(); auto indices_ptr = indices.data_ptr<const {self.dtype_indices}>();
int indice_in_num = indices.dim(0); int indice_in_num = indices.dim(0);
for (int i = 0; i < indice_in_num; ++i){{ for (int i = 0; i < indice_in_num; ++i){{
{self.dtype_indices} index = loc_iter.layout_npq(indices_ptr); {self.dtype_indices} index = loc_iter.layout_npq(indices_ptr);
...@@ -1182,7 +1230,7 @@ class SparseConvIndicesCPU(pccm.ParameterizedClass): ...@@ -1182,7 +1230,7 @@ class SparseConvIndicesCPU(pccm.ParameterizedClass):
indice_pairs_ptr[indices_pair_size_mul_RS + filter_offset_mul_indices_pair_size + i] = i; indice_pairs_ptr[indices_pair_size_mul_RS + filter_offset_mul_indices_pair_size + i] = i;
}} }}
}}else{{ }}else{{
indices_ptr = indices.data_ptr<{self.dtype_indices}>(); indices_ptr = indices.data_ptr<const {self.dtype_indices}>();
auto indice_num_per_loc_ptr = indice_num_per_loc.data_ptr<{self.dtype_indices}>() + filter_offset; auto indice_num_per_loc_ptr = indice_num_per_loc.data_ptr<{self.dtype_indices}>() + filter_offset;
for (int i = 0; i < indice_in_num; ++i){{ for (int i = 0; i < indice_in_num; ++i){{
tv::array<int, {self.ndim + 1}> npq_offset; tv::array<int, {self.ndim + 1}> npq_offset;
...@@ -1224,7 +1272,7 @@ class SparseConvIndicesCPU(pccm.ParameterizedClass): ...@@ -1224,7 +1272,7 @@ class SparseConvIndicesCPU(pccm.ParameterizedClass):
int indices_pair_size_mul_RS = indices_pair_size * kv; int indices_pair_size_mul_RS = indices_pair_size * kv;
auto indice_pairs_ptr = indice_pairs.data_ptr<{self.dtype_indices}>(); auto indice_pairs_ptr = indice_pairs.data_ptr<{self.dtype_indices}>();
std::unordered_map<{self.dtype_indices}, {self.dtype_indices}> hash; std::unordered_map<{self.dtype_indices}, {self.dtype_indices}> hash;
auto indices_ptr = indices.data_ptr<{self.dtype_indices}>(); auto indices_ptr = indices.data_ptr<const {self.dtype_indices}>();
auto out_inds_ptr = out_inds.data_ptr<{self.dtype_indices}>(); auto out_inds_ptr = out_inds.data_ptr<{self.dtype_indices}>();
TV_ASSERT_RT_ERR(input_dims.op<tv::arrayops::prod>() < std::numeric_limits<{self.dtype_indices}>::max(), TV_ASSERT_RT_ERR(input_dims.op<tv::arrayops::prod>() < std::numeric_limits<{self.dtype_indices}>::max(),
"kernel volume must smaller than max value of {self.dtype_indices}"); "kernel volume must smaller than max value of {self.dtype_indices}");
...@@ -1234,7 +1282,7 @@ class SparseConvIndicesCPU(pccm.ParameterizedClass): ...@@ -1234,7 +1282,7 @@ class SparseConvIndicesCPU(pccm.ParameterizedClass):
{self.dtype_indices} hashval; {self.dtype_indices} hashval;
for (int filter_offset = 0; filter_offset < kv; ++filter_offset){{ for (int filter_offset = 0; filter_offset < kv; ++filter_offset){{
int filter_offset_mul_indices_pair_size = filter_offset * indices_pair_size; int filter_offset_mul_indices_pair_size = filter_offset * indices_pair_size;
indices_ptr = indices.data_ptr<{self.dtype_indices}>(); indices_ptr = indices.data_ptr<const {self.dtype_indices}>();
auto indice_num_per_loc_ptr = indice_num_per_loc.data_ptr<{self.dtype_indices}>() + filter_offset; auto indice_num_per_loc_ptr = indice_num_per_loc.data_ptr<{self.dtype_indices}>() + filter_offset;
for (int i = 0; i < indice_in_num; ++i){{ for (int i = 0; i < indice_in_num; ++i){{
tv::array<int, {self.ndim + 1}> npq_offset; tv::array<int, {self.ndim + 1}> npq_offset;
......
...@@ -180,6 +180,85 @@ class IndiceMaxPool(pccm.Class): ...@@ -180,6 +180,85 @@ class IndiceMaxPool(pccm.Class):
""") """)
return code return code
@pccm.cuda.cuda_global_function
def forward_avgpool_implicit_gemm_kernel(self):
code = pccm.FunctionCode()
code.targ("T")
code.arg("out_features", f"T*")
code.arg("in_features", f"const T*")
code.arg("indices", "const int*")
code.arg("count_out", "int*")
code.arg("num_features", "int")
code.arg("RS", "int")
code.arg("num_indices", "int")
code.raw(f"""
for (int i : tv::KernelLoopY<int>(num_indices)) {{
auto out_ptr = out_features + i * num_features;
auto indices_ptr = indices + i;
int in_idx = 0;
int count = 0;
for (int k = 0; k < RS; ++k){{
in_idx = indices_ptr[0];
count += int(in_idx != -1);
indices_ptr += num_indices;
}}
if (count_out != nullptr){{
count_out[i] = count;
}}
for (int j : tv::KernelLoopX<int>(num_features)) {{
indices_ptr = indices + i;
int in_idx;
T in, in_temp;
in = T(0);
for (int k = 0; k < RS; ++k){{
in_idx = indices_ptr[0];
bool valid = in_idx != -1;
in_temp = valid ? in_features[in_idx * num_features + j] : T(0);
in += in_temp;
indices_ptr += num_indices;
}}
out_ptr[j] = count > 0 ? in / T(count) : T(0);
}}
}}
""")
return code
@pccm.cuda.cuda_global_function
def backward_avgpool_implicit_gemm_kernel(self):
code = pccm.FunctionCode()
code.targ("T")
code.arg("dout_features", f"const T*")
code.arg("din_features", f"T*")
code.arg("indices_bwd", "const int*")
code.arg("count_out", "const int*")
code.arg("num_features", "int")
code.arg("RS", "int")
code.arg("num_indices", "int")
code.raw(f"""
for (int i : tv::KernelLoopY<int>(num_indices)) {{
auto din_ptr = din_features + i * num_features;
for (int j : tv::KernelLoopX<int>(num_features)) {{
auto indices_ptr = indices_bwd + i;
int out_idx = 0;
T sum_val = T(0);
for (int k = 0; k < RS; ++k){{
out_idx = indices_ptr[0];
bool valid = out_idx != -1;
T dout = valid ? dout_features[out_idx * num_features + j] : T(0);
int count = valid ? count_out[out_idx] : T(0);
sum_val += dout * T(count);
indices_ptr += num_indices;
}}
din_ptr[j] = sum_val;
}}
}}
""")
return code
@pccm.cuda.static_function @pccm.cuda.static_function
def forward(self): def forward(self):
code = pccm.FunctionCode() code = pccm.FunctionCode()
...@@ -348,6 +427,92 @@ class IndiceMaxPool(pccm.Class): ...@@ -348,6 +427,92 @@ class IndiceMaxPool(pccm.Class):
""") """)
return code return code
@pccm.cuda.static_function
def forward_avgpool_implicit_gemm(self):
code = pccm.FunctionCode()
code.arg("out", "tv::Tensor")
code.arg("in", "tv::Tensor")
code.arg("inds", "tv::Tensor")
code.arg("count_out", "tv::Tensor")
code.arg("stream", "std::uintptr_t", "0")
code.raw(f"""
auto nhot = out.dim(0);
tv::check_shape(inds, {{-1, nhot}});
tv::check_shape(in, {{-1, out.dim(1)}});
auto cudastream = reinterpret_cast<cudaStream_t>(stream);
tv::dispatch<float, double, tv::half_t, tv::bfloat16_t>(out.dtype(), [&](auto I){{
using T = TV_DECLTYPE(I);
constexpr int MaxThreads = 512;
tv::cuda::Launch launcher(1);
bool found = tv::dispatch_int_noexcept<512, 256, 128, 64, 32, 16>(out.dim(1), [](int my, int expect){{return my >= expect;}}, [&](auto V){{
// if out.dim(1) > value in list above, run this function.
// if a value is found, other value won't be executed.
int NumFeatures = TV_DECLTYPE(V)::value;
int Num0 = MaxThreads / NumFeatures;
dim3 blocks(tv::div_up(out.dim(1), int64_t(NumFeatures)), tv::div_up(nhot, int64_t(Num0)));
dim3 threads(NumFeatures, Num0);
launcher = tv::cuda::Launch(blocks, threads, cudastream);
}});
if (!found){{
int NumFeatures = 16;
int Num0 = MaxThreads / NumFeatures;
dim3 blocks(tv::div_up(out.dim(1), int64_t(NumFeatures)), tv::div_up(nhot, int64_t(Num0)));
dim3 threads(NumFeatures, Num0);
launcher = tv::cuda::Launch(blocks, threads, cudastream);
}}
launcher(forward_avgpool_implicit_gemm_kernel<T>, out.data_ptr<T>(), in.data_ptr<const T>(),
inds.data_ptr<const int>(), count_out.data_ptr<int>(), out.dim(1), inds.dim(0), inds.dim(1));
}});
""")
return code
@pccm.cuda.static_function
def backward_avgpool_implicit_gemm(self):
code = pccm.FunctionCode()
code.arg("dout", "tv::Tensor")
code.arg("din", "tv::Tensor")
code.arg("inds", "tv::Tensor")
code.arg("count_out", "tv::Tensor")
code.arg("stream", "std::uintptr_t", "0")
code.raw(f"""
auto nhot = din.dim(0);
TV_ASSERT_RT_ERR(!count_out.empty(), "count out must not empty")
tv::check_shape(inds, {{-1, nhot}});
tv::check_shape(din, {{-1, dout.dim(1)}});
int num_act_out = dout.dim(1);
auto cudastream = reinterpret_cast<cudaStream_t>(stream);
tv::dispatch<float, double, tv::half_t, tv::bfloat16_t>(dout.dtype(), [&](auto I){{
using T = TV_DECLTYPE(I);
constexpr int MaxThreads = 512;
tv::cuda::Launch launcher(1);
bool found = tv::dispatch_int_noexcept<512, 256, 128, 64, 32, 16>(dout.dim(1), [](int my, int expect){{return my >= expect;}}, [&](auto V){{
// if out.dim(1) > value in list above, run this function.
// if a value is found, other value won't be executed.
int NumFeatures = TV_DECLTYPE(V)::value;
int Num0 = MaxThreads / NumFeatures;
dim3 blocks(tv::div_up(dout.dim(1), int64_t(NumFeatures)), tv::div_up(nhot, int64_t(Num0)));
dim3 threads(NumFeatures, Num0);
launcher = tv::cuda::Launch(blocks, threads, cudastream);
}});
if (!found){{
int NumFeatures = 16;
int Num0 = MaxThreads / NumFeatures;
dim3 blocks(tv::div_up(dout.dim(1), int64_t(NumFeatures)), tv::div_up(nhot, int64_t(Num0)));
dim3 threads(NumFeatures, Num0);
launcher = tv::cuda::Launch(blocks, threads, cudastream);
}}
launcher(backward_avgpool_implicit_gemm_kernel<T>,
dout.data_ptr<const T>(), din.data_ptr<T>(),
inds.data_ptr<const int>(), count_out.data_ptr<const int>(),
dout.dim(1), inds.dim(0), inds.dim(1));
}});
""")
return code
class IndiceMaxPoolCPU(pccm.Class): class IndiceMaxPoolCPU(pccm.Class):
def __init__(self): def __init__(self):
......
...@@ -297,7 +297,7 @@ class Point2Voxel(pccm.ParameterizedClass, pccm.pybind.PybindClassMixin): ...@@ -297,7 +297,7 @@ class Point2Voxel(pccm.ParameterizedClass, pccm.pybind.PybindClassMixin):
self.add_dependency(TensorView) self.add_dependency(TensorView)
self.p2v_c = Point2VoxelCommon(dtype, ndim, zyx) self.p2v_c = Point2VoxelCommon(dtype, ndim, zyx)
self.add_param_class("p2v_c", self.p2v_c, "Point2VoxelCommon") self.add_param_class("p2v_c", self.p2v_c, "Point2VoxelCommon")
layout = TensorGeneric(ndim, True) layout = TensorGeneric(ndim, False)
self.add_param_class("layout_ns", layout, "Layout") self.add_param_class("layout_ns", layout, "Layout")
self.dtype = dtype self.dtype = dtype
self.ndim = ndim self.ndim = ndim
...@@ -489,7 +489,7 @@ class Point2VoxelCPU(pccm.ParameterizedClass, pccm.pybind.PybindClassMixin): ...@@ -489,7 +489,7 @@ class Point2VoxelCPU(pccm.ParameterizedClass, pccm.pybind.PybindClassMixin):
def __init__(self, dtype: dtypes.DType, ndim: int, zyx: bool = True): def __init__(self, dtype: dtypes.DType, ndim: int, zyx: bool = True):
super().__init__() super().__init__()
self.add_dependency(TensorView) self.add_dependency(TensorView)
layout = TensorGeneric(ndim, True) layout = TensorGeneric(ndim, False)
self.add_param_class("layout_ns", layout, "Layout") self.add_param_class("layout_ns", layout, "Layout")
self.dtype = dtype self.dtype = dtype
self.ndim = ndim self.ndim = ndim
......
# Copyright 2022 Yan Yan
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
...@@ -10,33 +10,41 @@ from spconv.core import (IMPLGEMM_SIMT_PARAMS, IMPLGEMM_TURING_PARAMS, ...@@ -10,33 +10,41 @@ from spconv.core import (IMPLGEMM_SIMT_PARAMS, IMPLGEMM_TURING_PARAMS,
SHUFFLE_TURING_PARAMS, SHUFFLE_VOLTA_PARAMS) SHUFFLE_TURING_PARAMS, SHUFFLE_VOLTA_PARAMS)
from spconv.csrc.hash.core import HashTable from spconv.csrc.hash.core import HashTable
from spconv.csrc.sparse.all import SpconvOps from spconv.csrc.sparse.all import SpconvOps
from spconv.csrc.sparse.alloc import ExternalAllocator from spconv.csrc.sparse.alloc import ExternalAllocator, StaticAllocator
from spconv.csrc.sparse.convops import (ConvGemmOps, ConvTunerSimple, from spconv.csrc.sparse.convops import (ConvGemmOps, ConvTunerSimple,
ExternalSpconvMatmul, GemmTunerSimple, ExternalSpconvMatmul, GemmTunerSimple,
SimpleExternalSpconvMatmul) SimpleExternalSpconvMatmul)
from spconv.csrc.utils import BoxOps from spconv.csrc.utils import BoxOps
from cumm.gemm.algospec.core import (GemmAlgo, ShuffleStrideType)
from cumm.conv.bases import ConvLayout, ConvLayoutType, ConvOpType
def main(include: str, def main(include: str,
src: str, src: str,
libname: str = "spconv", libname: str = "spconv",
prefix: str = "spconvlib"): prefix: str = "spconvlib",
inference_only: bool = False):
all_shuffle = SHUFFLE_SIMT_PARAMS + SHUFFLE_VOLTA_PARAMS + SHUFFLE_TURING_PARAMS all_shuffle = SHUFFLE_SIMT_PARAMS + SHUFFLE_VOLTA_PARAMS + SHUFFLE_TURING_PARAMS
all_shuffle = list(filter(lambda x: not x.is_nvrtc, all_shuffle)) all_shuffle = list(filter(lambda x: not x.is_nvrtc, all_shuffle))
if inference_only:
all_shuffle = list(filter(lambda x: x.shuffle_stride != ShuffleStrideType.ShuffleAB, all_shuffle))
cu = GemmMainUnitTest(all_shuffle) cu = GemmMainUnitTest(all_shuffle)
cu.namespace = "cumm.gemm.main" cu.namespace = "cumm.gemm.main"
all_imp = (IMPLGEMM_SIMT_PARAMS + IMPLGEMM_VOLTA_PARAMS + all_imp = (IMPLGEMM_SIMT_PARAMS + IMPLGEMM_VOLTA_PARAMS +
IMPLGEMM_TURING_PARAMS) IMPLGEMM_TURING_PARAMS)
# all_imp = IMPLGEMM_SIMT_PARAMS # all_imp = IMPLGEMM_SIMT_PARAMS
all_imp = list(filter(lambda x: not x.is_nvrtc, all_imp)) all_imp = list(filter(lambda x: not x.is_nvrtc, all_imp))
if inference_only:
all_imp = list(filter(lambda x: x.op_type == ConvOpType.kForward, all_imp))
convcu = ConvMainUnitTest(all_imp) convcu = ConvMainUnitTest(all_imp)
convcu.namespace = "cumm.conv.main" convcu.namespace = "cumm.conv.main"
gemmtuner = GemmTunerSimple(cu) gemmtuner = GemmTunerSimple(cu)
gemmtuner.namespace = "csrc.sparse.convops.gemmops" gemmtuner.namespace = "spconv.csrc.sparse.convops.gemmops"
convtuner = ConvTunerSimple(convcu) convtuner = ConvTunerSimple(convcu)
convtuner.namespace = "csrc.sparse.convops.convops" convtuner.namespace = "spconv.csrc.sparse.convops.convops"
convops = ConvGemmOps(gemmtuner, convtuner) convops = ConvGemmOps(gemmtuner, convtuner)
convops.namespace = "csrc.sparse.convops.spops" convops.namespace = "spconv.csrc.sparse.convops.spops"
cus = [ cus = [
cu, cu,
...@@ -51,6 +59,7 @@ def main(include: str, ...@@ -51,6 +59,7 @@ def main(include: str,
ExternalAllocator(), ExternalAllocator(),
ExternalSpconvMatmul(), ExternalSpconvMatmul(),
SimpleExternalSpconvMatmul(), SimpleExternalSpconvMatmul(),
StaticAllocator(),
] ]
gen_cmake(libname, cus, include, src, namespace_prefix=prefix) gen_cmake(libname, cus, include, src, namespace_prefix=prefix)
......
...@@ -17,7 +17,9 @@ from spconv.pytorch.modules import (SparseModule, SparseSequential, ...@@ -17,7 +17,9 @@ from spconv.pytorch.modules import (SparseModule, SparseSequential,
assign_name_for_sparse_modules) assign_name_for_sparse_modules)
from spconv.pytorch.ops import ConvAlgo from spconv.pytorch.ops import ConvAlgo
from spconv.pytorch.pool import (SparseMaxPool1d, SparseMaxPool2d, from spconv.pytorch.pool import (SparseMaxPool1d, SparseMaxPool2d,
SparseMaxPool3d, SparseMaxPool4d) SparseMaxPool3d, SparseMaxPool4d,
SparseAvgPool1d, SparseAvgPool2d,
SparseAvgPool3d)
from spconv.pytorch.tables import AddTable, ConcatTable, JoinTable from spconv.pytorch.tables import AddTable, ConcatTable, JoinTable
......
...@@ -38,6 +38,9 @@ from torch.nn.init import calculate_gain ...@@ -38,6 +38,9 @@ from torch.nn.init import calculate_gain
FILTER_HWIO = False FILTER_HWIO = False
_MAX_NUM_VOXELS_DURING_TRAINING = "max_num_voxels_during_training"
class SparseConvolution(SparseModule): class SparseConvolution(SparseModule):
__constants__ = [ __constants__ = [
'stride', 'padding', 'dilation', 'groups', 'bias', 'subm', 'inverse', 'stride', 'padding', 'dilation', 'groups', 'bias', 'subm', 'inverse',
...@@ -61,6 +64,7 @@ class SparseConvolution(SparseModule): ...@@ -61,6 +64,7 @@ class SparseConvolution(SparseModule):
indice_key: Optional[str] = None, indice_key: Optional[str] = None,
algo: Optional[ConvAlgo] = None, algo: Optional[ConvAlgo] = None,
fp32_accum: Optional[bool] = None, fp32_accum: Optional[bool] = None,
record_voxel_count: bool = False,
name=None): name=None):
super(SparseConvolution, self).__init__(name=name) super(SparseConvolution, self).__init__(name=name)
assert groups == 1, "don't support groups for now" assert groups == 1, "don't support groups for now"
...@@ -89,6 +93,12 @@ class SparseConvolution(SparseModule): ...@@ -89,6 +93,12 @@ class SparseConvolution(SparseModule):
self.groups = groups self.groups = groups
self.subm = subm self.subm = subm
self.indice_key = indice_key self.indice_key = indice_key
if record_voxel_count and not self.subm and not self.inverse:
# we record maximum voxel num in both inference and training if
# record_voxel_count flag setting.
self.register_buffer(_MAX_NUM_VOXELS_DURING_TRAINING,
torch.zeros(1, dtype=torch.int32))
self.record_voxel_count = record_voxel_count
if algo is None: if algo is None:
if kv <= 32 and not CPU_ONLY_BUILD: if kv <= 32 and not CPU_ONLY_BUILD:
if kv < 8: if kv < 8:
...@@ -122,37 +132,46 @@ class SparseConvolution(SparseModule): ...@@ -122,37 +132,46 @@ class SparseConvolution(SparseModule):
else: else:
self.register_parameter('bias', None) self.register_parameter('bias', None)
self.reset_parameters() self.reset_parameters()
if hasattr(self, "_register_load_state_dict_pre_hook"):
self._register_load_state_dict_pre_hook(self._load_weight_different_layout) self._register_load_state_dict_pre_hook(
self._load_weight_different_layout)
def _load_weight_different_layout(
self, state_dict, prefix, local_metadata, strict, def _load_weight_different_layout(self, state_dict, prefix, local_metadata,
missing_keys, unexpected_keys, error_msgs): strict, missing_keys, unexpected_keys,
error_msgs):
if self.record_voxel_count and not self.subm and not self.inverse and _MAX_NUM_VOXELS_DURING_TRAINING not in state_dict:
state_dict[prefix + _MAX_NUM_VOXELS_DURING_TRAINING] = torch.zeros(
1, dtype=torch.int32)
if not SAVED_WEIGHT_LAYOUT: if not SAVED_WEIGHT_LAYOUT:
return return
key = prefix + "weight" key = prefix + "weight"
assert key in state_dict assert key in state_dict
ndim = self.ndim ndim = self.ndim
if SAVED_WEIGHT_LAYOUT == "RSKC": if SAVED_WEIGHT_LAYOUT == "RSKC":
state_dict[key] = state_dict[key].permute(ndim, *range(ndim), ndim + 1).contiguous() state_dict[key] = state_dict[key].permute(ndim, *range(ndim),
ndim + 1).contiguous()
elif SAVED_WEIGHT_LAYOUT == "RSCK": elif SAVED_WEIGHT_LAYOUT == "RSCK":
state_dict[key] = state_dict[key].permute(ndim + 1, *range(ndim), ndim).contiguous() state_dict[key] = state_dict[key].permute(ndim + 1, *range(ndim),
ndim).contiguous()
if ALL_WEIGHT_IS_KRSC or self.algo != ConvAlgo.Native: if ALL_WEIGHT_IS_KRSC or self.algo != ConvAlgo.Native:
# in spconv 2.2, we only support KRSC layout. # in spconv 2.2, we only support KRSC layout.
if SAVED_WEIGHT_LAYOUT == "RSKC": if SAVED_WEIGHT_LAYOUT == "RSKC":
state_dict[key] = state_dict[key].permute(ndim, *range(ndim), ndim + 1).contiguous() state_dict[key] = state_dict[key].permute(
ndim, *range(ndim), ndim + 1).contiguous()
elif SAVED_WEIGHT_LAYOUT == "RSCK": elif SAVED_WEIGHT_LAYOUT == "RSCK":
state_dict[key] = state_dict[key].permute(ndim + 1, *range(ndim), ndim).contiguous() state_dict[key] = state_dict[key].permute(
ndim + 1, *range(ndim), ndim).contiguous()
else: else:
if self.algo == ConvAlgo.Native: if self.algo == ConvAlgo.Native:
# to RSCK # to RSCK
if SAVED_WEIGHT_LAYOUT == "RSKC": if SAVED_WEIGHT_LAYOUT == "RSKC":
state_dict[key] = state_dict[key].permute(*range(ndim), ndim + 1, ndim).contiguous() state_dict[key] = state_dict[key].permute(
*range(ndim), ndim + 1, ndim).contiguous()
elif SAVED_WEIGHT_LAYOUT == "KRSC": elif SAVED_WEIGHT_LAYOUT == "KRSC":
state_dict[key] = state_dict[key].permute(*range(1, ndim + 1), 0, ndim + 1).contiguous() state_dict[key] = state_dict[key].permute(
*range(1, ndim + 1), 0, ndim + 1).contiguous()
def extra_repr(self): def extra_repr(self):
s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}' s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}'
...@@ -218,6 +237,9 @@ class SparseConvolution(SparseModule): ...@@ -218,6 +237,9 @@ class SparseConvolution(SparseModule):
bound = 1 / math.sqrt(fan_in) bound = 1 / math.sqrt(fan_in)
init.uniform_(self.bias, -bound, bound) init.uniform_(self.bias, -bound, bound)
def is_inverseable(self):
return self.indice_key is not None and not self.subm
def forward(self, input: SparseConvTensor): def forward(self, input: SparseConvTensor):
assert isinstance(input, SparseConvTensor) assert isinstance(input, SparseConvTensor)
assert input.features.shape[ assert input.features.shape[
...@@ -410,7 +432,6 @@ class SparseConvolution(SparseModule): ...@@ -410,7 +432,6 @@ class SparseConvolution(SparseModule):
self._check_subm_reuse_valid(input, spatial_shape, self._check_subm_reuse_valid(input, spatial_shape,
datas) datas)
else: else:
with input._timer.namespace("gen_pairs"): with input._timer.namespace("gen_pairs"):
# we need to gen bwd indices for regular conv # we need to gen bwd indices for regular conv
# because it may be inversed. # because it may be inversed.
...@@ -491,6 +512,14 @@ class SparseConvolution(SparseModule): ...@@ -491,6 +512,14 @@ class SparseConvolution(SparseModule):
features.shape[0]) features.shape[0])
out_tensor.benchmark_record[self.name]["num_out_points"].append( out_tensor.benchmark_record[self.name]["num_out_points"].append(
out_features.shape[0]) out_features.shape[0])
if not self.subm and not self.inverse and self.record_voxel_count:
if hasattr(self,
_MAX_NUM_VOXELS_DURING_TRAINING):
ops.maximum_value_int_(
getattr(
self,
_MAX_NUM_VOXELS_DURING_TRAINING),
outids.shape[0])
out_tensor = out_tensor.replace_feature(out_features) out_tensor = out_tensor.replace_feature(out_features)
out_tensor.indices = outids out_tensor.indices = outids
out_tensor.indice_dict = indice_dict out_tensor.indice_dict = indice_dict
...@@ -534,20 +563,23 @@ class SparseConv1d(SparseConvolution): ...@@ -534,20 +563,23 @@ class SparseConv1d(SparseConvolution):
indice_key=None, indice_key=None,
algo: Optional[ConvAlgo] = None, algo: Optional[ConvAlgo] = None,
fp32_accum: Optional[bool] = None, fp32_accum: Optional[bool] = None,
record_voxel_count: bool = False,
name=None): name=None):
super(SparseConv1d, self).__init__(1, super(SparseConv1d,
in_channels, self).__init__(1,
out_channels, in_channels,
kernel_size, out_channels,
stride, kernel_size,
padding, stride,
dilation, padding,
groups, dilation,
bias, groups,
indice_key=indice_key, bias,
algo=algo, indice_key=indice_key,
fp32_accum=fp32_accum, algo=algo,
name=name) fp32_accum=fp32_accum,
record_voxel_count=record_voxel_count,
name=name)
class SparseConv2d(SparseConvolution): class SparseConv2d(SparseConvolution):
...@@ -563,20 +595,23 @@ class SparseConv2d(SparseConvolution): ...@@ -563,20 +595,23 @@ class SparseConv2d(SparseConvolution):
indice_key=None, indice_key=None,
algo: Optional[ConvAlgo] = None, algo: Optional[ConvAlgo] = None,
fp32_accum: Optional[bool] = None, fp32_accum: Optional[bool] = None,
record_voxel_count: bool = False,
name=None): name=None):
super(SparseConv2d, self).__init__(2, super(SparseConv2d,
in_channels, self).__init__(2,
out_channels, in_channels,
kernel_size, out_channels,
stride, kernel_size,
padding, stride,
dilation, padding,
groups, dilation,
bias, groups,
indice_key=indice_key, bias,
algo=algo, indice_key=indice_key,
fp32_accum=fp32_accum, algo=algo,
name=name) fp32_accum=fp32_accum,
record_voxel_count=record_voxel_count,
name=name)
class SparseConv3d(SparseConvolution): class SparseConv3d(SparseConvolution):
...@@ -592,20 +627,23 @@ class SparseConv3d(SparseConvolution): ...@@ -592,20 +627,23 @@ class SparseConv3d(SparseConvolution):
indice_key=None, indice_key=None,
algo: Optional[ConvAlgo] = None, algo: Optional[ConvAlgo] = None,
fp32_accum: Optional[bool] = None, fp32_accum: Optional[bool] = None,
record_voxel_count: bool = False,
name=None): name=None):
super(SparseConv3d, self).__init__(3, super(SparseConv3d,
in_channels, self).__init__(3,
out_channels, in_channels,
kernel_size, out_channels,
stride, kernel_size,
padding, stride,
dilation, padding,
groups, dilation,
bias, groups,
indice_key=indice_key, bias,
algo=algo, indice_key=indice_key,
fp32_accum=fp32_accum, algo=algo,
name=name) fp32_accum=fp32_accum,
record_voxel_count=record_voxel_count,
name=name)
class SparseConv4d(SparseConvolution): class SparseConv4d(SparseConvolution):
...@@ -621,20 +659,23 @@ class SparseConv4d(SparseConvolution): ...@@ -621,20 +659,23 @@ class SparseConv4d(SparseConvolution):
indice_key=None, indice_key=None,
algo: Optional[ConvAlgo] = None, algo: Optional[ConvAlgo] = None,
fp32_accum: Optional[bool] = None, fp32_accum: Optional[bool] = None,
record_voxel_count: bool = False,
name=None): name=None):
super(SparseConv4d, self).__init__(4, super(SparseConv4d,
in_channels, self).__init__(4,
out_channels, in_channels,
kernel_size, out_channels,
stride, kernel_size,
padding, stride,
dilation, padding,
groups, dilation,
bias, groups,
indice_key=indice_key, bias,
algo=algo, indice_key=indice_key,
fp32_accum=fp32_accum, algo=algo,
name=name) fp32_accum=fp32_accum,
record_voxel_count=record_voxel_count,
name=name)
class SparseConvTranspose1d(SparseConvolution): class SparseConvTranspose1d(SparseConvolution):
...@@ -650,21 +691,24 @@ class SparseConvTranspose1d(SparseConvolution): ...@@ -650,21 +691,24 @@ class SparseConvTranspose1d(SparseConvolution):
indice_key=None, indice_key=None,
algo: Optional[ConvAlgo] = None, algo: Optional[ConvAlgo] = None,
fp32_accum: Optional[bool] = None, fp32_accum: Optional[bool] = None,
record_voxel_count: bool = False,
name=None): name=None):
super(SparseConvTranspose1d, self).__init__(1, super(SparseConvTranspose1d,
in_channels, self).__init__(1,
out_channels, in_channels,
kernel_size, out_channels,
stride, kernel_size,
padding, stride,
dilation, padding,
groups, dilation,
bias, groups,
transposed=True, bias,
indice_key=indice_key, transposed=True,
algo=algo, indice_key=indice_key,
fp32_accum=fp32_accum, algo=algo,
name=name) fp32_accum=fp32_accum,
record_voxel_count=record_voxel_count,
name=name)
class SparseConvTranspose2d(SparseConvolution): class SparseConvTranspose2d(SparseConvolution):
...@@ -680,21 +724,24 @@ class SparseConvTranspose2d(SparseConvolution): ...@@ -680,21 +724,24 @@ class SparseConvTranspose2d(SparseConvolution):
indice_key=None, indice_key=None,
algo: Optional[ConvAlgo] = None, algo: Optional[ConvAlgo] = None,
fp32_accum: Optional[bool] = None, fp32_accum: Optional[bool] = None,
record_voxel_count: bool = False,
name=None): name=None):
super(SparseConvTranspose2d, self).__init__(2, super(SparseConvTranspose2d,
in_channels, self).__init__(2,
out_channels, in_channels,
kernel_size, out_channels,
stride, kernel_size,
padding, stride,
dilation, padding,
groups, dilation,
bias, groups,
transposed=True, bias,
indice_key=indice_key, transposed=True,
algo=algo, indice_key=indice_key,
fp32_accum=fp32_accum, algo=algo,
name=name) fp32_accum=fp32_accum,
record_voxel_count=record_voxel_count,
name=name)
class SparseConvTranspose3d(SparseConvolution): class SparseConvTranspose3d(SparseConvolution):
...@@ -710,21 +757,24 @@ class SparseConvTranspose3d(SparseConvolution): ...@@ -710,21 +757,24 @@ class SparseConvTranspose3d(SparseConvolution):
indice_key=None, indice_key=None,
algo: Optional[ConvAlgo] = None, algo: Optional[ConvAlgo] = None,
fp32_accum: Optional[bool] = None, fp32_accum: Optional[bool] = None,
record_voxel_count: bool = False,
name=None): name=None):
super(SparseConvTranspose3d, self).__init__(3, super(SparseConvTranspose3d,
in_channels, self).__init__(3,
out_channels, in_channels,
kernel_size, out_channels,
stride, kernel_size,
padding, stride,
dilation, padding,
groups, dilation,
bias, groups,
transposed=True, bias,
indice_key=indice_key, transposed=True,
algo=algo, indice_key=indice_key,
fp32_accum=fp32_accum, algo=algo,
name=name) fp32_accum=fp32_accum,
record_voxel_count=record_voxel_count,
name=name)
class SparseConvTranspose4d(SparseConvolution): class SparseConvTranspose4d(SparseConvolution):
...@@ -740,21 +790,24 @@ class SparseConvTranspose4d(SparseConvolution): ...@@ -740,21 +790,24 @@ class SparseConvTranspose4d(SparseConvolution):
indice_key=None, indice_key=None,
algo: Optional[ConvAlgo] = None, algo: Optional[ConvAlgo] = None,
fp32_accum: Optional[bool] = None, fp32_accum: Optional[bool] = None,
record_voxel_count: bool = False,
name=None): name=None):
super(SparseConvTranspose4d, self).__init__(4, super(SparseConvTranspose4d,
in_channels, self).__init__(4,
out_channels, in_channels,
kernel_size, out_channels,
stride, kernel_size,
padding, stride,
dilation, padding,
groups, dilation,
bias, groups,
transposed=True, bias,
indice_key=indice_key, transposed=True,
algo=algo, indice_key=indice_key,
fp32_accum=fp32_accum, algo=algo,
name=name) fp32_accum=fp32_accum,
record_voxel_count=record_voxel_count,
name=name)
class SparseInverseConv1d(SparseConvolution): class SparseInverseConv1d(SparseConvolution):
......
...@@ -12,13 +12,14 @@ ...@@ -12,13 +12,14 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import List, Optional, Tuple, Union, Dict from typing import Any, List, Optional, Tuple, Union, Dict
import numpy as np import numpy as np
import torch import torch
from spconv.core import ConvAlgo from spconv.core import ConvAlgo
from spconv.pytorch.constants import PYTORCH_VERSION from spconv.pytorch.constants import PYTORCH_VERSION
from spconv.tools import CUDAKernelTimer from spconv.tools import CUDAKernelTimer
from spconv.constants import SPCONV_FX_TRACE_MODE
if PYTORCH_VERSION >= [1, 8, 0]: if PYTORCH_VERSION >= [1, 8, 0]:
try: try:
...@@ -59,7 +60,8 @@ class ThrustSortAllocator: ...@@ -59,7 +60,8 @@ class ThrustSortAllocator:
class IndiceData(object): class IndiceData(object):
def __init__(self, out_indices, indices, indice_pairs, indice_pair_num, def __init__(self, out_indices, indices, indice_pairs, indice_pair_num,
spatial_shape, out_spatial_shape, is_subm: bool, algo: ConvAlgo, spatial_shape, out_spatial_shape, is_subm: bool, algo: ConvAlgo,
ksize: List[int], stride: List[int], dilation: List[int], padding: List[int]): ksize: List[int], stride: List[int], dilation: List[int], padding: List[int],
voxel_num: Optional[Any] = None):
self.out_indices = out_indices self.out_indices = out_indices
self.indices = indices self.indices = indices
self.indice_pairs = indice_pairs self.indice_pairs = indice_pairs
...@@ -72,6 +74,8 @@ class IndiceData(object): ...@@ -72,6 +74,8 @@ class IndiceData(object):
self.stride = stride self.stride = stride
self.dilation = dilation self.dilation = dilation
self.padding = padding self.padding = padding
# voxel_num is only used in tensorrt conversion.
self.voxel_num = voxel_num
class ImplicitGemmIndiceData(object): class ImplicitGemmIndiceData(object):
...@@ -83,7 +87,9 @@ class ImplicitGemmIndiceData(object): ...@@ -83,7 +87,9 @@ class ImplicitGemmIndiceData(object):
mask_argsort_bwd_splits: List[torch.Tensor], mask_argsort_bwd_splits: List[torch.Tensor],
masks: List[np.ndarray], spatial_shape, masks: List[np.ndarray], spatial_shape,
out_spatial_shape, is_subm: bool, algo: ConvAlgo, out_spatial_shape, is_subm: bool, algo: ConvAlgo,
ksize: List[int], stride: List[int], dilation: List[int], padding: List[int]): ksize: List[int], stride: List[int], dilation: List[int], padding: List[int],
in_voxel_num: Optional[Any] = None,
out_voxel_num: Optional[Any] = None):
self.out_indices = out_indices self.out_indices = out_indices
self.indices = indices self.indices = indices
self.pair_fwd = pair_fwd self.pair_fwd = pair_fwd
...@@ -101,6 +107,9 @@ class ImplicitGemmIndiceData(object): ...@@ -101,6 +107,9 @@ class ImplicitGemmIndiceData(object):
self.stride = stride self.stride = stride
self.dilation = dilation self.dilation = dilation
self.padding = padding self.padding = padding
# in/out voxel_num is only used in tensorrt conversion.
self.in_voxel_num = in_voxel_num
self.out_voxel_num = out_voxel_num
def scatter_nd(indices, updates, shape): def scatter_nd(indices, updates, shape):
...@@ -147,11 +156,12 @@ class SparseConvTensor(metaclass=SpConvTensorMeta): ...@@ -147,11 +156,12 @@ class SparseConvTensor(metaclass=SpConvTensorMeta):
force_algo: force conv/pool layers use this algo, should only used for debug. force_algo: force conv/pool layers use this algo, should only used for debug.
""" """
ndim = indices.shape[1] - 1 ndim = indices.shape[1] - 1
assert features.ndim == 2 if not SPCONV_FX_TRACE_MODE:
assert indices.ndim == 2 assert features.ndim == 2
assert len(spatial_shape) == ndim, "spatial shape must equal to ndim" assert indices.ndim == 2
assert indices.dtype == torch.int32, "only support int32" assert len(spatial_shape) == ndim, "spatial shape must equal to ndim"
assert batch_size > 0 assert indices.dtype == torch.int32, "only support int32"
assert batch_size > 0
self._features = features self._features = features
self.indices = indices self.indices = indices
self.spatial_shape = [int(v) for v in spatial_shape] self.spatial_shape = [int(v) for v in spatial_shape]
......
...@@ -103,7 +103,7 @@ class TorchAllocator(ExternalAllocator): ...@@ -103,7 +103,7 @@ class TorchAllocator(ExternalAllocator):
self.allocated: Dict[Union[str, int], torch.Tensor] = {} self.allocated: Dict[Union[str, int], torch.Tensor] = {}
def zeros(self, name: str, shape: List[int], dtype: int, def zeros(self, name: str, shape: List[int], dtype: int,
device: int, is_temp_memory: bool = False, stream: int = 0) -> tv.Tensor: device: int, stream: int = 0, is_temp_memory: bool = False) -> tv.Tensor:
# TODO free memory by name if its already free by pointer. # TODO free memory by name if its already free by pointer.
# provide a name if you want to access it after c++ function exit. # provide a name if you want to access it after c++ function exit.
torch_uint_workaround = dtype in _TORCH_UINT_WORKAROUNDS torch_uint_workaround = dtype in _TORCH_UINT_WORKAROUNDS
...@@ -126,7 +126,7 @@ class TorchAllocator(ExternalAllocator): ...@@ -126,7 +126,7 @@ class TorchAllocator(ExternalAllocator):
return ten_tv return ten_tv
def empty(self, name: str, shape: List[int], dtype: int, def empty(self, name: str, shape: List[int], dtype: int,
device: int, is_temp_memory: bool = False, stream: int = 0) -> tv.Tensor: device: int, stream: int = 0, is_temp_memory: bool = False) -> tv.Tensor:
torch_uint_workaround = dtype in _TORCH_UINT_WORKAROUNDS torch_uint_workaround = dtype in _TORCH_UINT_WORKAROUNDS
dtype_bkp = dtype dtype_bkp = dtype
if dtype in _TORCH_UINT_WORKAROUNDS: if dtype in _TORCH_UINT_WORKAROUNDS:
...@@ -147,7 +147,7 @@ class TorchAllocator(ExternalAllocator): ...@@ -147,7 +147,7 @@ class TorchAllocator(ExternalAllocator):
return ten_tv return ten_tv
def full_int(self, name: str, shape: List[int], value: int, dtype: int, def full_int(self, name: str, shape: List[int], value: int, dtype: int,
device: int, is_temp_memory: bool = False, stream: int = 0) -> tv.Tensor: device: int, stream: int = 0, is_temp_memory: bool = False) -> tv.Tensor:
if dtype in _TORCH_UINT_WORKAROUNDS and value < 0: if dtype in _TORCH_UINT_WORKAROUNDS and value < 0:
raise NotImplementedError("you can't use full for unsigned dtypes") raise NotImplementedError("you can't use full for unsigned dtypes")
torch_uint_workaround = dtype in _TORCH_UINT_WORKAROUNDS torch_uint_workaround = dtype in _TORCH_UINT_WORKAROUNDS
...@@ -171,7 +171,7 @@ class TorchAllocator(ExternalAllocator): ...@@ -171,7 +171,7 @@ class TorchAllocator(ExternalAllocator):
return ten_tv return ten_tv
def full_float(self, name: str, shape: List[int], value: float, dtype: int, def full_float(self, name: str, shape: List[int], value: float, dtype: int,
device: int, is_temp_memory: bool = False, stream: int = 0) -> tv.Tensor: device: int, stream: int = 0, is_temp_memory: bool = False) -> tv.Tensor:
if dtype in _TORCH_UINT_WORKAROUNDS and value < 0: if dtype in _TORCH_UINT_WORKAROUNDS and value < 0:
raise NotImplementedError("you can't use full for unsigned dtypes") raise NotImplementedError("you can't use full for unsigned dtypes")
torch_uint_workaround = dtype in _TORCH_UINT_WORKAROUNDS torch_uint_workaround = dtype in _TORCH_UINT_WORKAROUNDS
......
...@@ -361,6 +361,25 @@ class SparseMaxPoolImplicitGemmFunction(Function): ...@@ -361,6 +361,25 @@ class SparseMaxPoolImplicitGemmFunction(Function):
features, out, grad_output, indice_pairs_bwd) features, out, grad_output, indice_pairs_bwd)
return input_bp, None, None, None return input_bp, None, None, None
class SparseAvgPoolImplicitGemmFunction(Function):
@staticmethod
@_TORCH_CUSTOM_FWD
def forward(ctx, features: torch.Tensor, indice_pairs_fwd: torch.Tensor,
indice_pairs_bwd: torch.Tensor, num_activate_out: int, calc_count):
out, count = ops.indice_avgpool_implicit_gemm(features, indice_pairs_fwd,
num_activate_out, calc_count)
ctx.save_for_backward(indice_pairs_bwd, features, out, count)
return out
@staticmethod
@once_differentiable
@_TORCH_CUSTOM_BWD
def backward(ctx, grad_output):
indice_pairs_bwd, features, out, count = ctx.saved_tensors
input_bp = ops.indice_avgpool_implicit_gemm_backward(
grad_output, indice_pairs_bwd, count)
return input_bp, None, None, None, None
indice_conv = SparseConvFunction.apply indice_conv = SparseConvFunction.apply
implicit_gemm = SparseImplicitGemmFunction.apply implicit_gemm = SparseImplicitGemmFunction.apply
...@@ -368,6 +387,7 @@ indice_inverse_conv = SparseInverseConvFunction.apply ...@@ -368,6 +387,7 @@ indice_inverse_conv = SparseInverseConvFunction.apply
indice_subm_conv = SubMConvFunction.apply indice_subm_conv = SubMConvFunction.apply
indice_maxpool = SparseMaxPoolFunction.apply indice_maxpool = SparseMaxPoolFunction.apply
indice_maxpool_implicit_gemm = SparseMaxPoolImplicitGemmFunction.apply indice_maxpool_implicit_gemm = SparseMaxPoolImplicitGemmFunction.apply
indice_avgpool_implicit_gemm = SparseAvgPoolImplicitGemmFunction.apply
def _indice_to_scalar(indices: torch.Tensor, shape: List[int]): def _indice_to_scalar(indices: torch.Tensor, shape: List[int]):
......
...@@ -132,12 +132,11 @@ class SparseSequential(SparseModule): ...@@ -132,12 +132,11 @@ class SparseSequential(SparseModule):
if isinstance(input, list): if isinstance(input, list):
input = module(input) input = module(input)
else: else:
assert isinstance(input, spconv.SparseConvTensor) # assert isinstance(input, spconv.SparseConvTensor)
# self._sparity_dict[k] = input.sparity # self._sparity_dict[k] = input.sparity
input = module(input) input = module(input)
else: else:
if isinstance(input, spconv.SparseConvTensor): if isinstance(input, spconv.SparseConvTensor):
print(input.features.shape)
if input.indices.shape[0] != 0: if input.indices.shape[0] != 0:
input = input.replace_feature(module(input.features)) input = input.replace_feature(module(input.features))
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment