Commit 73a5ce7d authored by yan.yan's avatar yan.yan
Browse files

add direct table

parent 0c07559f
......@@ -95,13 +95,19 @@ class AllocKeys:
HashV = "HashV"
ThrustTemp = "ThrustTemp"
TightUniqueCount = "TightUniqueCount"
SPCONV_DEBUG_WEIGHT = False
SPCONV_CPP_INDICE_PAIRS = False
SPCONV_CPP_INDICE_PAIRS_IGEMM = False
SPCONV_CPP_GEMM = False
# currently use cpp pair gen is slightly slower than python, I don't know why.
SPCONV_CPP_INDICE_PAIRS_IGEMM = os.getenv("SPCONV_CPP_INDICE_PAIRS_IGEMM", "0") == "1"
SPCONV_CPP_GEMM = True
SPCONV_FX_TRACE_MODE = os.getenv("SPCONV_FX_TRACE_MODE", "0") == "1"
SPCONV_DIRECT_TABLE_HASH_SIZE_SCALE = 1.1
\ No newline at end of file
from typing import overload, Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
from pccm.stubs import EnumValue, EnumClassValue
from cumm.tensorview import Tensor
from cumm.tensorview import CUDAKernelTimer
class ThrustCustomAllocatorV2:
alloc_func: Callable[int, int]
class SpconvOps:
......@@ -92,6 +93,55 @@ class SpconvOps:
"""
...
@staticmethod
def generate_conv_inds_mask_stage1_direct_table(indices: Tensor, hashdata_k: Tensor, hashdata_v: Tensor, indice_pairs_bwd: Tensor, indice_pairs_uniq: Tensor, indice_num_per_loc: Tensor, batch_size: int, output_dims: List[int], input_dims: List[int], ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], transposed: bool = False, stream_int: int = 0) -> None:
"""
Args:
indices:
hashdata_k:
hashdata_v:
indice_pairs_bwd:
indice_pairs_uniq:
indice_num_per_loc:
batch_size:
output_dims:
input_dims:
ksize:
stride:
padding:
dilation:
transposed:
stream_int:
"""
...
@staticmethod
def unique_hash(hashdata_k: Tensor, hashdata_v: Tensor, uniq_cnt: Tensor, out_indices_offset: Tensor, num_out_bound: int, stream_int: int = 0) -> int:
"""
Args:
hashdata_k:
hashdata_v:
uniq_cnt:
out_indices_offset:
num_out_bound:
stream_int:
"""
...
@staticmethod
def assign_output_direct_hash(out_indices_offset: Tensor, out_indices: Tensor, batch_size: int, output_dims: List[int], input_dims: List[int], ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], stream_int: int = 0) -> None:
"""
Args:
out_indices_offset:
out_indices:
batch_size:
output_dims:
input_dims:
ksize:
stride:
padding:
dilation:
stream_int:
"""
...
@staticmethod
def generate_conv_inds_mask_stage2(indices: Tensor, hashdata_k: Tensor, hashdata_v: Tensor, indice_pairs_fwd: Tensor, indice_pairs_bwd: Tensor, indice_pairs_uniq: Tensor, indice_pairs_uniq_before_sort: Tensor, out_inds: Tensor, mask_fwd: Tensor, mask_bwd: Tensor, num_out_act: int, batch_size: int, output_dims: List[int], input_dims: List[int], ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], transposed: bool = False, stream_int: int = 0) -> int:
"""
Args:
......@@ -118,6 +168,32 @@ class SpconvOps:
"""
...
@staticmethod
def generate_conv_inds_stage2_mask_direct_table(indices: Tensor, hashdata_k: Tensor, hashdata_v: Tensor, indice_pairs_fwd: Tensor, indice_pairs_bwd: Tensor, indice_pairs_uniq: Tensor, indice_pairs_uniq_before_sort: Tensor, out_inds: Tensor, mask_fwd: Tensor, mask_bwd: Tensor, num_out_act: int, batch_size: int, output_dims: List[int], input_dims: List[int], ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], transposed: bool = False, stream_int: int = 0) -> int:
"""
Args:
indices:
hashdata_k:
hashdata_v:
indice_pairs_fwd:
indice_pairs_bwd:
indice_pairs_uniq:
indice_pairs_uniq_before_sort:
out_inds:
mask_fwd:
mask_bwd:
num_out_act:
batch_size:
output_dims:
input_dims:
ksize:
stride:
padding:
dilation:
transposed:
stream_int:
"""
...
@staticmethod
def generate_subm_conv_inds(indices: Tensor, hashdata_k: Tensor, hashdata_v: Tensor, indice_pairs: Tensor, out_inds: Tensor, indice_num_per_loc: Tensor, batch_size: int, input_dims: List[int], ksize: List[int], dilation: List[int], indice_pair_mask: Tensor = Tensor(), backward: bool = False, stream_int: int = 0) -> int:
"""
Args:
......@@ -427,30 +503,45 @@ class SpconvOps:
@staticmethod
def get_int32_max() -> int: ...
@staticmethod
def get_indice_gen_workspace_size(kv: int, num_act_in: int, num_act_out_bound: int, subm: bool, use_int64_hash_k: bool) -> int:
def get_handcrafted_max_act_out(num_act_in: int, ksize: List[int], stride: List[int], padding: List[int], dilation: List[int]) -> int:
"""
Args:
num_act_in:
ksize:
stride:
padding:
dilation:
"""
...
@staticmethod
def get_indice_gen_workspace_size(kv: int, num_act_in: int, num_act_out_bound: int, max_act_out_in_theory: int, subm: bool, use_int64_hash_k: bool, direct_table: bool) -> int:
"""
Args:
kv:
num_act_in:
num_act_out_bound:
max_act_out_in_theory:
subm:
use_int64_hash_k:
direct_table:
"""
...
@staticmethod
def get_indice_gen_tensors_from_workspace(workspace, kv: int, num_act_in: int, num_act_out_bound: int, subm: bool, use_int64_hash_k: bool) -> Dict[str, Tensor]:
def get_indice_gen_tensors_from_workspace(workspace, kv: int, num_act_in: int, num_act_out_bound: int, max_act_out_in_theory: int, subm: bool, use_int64_hash_k: bool, direct_table: bool) -> Dict[str, Tensor]:
"""
Args:
workspace:
kv:
num_act_in:
num_act_out_bound:
max_act_out_in_theory:
subm:
use_int64_hash_k:
direct_table:
"""
...
@staticmethod
def get_indice_pairs_implicit_gemm(allocator, indices: Tensor, batch_size: int, input_dims: List[int], algo: int, ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], out_padding: List[int], subm: bool, transposed: bool, is_train: bool, stream_int: int = 0, num_out_act_bound: int = -1) -> Tuple[Tensor, int]:
def get_indice_pairs_implicit_gemm(allocator, indices: Tensor, batch_size: int, input_dims: List[int], algo: int, ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], out_padding: List[int], subm: bool, transposed: bool, is_train: bool, stream_int: int = 0, num_out_act_bound: int = -1, timer: CUDAKernelTimer = CUDAKernelTimer(False), direct_table: bool = False, preallocated: Dict[str, Tensor] = {}) -> Tuple[Tensor, int]:
"""
Args:
allocator:
......@@ -468,6 +559,9 @@ class SpconvOps:
is_train:
stream_int:
num_out_act_bound:
timer:
direct_table:
preallocated:
"""
...
@staticmethod
......
......@@ -13,7 +13,7 @@
# limitations under the License.
from typing import List
from cumm.common import TensorView, TensorViewCPU, TensorViewKernel, ThrustLib, GemmBasicHost
from cumm.common import TensorView, TensorViewCPU, TensorViewKernel, ThrustLib, GemmBasicHost, CppTimer
import cumm
from cumm.conv.bases import ConvOpType, NHWC
from cumm.conv.params import ConvProblem
......@@ -27,7 +27,7 @@ from .indices import SparseConvIndicesKernel, CudaCommonKernel, SparseConvIndice
from .maxpool import IndiceMaxPool, IndiceMaxPoolCPU
from .gather import GatherCPU
from .alloc import ExternalAllocator, ThrustAllocator
from spconv.constants import AllocKeys
from spconv.constants import SPCONV_DIRECT_TABLE_HASH_SIZE_SCALE, AllocKeys
import re
class CustomThrustLib(pccm.Class):
......@@ -78,6 +78,11 @@ def to_snake_case(name):
name = re.sub('([a-z0-9])([A-Z])', r'\1_\2', name)
return name.lower()
class HashCoreHost(pccm.Class):
def __init__(self):
super().__init__()
self.add_include("tensorview/hash/hash_core.h")
class SpconvOps(pccm.Class):
def __init__(self):
super().__init__()
......@@ -104,7 +109,10 @@ class SpconvOps(pccm.Class):
self.generate_conv_inds_stage1_5,
self.generate_conv_inds_stage2, self.sort_1d_by_key,
self.generate_conv_inds_mask_stage1,
self.generate_conv_inds_mask_stage2
self.generate_conv_inds_mask_stage2,
self.unique_hash, self.assign_output_direct_hash,
self.generate_conv_inds_mask_stage1_direct_table,
self.generate_conv_inds_stage2_mask_direct_table
]
self.add_impl_only_param_class(cuda_funcs, f"ops{ndim}d",
indices,
......@@ -306,6 +314,110 @@ class SpconvOps(pccm.Class):
return code # .ret("int")
@pccm.pybind.mark
@pccm.cuda.static_function
def generate_conv_inds_mask_stage1_direct_table(self):
code = pccm.FunctionCode()
if CUMM_CPU_ONLY_BUILD:
return code.make_invalid()
code.arg("indices, hashdata_k, hashdata_v", "tv::Tensor")
code.arg("indice_pairs_bwd, indice_pairs_uniq, indice_num_per_loc",
"tv::Tensor")
code.arg("batch_size", "int")
code.arg("output_dims, input_dims", f"std::vector<int>")
code.arg("ksize, stride, padding, dilation", f"std::vector<int>")
code.arg("transposed", f"bool", "false")
code.arg("stream_int", f"std::uintptr_t", "0", pyanno="int")
code.raw(f"""
int ndim = indices.dim(1) - 1;
TV_ASSERT_RT_ERR(output_dims.size() == ndim && input_dims.size() == ndim &&
ksize.size() == ndim && stride.size() == ndim && dilation.size() == ndim &&
padding.size() == ndim, "your params size not equal to ndim", ndim);
""")
for ndim in self.ndims:
code.raw(f"""
if (ndim == {ndim}){{
tv::array<int, {ndim}> output_dims_, input_dims_;
tv::array<int, {ndim}> ksize_, stride_, padding_, dilation_;
for (int i = 0; i < {ndim}; ++i){{
output_dims_[i] = output_dims[i];
input_dims_[i] = input_dims[i];
ksize_[i] = ksize[i];
stride_[i] = stride[i];
padding_[i] = padding[i];
dilation_[i] = dilation[i];
}}
return SpconvIndices{ndim}D::generate_conv_inds_mask_stage1_direct_table(indices,
hashdata_k, hashdata_v, indice_pairs_bwd, indice_pairs_uniq,
indice_num_per_loc, batch_size, output_dims_, input_dims_,
ksize_, stride_, padding_, dilation_, transposed, stream_int);
}}
""")
code.raw(f"""TV_THROW_RT_ERR("unknown ndim", ndim);""")
return code # .ret("int")
@pccm.pybind.mark
@pccm.cuda.static_function
def unique_hash(self):
code = pccm.FunctionCode()
if CUMM_CPU_ONLY_BUILD:
return code.make_invalid()
code.arg("hashdata_k, hashdata_v, uniq_cnt, out_indices_offset", "tv::Tensor")
code.arg("num_out_bound", "int")
code.arg("stream_int", f"std::uintptr_t", "0", pyanno="int")
code.raw(f"""
return SpconvIndices3D::unique_hash(hashdata_k, hashdata_v,
uniq_cnt, out_indices_offset, num_out_bound, stream_int);
""")
return code.ret("int")
@pccm.pybind.mark
@pccm.cuda.static_function
def assign_output_direct_hash(self):
code = pccm.FunctionCode()
if CUMM_CPU_ONLY_BUILD:
return code.make_invalid()
code.arg("out_indices_offset, out_indices", "tv::Tensor")
code.arg("batch_size", "int")
code.arg("output_dims, input_dims", f"std::vector<int>")
code.arg("ksize, stride, padding, dilation", f"std::vector<int>")
code.arg("stream_int", f"std::uintptr_t", "0", pyanno="int")
code.raw(f"""
int ndim = out_indices.dim(1) - 1;
TV_ASSERT_RT_ERR(output_dims.size() == ndim && input_dims.size() == ndim &&
ksize.size() == ndim && stride.size() == ndim && dilation.size() == ndim &&
padding.size() == ndim, "your params size not equal to ndim", ndim);
""")
for ndim in self.ndims:
code.raw(f"""
if (ndim == {ndim}){{
tv::array<int, {ndim}> output_dims_, input_dims_;
tv::array<int, {ndim}> ksize_, stride_, padding_, dilation_;
for (int i = 0; i < {ndim}; ++i){{
output_dims_[i] = output_dims[i];
input_dims_[i] = input_dims[i];
ksize_[i] = ksize[i];
stride_[i] = stride[i];
padding_[i] = padding[i];
dilation_[i] = dilation[i];
}}
return SpconvIndices{ndim}D::assign_output_direct_hash(
out_indices_offset, out_indices, batch_size, output_dims_, input_dims_,
ksize_, stride_, padding_, dilation_, stream_int);
}}
""")
code.raw(f"""TV_THROW_RT_ERR("unknown ndim", ndim);""")
return code
@pccm.pybind.mark
@pccm.cuda.static_function
def generate_conv_inds_mask_stage2(self):
......@@ -356,6 +468,55 @@ class SpconvOps(pccm.Class):
return code.ret("int")
@pccm.pybind.mark
@pccm.cuda.static_function
def generate_conv_inds_stage2_mask_direct_table(self):
code = pccm.FunctionCode()
if CUMM_CPU_ONLY_BUILD:
return code.make_invalid()
code.arg("indices, hashdata_k, hashdata_v", "tv::Tensor")
code.arg(
"indice_pairs_fwd, indice_pairs_bwd, indice_pairs_uniq, indice_pairs_uniq_before_sort, out_inds",
"tv::Tensor")
code.arg("mask_fwd, mask_bwd", "tv::Tensor")
code.arg("num_out_act", "int")
code.arg("batch_size", "int")
code.arg("output_dims, input_dims", f"std::vector<int>")
code.arg("ksize, stride, padding, dilation", f"std::vector<int>")
code.arg("transposed", f"bool", "false")
code.arg("stream_int", f"std::uintptr_t", "0", pyanno="int")
code.raw(f"""
int ndim = indices.dim(1) - 1;
TV_ASSERT_RT_ERR(output_dims.size() == ndim && input_dims.size() == ndim &&
ksize.size() == ndim && stride.size() == ndim && dilation.size() == ndim &&
padding.size() == ndim, "your params size not equal to ndim", ndim);
""")
for ndim in self.ndims:
code.raw(f"""
if (ndim == {ndim}){{
tv::array<int, {ndim}> output_dims_, input_dims_;
tv::array<int, {ndim}> ksize_, stride_, padding_, dilation_;
for (int i = 0; i < {ndim}; ++i){{
output_dims_[i] = output_dims[i];
input_dims_[i] = input_dims[i];
ksize_[i] = ksize[i];
stride_[i] = stride[i];
padding_[i] = padding[i];
dilation_[i] = dilation[i];
}}
return SpconvIndices{ndim}D::generate_conv_inds_stage2_mask_direct_table(
indices, hashdata_k, hashdata_v,
indice_pairs_fwd, indice_pairs_bwd,
indice_pairs_uniq, indice_pairs_uniq_before_sort,
out_inds, mask_fwd, mask_bwd,
num_out_act, batch_size, output_dims_, input_dims_,
ksize_, stride_, padding_, dilation_, transposed, stream_int);
}}
""")
code.raw(f"""TV_THROW_RT_ERR("unknown ndim", ndim);""")
return code.ret("int")
@pccm.pybind.mark
@pccm.cuda.static_function
def generate_subm_conv_inds(self):
......@@ -718,53 +879,6 @@ class SpconvOps(pccm.Class):
""")
return code
@pccm.pybind.mark
@pccm.cuda.static_function
def sort_1d_by_key(self):
code = pccm.FunctionCode()
if CUMM_CPU_ONLY_BUILD:
return code.make_invalid()
code.arg("data", "tv::Tensor")
code.arg("indices",
"tv::Tensor",
"tv::Tensor()",
pyanno="cumm.tensorview.Tensor = Tensor()")
code.arg("stream", "std::uintptr_t", "0", pyanno="int")
code.code_after_include = f"""
template <typename T> struct SmallOrEqualTo {{
TV_HOST_DEVICE_INLINE T operator()(const T &x, const T &y) const {{
return x < y;
}}
}};
template <typename T> __global__ void mask_input(T* inp, T mask, int size){{
for (int i : tv::KernelLoopX<int>(size)){{
inp[i] &= mask;
}}
}}
"""
code.add_dependency(CustomThrustLib, TensorViewKernel)
code.add_param_class("cudakers", self.cuda_common_kernel)
code.raw(f"""
cudaStream_t stream_cu = reinterpret_cast<cudaStream_t>(stream);
if (indices.empty()){{
indices = tv::empty({{data.dim(0)}}, tv::int32, 0);
}}
tv::cuda::Launch launcher(data.dim(0), stream_cu);
launcher(cudakers::arange_kernel<int32_t>, indices.data_ptr<int32_t>(), indices.dim(0));
auto timer = tv::CUDATimer();
tv::dispatch<int32_t, uint32_t, int64_t, uint64_t>(data.dtype(), [&](auto I){{
using T = TV_DECLTYPE(I);
thrust::device_ptr<T> ptr_tr(data.data_ptr<T>());
thrust::device_ptr<int32_t> ptr_k(indices.data_ptr<int32_t>());
auto thrust_ctx = thrust::cuda::par.on(stream_cu);
thrust::stable_sort_by_key(thrust_ctx, ptr_tr, ptr_tr + data.dim(0), ptr_k, SmallOrEqualTo<uint32_t>());
}});
// tv::ssprint("SORT BY KEY TIME", data.dim(0), timer.report() / 1000.0);
return indices;
""")
return code.ret("tv::Tensor")
def sort_1d_by_key_allocator_template(self, use_allocator: bool):
code = pccm.FunctionCode()
if CUMM_CPU_ONLY_BUILD:
......@@ -1379,6 +1493,29 @@ class SpconvOps(pccm.Class):
""")
return code.ret("int")
@pccm.pybind.mark
@pccm.static_function
def get_handcrafted_max_act_out(self):
code = pccm.code()
code.arg("num_act_in", "size_t")
code.arg("ksize, stride, padding, dilation", "std::vector<int>")
code.raw(f"""
int res = num_act_in;
for (int i = 0; i < ksize.size(); ++i){{
if (ksize[i] <= stride[i]){{
res *= 1;
}}
else if (ksize[i] > stride[i]){{
res *= tv::div_up(ksize[i], stride[i]);
}}
else{{
res *= ksize[i];
}}
}}
return res;
""")
return code.ret("int")
@pccm.pybind.mark
@pccm.static_function
def get_indice_gen_workspace_size(self):
......@@ -1386,15 +1523,20 @@ class SpconvOps(pccm.Class):
code.arg("kv", "size_t")
code.arg("num_act_in", "size_t")
code.arg("num_act_out_bound", "size_t")
code.arg("subm, use_int64_hash_k", "bool")
code.arg("max_act_out_in_theory", "size_t")
code.arg("subm, use_int64_hash_k, direct_table", "bool")
code.raw(f"""
int hash_size = 2 * num_act_out_bound;
if (direct_table){{
hash_size = int({SPCONV_DIRECT_TABLE_HASH_SIZE_SCALE} * max_act_out_in_theory);
}}
if (subm){{
return 2 * num_act_out_bound * (use_int64_hash_k ? 3 : 2) * sizeof(int);
return hash_size * (use_int64_hash_k ? 3 : 2) * sizeof(int) + 1 * sizeof(int);
}}else{{
size_t pair_single_size = kv * num_act_in; // 40000
size_t ind_uniq_and_bkp_size = (pair_single_size + 1) * 2 * (use_int64_hash_k ? sizeof(int64_t) : sizeof(int32_t));
size_t hash_size = 2 * num_act_out_bound * (use_int64_hash_k ? 3 : 2) * sizeof(int);
return ind_uniq_and_bkp_size + hash_size;
size_t hash_size = hash_size * (use_int64_hash_k ? 3 : 2) * sizeof(int);
return ind_uniq_and_bkp_size + hash_size + 1 * sizeof(int);
}}
""")
return code.ret("std::size_t")
......@@ -1407,20 +1549,26 @@ class SpconvOps(pccm.Class):
code.arg("kv", "size_t")
code.arg("num_act_in", "size_t")
code.arg("num_act_out_bound", "size_t")
code.arg("subm, use_int64_hash_k", "bool")
code.arg("max_act_out_in_theory", "size_t")
code.arg("subm, use_int64_hash_k, direct_table", "bool")
code.raw(f"""
std::unordered_map<std::string, tv::Tensor> res;
auto ws_prev = workspace;
auto expected_size = get_indice_gen_workspace_size(kv, num_act_in, num_act_out_bound, subm, use_int64_hash_k);
auto expected_size = get_indice_gen_workspace_size(kv, num_act_in, num_act_out_bound,
max_act_out_in_theory, subm, use_int64_hash_k, direct_table);
int hash_size = 2 * num_act_out_bound;
if (direct_table){{
hash_size = int({SPCONV_DIRECT_TABLE_HASH_SIZE_SCALE} * max_act_out_in_theory);
}}
if (use_int64_hash_k){{
auto ten = tv::from_blob(workspace, {{int64_t(num_act_out_bound) * 2}}, tv::int64, 0);
auto ten = tv::from_blob(workspace, {{int64_t(hash_size) * 2}}, tv::int64, 0);
res.insert({{{pccm.literal(AllocKeys.HashKOrKV)}, ten}});
workspace += ten.nbytes();
auto ten2 = tv::from_blob(workspace, {{int64_t(num_act_out_bound) * 2}}, tv::int32, 0);
auto ten2 = tv::from_blob(workspace, {{int64_t(hash_size) * 2}}, tv::int32, 0);
res.insert({{{pccm.literal(AllocKeys.HashV)}, ten2}});
workspace += ten2.nbytes();
}}else{{
auto ten = tv::from_blob(workspace, {{2, int64_t(num_act_out_bound) * 2}}, tv::int32, 0);
auto ten = tv::from_blob(workspace, {{2, int64_t(hash_size) * 2}}, tv::int32, 0);
res.insert({{{pccm.literal(AllocKeys.HashKOrKV)}, ten}});
workspace += ten.nbytes();
}}
......@@ -1433,6 +1581,10 @@ class SpconvOps(pccm.Class):
res.insert({{{pccm.literal(AllocKeys.IndicePairsUniqBackup)}, ten2}});
workspace += ten2.nbytes();
}}
auto uniq_cnt = tv::from_blob(workspace, {{1}}, tv::int32, 0);
res.insert({{{pccm.literal(AllocKeys.TightUniqueCount)}, uniq_cnt}});
workspace += uniq_cnt.nbytes();
TV_ASSERT_RT_ERR(workspace - ws_prev == expected_size, "this shouldn't happen");
return res;
""")
......@@ -1442,6 +1594,7 @@ class SpconvOps(pccm.Class):
@pccm.static_function
def get_indice_pairs_implicit_gemm(self):
code = pccm.code()
code.add_dependency(HashCoreHost)
code.arg("allocator", "ExternalAllocator&")
code.arg("indices", "tv::Tensor")
code.arg("batch_size", "int")
......@@ -1452,12 +1605,18 @@ class SpconvOps(pccm.Class):
code.arg("stream_int", f"std::uintptr_t", "0", pyanno="int")
code.arg("num_out_act_bound", f"int", "-1")
code.arg("timer", "tv::CUDAKernelTimer", "tv::CUDAKernelTimer(false)",
"cumm.tensorview.CUDAKernelTimer = CUDAKernelTimer(False)")
code.arg("direct_table", f"bool", "false")
code.arg("preallocated", f"std::unordered_map<std::string, tv::Tensor>",
"std::unordered_map<std::string, tv::Tensor>{}",
"Dict[str, cumm.tensorview.Tensor] = {}")
if CUMM_CPU_ONLY_BUILD:
code.raw(f"""
throw std::runtime_error("this function can only be used with CUDA.")
""")
return code.ret("tv::Tensor")
return code.ret("std::tuple<tv::Tensor, int>")
code.raw(f"""
auto tvctx = tv::Context();
tvctx.set_cuda_stream(reinterpret_cast<cudaStream_t>(stream_int));
......@@ -1479,20 +1638,24 @@ class SpconvOps(pccm.Class):
TV_THROW_RT_ERR("your out spatial shape", out_shape, "ratch zero!, input shape:", input_dims);
}}
}}
std::vector<int64_t> output_dims_i64(out_shape.begin(), out_shape.end());
int64_t out_spatial_volume = std::accumulate(output_dims_i64.begin(),
output_dims_i64.end(), int64_t(1), std::multiplies<int64_t>());
bool use_int64_hash_k = out_spatial_volume >= int64_t(std::numeric_limits<int>::max());
tv::DType indice_uniq_dtype = use_int64_hash_k ? tv::int64 : tv::int32;
TV_ASSERT_RT_ERR(conv_algo == tv::gemm::SparseConvAlgo::kMaskImplicitGemm ||
conv_algo == tv::gemm::SparseConvAlgo::kMaskSplitImplicitGemm, "only support implicit gemm");
bool is_mask_split = conv_algo == tv::gemm::SparseConvAlgo::kMaskSplitImplicitGemm;
int mask_split_count = is_mask_split ? 2 : 1;
tv::Tensor pair;
int64_t num_act_in = indices.dim(0);
""")
code.raw(f"""
tv::Tensor pair;
if (subm){{
if (preallocated.find({pccm.literal(AllocKeys.PairFwd)}) != preallocated.end()){{
pair = preallocated.at({pccm.literal(AllocKeys.PairFwd)});
}}
else{{
if (is_train){{
// query pair for fwd and bwd
pair = allocator.full_int({pccm.literal(AllocKeys.PairFwd)},
......@@ -1502,6 +1665,7 @@ class SpconvOps(pccm.Class):
pair = allocator.full_int({pccm.literal(AllocKeys.PairFwd)},
{{1, kv, num_act_in}}, -1, indices.dtype(), indices.device(), stream_int);
}}
}}
}}else{{
if (is_train){{
// query pair bwd
......@@ -1512,9 +1676,17 @@ class SpconvOps(pccm.Class):
pair = tv::Tensor();
}}
}}
""")
auto indice_num_per_loc = allocator.zeros({pccm.literal(AllocKeys.IndiceNumPerLoc)},
code.raw(f"""
tv::Tensor indice_num_per_loc;
if (preallocated.find({pccm.literal(AllocKeys.IndiceNumPerLoc)}) != preallocated.end()){{
indice_num_per_loc = preallocated.at({pccm.literal(AllocKeys.IndiceNumPerLoc)});
}}
else{{
indice_num_per_loc = allocator.zeros({pccm.literal(AllocKeys.IndiceNumPerLoc)},
{{kv}}, indices.dtype(), indices.device(), stream_int);
}}
tv::Tensor mask_tensor = tv::zeros({{mask_split_count}}, tv::uint32, -1);
auto mask_tensor_ptr = mask_tensor.data_ptr<uint32_t>();
......@@ -1533,29 +1705,45 @@ class SpconvOps(pccm.Class):
tv::Tensor out_inds;
ThrustAllocator thrustalloc(allocator);
int num_act_out = 0;
if (subm){{
""")
with code.if_("subm"):
code.raw(f"""
ExternalAllocator::guard_t hash_k_guard, hash_v_gurad, hash_kv_gurad;
out_inds = indices;
num_act_out = indices.dim(0);
int num_points = out_inds.dim(0);
int hash_size = out_inds.dim(0) * 2;
""")
code.raw(f"""
tv::Tensor hash_k, hash_v;
if (use_int64_hash_k){{
hash_k_guard = allocator.empty_guard({{num_points * 2}},
hash_k_guard = allocator.empty_guard({{hash_size}},
tv::int64, 0, {pccm.literal(AllocKeys.HashKOrKV)});
hash_v_gurad = allocator.empty_guard({{num_points * 2}},
hash_v_gurad = allocator.empty_guard({{hash_size}},
tv::int32, 0, {pccm.literal(AllocKeys.HashV)});
hash_k = hash_k_guard->tensor;
hash_v = hash_v_gurad->tensor;
}}else{{
hash_kv_gurad = allocator.empty_guard({{2, num_points * 2}},
if (preallocated.find({pccm.literal(AllocKeys.HashKOrKV)}) != preallocated.end()){{
auto hash_kv = preallocated.at({pccm.literal(AllocKeys.HashKOrKV)});
hash_k = hash_kv[0];
hash_v = hash_kv[1];
}}else{{
hash_kv_gurad = allocator.empty_guard({{2, hash_size}},
tv::int32, 0, {pccm.literal(AllocKeys.HashKOrKV)});
hash_k = hash_kv_gurad->tensor[0];
hash_v = hash_kv_gurad->tensor[1];
}}
auto pair_mask = allocator.empty({pccm.literal(AllocKeys.PairMask)},
}}
""")
code.raw(f"""
tv::Tensor pair_mask;
if (preallocated.find({pccm.literal(AllocKeys.PairMask)}) != preallocated.end()){{
pair_mask = preallocated.at({pccm.literal(AllocKeys.PairMask)});
}}else{{
pair_mask = allocator.empty({pccm.literal(AllocKeys.PairMask)},
{{mask_split_count, num_act_in}}, tv::uint32, 0, stream_int);
}}
generate_subm_conv_inds(indices, hash_k, hash_v, pair, out_inds, indice_num_per_loc,
batch_size, input_dims, ksize, dilation, pair_mask, is_train, stream_int);
auto mask_argsort = allocator.empty({pccm.literal(AllocKeys.MaskArgSort)},
......@@ -1563,64 +1751,135 @@ class SpconvOps(pccm.Class):
for (int j = 0; j < mask_split_count; ++j){{
sort_1d_by_key_allocator_v2(pair_mask[j], thrustalloc, mask_argsort[j], stream_int);
}}
}}else{{
""")
with code.else_():
code.raw(f"""
// auto start = tv::CPUEvent().record(stream_int);
auto pair_bwd = pair;
auto pair_size = kv * num_act_in;
auto indice_pairs_uniq_guard = allocator.empty_guard({{int64_t(pair_size + 1)}},
ExternalAllocator::guard_t hash_k_guard, hash_v_gurad, hash_kv_gurad;
ExternalAllocator::guard_t indice_pairs_uniq_guard, indice_pairs_uniq_bkp_guard;
tv::Tensor hash_k, hash_v, indice_pairs_uniq;
int max_num_act = get_handcrafted_max_act_out(num_act_in, ksize, stride, padding, dilation);
if (transposed){{
max_num_act = pair_size;
}}
int hash_size = int(max_num_act * {SPCONV_DIRECT_TABLE_HASH_SIZE_SCALE});
if (direct_table){{
if (use_int64_hash_k){{
// temp memory don't need to be fixed, static alloc will check
// that tensor is large enough.
hash_k_guard = allocator.empty_guard({{hash_size}},
tv::int64, 0, {pccm.literal(AllocKeys.HashKOrKV)});
hash_v_gurad = allocator.empty_guard({{hash_size}},
tv::int32, 0, {pccm.literal(AllocKeys.HashV)});
hash_k = hash_k_guard->tensor;
hash_v = hash_v_gurad->tensor;
}}else{{
hash_kv_gurad = allocator.empty_guard({{2, hash_size}},
tv::int32, 0, {pccm.literal(AllocKeys.HashKOrKV)});
hash_k = hash_kv_gurad->tensor[0];
hash_v = hash_kv_gurad->tensor[1];
}}
}}
indice_pairs_uniq_guard = allocator.empty_guard({{2, int64_t(pair_size + 1)}},
indice_uniq_dtype, 0, {pccm.literal(AllocKeys.IndicePairsUniq)});
auto indice_pairs_uniq_bkp_guard = allocator.empty_guard({{int64_t(pair_size + 1)}},
indice_uniq_dtype, 0, {pccm.literal(AllocKeys.IndicePairsUniqBackup)});
auto indice_pairs_uniq = indice_pairs_uniq_guard->tensor;
indice_pairs_uniq = indice_pairs_uniq_guard->tensor[0];
auto indice_pairs_uniq_bkp = indice_pairs_uniq_guard->tensor[1];
// indice_pairs_uniq_bkp_guard = allocator.empty_guard({{int64_t(pair_size + 1)}},
// indice_uniq_dtype, 0, {pccm.literal(AllocKeys.IndicePairsUniqBackup)});
{{
tv::CUDAKernelTimerGuard timer_guard("gen_conv_inds_stage1",
timer, reinterpret_cast<cudaStream_t>(stream_int));
if (direct_table){{
generate_conv_inds_mask_stage1_direct_table(indices,
hash_k, hash_v, pair_bwd, indice_pairs_uniq_bkp,
indice_num_per_loc, batch_size, out_shape, input_dims, ksize,
stride, padding, dilation, transposed, stream_int);
}}else{{
generate_conv_inds_mask_stage1(indices, pair_bwd, indice_pairs_uniq,
indice_num_per_loc, batch_size, out_shape, input_dims, ksize,
stride, padding, dilation, transposed, stream_int);
indice_pairs_uniq_bkp_guard->tensor.copy_(indice_pairs_uniq, tvctx);
// TODO pytorch unique may be faster?
indice_pairs_uniq_bkp.copy_(indice_pairs_uniq, tvctx);
}}
}}
// TODO pytorch unique run faster.
{{
tv::CUDAKernelTimerGuard timer_guard(std::string("unique_") + std::to_string(indice_pairs_uniq.dim(0)),
timer, reinterpret_cast<cudaStream_t>(stream_int));
if (direct_table){{
auto uniqcnt = allocator.zeros_guard({{1}}, tv::int32, 0,
{pccm.literal(AllocKeys.TightUniqueCount)}, stream_int);
num_act_out = unique_hash(hash_k, hash_v, uniqcnt->tensor,
indice_pairs_uniq, num_out_act_bound, stream_int);
}}else{{
num_act_out = apply_thrust_unique_to_indice_pairs_uniq(indice_pairs_uniq, thrustalloc, stream_int);
}}
}}
// tv::ssprint("HASH SIZE", hash_size, num_act_out);
if (num_out_act_bound > 0 && num_act_out > num_out_act_bound){{
num_act_out = num_out_act_bound;
}}
indice_pairs_uniq = indice_pairs_uniq.slice_first_axis(0, num_act_out);
// for fixed size allocator, all memory alloc size must be fixed.
tv::Tensor pair_fwd, pair_mask_fwd, pair_mask_bwd;
{{
tv::CUDAKernelTimerGuard timer_guard("alloc_stage2",
timer, reinterpret_cast<cudaStream_t>(stream_int));
out_inds = allocator.empty({pccm.literal(AllocKeys.OutIndices)},
{{num_act_out, indices.dim(1)}}, indices.dtype(), 0, stream_int);
auto pair_fwd = allocator.full_int({pccm.literal(AllocKeys.PairFwd)},
pair_fwd = allocator.full_int({pccm.literal(AllocKeys.PairFwd)},
{{kv, num_act_out}}, -1, indices.dtype(), indices.device(), stream_int);
auto pair_mask_fwd = allocator.zeros({pccm.literal(AllocKeys.PairMask)},
pair_mask_fwd = allocator.zeros({pccm.literal(AllocKeys.PairMask)},
{{mask_split_count, num_act_out}}, tv::uint32, 0, stream_int);
auto pair_mask_bwd = tv::Tensor();
pair_mask_bwd = tv::Tensor();
if (is_train){{
pair_mask_bwd = allocator.zeros({pccm.literal(AllocKeys.PairMaskBwd)},
{{mask_split_count, indices.dim(0)}}, tv::uint32, 0, stream_int);
}}
ExternalAllocator::guard_t hash_k_guard, hash_v_gurad, hash_kv_gurad;
tv::Tensor hash_k, hash_v;
}}
if (!direct_table){{
int hash_size = int(num_act_out * 2);
if (use_int64_hash_k){{
// temp memory don't need to be fixed, static alloc will check
// that tensor is large enough.
hash_k_guard = allocator.empty_guard({{num_act_out * 2}},
hash_k_guard = allocator.empty_guard({{hash_size}},
tv::int64, 0, {pccm.literal(AllocKeys.HashKOrKV)});
hash_v_gurad = allocator.empty_guard({{num_act_out * 2}},
hash_v_gurad = allocator.empty_guard({{hash_size}},
tv::int32, 0, {pccm.literal(AllocKeys.HashV)});
hash_k = hash_k_guard->tensor;
hash_v = hash_v_gurad->tensor;
}}else{{
hash_kv_gurad = allocator.empty_guard({{2, num_act_out * 2}},
hash_kv_gurad = allocator.empty_guard({{2, hash_size}},
tv::int32, 0, {pccm.literal(AllocKeys.HashKOrKV)});
hash_k = hash_kv_gurad->tensor[0];
hash_v = hash_kv_gurad->tensor[1];
}}
}}
{{
tv::CUDAKernelTimerGuard timer_guard(std::string("gen_conv_inds_stage2_") + std::to_string(num_act_out),
timer, reinterpret_cast<cudaStream_t>(stream_int));
if (direct_table){{
assign_output_direct_hash(indice_pairs_uniq, out_inds,
batch_size, out_shape,
input_dims, ksize, stride, padding, dilation, stream_int);
generate_conv_inds_stage2_mask_direct_table(indices, hash_k, hash_v, pair_fwd, pair_bwd,
indice_pairs_uniq, indice_pairs_uniq_bkp,
out_inds, pair_mask_fwd, pair_mask_bwd, num_act_out,
batch_size, out_shape, input_dims, ksize, stride, padding, dilation,
transposed, stream_int);
}}else{{
generate_conv_inds_mask_stage2(indices, hash_k, hash_v, pair_fwd, pair_bwd,
indice_pairs_uniq, indice_pairs_uniq_bkp_guard->tensor,
indice_pairs_uniq, indice_pairs_uniq_bkp,
out_inds, pair_mask_fwd, pair_mask_bwd, num_act_out,
batch_size, out_shape, input_dims, ksize, stride, padding, dilation,
transposed, stream_int);
}}
}}
""")
code.raw(f"""
auto mask_argsort_fwd = allocator.empty({pccm.literal(AllocKeys.MaskArgSort)},
{{mask_split_count, num_act_out}}, tv::int32, 0, stream_int);
tv::Tensor mask_argsort_bwd = tv::Tensor();
......@@ -1628,7 +1887,9 @@ class SpconvOps(pccm.Class):
mask_argsort_bwd = allocator.zeros({pccm.literal(AllocKeys.MaskArgSortBwd)},
{{mask_split_count, num_act_in}}, tv::int32, 0, stream_int);
}}
{{
tv::CUDAKernelTimerGuard timer_guard("gen_conv_inds_sort",
timer, reinterpret_cast<cudaStream_t>(stream_int));
if (is_mask_split){{
for (int j = 0; j < mask_split_count; ++j){{
auto mask_tensor_sub = mask_tensor.slice_first_axis(j, j + 1);
......@@ -1653,8 +1914,9 @@ class SpconvOps(pccm.Class):
mask_argsort_bwd[0], stream_int);
}}
}}
}}
""")
code.raw(f"""
return std::make_tuple(mask_tensor, num_act_out);
""")
return code.ret("std::tuple<tv::Tensor, int>")
......
......@@ -73,7 +73,9 @@ class CudaCommonKernel(pccm.ParameterizedClass):
""")
return code
class ConvOutLocIter(pccm.ParameterizedClass):
def __init__(self, problem: ConvProblem):
super().__init__()
self.add_dependency(TensorView)
......@@ -264,6 +266,7 @@ class ConvOutLocIter(pccm.ParameterizedClass):
class SparseConvIndicesKernel(pccm.ParameterizedClass):
def __init__(self, problem: ConvProblem, dtype_indices: dtypes.DType):
super().__init__()
self.add_dependency(TensorView, TensorViewKernel, TensorViewHashKernel)
......@@ -278,7 +281,6 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
assert dtype_indices == dtypes.int32 or dtype_indices == dtypes.int64
@pccm.cuda.cuda_global_function
def calc_conv_indices_stage1(self):
code = pccm.FunctionCode()
......@@ -331,7 +333,8 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
code.arg("table", f"TTable") # [N, ndim + 1]
code.arg("indices_out", f"int*") # [N, ndim + 1]
code.arg("indice_pairs_for_uniq",
code.arg(
"indice_pairs_for_uniq",
f"const typename TTable::key_type*") # [2, kernelProd, MaxSize]
code.arg("layout_npq",
......@@ -349,12 +352,86 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
return code
@pccm.cuda.cuda_global_function
def calc_conv_indices_stage2(self):
def arange_hash_table_and_assign_out(self):
code = pccm.FunctionCode()
code.targ("TTable")
code.arg("table", f"TTable") # [N, ndim + 1]
code.arg("indices_out", f"int*") # [N, ndim + 1]
code.arg("count", f"int*") # [N, ndim + 1]
code.arg("limit", f"int") # [N, ndim + 1]
code.arg("layout_npq",
f"spinds::LayoutNPQ") # [2, kernelProd, MaxSize]
code.raw(f"""
auto key_ptr = table.key_ptr();
auto value_ptr = table.value_ptr();
for (auto i : tv::KernelLoopX<int>(table.size())) {{
auto output_coord_offset = key_ptr[i];
if (output_coord_offset != TTable::empty_key) {{
auto output_index = tv::cuda::atomicAggInc(count);
if (output_index < limit){{
value_ptr[i] = output_index;
layout_npq.inverse(output_coord_offset, indices_out + {self.ndim + 1} * output_index);
}}else{{
value_ptr[i] = -1;
}}
}}
}}
""")
return code
@pccm.cuda.cuda_global_function
def arange_hash_table(self):
code = pccm.FunctionCode()
code.targ("TTable")
code.arg("table", f"TTable") # [N, ndim + 1]
code.arg("out_indices_offset", f"typename TTable::key_type *") # [N, ndim + 1]
code.arg("count", f"int*") # [N, ndim + 1]
code.arg("limit", f"int") # [N, ndim + 1]
code.raw(f"""
auto key_ptr = table.key_ptr();
auto value_ptr = table.value_ptr();
for (auto i : tv::KernelLoopX<int>(table.size())) {{
auto output_coord_offset = key_ptr[i];
if (output_coord_offset != TTable::empty_key) {{
auto output_index = tv::cuda::atomicAggInc(count);
value_ptr[i] = output_index < limit ? output_index : -1;
out_indices_offset[output_index] = output_coord_offset;
}}
}}
""")
return code
@pccm.cuda.cuda_global_function
def assign_out_indices(self):
code = pccm.FunctionCode()
code.targ("T")
code.arg("indices_out", f"int*") # [N, ndim + 1]
code.arg("out_indices_offset", f"const T*") # [N, ndim + 1]
code.arg("layout_npq",
f"spinds::LayoutNPQ") # [2, kernelProd, MaxSize]
code.arg("size", f"int") # [N, ndim + 1]
code.raw(f"""
for (auto i : tv::KernelLoopX<int>(size)) {{
layout_npq.inverse(out_indices_offset[i], indices_out + {self.ndim + 1} * i);
}}
""")
return code
@pccm.cuda.cuda_global_function
def calc_conv_indices_stage2(self):
code = pccm.FunctionCode()
code.targ("TTable")
code.arg("table", f"TTable") # [N, ndim + 1]
code.arg("indice_pairs_uniq_before_sort", f"const typename TTable::key_type*") # [kernelProd, MaxSize]
code.arg("indice_pairs_uniq_before_sort",
f"const typename TTable::key_type*") # [kernelProd, MaxSize]
code.arg("indice_pairs_out_part", f"int*") # [kernelProd, MaxSize]
code.arg("num_indices_in", "int")
code.arg("indices_pair_size", "int")
......@@ -362,7 +439,6 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
int filter_offset = blockIdx.y;
auto indice_pairs_out_part_filter = indice_pairs_out_part + filter_offset * indices_pair_size;
auto indice_pairs_uniq_before_sort_filter = indice_pairs_uniq_before_sort + filter_offset * indices_pair_size;
for (int i : tv::KernelLoopX<int>(num_indices_in)) {{
{self.dtype_indices} output_coord_offset = indice_pairs_uniq_before_sort_filter[i];
if (output_coord_offset != std::numeric_limits<typename TTable::key_type>::max()){{
......@@ -386,8 +462,10 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
code.targ("TTable")
code.arg("table", f"TTable") # [N, ndim + 1]
code.arg("indice_pairs_uniq_before_sort", f"const typename TTable::key_type*") # [kernelProd, MaxSize]
code.arg("indice_pairs_in_part_temp", f"const int*") # [kernelProd, MaxSize]
code.arg("indice_pairs_uniq_before_sort",
f"const typename TTable::key_type*") # [kernelProd, MaxSize]
code.arg("indice_pairs_in_part_temp",
f"const int*") # [kernelProd, MaxSize]
code.arg("indice_pairs_in_part", f"int*") # [kernelProd, MaxSize]
code.arg("indice_pairs_out_part", f"int*") # [kernelProd, MaxSize]
code.arg("indice_num_per_loc", f"int*") # [kernelProd]
......@@ -448,13 +526,63 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
valid = loc_iter.query_npq(indices_in + input_index * {self.ndim + 1}, npq_offset);
}}
if (valid){{
int old_num = tv::cuda::atomicAggInc(indice_num_per_loc + filter_offset);
// int old_num = tv::cuda::atomicAggInc(indice_num_per_loc + filter_offset);
TIndiceUniq output_coord_offset = loc_iter.layout_npq(npq_offset);
// if (old_num < indices_pair_size){{
// indice_pairs[filter_offset_mul_indices_pair_size + old_num] = i;
// indice_pairs_bwd[filter_offset_mul_indices_pair_size + input_index] = output_coord_offset;
// indice_pairs_for_uniq[filter_offset_mul_indices_pair_size + old_num] = output_coord_offset;
indice_pairs_for_uniq[filter_offset_mul_indices_pair_size + input_index] = output_coord_offset;
// }}
}}
}}
""")
return code
@pccm.cuda.cuda_global_function
def calc_conv_indices_stage1_mask_direct_table(self):
code = pccm.FunctionCode()
code.targ("TIndiceUniq")
code.targ("TTable")
code.arg("table", f"TTable") # [N, ndim + 1]
code.arg("loc_iter", f"ConvLocIter") # [N, ndim + 1]
code.arg("indices_in", f"const int*") # [N, ndim + 1]
code.arg("indice_pairs_bwd",
f"{self.dtype_indices}*") # [kernelProd, MaxSize]
code.arg("indice_pairs_for_uniq",
f"TIndiceUniq*") # [kernelProd * MaxSize + 1]
code.arg("indice_num_per_loc", f"int*") # [kernelProd]
code.arg("num_indices_in", "int")
code.arg("RS", "int")
code.arg("transposed", "bool")
code.raw(f"""
int filter_offset = blockIdx.y;
loc_iter.set_filter_offset(filter_offset);
// int indices_pair_size_mul_RS = num_indices_in * RS;
int filter_offset_mul_indices_pair_size = filter_offset * num_indices_in;
for (int input_index : tv::KernelLoopX<int>(num_indices_in)) {{
tv::array<int, {self.ndim + 1}> npq_offset;
bool valid;
if (transposed){{
valid = loc_iter.query_nhw_out(indices_in + input_index * {self.ndim + 1}, npq_offset);
}}else{{
valid = loc_iter.query_npq(indices_in + input_index * {self.ndim + 1}, npq_offset);
}}
if (valid){{
// int old_num = tv::cuda::atomicAggInc(indice_num_per_loc + filter_offset);
TIndiceUniq output_coord_offset = loc_iter.layout_npq(npq_offset);
// if (old_num < indices_pair_size){{
// indice_pairs[filter_offset_mul_indices_pair_size + old_num] = i;
// indice_pairs_bwd[filter_offset_mul_indices_pair_size + input_index] = output_coord_offset;
// indice_pairs_for_uniq[filter_offset_mul_indices_pair_size + old_num] = output_coord_offset;
table.insert_key_only(output_coord_offset);
indice_pairs_for_uniq[filter_offset_mul_indices_pair_size + input_index] = output_coord_offset;
// }}
}}
......@@ -466,12 +594,15 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
def calc_conv_indices_stage2_mask(self):
code = pccm.FunctionCode()
code.targ("TTable")
code.nontype_targ("CheckValueValid", "bool")
code.arg("table", f"TTable") # [N, ndim + 1]
code.arg("indice_pairs_fwd",
f"int*") # [kernelProd, MaxSize], inp -> out
code.arg("indice_pairs_bwd",
f"int*") # [kernelProd, MaxSize], out -> inp
code.arg("indice_pairs_uniq_before_sort", f"const typename TTable::key_type*") # [kernelProd, MaxSize]
code.arg("indice_pairs_uniq_before_sort",
f"const typename TTable::key_type*") # [kernelProd, MaxSize]
code.arg("mask_fwd", f"uint32_t*") # [kernelProd]
code.arg("mask_bwd", f"uint32_t*") # [kernelProd]
......@@ -495,6 +626,8 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
auto table_offset = table.lookup_offset(output_coord_offset);
if (table_offset != -1){{
auto output_index = table.value_ptr()[table_offset];
bool valid = CheckValueValid ? output_index >= 0 : true;
if (valid){{
atomicOr(mask_fwd + output_index, filter_mask_fwd);
// atomicOr(mask_bwd + input_index, filter_mask_bwd);
indice_pairs_fwd_filter[output_index] = input_index;
......@@ -504,6 +637,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
}}
}}
}}
}}
""")
return code
......@@ -533,13 +667,15 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
def calc_conv_indices_stage2_inference_mask(self):
code = pccm.FunctionCode()
code.targ("TTable")
code.nontype_targ("CheckValueValid", "bool")
code.arg("table", f"TTable") # [N, ndim + 1]
code.arg("indice_pairs_fwd",
f"int*") # [kernelProd, MaxSize], inp -> out
code.arg("indice_pairs_bwd",
f"int*") # [kernelProd, MaxSize], out -> inp
code.arg("indice_pairs_uniq_before_sort", f"const typename TTable::key_type*") # [kernelProd, MaxSize]
code.arg("indice_pairs_uniq_before_sort",
f"const typename TTable::key_type*") # [kernelProd, MaxSize]
code.arg("mask_fwd", f"uint32_t*") # [kernelProd]
code.arg("num_indices_in", "int")
......@@ -559,11 +695,14 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
auto table_offset = table.lookup_offset(output_coord_offset);
if (table_offset != -1){{
auto output_index = table.value_ptr()[table_offset];
bool valid = CheckValueValid ? output_index >= 0 : true;
if (valid){{
atomicOr(mask_fwd + output_index, filter_mask_fwd);
indice_pairs_fwd_filter[output_index] = input_index;
}}
}}
}}
}}
""")
return code
......@@ -854,7 +993,9 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
def generate_conv_inds_stage2(self):
code = pccm.FunctionCode()
code.arg("indices, hashdata_k, hashdata_v", "tv::Tensor")
code.arg("indice_pairs, indice_pairs_uniq, indice_pairs_uniq_before_sort, out_inds", "tv::Tensor")
code.arg(
"indice_pairs, indice_pairs_uniq, indice_pairs_uniq_before_sort, out_inds",
"tv::Tensor")
code.arg("indice_num_per_loc", "tv::Tensor")
code.arg("num_out_act", "int")
......@@ -938,8 +1079,8 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
def generate_conv_inds_mask_stage1(self):
code = pccm.FunctionCode()
code.arg("indices", "tv::Tensor")
code.arg("indice_pairs_bwd, indice_pairs_uniq, indice_num_per_loc",
"tv::Tensor")
code.arg("indice_pairs_bwd, indice_pairs_uniq", "tv::Tensor")
code.arg("indice_num_per_loc", "tv::Tensor")
code.arg("batch_size", "int")
code.arg("output_dims, input_dims", f"tv::array<int, {self.ndim}>")
code.arg("ksize, stride, padding, dilation",
......@@ -982,8 +1123,67 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
""")
return code # .ret("int")
@pccm.cuda.static_function
def generate_conv_inds_stage2_mask(self):
def generate_conv_inds_mask_stage1_direct_table(self):
code = pccm.FunctionCode()
code.arg("indices, hashdata_k, hashdata_v", "tv::Tensor")
code.arg("indice_pairs_bwd, indice_pairs_uniq",
"tv::Tensor")
code.arg("indice_num_per_loc", "tv::Tensor")
code.arg("batch_size", "int")
code.arg("output_dims, input_dims", f"tv::array<int, {self.ndim}>")
code.arg("ksize, stride, padding, dilation",
f"tv::array<int, {self.ndim}>")
code.arg("transposed", f"bool", "false")
code.arg("stream_int", f"std::uintptr_t", "0")
code.raw(f"""
// TODO stream
// TODO handle num input == 0
int kv = ksize.op<tv::arrayops::prod>();
int num_act_in = indices.dim(0);
// indice_pairs_bwd: [kv, num_act_in] or empty
// indice_pairs_uniq: [kv * num_act_in + 1]
if (!indice_pairs_bwd.empty()){{
tv::check_shape(indice_pairs_bwd, {{kv, num_act_in}});
}}
tv::check_shape(indice_num_per_loc, {{kv}});
int64_t uniq_size = kv * num_act_in + 1;
TV_ASSERT_RT_ERR(indice_pairs_uniq.dim(0) == uniq_size, "error");
tv::cuda::Launch launcher_num_act_in(indices.dim(0), reinterpret_cast<cudaStream_t>(stream_int));
// tv::cuda::Launch launcher_num_act_in_2(indices.dim(0));
launcher_num_act_in.blocks.y = kv;
ConvProblem problem(batch_size, 1, 1, input_dims, output_dims, ksize, padding, stride, dilation);
ConvLocIter loc_iter(problem);
tv::cuda::Launch launcher_clean_uniq(uniq_size, reinterpret_cast<cudaStream_t>(stream_int));
tv::dispatch<int32_t, int64_t>(indice_pairs_uniq.dtype(), [&](auto I){{
using V = {self.dtype_indices};
using K = TV_DECLTYPE(I);
using table_t =
tv::hash::LinearHashTableSplit<K, V, tv::hash::Murmur3Hash<K>,
tv::hash::default_empty_key_v<K>, false>;
table_t table = table_t(hashdata_k.data_ptr<K>(), hashdata_v.data_ptr<V>(), hashdata_k.dim(0));
tv::hash::clear_map_split(table, reinterpret_cast<cudaStream_t>(stream_int));
using T = TV_DECLTYPE(I);
TV_ASSERT_RT_ERR(input_dims.op<tv::arrayops::prod>() < std::numeric_limits<T>::max(),
"kernel volume must smaller than max value of T");
launcher_clean_uniq(clean_indices_uniq<T>, indice_pairs_uniq.data_ptr<T>(), uniq_size);
launcher_num_act_in(calc_conv_indices_stage1_mask_direct_table<T, table_t>, table,
loc_iter, indices.data_ptr<const int>(),
indice_pairs_bwd.data_ptr<{self.dtype_indices}>(),
indice_pairs_uniq.data_ptr<T>(), indice_num_per_loc.data_ptr<int>(),
indices.dim(0),
kv, transposed);
}});
""")
return code
def generate_conv_inds_stage2_mask_template(self, is_direct_table: bool):
"""here indice_pairs_uniq may be bounded, some
points may be dropped.
"""
......@@ -1013,8 +1213,12 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
ctx.set_cuda_stream(custream);
int num_act_in = indices.dim(0);
int num_act_out = num_out_act;
""")
if not is_direct_table:
code.raw(f"""
TV_ASSERT_RT_ERR(hashdata_k.dtype() == indice_pairs_uniq.dtype(), "error");
""")
code.raw(f"""
TV_ASSERT_RT_ERR(hashdata_v.dtype() == tv::int32, "error");
// out_inds: [num_out_act, {self.ndim + 1}]
// auto timer = tv::CudaContextTimer<>();
......@@ -1030,11 +1234,17 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
ConvProblem problem(batch_size, 1, 1, input_dims, output_dims, ksize, padding, stride, dilation);
ConvLocIter loc_iter(problem);
tv::cuda::Launch lanucher_build_hash(num_out_act, custream);
// TODO handle invalid num_out_act
""")
if not is_direct_table:
code.raw(f"""
indice_pairs_uniq = indice_pairs_uniq.slice_first_axis(0, num_out_act);
tv::cuda::Launch lanucher_build_hash(num_out_act, custream);
tv::dispatch<int32_t, int64_t>(hashdata_k.dtype(), [&](auto I){{
""")
with code.block("", start="tv::dispatch<int32_t, int64_t>(hashdata_k.dtype(), [&](auto I){",
end="});"):
code.raw(f"""
using V = {self.dtype_indices};
using K = TV_DECLTYPE(I);
using table_t =
......@@ -1042,13 +1252,18 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
tv::hash::default_empty_key_v<K>, false>;
TV_ASSERT_RT_ERR(hashdata_k.dim(0) >= num_out_act, "hash size not enough");
table_t hash = table_t(hashdata_k.data_ptr<K>(), hashdata_v.data_ptr<V>(), hashdata_k.dim(0));
""")
if not is_direct_table:
# direct table built in stage 1.
code.raw(f"""
tv::hash::clear_map_split(hash, custream);
lanucher_build_hash(build_conv_hash_table<table_t>, hash,
out_inds.data_ptr<int>(), indice_pairs_uniq.data_ptr<const K>(),
loc_iter.layout_npq, num_out_act);
""")
code.raw(f"""
if (!mask_bwd.empty()){{
launcher_num_act_in(calc_conv_indices_stage2_mask<table_t>, hash,
launcher_num_act_in(calc_conv_indices_stage2_mask<table_t, {pccm.literal(is_direct_table)}>, hash,
indice_pairs_fwd.data_ptr<int>(), indice_pairs_bwd.data_ptr<int>(),
indice_pairs_uniq_before_sort.data_ptr<K>(),
mask_fwd.data_ptr<uint32_t>(), mask_bwd.data_ptr<uint32_t>(),
......@@ -1064,7 +1279,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
mask_bwd[1].copy_(mask_bwd[0], ctx);
}}
}}else{{
launcher_num_act_in(calc_conv_indices_stage2_inference_mask<table_t>, hash,
launcher_num_act_in(calc_conv_indices_stage2_inference_mask<table_t, {pccm.literal(is_direct_table)}>, hash,
indice_pairs_fwd.data_ptr<int>(), indice_pairs_bwd.data_ptr<int>(),
indice_pairs_uniq_before_sort.data_ptr<K>(),
mask_fwd.data_ptr<uint32_t>(),
......@@ -1073,11 +1288,130 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
mask_fwd[1].copy_(mask_fwd[0], ctx);
}}
}}
}});
""")
code.raw(f"""
return num_out_act;
""")
return code.ret("int")
@pccm.cuda.static_function
def generate_conv_inds_stage2_mask(self):
"""here indice_pairs_uniq may be bounded, some
points may be dropped.
"""
return self.generate_conv_inds_stage2_mask_template(False)
@pccm.cuda.static_function
def generate_conv_inds_stage2_mask_direct_table(self):
"""here indice_pairs_uniq may be bounded, some
points may be dropped.
"""
return self.generate_conv_inds_stage2_mask_template(True)
@pccm.cuda.static_function
def unique_and_assign_output_direct_hash(self):
"""unique by hash
"""
code = pccm.FunctionCode()
code.arg("hashdata_k, hashdata_v, uniq_cnt", "tv::Tensor")
code.arg(
"out_inds",
"tv::Tensor")
code.arg("num_out_bound", "int")
code.arg("batch_size", "int")
code.arg("output_dims, input_dims", f"tv::array<int, {self.ndim}>")
code.arg("ksize, stride, padding, dilation",
f"tv::array<int, {self.ndim}>")
code.arg("stream_int", f"std::uintptr_t", "0")
code.raw(f"""
auto custream = reinterpret_cast<cudaStream_t>(stream_int);
tv::cuda::Launch lanucher_build_hash(hashdata_k.size(), custream);
ConvProblem problem(batch_size, 1, 1, input_dims, output_dims, ksize, padding, stride, dilation);
ConvLocIter loc_iter(problem);
auto tvctx = tv::Context();
tvctx.set_cuda_stream(reinterpret_cast<cudaStream_t>(stream_int));
if (num_out_bound <= 0){{
num_out_bound = hashdata_k.size();
}}
tv::dispatch<int32_t, int64_t>(hashdata_k.dtype(), [&](auto I){{
using V = {self.dtype_indices};
using K = TV_DECLTYPE(I);
using table_t =
tv::hash::LinearHashTableSplit<K, V, tv::hash::Murmur3Hash<K>,
tv::hash::default_empty_key_v<K>, false>;
table_t table = table_t(hashdata_k.data_ptr<K>(), hashdata_v.data_ptr<V>(), hashdata_k.dim(0));
lanucher_build_hash(arange_hash_table_and_assign_out<table_t>, table,
out_inds.data_ptr<int>(), uniq_cnt.data_ptr<int>(), num_out_bound,
loc_iter.layout_npq);
}});
auto uniq_cnt_cpu = uniq_cnt.cpu(tvctx);
return std::min(uniq_cnt_cpu.data_ptr<int>()[0], num_out_bound);
""")
return code.ret("int")
@pccm.cuda.static_function
def unique_hash(self):
"""unique by hash
"""
code = pccm.FunctionCode()
code.arg("hashdata_k, hashdata_v, uniq_cnt, out_indices_offset", "tv::Tensor")
code.arg("num_out_bound", "int")
code.arg("stream_int", f"std::uintptr_t", "0")
code.raw(f"""
auto custream = reinterpret_cast<cudaStream_t>(stream_int);
tv::cuda::Launch lanucher_build_hash(hashdata_k.size(), custream);
auto tvctx = tv::Context();
tvctx.set_cuda_stream(reinterpret_cast<cudaStream_t>(stream_int));
if (num_out_bound <= 0){{
num_out_bound = out_indices_offset.dim(0);
}}
tv::dispatch<int32_t, int64_t>(hashdata_k.dtype(), [&](auto I){{
using V = {self.dtype_indices};
using K = TV_DECLTYPE(I);
using table_t =
tv::hash::LinearHashTableSplit<K, V, tv::hash::Murmur3Hash<K>,
tv::hash::default_empty_key_v<K>, false>;
table_t table = table_t(hashdata_k.data_ptr<K>(), hashdata_v.data_ptr<V>(), hashdata_k.dim(0));
lanucher_build_hash(arange_hash_table<table_t>, table,
out_indices_offset.data_ptr<K>(),
uniq_cnt.data_ptr<int>(), num_out_bound);
}});
auto uniq_cnt_cpu = uniq_cnt.cpu(tvctx);
return std::min(uniq_cnt_cpu.data_ptr<int>()[0], num_out_bound);
""")
return code.ret("int")
@pccm.cuda.static_function
def assign_output_direct_hash(self):
"""unique by hash
"""
code = pccm.FunctionCode()
code.arg("out_indices_offset", "tv::Tensor")
code.arg("out_inds", "tv::Tensor")
code.arg("batch_size", "int")
code.arg("output_dims, input_dims", f"tv::array<int, {self.ndim}>")
code.arg("ksize, stride, padding, dilation",
f"tv::array<int, {self.ndim}>")
code.arg("stream_int", f"std::uintptr_t", "0")
code.raw(f"""
auto custream = reinterpret_cast<cudaStream_t>(stream_int);
tv::cuda::Launch lanucher_build_hash(out_inds.dim(0), custream);
TV_ASSERT_RT_ERR(out_indices_offset.dim(0) >= out_inds.dim(0), "error");
ConvProblem problem(batch_size, 1, 1, input_dims, output_dims, ksize, padding, stride, dilation);
ConvLocIter loc_iter(problem);
auto tvctx = tv::Context();
tvctx.set_cuda_stream(reinterpret_cast<cudaStream_t>(stream_int));
tv::dispatch<int32_t, int64_t>(out_indices_offset.dtype(), [&](auto I){{
using K = TV_DECLTYPE(I);
lanucher_build_hash(assign_out_indices<K>, out_inds.data_ptr<int>(),
out_indices_offset.data_ptr<const K>(),
loc_iter.layout_npq, out_inds.dim(0));
}});
""")
return code
@pccm.cuda.static_function
def generate_subm_conv_inds(self):
code = pccm.FunctionCode()
......@@ -1175,6 +1509,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
class SparseConvIndicesCPU(pccm.ParameterizedClass):
def __init__(self, problem: ConvProblem, dtype_indices: dtypes.DType):
super().__init__()
self.add_dependency(TensorView)
......
......@@ -33,13 +33,21 @@ _TORCH_DTYPE_TO_TV = {
torch.int16: tv.int16,
torch.uint8: tv.uint8,
}
_TV_DTYPE_TO_TORCH = {v: k for k, v in _TORCH_DTYPE_TO_TV.items()}
_TORCH_UINT_WORKAROUNDS = {
tv.uint32: tv.int32,
tv.uint16: tv.int16,
tv.uint64: tv.int64
}
_TV_DTYPE_TO_TORCH = {v: k for k, v in _TORCH_DTYPE_TO_TV.items()}
_TV_DTYPE_TO_TORCH.update({
tv.uint32: torch.int32,
tv.uint16: torch.int16,
tv.uint64: torch.int64
})
_ALL_INTS = {
tv.int32, tv.int16, tv.int8, tv.int64, tv.uint64, tv.uint8, tv.uint32,
tv.uint16
......@@ -106,91 +114,66 @@ class TorchAllocator(ExternalAllocator):
device: int, stream: int = 0, is_temp_memory: bool = False) -> tv.Tensor:
# TODO free memory by name if its already free by pointer.
# provide a name if you want to access it after c++ function exit.
torch_uint_workaround = dtype in _TORCH_UINT_WORKAROUNDS
dtype_bkp = dtype
if dtype in _TORCH_UINT_WORKAROUNDS:
# assert name == "", "must be temp memory for uint dtypes"
dtype = _TORCH_UINT_WORKAROUNDS[dtype]
th_dtype = _TV_DTYPE_TO_TORCH[dtype]
if device == -1:
dev = self.cpudevice
else:
dev = self.gpudevice
ten = torch.zeros(shape, dtype=th_dtype, device=dev)
ten_tv = torch_tensor_to_tv(ten)
self.allocated[ten.data_ptr()] = ten
ten_tv = torch_tensor_to_tv(ten, dtype_bkp)
self.allocated[ten_tv.byte_pointer()] = ten
if name and not is_temp_memory:
self.allocated[name] = ten
if torch_uint_workaround:
return ten_tv.type_view(dtype_bkp)
return ten_tv
def empty(self, name: str, shape: List[int], dtype: int,
device: int, stream: int = 0, is_temp_memory: bool = False) -> tv.Tensor:
torch_uint_workaround = dtype in _TORCH_UINT_WORKAROUNDS
dtype_bkp = dtype
if dtype in _TORCH_UINT_WORKAROUNDS:
# assert name == "", "must be temp memory for uint dtypes"
dtype = _TORCH_UINT_WORKAROUNDS[dtype]
th_dtype = _TV_DTYPE_TO_TORCH[dtype]
if device == -1:
dev = self.cpudevice
else:
dev = self.gpudevice
ten = torch.empty(shape, dtype=th_dtype, device=dev)
ten_tv = torch_tensor_to_tv(ten)
self.allocated[ten.data_ptr()] = ten
ten_tv = torch_tensor_to_tv(ten, dtype_bkp)
self.allocated[ten_tv.byte_pointer()] = ten
if name and not is_temp_memory:
self.allocated[name] = ten
if torch_uint_workaround:
return ten_tv.type_view(dtype_bkp)
return ten_tv
def full_int(self, name: str, shape: List[int], value: int, dtype: int,
device: int, stream: int = 0, is_temp_memory: bool = False) -> tv.Tensor:
if dtype in _TORCH_UINT_WORKAROUNDS and value < 0:
raise NotImplementedError("you can't use full for unsigned dtypes")
torch_uint_workaround = dtype in _TORCH_UINT_WORKAROUNDS
dtype_bkp = dtype
if dtype in _TORCH_UINT_WORKAROUNDS:
assert name == "", "must be temp memory for uint dtypes"
dtype = _TORCH_UINT_WORKAROUNDS[dtype]
th_dtype = _TV_DTYPE_TO_TORCH[dtype]
if device == -1:
dev = self.cpudevice
else:
dev = self.gpudevice
ten = torch.full(shape, value, dtype=th_dtype, device=dev)
ten_tv = torch_tensor_to_tv(ten)
self.allocated[ten.data_ptr()] = ten
ten_tv = torch_tensor_to_tv(ten, dtype_bkp)
self.allocated[ten_tv.byte_pointer()] = ten
if name and not is_temp_memory:
self.allocated[name] = ten
if torch_uint_workaround:
return ten_tv.type_view(dtype_bkp)
return ten_tv
def full_float(self, name: str, shape: List[int], value: float, dtype: int,
device: int, stream: int = 0, is_temp_memory: bool = False) -> tv.Tensor:
if dtype in _TORCH_UINT_WORKAROUNDS and value < 0:
raise NotImplementedError("you can't use full for unsigned dtypes")
torch_uint_workaround = dtype in _TORCH_UINT_WORKAROUNDS
dtype_bkp = dtype
if dtype in _TORCH_UINT_WORKAROUNDS:
assert name == "", "must be temp memory for uint dtypes"
dtype = _TORCH_UINT_WORKAROUNDS[dtype]
th_dtype = _TV_DTYPE_TO_TORCH[dtype]
if device == -1:
dev = self.cpudevice
else:
dev = self.gpudevice
ten = torch.full(shape, value, dtype=th_dtype, device=dev)
ten_tv = torch_tensor_to_tv(ten)
self.allocated[ten.data_ptr()] = ten
ten_tv = torch_tensor_to_tv(ten, dtype_bkp)
self.allocated[ten_tv.byte_pointer()] = ten
if name and not is_temp_memory:
self.allocated[name] = ten
if torch_uint_workaround:
return ten_tv.type_view(dtype_bkp)
return ten_tv
def get_tensor_by_name(self, name: str):
......
......@@ -26,7 +26,7 @@ from spconv.pytorch.core import ThrustSortAllocator
from spconv.pytorch.cppcore import TorchAllocator, torch_tensor_to_tv, get_current_stream, get_arch, TorchSpconvMatmul
from spconv.core_cc.csrc.sparse.all import SpconvOps
from spconv.core_cc.csrc.sparse.alloc import ExternalAllocator
from spconv.constants import SPCONV_CPP_INDICE_PAIRS, SPCONV_CPP_INDICE_PAIRS_IGEMM, SPCONV_CPP_GEMM
from spconv.constants import SPCONV_CPP_INDICE_PAIRS, SPCONV_CPP_INDICE_PAIRS_IGEMM, SPCONV_CPP_GEMM, SPCONV_DIRECT_TABLE_HASH_SIZE_SCALE
import spconv.core_cc as _ext
from spconv.core_cc.csrc.sparse.convops.spops import ConvGemmOps
from spconv.utils import nullcontext
......@@ -46,7 +46,7 @@ from cumm.gemm import codeops
from spconv.tools import CUDAKernelTimer
DEBUG = False
DEBUG_INT64_HASH_K = True
DEBUG_INT64_HASH_K = False
INT32_MAX = SpconvOps.get_int32_max()
......@@ -77,12 +77,17 @@ def get_deconv_output_size(input_size, kernel_size, stride, padding, dilation,
class _HashData:
def __init__(self, num: int, use_i64: bool, device: torch.device) -> None:
def __init__(self,
num: int,
use_i64: bool,
device: torch.device,
rate: float = 2.0) -> None:
if use_i64:
self.hashdata_k = torch.empty((num * 2, ),
self.hashdata_k = torch.empty((int(num * rate), ),
dtype=torch.int64,
device=device)
self.hashdata_v = torch.empty((num * 2, ),
self.hashdata_v = torch.empty((int(num * rate), ),
dtype=torch.int32,
device=device)
self.hashdata_k_tv = torch_tensor_to_tv(self.hashdata_k)
......@@ -91,7 +96,7 @@ class _HashData:
else:
self.hashdata = torch.empty((
2,
num * 2,
int(num * rate),
),
dtype=torch.int32,
device=device)
......@@ -309,7 +314,8 @@ def get_indice_pairs_implicit_gemm(
is_train: bool = True,
alloc: Optional[ThrustSortAllocator] = None,
timer: CUDAKernelTimer = CUDAKernelTimer(False),
num_out_act_bound: int = -1):
num_out_act_bound: int = -1,
direct_table: bool = True):
"""
Why return tuple? because pytorch seems don't support custom object in autograd.
return: (
......@@ -323,14 +329,33 @@ def get_indice_pairs_implicit_gemm(
mask_argsort_bwd_splits, # torch.Tensor() if subm or inference mode
masks,
)
direct_table: a hash-based regular conv pair gen algo to avoid unique operation.
runs faster than pytorch unique with num_voxel < 1000k.
"""
stream = get_current_stream()
if SPCONV_CPP_INDICE_PAIRS_IGEMM:
thalloc = TorchAllocator(indices.device)
timer_cpp = tv.CUDAKernelTimer(False)
if timer._timer is not None:
timer_cpp = timer._timer
mask_tensor, num_act_out = SpconvOps.get_indice_pairs_implicit_gemm(
thalloc, torch_tensor_to_tv(indices), batch_size, spatial_shape,
algo.value, ksize, stride, padding, dilation, out_padding, subm,
transpose, is_train, stream, num_out_act_bound)
thalloc,
torch_tensor_to_tv(indices),
batch_size,
spatial_shape,
algo.value,
ksize,
stride,
padding,
dilation,
out_padding,
subm,
transpose,
is_train,
stream,
num_out_act_bound,
timer=timer_cpp,
direct_table=direct_table)
mask_split_count = mask_tensor.dim(0)
masks = [mask_tensor[i:i + 1].numpy() for i in range(mask_split_count)]
if subm:
......@@ -342,7 +367,6 @@ def get_indice_pairs_implicit_gemm(
# for subm, if training, pair shape is [2, kv, ...]
# if not training, pair is [1, kv, ...]
pair = thalloc.allocated[AllocKeys.PairFwd]
pair_mask = thalloc.allocated[AllocKeys.PairMask]
mask_argsort = thalloc.allocated[AllocKeys.MaskArgSort]
pair_mask_in_splits = [
......@@ -367,7 +391,6 @@ def get_indice_pairs_implicit_gemm(
if is_train:
pair_mask_bwd = thalloc.allocated[AllocKeys.PairMaskBwd]
mask_argsort_bwd = thalloc.allocated[AllocKeys.MaskArgSortBwd]
mask_argsort_fwd = thalloc.allocated[AllocKeys.MaskArgSort]
if not is_train:
pair_mask_bwd_splits: List[torch.Tensor] = []
......@@ -388,11 +411,6 @@ def get_indice_pairs_implicit_gemm(
return (out_inds, indice_num_per_loc, pair_fwd, pair_bwd,
pair_mask_fwd_splits, pair_mask_bwd_splits,
mask_argsort_fwd_splits, mask_argsort_bwd_splits, masks)
t = 0
if DEBUG:
CONV.stream_synchronize(stream)
t = time.time()
assert indices.is_cuda, "implicit gemm only support cuda"
ndim = indices.shape[1] - 1
kv: int = functools.reduce(lambda x, y: x * y, ksize, 1)
......@@ -452,8 +470,6 @@ def get_indice_pairs_implicit_gemm(
masks = [first.astype(np.uint32), second.astype(np.uint32)]
else:
masks = [np.array([0xffffffff], dtype=np.uint32)]
# torch.cuda.synchronize()
# print("SUBM0", time.time() - t)
if subm:
out_inds = indices
......@@ -508,10 +524,6 @@ def get_indice_pairs_implicit_gemm(
mask_argsort_in_splits = [
mask_argsort[i] for i in range(mask_split_count)
]
if DEBUG:
CONV.stream_synchronize(stream)
print("SUBM", time.time() - t)
if is_train:
return (out_inds, indice_num_per_loc, pair[0], pair[1],
pair_mask_in_splits, [], mask_argsort_in_splits, [], masks)
......@@ -519,11 +531,10 @@ def get_indice_pairs_implicit_gemm(
return (out_inds, indice_num_per_loc, pair[0], torch.Tensor(),
pair_mask_in_splits, [], mask_argsort_in_splits, [], masks)
else:
if DEBUG:
CONV.stream_synchronize(stream)
print("REGU_PREPARE", time.time() - t)
t = time.time()
max_num_act = SpconvOps.get_handcrafted_max_act_out(
indices.shape[0], ksize, stride, padding, dilation)
if transpose:
max_num_act = kv * indices.shape[0]
pair_bwd = pair
pair_bwd_tv = pair_tv
......@@ -531,8 +542,38 @@ def get_indice_pairs_implicit_gemm(
dtype=indice_dtype,
device=indices.device)
indice_pairs_uniq_tv = torch_tensor_to_tv(indice_pairs_uniq)
hashdata = _HashData(0, use_int64_hash_k, indices.device)
indice_pairs_uniq_bkp_tv = tv.Tensor()
if direct_table:
# print("HASH SIZE", max_num_act * 2)
hashdata = _HashData(max_num_act, use_int64_hash_k, indices.device,
SPCONV_DIRECT_TABLE_HASH_SIZE_SCALE)
indice_pairs_uniq_bkp = torch.empty((pair.numel() + 1, ),
dtype=indice_dtype,
device=indices.device)
indice_pairs_uniq_bkp_tv = torch_tensor_to_tv(
indice_pairs_uniq_bkp)
with timer.record("gen_conv_inds_stage1", stream):
SpconvOps.generate_conv_inds_mask_stage1(inds_tv,
SpconvOps.generate_conv_inds_mask_stage1_direct_table(
inds_tv,
hashdata.hashdata_k_tv,
hashdata.hashdata_v_tv,
pair_bwd_tv,
indice_pairs_uniq_bkp_tv,
indice_num_per_loc_tv,
batch_size=batch_size,
output_dims=out_shape,
input_dims=spatial_shape,
ksize=ksize,
stride=stride,
padding=padding,
dilation=dilation,
transposed=transpose,
stream_int=stream)
else:
with timer.record("gen_conv_inds_stage1", stream):
SpconvOps.generate_conv_inds_mask_stage1(
inds_tv,
pair_bwd_tv,
indice_pairs_uniq_tv,
indice_num_per_loc_tv,
......@@ -545,23 +586,31 @@ def get_indice_pairs_implicit_gemm(
dilation=dilation,
transposed=transpose,
stream_int=stream)
if DEBUG:
CONV.stream_synchronize(stream)
print("REGU_S1", time.time() - t)
t = time.time()
uniq_out_indices_offset_tv = tv.Tensor()
with timer.record(f"unique_{indice_pairs_uniq.shape[0]}", stream):
if direct_table:
uniq_cnt = torch.zeros([1],
dtype=torch.int32,
device=indices.device)
uniq_cnt_tv = torch_tensor_to_tv(uniq_cnt)
num_act_out = SpconvOps.unique_hash(hashdata.hashdata_k_tv,
hashdata.hashdata_v_tv,
uniq_cnt_tv,
indice_pairs_uniq_tv,
num_out_act_bound, stream)
uniq_out_indices_offset_tv = indice_pairs_uniq_tv
raw_out_indices_offset_tv = indice_pairs_uniq_bkp_tv
else:
uniq_res = indice_pairs_uniq.unique()
num_act_out = uniq_res.shape[0] - 1
uniq_out_indices_offset_tv = torch_tensor_to_tv(uniq_res)
raw_out_indices_offset_tv = indice_pairs_uniq_tv
if num_out_act_bound > 0 and num_act_out > num_out_act_bound:
num_act_out = num_out_act_bound
if DEBUG:
CONV.stream_synchronize(stream)
print("REGU_UNIQ", time.time() - t)
t = time.time()
with timer.record(f"alloc_stage2", stream):
uniq_res_tv = torch_tensor_to_tv(uniq_res)
out_inds = torch.empty((num_act_out, indices.shape[1]),
dtype=indices.dtype,
device=indices.device)
......@@ -574,15 +623,18 @@ def get_indice_pairs_implicit_gemm(
dtype=torch.int32,
device=indices.device)
pair_fwd_tv = torch_tensor_to_tv(pair_fwd)
pair_mask_fwd_tv = torch_tensor_to_tv(pair_mask_fwd, dtype=tv.uint32)
pair_mask_fwd_tv = torch_tensor_to_tv(pair_mask_fwd,
dtype=tv.uint32)
pair_mask_bwd = torch.Tensor()
pair_mask_bwd_tv = tv.Tensor()
if is_train:
pair_mask_bwd = torch.zeros((mask_split_count, indices.shape[0]),
pair_mask_bwd = torch.zeros(
(mask_split_count, indices.shape[0]),
dtype=torch.int32,
device=indices.device)
pair_mask_bwd_tv = torch_tensor_to_tv(pair_mask_bwd,
dtype=tv.uint32)
if not direct_table:
hashdata = _HashData(out_inds.shape[0], use_int64_hash_k,
indices.device)
......@@ -591,19 +643,28 @@ def get_indice_pairs_implicit_gemm(
# device=indices.device)
out_inds_tv = torch_tensor_to_tv(out_inds)
# hashdata_tv = torch_tensor_to_tv(hashdata, dtype=tv.custom64)
if DEBUG:
with timer.record(f"gen_conv_inds_stage2_{num_act_out}", stream):
stage2_fn = SpconvOps.generate_conv_inds_mask_stage2
if direct_table:
SpconvOps.assign_output_direct_hash(indice_pairs_uniq_tv,
out_inds_tv,
batch_size=batch_size,
output_dims=out_shape,
input_dims=spatial_shape,
ksize=ksize,
stride=stride,
padding=padding,
dilation=dilation,
stream_int=stream)
stage2_fn = SpconvOps.generate_conv_inds_stage2_mask_direct_table
CONV.stream_synchronize(stream)
print("REGU_S2_PREPARE", time.time() - t)
t = time.time()
with timer.record("gen_conv_inds_stage2", stream):
SpconvOps.generate_conv_inds_mask_stage2(inds_tv,
stage2_fn(inds_tv,
hashdata.hashdata_k_tv,
hashdata.hashdata_v_tv,
pair_fwd_tv,
pair_bwd_tv,
uniq_res_tv,
indice_pairs_uniq_tv,
uniq_out_indices_offset_tv,
raw_out_indices_offset_tv,
out_inds_tv,
pair_mask_fwd_tv,
pair_mask_bwd_tv,
......@@ -617,12 +678,6 @@ def get_indice_pairs_implicit_gemm(
dilation=dilation,
transposed=transpose,
stream_int=stream)
if DEBUG:
CONV.stream_synchronize(stream)
print("REGU_S2", time.time() - t)
t = time.time()
mask_argsort_fwd = torch.empty((mask_split_count, out_inds.shape[0]),
dtype=torch.int32,
device=indices.device)
......@@ -693,10 +748,6 @@ def get_indice_pairs_implicit_gemm(
SpconvOps.sort_1d_by_key_allocator(
pair_mask_bwd_tv[0], alloc.alloc,
mask_argsort_bwd_tv[0], stream)
if DEBUG:
CONV.stream_synchronize(stream)
print("REGU_S2_FINISH", time.time() - t)
t = time.time()
# CONV.stream_synchronize(stream)
if not is_train:
......@@ -716,9 +767,6 @@ def get_indice_pairs_implicit_gemm(
mask_argsort_fwd_splits = [
mask_argsort_fwd[i] for i in range(mask_split_count)
]
if DEBUG:
CONV.stream_synchronize(stream)
print("REGU", time.time() - t)
return (out_inds, indice_num_per_loc, pair_fwd, pair_bwd,
pair_mask_fwd_splits, pair_mask_bwd_splits,
......@@ -769,8 +817,7 @@ def indice_conv(features: torch.Tensor,
stream = get_current_stream()
ConvGemmOps.indice_conv(alloc, ext_mm, GEMM_CPP, ALL_WEIGHT_IS_KRSC,
FILTER_HWIO, features_tv, filters_tv,
indice_pairs_tv, indice_pair_num_tv,
arch,
indice_pairs_tv, indice_pair_num_tv, arch,
num_activate_out, inverse, subm, algo.value,
stream)
out_features = alloc.allocated[AllocKeys.OutFeatures]
......@@ -1018,8 +1065,8 @@ def indice_conv_backward(features: torch.Tensor,
ALL_WEIGHT_IS_KRSC, FILTER_HWIO,
features_tv, filters_tv, out_bp_tv,
indice_pairs_tv, indice_pair_num_tv,
arch,
inverse, subm, algo.value, stream)
arch, inverse, subm, algo.value,
stream)
din = alloc.allocated[AllocKeys.DIn]
df = alloc.allocated[AllocKeys.DFilters]
return din, df
......@@ -1369,8 +1416,8 @@ def implicit_gemm(features: torch.Tensor,
mask_width = ConvGemmOps.implicit_gemm(
alloc, CONV_CPP, features_tv, filters_tv, pair_fwd_tv,
pair_mask_fwd_splits_tv, mask_argsort_fwd_splits_tv,
num_activate_out, mask_tv, arch, is_train, is_subm, stream, timer_cpp,
auto_fp32_accum, fp32_accum)
num_activate_out, mask_tv, arch, is_train, is_subm, stream,
timer_cpp, auto_fp32_accum, fp32_accum)
out_features = alloc.allocated[AllocKeys.OutFeatures]
mask_output_fwd = alloc.allocated.get(AllocKeys.MaskOutputFwd, None)
if is_train:
......@@ -1460,7 +1507,7 @@ def implicit_gemm(features: torch.Tensor,
# CONV.stream_synchronize(stream)
# t = time.time()
print(tune_res.algo_desp, "REF", features_tv.shape, filters.shape)
# print(tune_res.algo_desp, "REF", features_tv.shape, filters.shape)
# with tv.measure_and_print("f16 time"):
with timer.record("implicit_gemm", stream):
for j in range(num_split):
......@@ -1921,8 +1968,10 @@ def indice_maxpool_implicit_gemm_backward(features, out_features, out_bp,
indice_pairs_tv, stream)
return din
def indice_avgpool_implicit_gemm(features: torch.Tensor,
indice_pairs: torch.Tensor, num_activate_out, calc_count: bool):
indice_pairs: torch.Tensor, num_activate_out,
calc_count: bool):
# torch.cuda.synchronize()
# t = time.time()
stream = get_current_stream()
......@@ -1943,12 +1992,13 @@ def indice_avgpool_implicit_gemm(features: torch.Tensor,
count_out = torch.Tensor()
count_out_tv = tv.Tensor()
if calc_count:
count_out = torch.zeros((num_activate_out,),
count_out = torch.zeros((num_activate_out, ),
dtype=torch.int32,
device=features.device)
count_out_tv = torch_tensor_to_tv(count_out)
SpconvOps.avgpool_implicit_gemm_forward(out_features_tv, features_tv,
indice_pairs_tv, count_out_tv, stream)
indice_pairs_tv, count_out_tv,
stream)
# CONV.stream_synchronize(stream)
# print("M", time.time() - t)
......@@ -1956,12 +2006,13 @@ def indice_avgpool_implicit_gemm(features: torch.Tensor,
return out_features, count_out
def indice_avgpool_implicit_gemm_backward(out_bp,
indice_pairs, count_out):
def indice_avgpool_implicit_gemm_backward(out_bp, indice_pairs, count_out):
# torch.cuda.synchronize()
# t = time.time()
out_channel = out_bp.shape[-1]
din = torch.zeros((indice_pairs.shape[1], out_bp.shape[1]), dtype=out_bp.dtype, device=out_bp.device)
din = torch.zeros((indice_pairs.shape[1], out_bp.shape[1]),
dtype=out_bp.dtype,
device=out_bp.device)
assert out_bp.is_cuda
if not out_bp.is_contiguous():
out_bp = out_bp.contiguous()
......@@ -1972,7 +2023,8 @@ def indice_avgpool_implicit_gemm_backward(out_bp,
din_tv = torch_tensor_to_tv(din)
indice_pairs_tv = torch_tensor_to_tv(indice_pairs)
SpconvOps.avgpool_implicit_gemm_backward(out_bp_tv, din_tv,
indice_pairs_tv, count_out_tv, stream)
indice_pairs_tv, count_out_tv,
stream)
return din
......
......@@ -323,6 +323,8 @@ def main():
# pickle.dump((voxels, coors, spatial_shape), f)
with open(Path(__file__).parent / "data" / "test_spconv.pkl", "rb") as f:
(voxels, coors, spatial_shape) = pickle.load(f)
# voxels, coors, spatial_shape = waymo_data_large()
print(spatial_shape)
print(voxels.shape)
# voxels = voxels[:100]
......@@ -366,15 +368,14 @@ def main():
dout = np.random.uniform(-0.2, 0.2, out.features.shape).astype(np.float32)
dout_t = torch.from_numpy(dout).to(device).to(dtype)
print(out.spatial_shape, out.features.mean(), out.features.max(),
print(out.spatial_shape, out.features.sum(1).mean(), out.features.max(),
out.features.min())
times = []
show_metrics = False
with torch.no_grad():
for i in range(20):
print("------------")
torch.cuda.synchronize()
t = time.time()
for i in range(100):
# print("------------")
with tv.measure_duration() as measure:
out_nograd = net(voxels_th, coors_th, 1, show_metrics)
# res = timer.collect_by_name("forward", timer.get_all_pair_time())
# res2 = timer.collect_by_name("forward0", timer.get_all_pair_time())
......@@ -383,14 +384,19 @@ def main():
# print(timer.get_all_pair_time())
# print(sum(timer.get_all_pair_time().values()))
torch.cuda.synchronize()
# sort_bench()
times.append(time.time() - t)
times.append(measure.duration)
if show_metrics:
timer = out_nograd._timer
items = list(timer.get_all_pair_time().items())
items.sort(key=lambda x: x[0])
print("SUM TIME:", sum([x[1] for x in items]))
print(json.dumps(dict(items), indent=2))
inds_sum = 0
for k, v in items:
if "gen_pairs" in k:
inds_sum += v
print("SUM GEN INDS:", inds_sum)
# state = net.state_dict()
# state.pop("net.2.max_num_voxels_during_training")
......
......@@ -231,8 +231,8 @@ def _test_impgemm_conv_cuda(subm: bool):
# out_channels = [32, 48, 64]
in_channels = [32, 47]
out_channels = [32, 48, 62]
in_channels = [32]
out_channels = [32]
# in_channels = [32]
# out_channels = [32]
multiple_base = 16
if subm:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment