Unverified Commit ee8c9465 authored by FindDefinition's avatar FindDefinition Committed by GitHub
Browse files

Large kernel for implicit gemm (#547)



* large kernel bwd&bwdI, not test increment RS

* large kernel fix, no split_mask and increment rs

* large kernel fix2, no split_mask and increment rs

* reset benchmark.py

* fix merge
Co-authored-by: default avatarEvernightAurora <2465542858@qq.com>
parent bdfbf4a2
......@@ -618,6 +618,7 @@ class SimpleConv:
]
self.prebuilt_desps = prebuilt_desps
self.prebuilt_desp_names = {str(d) for d in prebuilt_desps}
self.lock = Lock()
self.static_key_to_desps = group_by(self.get_static_key, all_desps)
......@@ -823,6 +824,7 @@ class SimpleConv:
mask_argsort: tv.Tensor,
indices: tv.Tensor,
reverse_mask: bool,
mask_int_count: int = 1,
mask_filter: int = 0xffffffff,
mask_width: int = -1,
mask_output: tv.Tensor = tv.Tensor(),
......@@ -863,6 +865,8 @@ class SimpleConv:
params.indices = indices
params.mask = mask
params.mask_output = mask_output
params.mask_int_count = mask_int_count
# if op_type == ConvOpType.kBackwardWeight:
# assert not mask_output.empty()
if op_type == ConvOpType.kBackwardInput:
......@@ -940,7 +944,8 @@ class SimpleConv:
bias: Optional[tv.Tensor] = None,
act_alpha: float = 0.0,
act_beta: float = 0.0,
act_type: tv.gemm.Activation = tv.gemm.Activation.None_):
act_type: tv.gemm.Activation = tv.gemm.Activation.None_,
mask_int_count: Union[int, None] = None):
channel_k = output.dim(1)
channel_c = inp.dim(1)
# GemmMainUnitTest.stream_synchronize(stream)
......@@ -981,6 +986,7 @@ class SimpleConv:
params.mask_filter = mask_filter
params.mask_output = mask_output
params.reverse_mask = reverse_mask
params.mask_int_count = mask_int_count
if bias is not None:
params.bias = bias
if timer.enable:
......
......@@ -144,7 +144,7 @@ class SpconvOps:
"""
...
@staticmethod
def generate_conv_inds_mask_stage2(indices: Tensor, hashdata_k: Tensor, hashdata_v: Tensor, indice_pairs_fwd: Tensor, indice_pairs_bwd: Tensor, indice_pairs_uniq: Tensor, indice_pairs_uniq_before_sort: Tensor, out_inds: Tensor, mask_fwd: Tensor, mask_bwd: Tensor, num_out_act: int, batch_size: int, output_dims: List[int], input_dims: List[int], ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], transposed: bool = False, stream_int: int = 0) -> int:
def generate_conv_inds_mask_stage2(indices: Tensor, hashdata_k: Tensor, hashdata_v: Tensor, indice_pairs_fwd: Tensor, indice_pairs_bwd: Tensor, indice_pairs_uniq: Tensor, indice_pairs_uniq_before_sort: Tensor, out_inds: Tensor, mask_fwd: Tensor, mask_bwd: Tensor, num_out_act: int, batch_size: int, output_dims: List[int], input_dims: List[int], ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], transposed: bool = False, stream_int: int = 0, mask_int_count: int = 1) -> int:
"""
Args:
indices:
......@@ -167,10 +167,11 @@ class SpconvOps:
dilation:
transposed:
stream_int:
mask_int_count:
"""
...
@staticmethod
def generate_conv_inds_stage2_mask_direct_table(indices: Tensor, hashdata_k: Tensor, hashdata_v: Tensor, indice_pairs_fwd: Tensor, indice_pairs_bwd: Tensor, indice_pairs_uniq: Tensor, indice_pairs_uniq_before_sort: Tensor, out_inds: Tensor, mask_fwd: Tensor, mask_bwd: Tensor, num_out_act: int, batch_size: int, output_dims: List[int], input_dims: List[int], ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], transposed: bool = False, stream_int: int = 0) -> int:
def generate_conv_inds_stage2_mask_direct_table(indices: Tensor, hashdata_k: Tensor, hashdata_v: Tensor, indice_pairs_fwd: Tensor, indice_pairs_bwd: Tensor, indice_pairs_uniq: Tensor, indice_pairs_uniq_before_sort: Tensor, out_inds: Tensor, mask_fwd: Tensor, mask_bwd: Tensor, num_out_act: int, batch_size: int, output_dims: List[int], input_dims: List[int], ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], transposed: bool = False, stream_int: int = 0, mask_int_count: int = 1) -> int:
"""
Args:
indices:
......@@ -193,10 +194,11 @@ class SpconvOps:
dilation:
transposed:
stream_int:
mask_int_count:
"""
...
@staticmethod
def generate_subm_conv_inds(indices: Tensor, hashdata_k: Tensor, hashdata_v: Tensor, indice_pairs: Tensor, out_inds: Tensor, indice_num_per_loc: Tensor, batch_size: int, input_dims: List[int], ksize: List[int], dilation: List[int], indice_pair_mask: Tensor = Tensor(), backward: bool = False, stream_int: int = 0) -> int:
def generate_subm_conv_inds(indices: Tensor, hashdata_k: Tensor, hashdata_v: Tensor, indice_pairs: Tensor, out_inds: Tensor, indice_num_per_loc: Tensor, batch_size: int, input_dims: List[int], ksize: List[int], dilation: List[int], indice_pair_mask: Tensor = Tensor(), backward: bool = False, stream_int: int = 0, mask_int_count: int = 1) -> int:
"""
Args:
indices:
......@@ -212,6 +214,7 @@ class SpconvOps:
indice_pair_mask:
backward:
stream_int:
mask_int_count:
"""
...
@staticmethod
......@@ -380,7 +383,7 @@ class SpconvOps:
"""
...
@staticmethod
def sort_1d_by_key_allocator(data: Tensor, alloc_func, indices: Tensor = Tensor(), stream: int = 0) -> Tensor:
def sort_1d_by_key_allocator_mask32(data: Tensor, alloc_func, indices: Tensor = Tensor(), stream: int = 0) -> Tensor:
"""
Args:
data:
......@@ -390,6 +393,58 @@ class SpconvOps:
"""
...
@staticmethod
def sort_1d_by_key_allocator_mask32_v2(data: Tensor, allocator, indices: Tensor = Tensor(), stream: int = 0) -> Tensor:
"""
Args:
data:
allocator:
indices:
stream:
"""
...
@staticmethod
def sort_1d_by_key_allocator_mask128(data: Tensor, alloc_func, indices: Tensor = Tensor(), stream: int = 0) -> Tensor:
"""
Args:
data:
alloc_func:
indices:
stream:
"""
...
@staticmethod
def sort_1d_by_key_allocator_mask128_v2(data: Tensor, allocator, indices: Tensor = Tensor(), stream: int = 0) -> Tensor:
"""
Args:
data:
allocator:
indices:
stream:
"""
...
@staticmethod
def sort_1d_by_key_allocator_mask_auto(data: Tensor, alloc_param, indices: Tensor = Tensor(), stream: int = 0, mask_int_count: int = 1) -> Tensor:
"""
Args:
data:
alloc_param:
indices:
stream:
mask_int_count:
"""
...
@staticmethod
def sort_1d_by_key_allocator_mask_auto_v2(data: Tensor, alloc_param, indices: Tensor = Tensor(), stream: int = 0, mask_int_count: int = 1) -> Tensor:
"""
Args:
data:
alloc_param:
indices:
stream:
mask_int_count:
"""
...
@staticmethod
def sort_1d_by_key_split(data: Tensor, mask: Tensor, indices: Tensor = Tensor(), stream: int = 0, mask_output: bool = False) -> Tensor:
"""
Args:
......@@ -543,7 +598,7 @@ class SpconvOps:
"""
...
@staticmethod
def get_indice_pairs_implicit_gemm(allocator, indices: Tensor, batch_size: int, input_dims: List[int], algo: int, ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], out_padding: List[int], subm: bool, transposed: bool, is_train: bool, stream_int: int = 0, num_out_act_bound: int = -1, timer: CUDAKernelTimer = CUDAKernelTimer(False), direct_table: bool = False, preallocated: Dict[str, Tensor] = {}) -> Tuple[Tensor, int]:
def get_indice_pairs_implicit_gemm(allocator, indices: Tensor, batch_size: int, input_dims: List[int], algo: int, ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], out_padding: List[int], subm: bool, transposed: bool, is_train: bool, stream_int: int = 0, num_out_act_bound: int = -1, timer: CUDAKernelTimer = CUDAKernelTimer(False), direct_table: bool = False, preallocated: Dict[str, Tensor] = {}) -> Tuple[Tensor, int, int]:
"""
Args:
allocator:
......
......@@ -48,7 +48,7 @@ class ConvTunerSimple:
stream_int:
"""
...
def tune_and_cache(self, op_type: int, inp: Tensor, weight: Tensor, output: Tensor, layout_i: int, layout_w: int, layout_o: int, interleave_i: int, interleave_w: int, interleave_o: int, arch: Tuple[int, int], mask: Tensor, mask_argsort: Tensor, indices: Tensor, reverse_mask: bool, mask_filter: int = 0xffffffff, mask_width: int = -1, mask_output: Tensor = Tensor(), alpha: float = 1.0, beta: float = 0.0, stream_int: int = 0, auto_fp32_accum: bool = True, fp32_accum: bool = False, num_run: int = 5, use_tf32: bool = True) -> Tuple[ConvTuneResult, float]:
def tune_and_cache(self, op_type: int, inp: Tensor, weight: Tensor, output: Tensor, layout_i: int, layout_w: int, layout_o: int, interleave_i: int, interleave_w: int, interleave_o: int, arch: Tuple[int, int], mask: Tensor, mask_argsort: Tensor, indices: Tensor, reverse_mask: bool, mask_filter: int = 0xffffffff, mask_width: int = -1, mask_output: Tensor = Tensor(), alpha: float = 1.0, beta: float = 0.0, stream_int: int = 0, mask_int_count: int = 1, auto_fp32_accum: bool = True, fp32_accum: bool = False, num_run: int = 5, use_tf32: bool = True) -> Tuple[ConvTuneResult, float]:
"""
Args:
op_type:
......@@ -72,6 +72,7 @@ class ConvTunerSimple:
alpha:
beta:
stream_int:
mask_int_count:
auto_fp32_accum:
fp32_accum:
num_run:
......@@ -91,7 +92,7 @@ class ConvTunerSimple:
mask_width:
"""
...
def run_with_tuned_result(self, profile_res, op_type: int, inp: Tensor, weight: Tensor, output: Tensor, mask: Tensor, mask_argsort: Tensor, mask_output: Tensor, indices: Tensor, reverse_mask: bool, mask_filter: int = 0xffffffff, mask_width: int = -1, alpha: float = 1.0, beta: float = 0.0, stream_int: int = 0, workspace: Tensor = Tensor(), verbose: bool = False, timer: CUDAKernelTimer = CUDAKernelTimer(false), force_nvrtc: bool = False, bias: Tensor = Tensor(), act_alpha: float = 0.0, act_beta: float = 0.0, act_type: Activation = Activation.None_) -> None:
def run_with_tuned_result(self, profile_res, op_type: int, inp: Tensor, weight: Tensor, output: Tensor, mask: Tensor, mask_argsort: Tensor, mask_output: Tensor, indices: Tensor, reverse_mask: bool, mask_filter: int = 0xffffffff, mask_width: int = -1, alpha: float = 1.0, beta: float = 0.0, stream_int: int = 0, mask_int_count: int = 1, workspace: Tensor = Tensor(), verbose: bool = False, timer: CUDAKernelTimer = CUDAKernelTimer(false), force_nvrtc: bool = False, bias: Tensor = Tensor(), act_alpha: float = 0.0, act_beta: float = 0.0, act_type: Activation = Activation.None_) -> None:
"""
Args:
profile_res:
......@@ -109,6 +110,7 @@ class ConvTunerSimple:
alpha:
beta:
stream_int:
mask_int_count:
workspace:
verbose:
timer:
......
......@@ -63,7 +63,7 @@ class ConvGemmOps:
"""
...
@staticmethod
def implicit_gemm(allocator, conv_tuner, features: Tensor, filters: Tensor, pair_fwd: Tensor, pair_mask_fwd_splits: List[Tensor], mask_argsort_fwd_splits: List[Tensor], num_activate_out: int, masks: Tensor, arch: Tuple[int, int], is_train: bool = False, is_subm: bool = False, stream_int: int = 0, timer: CUDAKernelTimer = CUDAKernelTimer(False), auto_fp32_accum: bool = True, fp32_accum: bool = False, bias: Tensor = Tensor(), act_alpha: float = 0.0, act_beta: float = 0.0, act_type: Activation = Activation.None_, use_tf32: bool = True) -> Tuple[int, Any]:
def implicit_gemm(allocator, conv_tuner, features: Tensor, filters: Tensor, pair_fwd: Tensor, pair_mask_fwd_splits: List[Tensor], mask_argsort_fwd_splits: List[Tensor], num_activate_out: int, masks: Tensor, mask_int_count: int, arch: Tuple[int, int], is_train: bool = False, is_subm: bool = False, stream_int: int = 0, timer: CUDAKernelTimer = CUDAKernelTimer(False), auto_fp32_accum: bool = True, fp32_accum: bool = False, bias: Tensor = Tensor(), act_alpha: float = 0.0, act_beta: float = 0.0, act_type: Activation = Activation.None_, use_tf32: bool = True) -> Tuple[int, Any]:
"""
Args:
allocator:
......@@ -75,6 +75,7 @@ class ConvGemmOps:
mask_argsort_fwd_splits:
num_activate_out:
masks:
mask_int_count:
arch:
is_train:
is_subm:
......@@ -90,7 +91,7 @@ class ConvGemmOps:
"""
...
@staticmethod
def implicit_gemm_backward(allocator, conv_tuner, features: Tensor, filters: Tensor, out_bp: Tensor, pair_fwd: Tensor, pair_bwd: Tensor, pair_mask_fwd_splits: List[Tensor], pair_mask_bwd_splits: List[Tensor], mask_argsort_fwd_splits: List[Tensor], mask_argsort_bwd_splits: List[Tensor], mask_output_fwd: Tensor, masks: Tensor, arch: Tuple[int, int], mask_width: int, is_subm: bool, stream_int: int = 0, timer: CUDAKernelTimer = CUDAKernelTimer(False), auto_fp32_accum: bool = True, fp32_accum: bool = False, use_tf32: bool = True) -> None:
def implicit_gemm_backward(allocator, conv_tuner, features: Tensor, filters: Tensor, out_bp: Tensor, pair_fwd: Tensor, pair_bwd: Tensor, pair_mask_fwd_splits: List[Tensor], pair_mask_bwd_splits: List[Tensor], mask_argsort_fwd_splits: List[Tensor], mask_argsort_bwd_splits: List[Tensor], mask_output_fwd: Tensor, masks: Tensor, mask_int_count: int, arch: Tuple[int, int], mask_width: int, is_subm: bool, stream_int: int = 0, timer: CUDAKernelTimer = CUDAKernelTimer(False), auto_fp32_accum: bool = True, fp32_accum: bool = False, use_tf32: bool = True) -> None:
"""
Args:
allocator:
......@@ -106,6 +107,7 @@ class ConvGemmOps:
mask_argsort_bwd_splits:
mask_output_fwd:
masks:
mask_int_count:
arch:
mask_width:
is_subm:
......
......@@ -462,6 +462,7 @@ class SpconvOps(pccm.Class):
code.arg("ksize, stride, padding, dilation", f"std::vector<int>")
code.arg("transposed", f"bool", "false")
code.arg("stream_int", f"std::uintptr_t", "0", pyanno="int")
code.arg("mask_int_count", "int", "1")
code.raw(f"""
int ndim = indices.dim(1) - 1;
TV_ASSERT_RT_ERR(output_dims.size() == ndim && input_dims.size() == ndim &&
......@@ -488,7 +489,7 @@ class SpconvOps(pccm.Class):
indice_pairs_uniq, indice_pairs_uniq_before_sort,
out_inds, mask_fwd, mask_bwd,
num_out_act, batch_size, output_dims_, input_dims_,
ksize_, stride_, padding_, dilation_, transposed, stream_int);
ksize_, stride_, padding_, dilation_, transposed, stream_int, mask_int_count);
}}
""")
code.raw(f"""TV_THROW_RT_ERR("unknown ndim", ndim);""")
......@@ -512,6 +513,7 @@ class SpconvOps(pccm.Class):
code.arg("ksize, stride, padding, dilation", f"std::vector<int>")
code.arg("transposed", f"bool", "false")
code.arg("stream_int", f"std::uintptr_t", "0", pyanno="int")
code.arg("mask_int_count", "int", "1")
code.raw(f"""
int ndim = indices.dim(1) - 1;
TV_ASSERT_RT_ERR(output_dims.size() == ndim && input_dims.size() == ndim &&
......@@ -538,7 +540,7 @@ class SpconvOps(pccm.Class):
indice_pairs_uniq, indice_pairs_uniq_before_sort,
out_inds, mask_fwd, mask_bwd,
num_out_act, batch_size, output_dims_, input_dims_,
ksize_, stride_, padding_, dilation_, transposed, stream_int);
ksize_, stride_, padding_, dilation_, transposed, stream_int, mask_int_count);
}}
""")
code.raw(f"""TV_THROW_RT_ERR("unknown ndim", ndim);""")
......@@ -559,6 +561,7 @@ class SpconvOps(pccm.Class):
"cumm.tensorview.Tensor = Tensor()")
code.arg("backward", "bool", "false")
code.arg("stream_int", f"std::uintptr_t", "0", pyanno="int = 0")
code.arg("mask_int_count", "int", "1")
code.raw(f"""
int ndim = indices.dim(1) - 1;
TV_ASSERT_RT_ERR(input_dims.size() == ndim &&
......@@ -579,7 +582,7 @@ class SpconvOps(pccm.Class):
indice_pairs, out_inds, indice_num_per_loc,
batch_size, input_dims_,
ksize_, dilation_, indice_pair_mask, backward,
stream_int);
stream_int, mask_int_count);
}}
""")
code.raw(f"""TV_THROW_RT_ERR("unknown ndim", ndim);""")
......@@ -906,7 +909,7 @@ class SpconvOps(pccm.Class):
""")
return code
def sort_1d_by_key_allocator_template(self, use_allocator: bool):
def sort_1d_by_key_allocator_template(self, use_allocator: bool, int_count: int = 1):
code = pccm.FunctionCode()
if CUMM_CPU_ONLY_BUILD:
return code.make_invalid()
......@@ -942,18 +945,19 @@ class SpconvOps(pccm.Class):
code.raw(f"""
cudaStream_t stream_cu = reinterpret_cast<cudaStream_t>(stream);
if (indices.empty()){{
indices = tv::empty({{data.dim(0)}}, tv::int32, 0);
indices = tv::empty({{data.dim(0) / {int_count}}}, tv::int32, 0);
}}
tv::cuda::Launch launcher(data.dim(0), stream_cu);
launcher(cudakers::arange_kernel<int32_t>, indices.data_ptr<int32_t>(), indices.dim(0));
// auto timer = tv::CUDATimer();
tv::dispatch<int32_t, uint32_t, int64_t, uint64_t>(data.dtype(), [&](auto I){{
using T = TV_DECLTYPE(I);
thrust::device_ptr<T> ptr_tr(data.data_ptr<T>());
using T_ = TV_DECLTYPE(I);
using T = {"T_" if int_count == 1 else f"thrust::tuple<{', '.join(['T_'] * int_count) }>"};
thrust::device_ptr<T> ptr_tr(reinterpret_cast<T*>(data.data_ptr<T_>()));
thrust::device_ptr<int32_t> ptr_k(indices.data_ptr<int32_t>());
auto thrust_ctx = thrust::cuda::par.on(stream_cu);
auto ctx2 = thrust::cuda::par(allocator).on(stream_cu);
thrust::sort_by_key(ctx2, ptr_tr, ptr_tr + data.dim(0), ptr_k);
thrust::sort_by_key(ctx2, ptr_tr, ptr_tr + data.dim(0) / {int_count}, ptr_k);
}});
// tv::ssprint("SORT BY KEY TIME", data.dim(0), timer.report() / 1000.0);
return indices;
......@@ -963,16 +967,73 @@ class SpconvOps(pccm.Class):
@pccm.pybind.mark
@_STATIC_FUNCTION
def sort_1d_by_key_allocator(self):
def sort_1d_by_key_allocator_mask32(self):
# for python
return self.sort_1d_by_key_allocator_template(False)
@pccm.pybind.mark
@_STATIC_FUNCTION
def sort_1d_by_key_allocator_mask32_v2(self):
# for python
return self.sort_1d_by_key_allocator_template(True)
@pccm.pybind.mark
@_STATIC_FUNCTION
def sort_1d_by_key_allocator_mask128(self):
# for python
return self.sort_1d_by_key_allocator_template(False, 4)
@pccm.pybind.mark
@_STATIC_FUNCTION
def sort_1d_by_key_allocator_mask128_v2(self):
# for python
return self.sort_1d_by_key_allocator_template(True, 4)
def sort_1d_by_key_allocator_mask_auto_template(self, use_allocator: bool):
code = pccm.FunctionCode()
if CUMM_CPU_ONLY_BUILD:
return code.make_invalid()
code.arg("data", "tv::Tensor")
if not use_allocator:
code.arg("alloc_param", "std::function<std::uintptr_t(std::size_t)>")
else:
code.arg("alloc_param", "ThrustAllocator&")
code.arg("indices",
"tv::Tensor",
"tv::Tensor()",
pyanno="cumm.tensorview.Tensor = Tensor()")
code.arg("stream", "std::uintptr_t", "0", pyanno="int")
code.arg("mask_int_count", "int", "1")
code.raw(f"""
switch (mask_int_count){{
case 1:
return sort_1d_by_key_allocator_mask32{"_v2" if use_allocator else ""}(data, alloc_param, indices, stream);
case 4:
return sort_1d_by_key_allocator_mask128{"_v2" if use_allocator else ""}(data, alloc_param, indices, stream);
default:
TV_ASSERT_RT_ERR(false, "Not implement for other mask_int_count");
return tv::Tensor();
}}
""")
return code.ret("tv::Tensor")
@pccm.pybind.mark
@pccm.static_function
def sort_1d_by_key_allocator_mask_auto(self):
return self.sort_1d_by_key_allocator_mask_auto_template(False)
@pccm.pybind.mark
@pccm.static_function
def sort_1d_by_key_allocator_mask_auto_v2(self):
return self.sort_1d_by_key_allocator_mask_auto_template(True)
@_STATIC_FUNCTION
def sort_1d_by_key_allocator_v2(self):
# for cpp only
return self.sort_1d_by_key_allocator_template(True)
@pccm.pybind.mark
@_STATIC_FUNCTION
def sort_1d_by_key_split(self):
......@@ -1659,7 +1720,11 @@ class SpconvOps(pccm.Class):
tvctx.set_cuda_stream(reinterpret_cast<cudaStream_t>(stream_int));
auto conv_algo = static_cast<tv::gemm::SparseConvAlgo>(algo);
int kv = std::accumulate(ksize.begin(), ksize.end(), 1, std::multiplies<int>());
TV_ASSERT_RT_ERR(kv <= 32, "currently only support ksize < 32");
int mask_int_count = (kv + 31) / 32;
if (mask_int_count > 1 && mask_int_count < 4)
mask_int_count = 4;
TV_ASSERT_RT_ERR(mask_int_count == 1 || mask_int_count == 4, "Not Implement too large kernel");
// TV_ASSERT_RT_ERR(kv <= 32, "currently only support ksize < 32");
std::vector<int> out_shape;
if (!subm){{
if (transposed){{
......@@ -1728,6 +1793,7 @@ class SpconvOps(pccm.Class):
auto mask_tensor_ptr = mask_tensor.data_ptr<uint32_t>();
if (is_mask_split){{
TV_ASSERT_RT_ERR(mask_int_count == 1, "not support for kv > 32");
auto kv_div_2 = kv / 2;
auto remain = kv - kv_div_2;
uint64_t mask_np_1 = 1;
......@@ -1779,14 +1845,14 @@ class SpconvOps(pccm.Class):
pair_mask = preallocated.at({pccm.literal(AllocKeys.PairMask)});
}}else{{
pair_mask = allocator.empty({pccm.literal(AllocKeys.PairMask)},
{{mask_split_count, num_act_in}}, tv::uint32, 0, stream_int);
{{mask_split_count, num_act_in * mask_int_count}}, tv::uint32, 0, stream_int);
}}
generate_subm_conv_inds(indices, hash_k, hash_v, pair, out_inds, indice_num_per_loc,
batch_size, input_dims, ksize, dilation, pair_mask, is_train, stream_int);
batch_size, input_dims, ksize, dilation, pair_mask, is_train, stream_int, mask_int_count);
auto mask_argsort = allocator.empty({pccm.literal(AllocKeys.MaskArgSort)},
{{mask_split_count, num_act_in}}, tv::int32, 0, stream_int);
for (int j = 0; j < mask_split_count; ++j){{
sort_1d_by_key_allocator_v2(pair_mask[j], thrustalloc, mask_argsort[j], stream_int);
sort_1d_by_key_allocator_mask_auto_v2(pair_mask[j], thrustalloc, mask_argsort[j], stream_int, mask_int_count);
}}
""")
with code.else_():
......@@ -1892,11 +1958,11 @@ Your Conv Params: )" << "\\n";
pair_fwd = allocator.full_int({pccm.literal(AllocKeys.PairFwd)},
{{kv, num_act_out}}, -1, indices.dtype(), indices.device(), stream_int);
pair_mask_fwd = allocator.zeros({pccm.literal(AllocKeys.PairMask)},
{{mask_split_count, num_act_out}}, tv::uint32, 0, stream_int);
{{mask_split_count, num_act_out * mask_int_count}}, tv::uint32, 0, stream_int);
pair_mask_bwd = tv::Tensor();
if (is_train){{
pair_mask_bwd = allocator.zeros({pccm.literal(AllocKeys.PairMaskBwd)},
{{mask_split_count, indices.dim(0)}}, tv::uint32, 0, stream_int);
{{mask_split_count, indices.dim(0) * mask_int_count}}, tv::uint32, 0, stream_int);
}}
}}
if (!direct_table){{
......@@ -1928,13 +1994,13 @@ Your Conv Params: )" << "\\n";
indice_pairs_uniq, indice_pairs_uniq_bkp,
out_inds, pair_mask_fwd, pair_mask_bwd, num_act_out,
batch_size, out_shape, input_dims, ksize, stride, padding, dilation,
transposed, stream_int);
transposed, stream_int, mask_int_count);
}}else{{
generate_conv_inds_mask_stage2(indices, hash_k, hash_v, pair_fwd, pair_bwd,
indice_pairs_uniq, indice_pairs_uniq_bkp,
out_inds, pair_mask_fwd, pair_mask_bwd, num_act_out,
batch_size, out_shape, input_dims, ksize, stride, padding, dilation,
transposed, stream_int);
transposed, stream_int, mask_int_count);
}}
}}
""")
......@@ -1964,21 +2030,21 @@ Your Conv Params: )" << "\\n";
}}
}}else{{
if (!is_train){{
sort_1d_by_key_allocator_v2(pair_mask_fwd[0], thrustalloc,
mask_argsort_fwd[0], stream_int);
sort_1d_by_key_allocator_mask_auto_v2(pair_mask_fwd[0], thrustalloc,
mask_argsort_fwd[0], stream_int, mask_int_count);
}}else{{
sort_1d_by_key_allocator_v2(pair_mask_fwd[0], thrustalloc,
mask_argsort_fwd[0], stream_int);
sort_1d_by_key_allocator_v2(pair_mask_bwd[0], thrustalloc,
mask_argsort_bwd[0], stream_int);
sort_1d_by_key_allocator_mask_auto_v2(pair_mask_fwd[0], thrustalloc,
mask_argsort_fwd[0], stream_int, mask_int_count);
sort_1d_by_key_allocator_mask_auto_v2(pair_mask_bwd[0], thrustalloc,
mask_argsort_bwd[0], stream_int, mask_int_count);
}}
}}
}}
""")
code.raw(f"""
return std::make_tuple(mask_tensor, num_act_out);
return std::make_tuple(mask_tensor, num_act_out, mask_int_count);
""")
return code.ret("std::tuple<tv::Tensor, int>")
return code.ret("std::tuple<tv::Tensor, int, int>")
@pccm.pybind.mark
@pccm.static_function
......
......@@ -1138,6 +1138,7 @@ class ConvTunerSimple(pccm.ParameterizedClass):
code.arg("beta", "float", "0.0")
code.arg("stream_int", f"std::uintptr_t", "0", pyanno="int")
code.arg("mask_int_count", "int", "1")
code.arg("auto_fp32_accum", "bool", "true")
code.arg("fp32_accum", "bool", "false")
code.arg("num_run", "int", "5")
......@@ -1186,6 +1187,8 @@ class ConvTunerSimple(pccm.ParameterizedClass):
params.indices = indices;
params.mask = mask;
params.mask_output = mask_output;
params.mask_int_count = mask_int_count;
// if (op_type_cpp == tv::gemm::ConvOpType::kBackwardWeight){{
// TV_ASSERT_RT_ERR(!mask_output.empty(), "error");
// }}
......@@ -1335,7 +1338,7 @@ class ConvTunerSimple(pccm.ParameterizedClass):
code.arg("beta", "float", "0.0")
code.arg("stream_int", f"std::uintptr_t", "0")
code.arg("mask_int_count", "int", "1")
code.arg("workspace", "tv::Tensor", "tv::Tensor()",
"cumm.tensorview.Tensor = Tensor()")
code.arg("verbose", f"bool", "false")
......@@ -1390,6 +1393,7 @@ class ConvTunerSimple(pccm.ParameterizedClass):
params.mask_width = mask_width;
params.mask_output = mask_output;
params.reverse_mask = reverse_mask;
params.mask_int_count = mask_int_count;
if (timer.enable()){{
params.timer = timer;
......@@ -2035,6 +2039,7 @@ class ConvGemmOps(pccm.ParameterizedClass):
"std::vector<tv::Tensor>")
code.arg("num_activate_out", "int")
code.arg("masks", "tv::Tensor")
code.arg("mask_int_count", "int")
code.arg("arch", "std::tuple<int, int>")
code.arg("is_train, is_subm", "bool", "false")
......@@ -2108,6 +2113,7 @@ class ConvGemmOps(pccm.ParameterizedClass):
tv::Tensor(), // mask_output
1.0, 0.0,
stream_int,
mask_int_count, // mask_int_count is after stream_int
auto_fp32_accum,
fp32_accum,
5, // num_run
......@@ -2120,7 +2126,7 @@ class ConvGemmOps(pccm.ParameterizedClass):
std::vector<tv::Tensor> mask_output_fwd_splits;
if (is_train){{
mask_output_fwd = allocator.empty({pccm.literal(AllocKeys.MaskOutputFwd)},
{{num_split, tv::div_up(num_activate_out, mask_width)}},
{{num_split, tv::div_up(num_activate_out, mask_width) * mask_int_count}},
tv::uint32, features.device(), stream_int);
for (int i = 0; i < num_split; ++i){{
mask_output_fwd_splits.push_back(mask_output_fwd[i]);
......@@ -2154,6 +2160,7 @@ class ConvGemmOps(pccm.ParameterizedClass):
-1, // mask_width
1.0, beta,
stream_int,
mask_int_count, // mask_int_count is after stream_int
tv::Tensor(), // workspace
false, // verbose
timer,
......@@ -2186,6 +2193,7 @@ class ConvGemmOps(pccm.ParameterizedClass):
code.arg("mask_output_fwd", "tv::Tensor")
code.arg("masks", "tv::Tensor")
code.arg("mask_int_count", "int")
code.arg("arch", "std::tuple<int, int>")
code.arg("mask_width", "int")
......@@ -2278,6 +2286,7 @@ class ConvGemmOps(pccm.ParameterizedClass):
tv::Tensor(), // mask_output
1.0, 0.0,
stream_int,
mask_int_count,
auto_fp32_accum,
fp32_accum,
5, // num_run
......@@ -2302,6 +2311,7 @@ class ConvGemmOps(pccm.ParameterizedClass):
tv::Tensor(), // mask_output
1.0, 0.0,
stream_int,
mask_int_count,
auto_fp32_accum,
fp32_accum,
5, // num_run
......@@ -2344,6 +2354,7 @@ class ConvGemmOps(pccm.ParameterizedClass):
-1, // mask_width
1.0, beta,
stream_int,
mask_int_count,
tv::Tensor(), // workspace
false, // verbose
timer);
......@@ -2361,6 +2372,7 @@ class ConvGemmOps(pccm.ParameterizedClass):
mask_width,
1.0, 0.0,
stream_int,
mask_int_count,
workspace, // workspace
false, // verbose
timer);
......
......@@ -613,10 +613,12 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
code.arg("num_indices_in", "int")
code.arg("num_indices_out", "int")
code.arg("mask_int_count", "int")
code.raw(f"""
int filter_offset = blockIdx.y;
uint32_t filter_mask_fwd = (1u << (filter_offset));
int filter_pointer_offset = filter_offset / 32;
uint32_t filter_mask_fwd = (1u << (filter_offset % 32));
// TODO following rule for even kernel size is wrong.
// uint32_t filter_mask_bwd = (1u << (gridDim.y - 1 - filter_offset));
......@@ -633,7 +635,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
auto output_index = table.value_ptr()[table_offset];
bool valid = CheckValueValid ? output_index >= 0 : true;
if (valid){{
atomicOr(mask_fwd + output_index, filter_mask_fwd);
atomicOr(mask_fwd + output_index * mask_int_count + filter_pointer_offset, filter_mask_fwd);
// atomicOr(mask_bwd + input_index, filter_mask_bwd);
indice_pairs_fwd_filter[output_index] = input_index;
if (indice_pairs_bwd != nullptr){{
......@@ -655,15 +657,19 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
code.arg("num_indices_in", "int")
code.arg("kv", "int")
code.arg("mask_int_count", "int")
code.raw(f"""
for (int input_index : tv::KernelLoopX<int>(num_indices_in)) {{
for (int mask_offset = 0; mask_offset < mask_int_count; ++mask_offset){{
uint32_t mask = 0;
for (int filter_offset = 0; filter_offset < kv; ++filter_offset){{
for (int filter_offset = mask_offset * 32; filter_offset < mask_offset * 32 + 32 && filter_offset < kv; ++filter_offset){{
auto val = indice_pairs_bwd[filter_offset * num_indices_in + input_index];
mask |= (val != -1) << filter_offset;
mask |= (val != -1) << (filter_offset % 32);
}}
mask_bwd[input_index * mask_int_count + mask_offset] = mask;
}}
mask_bwd[input_index] = mask;
}}
""")
return code
......@@ -685,11 +691,13 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
code.arg("mask_fwd", f"uint32_t*") # [kernelProd]
code.arg("num_indices_in", "int")
code.arg("num_indices_out", "int")
code.arg("mask_int_count", "int")
# TODO use block instead of filter_offset?
code.raw(f"""
int filter_offset = blockIdx.y;
uint32_t filter_mask_fwd = (1u << (filter_offset));
int filter_pointer_offset = filter_offset / 32;
uint32_t filter_mask_fwd = (1u << (filter_offset % 32));
auto indice_pairs_fwd_filter = indice_pairs_fwd + filter_offset * num_indices_out;
// auto indice_pairs_bwd_filter = indice_pairs_bwd + filter_offset * num_indices_in;
......@@ -702,7 +710,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
auto output_index = table.value_ptr()[table_offset];
bool valid = CheckValueValid ? output_index >= 0 : true;
if (valid){{
atomicOr(mask_fwd + output_index, filter_mask_fwd);
atomicOr(mask_fwd + output_index * mask_int_count + filter_pointer_offset, filter_mask_fwd);
indice_pairs_fwd_filter[output_index] = input_index;
}}
}}
......@@ -812,11 +820,14 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
code.arg("RS", "int")
code.arg("is_train", "bool")
code.arg("mask_int_count", "int", "1")
code.raw(f"""
int filter_offset = blockIdx.y;
uint32_t filter_mask_out = (1u << (filter_offset));
uint32_t filter_mask_in = (1u << (RS - 1 - filter_offset));
uint32_t filter_mask_out = (1u << (filter_offset % 32));
uint32_t filter_mask_out_offset = filter_offset / 32;
uint32_t filter_mask_in = (1u << ((RS - 1 - filter_offset) % 32));
uint32_t filter_mask_in_offset = (RS - 1 - filter_offset) / 32;
// uint32_t filter_mask_center = (1u << (RS / 2));
loc_iter.set_filter_offset(filter_offset);
......@@ -843,8 +854,8 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
auto table_offset = table.lookup_offset(offset); // performance bound
if (table_offset != -1){{
auto input_index = table.value_ptr()[table_offset]; // we find a input indice idx.
atomicOr(mask + output_index, filter_mask_out);
atomicOr(mask + input_index, filter_mask_in);
atomicOr(mask + output_index * mask_int_count + filter_mask_out_offset, filter_mask_out);
atomicOr(mask + input_index * mask_int_count + filter_mask_in_offset, filter_mask_in);
// for this output, we set correct input idx.
indice_pairs[filter_offset_mul_indices_pair_size + output_index] = input_index;
if (is_train){{
......@@ -1244,6 +1255,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
f"tv::array<int, {self.ndim}>")
code.arg("transposed", f"bool", "false")
code.arg("stream_int", f"std::uintptr_t", "0")
code.arg("mask_int_count", "int", "1")
code.raw(f"""
auto custream = reinterpret_cast<cudaStream_t>(stream_int);
......@@ -1320,11 +1332,13 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
indice_pairs_fwd.data_ptr<int>(), indice_pairs_bwd.data_ptr<int>(),
indice_pairs_uniq_before_sort.data_ptr<K>(),
mask_fwd.data_ptr<uint32_t>(), mask_bwd.data_ptr<uint32_t>(),
num_act_in, indice_pairs_fwd.dim(1));
num_act_in, indice_pairs_fwd.dim(1),
mask_int_count);
launcher_num_act_in_no_y(calc_conv_indices_stage2_mask_output,
indice_pairs_bwd.data_ptr<int>(),
mask_bwd.data_ptr<uint32_t>(),
num_act_in, kv);
num_act_in, kv,
mask_int_count);
if (mask_fwd.dim(0) == 2){{
mask_fwd[1].copy_(mask_fwd[0], ctx);
}}
......@@ -1336,7 +1350,8 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
indice_pairs_fwd.data_ptr<int>(), indice_pairs_bwd.data_ptr<int>(),
indice_pairs_uniq_before_sort.data_ptr<K>(),
mask_fwd.data_ptr<uint32_t>(),
num_act_in, indice_pairs_fwd.dim(1));
num_act_in, indice_pairs_fwd.dim(1),
mask_int_count);
if (mask_fwd.dim(0) == 2){{
mask_fwd[1].copy_(mask_fwd[0], ctx);
}}
......@@ -1489,6 +1504,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
"cumm.tensorview.Tensor = Tensor()")
code.arg("is_train", "bool", "true")
code.arg("stream_int", f"std::uintptr_t", "0")
code.arg("mask_int_count", "int", "1")
code.raw(f"""
int num_act_in_real = indices.dim(0);
......@@ -1496,7 +1512,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
auto ctx = tv::Context();
ctx.set_cuda_stream(custream);
if (!indice_pair_mask.empty()){{
TV_ASSERT_INVALID_ARG(ksize.op<tv::arrayops::prod>() <= 32, "for now only support 32bit mask");
// TV_ASSERT_INVALID_ARG(ksize.op<tv::arrayops::prod>() <= 32, "for now only support 32bit mask");
}}
// TODO stream
// TODO handle num input == 0
......@@ -1557,11 +1573,15 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
}}else{{
// indice_pair_mask: [1, num_act_in]
tv::cuda::Launch lanucher_fill(num_act_in_real, custream);
if (mask_int_count == 1)
lanucher_fill(cudakers::fill_kernel<uint32_t>, indice_pair_mask.data_ptr<uint32_t>(), (1 << (kv / 2)), indices.dim(0));
else
lanucher_fill(init_subm_multiple_mask_int_kernel<uint32_t>,
indice_pair_mask.data_ptr<uint32_t>(), kv / 2, indices.dim(0), mask_int_count);
TV_ASSERT_RT_ERR(indice_pair_mask.dim(0) == 1, "error");
launcher_num_act_in(calc_subm_conv_indices_mask<table_t, {loc_type}>, loc_iter, hash,
indices.data_ptr<const int>(), indice_pairs.data_ptr<int>(),
indice_pair_mask.data_ptr<uint32_t>(), indices.dim(0), indice_pairs.dim(2), kv, is_train);
indice_pair_mask.data_ptr<uint32_t>(), indices.dim(0), indice_pairs.dim(2), kv, is_train, mask_int_count);
}}
}}else{{
TV_ASSERT_RT_ERR(indice_pairs.ndim() == 3, "error");
......@@ -1576,6 +1596,25 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
return code.ret("int")
@pccm.cuda.cuda_global_function
def init_subm_multiple_mask_int_kernel(self):
code = pccm.FunctionCode()
code.targ("T")
code.arg("ptr", "T*")
code.arg("set_bit", "int")
code.arg("length", "int")
code.arg("mask_int_count", "int")
code.raw(f"""
int initial_offset = blockIdx.x * blockDim.x + threadIdx.x;
int bit_offset = set_bit / 32;
int bit_residue = set_bit % 32;
for(int offset : tv::KernelLoopX<int>(length)){{
for (int i=0; i < mask_int_count; ++i)
ptr[offset * mask_int_count + i] = (i == bit_offset) * (1 << bit_residue);
}}
""")
return code
class SparseConvIndicesCPU(pccm.ParameterizedClass):
......
......@@ -111,8 +111,8 @@ class SparseConvolution(SparseModule):
algo = ConvAlgo.MaskImplicitGemm
else:
algo = ConvAlgo.Native
if kv > 32:
assert algo == ConvAlgo.Native, "implicit gemm don't support kv >= 32 for now"
# if kv > 32:
# assert algo == ConvAlgo.Native, "implicit gemm don't support kv >= 32 for now"
if CPU_ONLY_BUILD:
assert algo == ConvAlgo.Native, "cpu only build only support native algorithm"
self.algo = algo
......@@ -481,6 +481,7 @@ class SparseConvolution(SparseModule):
mask_argsort_fwd_splits = datas.mask_argsort_fwd_splits
mask_argsort_bwd_splits = datas.mask_argsort_bwd_splits
masks = datas.masks
mask_int_count = datas.mask_int_count
assert self.subm, "only support reuse subm indices"
self._check_subm_reuse_valid(input, spatial_shape,
datas)
......@@ -523,6 +524,7 @@ class SparseConvolution(SparseModule):
mask_argsort_fwd_splits = res[6]
mask_argsort_bwd_splits = res[7]
masks = res[8]
mask_int_count = res[9]
if self.indice_key is not None:
indice_data = ImplicitGemmIndiceData(
outids,
......@@ -541,7 +543,8 @@ class SparseConvolution(SparseModule):
ksize=self.kernel_size,
stride=self.stride,
padding=self.padding,
dilation=self.dilation)
dilation=self.dilation,
mask_int_count=mask_int_count)
msg = f"your indice key {self.indice_key} already exists in this sparse tensor."
assert self.indice_key not in indice_dict, msg
indice_dict[self.indice_key] = indice_data
......@@ -553,7 +556,7 @@ class SparseConvolution(SparseModule):
features, self.weight, pair_fwd, pair_bwd,
pair_mask_fwd_splits, pair_mask_bwd_splits,
mask_argsort_fwd_splits, mask_argsort_bwd_splits,
num_activate_out, masks, self.training, self.subm,
num_activate_out, masks, mask_int_count, self.training, self.subm,
input._timer, self.fp32_accum,
bias_for_infer,
self.act_alpha,
......
......@@ -89,7 +89,8 @@ class ImplicitGemmIndiceData(object):
out_spatial_shape, is_subm: bool, algo: ConvAlgo,
ksize: List[int], stride: List[int], dilation: List[int], padding: List[int],
in_voxel_num: Optional[Any] = None,
out_voxel_num: Optional[Any] = None):
out_voxel_num: Optional[Any] = None,
mask_int_count: int=1):
self.out_indices = out_indices
self.indices = indices
self.pair_fwd = pair_fwd
......@@ -110,6 +111,7 @@ class ImplicitGemmIndiceData(object):
# in/out voxel_num is only used in tensorrt conversion.
self.in_voxel_num = in_voxel_num
self.out_voxel_num = out_voxel_num
self.mask_int_count = mask_int_count
def scatter_nd(indices, updates, shape):
......
......@@ -198,6 +198,7 @@ class SparseImplicitGemmFunction(Function):
mask_argsort_bwd_splits: List[torch.Tensor],
num_activate_out: int,
masks: List[np.ndarray],
mask_int_count: int,
is_train: bool,
is_subm: bool,
timer: CUDAKernelTimer = CUDAKernelTimer(False),
......@@ -209,7 +210,7 @@ class SparseImplicitGemmFunction(Function):
try:
out, mask_out, mask_width = ops.implicit_gemm(
features, filters, pair_fwd, pair_mask_fwd_splits,
mask_argsort_fwd_splits, num_activate_out, masks, is_train,
mask_argsort_fwd_splits, num_activate_out, masks, mask_int_count, is_train,
is_subm, timer, fp32_accum, bias, act_alpha, act_beta,
act_type)
except Exception as e:
......@@ -235,6 +236,7 @@ class SparseImplicitGemmFunction(Function):
ctx.masks = masks
ctx.is_subm = is_subm
ctx.fp32_accum = fp32_accum
ctx.mask_int_count = mask_int_count
return out
@staticmethod
......@@ -253,6 +255,7 @@ class SparseImplicitGemmFunction(Function):
is_subm = ctx.is_subm
timer = ctx.timer
fp32_accum = ctx.fp32_accum
mask_int_count = ctx.mask_int_count
try:
input_bp, filters_bp = ops.implicit_gemm_backward(
......@@ -267,6 +270,7 @@ class SparseImplicitGemmFunction(Function):
mask_argsort_bwd_splits,
mask_output_fwd=mask_out,
masks=masks,
mask_int_count=mask_int_count,
mask_width=mask_width,
is_subm=is_subm,
timer=timer,
......@@ -282,7 +286,7 @@ class SparseImplicitGemmFunction(Function):
mask_argsort_bwd_splits, masks))
raise e
None_9 = [None] * 16
None_9 = [None] * 17
return (input_bp, filters_bp, *None_9)
......
......@@ -365,7 +365,7 @@ def get_indice_pairs_implicit_gemm(
timer_cpp = tv.CUDAKernelTimer(False)
if timer._timer is not None:
timer_cpp = timer._timer
mask_tensor, num_act_out = SpconvOps.get_indice_pairs_implicit_gemm(
mask_tensor, num_act_out, mask_int_count = SpconvOps.get_indice_pairs_implicit_gemm(
thalloc,
torch_tensor_to_tv(indices),
batch_size,
......@@ -408,7 +408,7 @@ def get_indice_pairs_implicit_gemm(
assert pair.shape[0] == 2
pair_bwd = pair[1]
return (out_inds, indice_num_per_loc, pair[0], pair_bwd,
pair_mask_in_splits, [], mask_argsort_in_splits, [], masks)
pair_mask_in_splits, [], mask_argsort_in_splits, [], masks, mask_int_count)
else:
pair_bwd = thalloc.allocated.get(AllocKeys.PairBwd, torch.Tensor())
pair_fwd = thalloc.allocated[AllocKeys.PairFwd]
......@@ -437,12 +437,16 @@ def get_indice_pairs_implicit_gemm(
]
return (out_inds, indice_num_per_loc, pair_fwd, pair_bwd,
pair_mask_fwd_splits, pair_mask_bwd_splits,
mask_argsort_fwd_splits, mask_argsort_bwd_splits, masks)
mask_argsort_fwd_splits, mask_argsort_bwd_splits, masks, mask_int_count)
assert indices.is_cuda, "implicit gemm only support cuda"
ndim = indices.shape[1] - 1
kv: int = functools.reduce(lambda x, y: x * y, ksize, 1)
# TODO in future we will support up to 128 kernel volume.
assert kv <= 32, "currently only support kernel volume <= 32 to use implicit gemm"
# assert kv <= 32, "currently only support kernel volume <= 32 to use implicit gemm"
mask_int_count = (kv + 31) // 32
if 1 < mask_int_count < 4:
mask_int_count = 4
assert mask_int_count in [1, 4]
if not subm:
if transpose:
......@@ -489,6 +493,7 @@ def get_indice_pairs_implicit_gemm(
pair_tv = torch_tensor_to_tv(pair)
indice_num_per_loc_tv = torch_tensor_to_tv(indice_num_per_loc)
if is_mask_split:
assert mask_int_count == 1, "Not Implemented"
kv_div_2 = kv // 2
remain = kv - kv_div_2
mask_np_1 = np.array([1], dtype=np.uint64)
......@@ -506,7 +511,7 @@ def get_indice_pairs_implicit_gemm(
hashdata = _HashData(out_inds.shape[0], use_int64_hash_k,
indices.device)
pair_mask = torch.empty((mask_split_count, indices.shape[0]),
pair_mask = torch.empty((mask_split_count, indices.shape[0] * mask_int_count),
dtype=torch.int32,
device=indices.device)
......@@ -526,7 +531,8 @@ def get_indice_pairs_implicit_gemm(
dilation=dilation,
indice_pair_mask=pair_mask_tv,
backward=is_train,
stream_int=stream)
stream_int=stream,
mask_int_count=mask_int_count)
# torch.cuda.synchronize()
# print("SUBM0", time.time() - t)
# CONV.stream_synchronize(stream)
......@@ -543,9 +549,10 @@ def get_indice_pairs_implicit_gemm(
# so I use this stupid hack to use torch allocator without touch
# pytorch binary (c++).
# f**k thrust
SpconvOps.sort_1d_by_key_allocator(pair_mask_tv[j],
SpconvOps.sort_1d_by_key_allocator_mask_auto(pair_mask_tv[j],
alloc.alloc,
mask_argsort_tv[j], stream)
mask_argsort_tv[j], stream,
mask_int_count)
# CONV.stream_synchronize(stream)
pair_mask_in_splits = [pair_mask[i] for i in range(mask_split_count)]
mask_argsort_in_splits = [
......@@ -553,10 +560,10 @@ def get_indice_pairs_implicit_gemm(
]
if is_train:
return (out_inds, indice_num_per_loc, pair[0], pair[1],
pair_mask_in_splits, [], mask_argsort_in_splits, [], masks)
pair_mask_in_splits, [], mask_argsort_in_splits, [], masks, mask_int_count)
else:
return (out_inds, indice_num_per_loc, pair[0], torch.Tensor(),
pair_mask_in_splits, [], mask_argsort_in_splits, [], masks)
pair_mask_in_splits, [], mask_argsort_in_splits, [], masks, mask_int_count)
else:
max_num_act = SpconvOps.get_handcrafted_max_act_out(
indices.shape[0], ksize, stride, padding, dilation)
......@@ -648,7 +655,7 @@ def get_indice_pairs_implicit_gemm(
-1,
dtype=indices.dtype,
device=indices.device)
pair_mask_fwd = torch.zeros((mask_split_count, num_act_out),
pair_mask_fwd = torch.zeros((mask_split_count, num_act_out * mask_int_count),
dtype=torch.int32,
device=indices.device)
pair_fwd_tv = torch_tensor_to_tv(pair_fwd)
......@@ -658,7 +665,7 @@ def get_indice_pairs_implicit_gemm(
pair_mask_bwd_tv = tv.Tensor()
if is_train:
pair_mask_bwd = torch.zeros(
(mask_split_count, indices.shape[0]),
(mask_split_count, indices.shape[0] * mask_int_count),
dtype=torch.int32,
device=indices.device)
pair_mask_bwd_tv = torch_tensor_to_tv(pair_mask_bwd,
......@@ -706,7 +713,8 @@ def get_indice_pairs_implicit_gemm(
padding=padding,
dilation=dilation,
transposed=transpose,
stream_int=stream)
stream_int=stream,
mask_int_count=mask_int_count)
mask_argsort_fwd = torch.empty((mask_split_count, out_inds.shape[0]),
dtype=torch.int32,
device=indices.device)
......@@ -758,25 +766,26 @@ def get_indice_pairs_implicit_gemm(
else:
# if pair_mask_bwd_tv.dim(1) > pair_mask_fwd_tv.dim(1):
if not is_train:
SpconvOps.sort_1d_by_key_allocator(pair_mask_fwd_tv[0],
SpconvOps.sort_1d_by_key_allocator_mask_auto(pair_mask_fwd_tv[0],
alloc.alloc,
mask_argsort_fwd_tv[0],
stream)
stream,
mask_int_count)
else:
if pair_mask_bwd_tv.dim(1) > pair_mask_fwd_tv.dim(1):
SpconvOps.sort_1d_by_key_allocator(
SpconvOps.sort_1d_by_key_allocator_mask_auto(
pair_mask_bwd_tv[0], alloc.alloc,
mask_argsort_bwd_tv[0], stream)
SpconvOps.sort_1d_by_key_allocator(
mask_argsort_bwd_tv[0], stream, mask_int_count)
SpconvOps.sort_1d_by_key_allocator_mask_auto(
pair_mask_fwd_tv[0], alloc.alloc,
mask_argsort_fwd_tv[0], stream)
mask_argsort_fwd_tv[0], stream, mask_int_count)
else:
SpconvOps.sort_1d_by_key_allocator(
SpconvOps.sort_1d_by_key_allocator_mask_auto(
pair_mask_fwd_tv[0], alloc.alloc,
mask_argsort_fwd_tv[0], stream)
SpconvOps.sort_1d_by_key_allocator(
mask_argsort_fwd_tv[0], stream, mask_int_count)
SpconvOps.sort_1d_by_key_allocator_mask_auto(
pair_mask_bwd_tv[0], alloc.alloc,
mask_argsort_bwd_tv[0], stream)
mask_argsort_bwd_tv[0], stream, mask_int_count)
# CONV.stream_synchronize(stream)
if not is_train:
......@@ -799,7 +808,7 @@ def get_indice_pairs_implicit_gemm(
return (out_inds, indice_num_per_loc, pair_fwd, pair_bwd,
pair_mask_fwd_splits, pair_mask_bwd_splits,
mask_argsort_fwd_splits, mask_argsort_bwd_splits, masks)
mask_argsort_fwd_splits, mask_argsort_bwd_splits, masks, mask_int_count)
def indice_conv(features: torch.Tensor,
......@@ -1448,6 +1457,7 @@ def implicit_gemm(features: torch.Tensor,
mask_argsort_fwd_splits: List[torch.Tensor],
num_activate_out: int,
masks: List[np.ndarray],
mask_int_count: int,
is_train: bool,
is_subm: bool,
timer: CUDAKernelTimer = CUDAKernelTimer(False),
......@@ -1491,7 +1501,7 @@ def implicit_gemm(features: torch.Tensor,
mask_width, tune_res_cpp = ConvGemmOps.implicit_gemm(
alloc, CONV_CPP, features_tv, filters_tv, pair_fwd_tv,
pair_mask_fwd_splits_tv, mask_argsort_fwd_splits_tv,
num_activate_out, mask_tv, arch, is_train, is_subm, stream,
num_activate_out, mask_tv, mask_int_count, arch, is_train, is_subm, stream,
timer_cpp, auto_fp32_accum, fp32_accum, bias_tv, act_alpha, act_beta, act_type,
use_tf32=constants.SPCONV_ALLOW_TF32)
out_features = alloc.allocated[AllocKeys.OutFeatures]
......@@ -1557,13 +1567,14 @@ def implicit_gemm(features: torch.Tensor,
mask_filter=masks[0].item(),
stream=stream,
fp32_accum=fp32_accum,
use_tf32=constants.SPCONV_ALLOW_TF32)
use_tf32=constants.SPCONV_ALLOW_TF32,
mask_int_count=mask_int_count)
mask_width = tune_res.algo_desp.tile_shape[0]
if is_train:
mask_output_fwd = torch.empty(
[num_split,
codeops.div_up(num_activate_out, mask_width)],
codeops.div_up(num_activate_out, mask_width) * mask_int_count],
dtype=torch.int32,
device=features.device)
# pytorch don't support uint32.
......@@ -1611,7 +1622,8 @@ def implicit_gemm(features: torch.Tensor,
bias=bias_tv,
act_type=act_type,
act_alpha=act_alpha,
act_beta=act_beta)
act_beta=act_beta,
mask_int_count=mask_int_count)
# INT8_TEST = True
# if INT8_TEST:
# if features.shape[1] % 32 != 0:
......@@ -1710,6 +1722,7 @@ def implicit_gemm_backward(features: torch.Tensor,
mask_argsort_bwd_splits: List[torch.Tensor],
mask_output_fwd: Optional[torch.Tensor],
masks: List[np.ndarray],
mask_int_count: int,
mask_width: int,
is_subm: bool,
timer: CUDAKernelTimer = CUDAKernelTimer(False),
......@@ -1769,7 +1782,7 @@ def implicit_gemm_backward(features: torch.Tensor,
alloc, CONV_CPP, features_tv, filters_tv, out_bp_tv, pair_fwd_tv,
pair_bwd_tv, pair_mask_fwd_splits_tv, pair_mask_bwd_splits_tv,
mask_argsort_fwd_splits_tv, mask_argsort_bwd_splits_tv,
mask_output_fwd_tv, mask_tv, arch, mask_width, is_subm, stream,
mask_output_fwd_tv, mask_tv, mask_int_count, arch, mask_width, is_subm, stream,
timer_cpp, auto_fp32_accum, fp32_accum,
use_tf32=constants.SPCONV_ALLOW_TF32)
din = alloc.allocated[AllocKeys.DIn]
......@@ -1848,7 +1861,8 @@ def implicit_gemm_backward(features: torch.Tensor,
mask_filter=masks[0].item(),
stream=stream,
fp32_accum=fp32_accum,
use_tf32=constants.SPCONV_ALLOW_TF32)
use_tf32=constants.SPCONV_ALLOW_TF32,
mask_int_count=mask_int_count)
if wgrad_tune_res is None:
wgrad_tune_res, _ = CONV.tune_and_cache(
ConvOpType.kBackwardWeight,
......@@ -1867,7 +1881,8 @@ def implicit_gemm_backward(features: torch.Tensor,
mask_output=tv.Tensor(),
mask_width=mask_width,
stream=stream,
use_tf32=constants.SPCONV_ALLOW_TF32)
use_tf32=constants.SPCONV_ALLOW_TF32,
mask_int_count=mask_int_count)
workspace_size = CONV.query_workspace_size(wgrad_tune_res.algo_desp,
wgrad_tune_res.splitk,
ConvOpType.kBackwardWeight,
......@@ -1904,7 +1919,8 @@ def implicit_gemm_backward(features: torch.Tensor,
mask_filter=masks[j].item(),
mask_width=-1,
beta=beta,
stream=stream)
stream=stream,
mask_int_count=mask_int_count)
# for backward weight, beta = 0 because each split
# handle different kernel locations.
# TODO remove D iterator in backward weight kernel
......@@ -1923,7 +1939,8 @@ def implicit_gemm_backward(features: torch.Tensor,
mask_width=mask_width,
beta=0,
workspace=workspace_tv,
stream=stream)
stream=stream,
mask_int_count=mask_int_count)
return (din, dfilters.reshape(filters_shape))
......
......@@ -291,6 +291,335 @@ class Net2(nn.Module):
return self.net(x)
class Net_kv75(nn.Module):
def __init__(self, shape, algo):
super().__init__()
pool_algo = algo
# pool_algo = ConvAlgo.Native
self.net = spconv.SparseSequential(
spconv.SubMConv3d(3, 64, (3, 5, 5), bias=False, indice_key="c0",
algo=algo),
# spconv.SubMConv3d(32,
# 32,
# 3,
# bias=False,
# indice_key="c0",
# algo=algo),
# # nn.BatchNorm1d(32),
# # nn.ReLU(),
# # spconv.SparseConv3d(64, 64, 2, 2, bias=False,
# # algo=algo),
# spconv.SubMConv3d(32, 64, 3, bias=False, indice_key="c0",
# algo=algo),
# spconv.SubMConv3d(64, 64, 3, bias=False, indice_key="c0",
# algo=algo),
# spconv.SubMConv3d(32,
# 32,
# 3,
# bias=False,
# indice_key="c0",
# algo=algo),
# # nn.BatchNorm1d(32),
# # nn.ReLU(),
# # spconv.SparseConv3d(64, 64, 2, 2, bias=False,
# # algo=algo),
# spconv.SubMConv3d(32, 64, 3, bias=False, indice_key="c0",
# algo=algo),
spconv.SubMConv3d(64,
64,
(3, 5, 5),
bias=False,
indice_key="c0",
algo=algo),
# nn.BatchNorm1d(32),
# nn.ReLU(),
# spconv.SparseConv3d(64, 64, 2, 2, bias=False, indice_key="m0"),
spconv.SparseMaxPool3d(2, 2, algo=pool_algo),
spconv.SubMConv3d(64,
96,
(3, 5, 5),
bias=False,
indice_key="c1",
algo=algo),
spconv.SubMConv3d(96,
96,
(3, 5, 5),
bias=False,
indice_key="c1",
algo=algo),
# nn.BatchNorm1d(64),
# nn.ReLU(),
# spconv.SparseConv3d(96, 96, 2, 2, bias=False, indice_key="m1"),
spconv.SparseMaxPool3d(2, 2, algo=pool_algo),
spconv.SubMConv3d(96,
128,
(3, 5, 5),
bias=False,
indice_key="c2",
algo=algo),
spconv.SubMConv3d(128,
128,
(3, 5, 5),
bias=False,
indice_key="c2",
algo=algo),
# nn.BatchNorm1d(128),
# nn.ReLU(),
# spconv.SparseConv3d(128, 128, 2, 2, bias=False, indice_key="m2"),
spconv.SparseMaxPool3d(2, 2, algo=pool_algo),
spconv.SubMConv3d(128,
160,
(3, 5, 5),
bias=False,
indice_key="c3",
algo=algo),
spconv.SubMConv3d(160,
160,
(3, 5, 5),
bias=False,
indice_key="c3",
algo=algo),
# nn.BatchNorm1d(128),
# nn.ReLU(),
# spconv.SparseConv3d(160, 160, 2, 2, bias=False, indice_key="m3"),
spconv.SparseMaxPool3d(2, 2, algo=pool_algo),
spconv.SubMConv3d(160,
192,
(3, 5, 5),
bias=False,
indice_key="c4",
algo=algo),
spconv.SubMConv3d(192,
192,
(3, 5, 5),
bias=False,
indice_key="c4",
algo=algo),
# nn.BatchNorm1d(128),
# nn.ReLU(),
spconv.SparseMaxPool3d(2, 2, indice_key="m4", algo=pool_algo),
# spconv.SparseConv3d(192, 192, 2, 2, bias=False, indice_key="m4"),
spconv.SubMConv3d(192,
224,
(3, 5, 5),
bias=False,
indice_key="c5",
algo=algo),
spconv.SubMConv3d(224,
224,
(3, 5, 5),
bias=False,
indice_key="c5",
algo=algo),
# nn.BatchNorm1d(224),
# nn.ReLU(),
# spconv.SparseConv3d(224, 224, 2, 2, bias=False, indice_key="m5"),
spconv.SparseMaxPool3d(2, 2, indice_key="m5", algo=pool_algo),
spconv.SparseConv3d(224,
256,
(3, 5, 5),
padding=(1, 2, 2),
bias=False,
# indice_key="c6",
algo=algo),
spconv.SubMConv3d(256,
256,
(3, 5, 5),
bias=False,
indice_key="c6",
algo=algo),
# nn.BatchNorm1d(256),
# nn.ReLU(),
# spconv.SparseInverseConv3d(256, 128, 2, indice_key="m5", bias=False, algo=algo),
# # # nn.BatchNorm1d(128),
# # # nn.ReLU(),
# spconv.SparseInverseConv3d(128, 64, 2, indice_key="m4", bias=False, algo=algo),
)
max_batch_size = 1
# grid (dense map) is used for indice generation. use pre-allocated grid can run faster.
# self.grid = torch.full([max_batch_size, *shape], -1,
# dtype=torch.int32).cuda()
# self.grid = None
self.shape = shape
def forward(self, features, coors, batch_size, enable_timer: bool = False):
x = spconv.SparseConvTensor(features,
coors,
self.shape,
batch_size,
# self.grid,
enable_timer=enable_timer)
return self.net(x)
class Net_kv125(nn.Module):
def __init__(self, shape, algo):
super().__init__()
pool_algo = algo
# pool_algo = ConvAlgo.Native
self.net = spconv.SparseSequential(
spconv.SubMConv3d(3, 64, 5, bias=False, indice_key="c0",
algo=algo),
# spconv.SubMConv3d(32,
# 32,
# 3,
# bias=False,
# indice_key="c0",
# algo=algo),
# # nn.BatchNorm1d(32),
# # nn.ReLU(),
# # spconv.SparseConv3d(64, 64, 2, 2, bias=False,
# # algo=algo),
# spconv.SubMConv3d(32, 64, 3, bias=False, indice_key="c0",
# algo=algo),
# spconv.SubMConv3d(64, 64, 3, bias=False, indice_key="c0",
# algo=algo),
# spconv.SubMConv3d(32,
# 32,
# 3,
# bias=False,
# indice_key="c0",
# algo=algo),
# # nn.BatchNorm1d(32),
# # nn.ReLU(),
# # spconv.SparseConv3d(64, 64, 2, 2, bias=False,
# # algo=algo),
# spconv.SubMConv3d(32, 64, 3, bias=False, indice_key="c0",
# algo=algo),
spconv.SubMConv3d(64,
64,
5,
bias=False,
indice_key="c0",
algo=algo),
# nn.BatchNorm1d(32),
# nn.ReLU(),
# spconv.SparseConv3d(64, 64, 2, 2, bias=False, indice_key="m0"),
spconv.SparseMaxPool3d(2, 2, algo=pool_algo),
spconv.SubMConv3d(64,
96,
5,
bias=False,
indice_key="c1",
algo=algo),
spconv.SubMConv3d(96,
96,
5,
bias=False,
indice_key="c1",
algo=algo),
# nn.BatchNorm1d(64),
# nn.ReLU(),
# spconv.SparseConv3d(96, 96, 2, 2, bias=False, indice_key="m1"),
spconv.SparseMaxPool3d(2, 2, algo=pool_algo),
spconv.SubMConv3d(96,
128,
5,
bias=False,
indice_key="c2",
algo=algo),
spconv.SubMConv3d(128,
128,
5,
bias=False,
indice_key="c2",
algo=algo),
# nn.BatchNorm1d(128),
# nn.ReLU(),
# spconv.SparseConv3d(128, 128, 2, 2, bias=False, indice_key="m2"),
spconv.SparseMaxPool3d(2, 2, algo=pool_algo),
spconv.SubMConv3d(128,
160,
5,
bias=False,
indice_key="c3",
algo=algo),
spconv.SubMConv3d(160,
160,
5,
bias=False,
indice_key="c3",
algo=algo),
# nn.BatchNorm1d(128),
# nn.ReLU(),
# spconv.SparseConv3d(160, 160, 2, 2, bias=False, indice_key="m3"),
spconv.SparseMaxPool3d(2, 2, algo=pool_algo),
spconv.SubMConv3d(160,
192,
5,
bias=False,
indice_key="c4",
algo=algo),
spconv.SubMConv3d(192,
192,
5,
bias=False,
indice_key="c4",
algo=algo),
# nn.BatchNorm1d(128),
# nn.ReLU(),
spconv.SparseMaxPool3d(2, 2, indice_key="m4", algo=pool_algo),
# spconv.SparseConv3d(192, 192, 2, 2, bias=False, indice_key="m4"),
spconv.SubMConv3d(192,
224,
5,
bias=False,
indice_key="c5",
algo=algo),
spconv.SubMConv3d(224,
224,
5,
bias=False,
indice_key="c5",
algo=algo),
# nn.BatchNorm1d(224),
# nn.ReLU(),
# spconv.SparseConv3d(224, 224, 2, 2, bias=False, indice_key="m5"),
spconv.SparseMaxPool3d(2, 2, indice_key="m5", algo=pool_algo),
spconv.SubMConv3d(224,
256,
5,
bias=False,
indice_key="c6",
algo=algo),
spconv.SubMConv3d(256,
256,
5,
bias=False,
indice_key="c6",
algo=algo),
# nn.BatchNorm1d(256),
# nn.ReLU(),
# spconv.SparseInverseConv3d(256, 128, 2, indice_key="m5", bias=False, algo=algo),
# # # nn.BatchNorm1d(128),
# # # nn.ReLU(),
# spconv.SparseInverseConv3d(128, 64, 2, indice_key="m4", bias=False, algo=algo),
)
max_batch_size = 1
# grid (dense map) is used for indice generation. use pre-allocated grid can run faster.
# self.grid = torch.full([max_batch_size, *shape], -1,
# dtype=torch.int32).cuda()
# self.grid = None
self.shape = shape
def forward(self, features, coors, batch_size, enable_timer: bool = False):
x = spconv.SparseConvTensor(features,
coors,
self.shape,
batch_size,
# self.grid,
enable_timer=enable_timer)
return self.net(x)
class NetSm(nn.Module):
def __init__(self, shape, algo):
......@@ -395,9 +724,9 @@ def main():
np.random.seed(50051)
torch.manual_seed(50051)
# voxels, coors, spatial_shape = waymo_data(num_features=3)
with open(Path(__file__).parent / "data" / "test_spconv.pkl", "rb") as f:
(voxels, coors, spatial_shape) = pickle.load(f)
# voxels, coors, spatial_shape = waymo_data_large()
# with open(Path(__file__).parent / "data" / "test_spconv.pkl", "rb") as f:
# (voxels, coors, spatial_shape) = pickle.load(f)
voxels, coors, spatial_shape = waymo_data_large()
# breakpoint()
print(spatial_shape)
......@@ -474,17 +803,18 @@ def main():
# state.pop("net.2.max_num_voxels_during_training")
# net.load_state_dict(state)
# breakpoint()
# net = net.train()
print("spconv time", np.mean(times[10:]))
# times = []
# for i in range(10):
# out = net(voxels_th, coors_th, 1)
# print("------------")
# # torch.cuda.synchronize()
# # t = time.time()
# torch.cuda.synchronize()
# t = time.time()
# out.features.backward(dout_t)
# # torch.cuda.synchronize()
# # times.append(time.time() - t)
# torch.cuda.synchronize()
# times.append(time.time() - t)
# # # print((net.grid == -1).float().sum(), net.grid.numel())
# # # print("spconv time", time.time() - t)
......
......@@ -135,6 +135,7 @@ class SparseConvTester:
self.mask_argsort_fwd_splits = res[6]
self.mask_argsort_bwd_splits = res[7]
self.masks = res[8]
self.mask_int_count = res[9]
self.out_inds_scalar = Fsp._indice_to_scalar(self.out_inds.long(), [bs, *out_shape])
......@@ -293,12 +294,12 @@ def _test_impgemm_conv_cuda(subm: bool):
multiple_base = 16
if subm:
ksizes = [3]
ksizes = [3, (3, 3, 5), (3, 5, 5), 5]
strides = [1]
paddings = [0]
dilations = [1]
else:
ksizes = [2, 3]
ksizes = [2, 3, (3, 3, 4), 4, (4, 5, 5), 5]
strides = [1, 2, 3]
paddings = [0, 1]
dilations = [1, 2]
......@@ -354,8 +355,9 @@ def _test_impgemm_conv_cuda(subm: bool):
mask_width = desp.tile_shape[0]
# if mask_width != 32:
# continue
if mask_width not in mask_width_to_mask_out_fwd:
mask_width_to_mask_out_fwd[mask_width] = torch.zeros([2, div_up(tester.out_inds.shape[0], mask_width)],
mask_width_to_mask_out_fwd[mask_width] = torch.zeros([2, tester.mask_int_count * div_up(tester.out_inds.shape[0], mask_width)],
dtype=torch.int32,
device=tester.device)
mask_output_fwd = mask_width_to_mask_out_fwd[mask_width]
......@@ -413,6 +415,7 @@ def _test_impgemm_conv_cuda(subm: bool):
force_nvrtc=force_nvrtc,
bias=bias_cur if is_fwd and bias_cur is not None else tv.Tensor(),
act_type=act,
mask_int_count=tester.mask_int_count,
)
else:
CONV.run_with_tuned_result(
......@@ -433,6 +436,7 @@ def _test_impgemm_conv_cuda(subm: bool):
force_nvrtc=force_nvrtc,
bias=bias_cur if is_fwd else None,
act_type=act,
mask_int_count=tester.mask_int_count
)
else:
......@@ -491,6 +495,7 @@ def _test_impgemm_conv_cuda(subm: bool):
force_nvrtc=force_nvrtc,
bias=bias if is_fwd and bias is not None else tv.Tensor(),
act_type=act,
mask_int_count=tester.mask_int_count,
)
else:
CONV.run_with_tuned_result(
......@@ -511,6 +516,7 @@ def _test_impgemm_conv_cuda(subm: bool):
force_nvrtc=force_nvrtc,
bias=bias if is_fwd else None,
act_type=act,
mask_int_count=tester.mask_int_count,
)
out_ref = tester.out_ref
......@@ -572,6 +578,7 @@ def _test_impgemm_conv_cuda(subm: bool):
mask_width=mask_width,
beta=beta,
verbose=False,
mask_int_count=tester.mask_int_count,
)
else:
indice_pairs = tester.pair_fwd # inp -> out
......@@ -599,6 +606,7 @@ def _test_impgemm_conv_cuda(subm: bool):
mask_width=mask_width,
beta=beta,
verbose=False,
mask_int_count=tester.mask_int_count,
)
dw_ref = tester.dw_ref
dw_my = weight_tv.cpu().numpy()
......@@ -918,8 +926,8 @@ def test_all_algo_unit():
# for i in range(5):
_test_impgemm_conv_cuda(True)
_test_impgemm_conv_cuda(False)
_test_native_conv_cuda(True)
_test_native_conv_cuda(False)
# _test_native_conv_cuda(True)
# _test_native_conv_cuda(False)
if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment