Unverified Commit ee8c9465 authored by FindDefinition's avatar FindDefinition Committed by GitHub
Browse files

Large kernel for implicit gemm (#547)



* large kernel bwd&bwdI, not test increment RS

* large kernel fix, no split_mask and increment rs

* large kernel fix2, no split_mask and increment rs

* reset benchmark.py

* fix merge
Co-authored-by: default avatarEvernightAurora <2465542858@qq.com>
parent bdfbf4a2
...@@ -618,6 +618,7 @@ class SimpleConv: ...@@ -618,6 +618,7 @@ class SimpleConv:
] ]
self.prebuilt_desps = prebuilt_desps self.prebuilt_desps = prebuilt_desps
self.prebuilt_desp_names = {str(d) for d in prebuilt_desps} self.prebuilt_desp_names = {str(d) for d in prebuilt_desps}
self.lock = Lock() self.lock = Lock()
self.static_key_to_desps = group_by(self.get_static_key, all_desps) self.static_key_to_desps = group_by(self.get_static_key, all_desps)
...@@ -823,6 +824,7 @@ class SimpleConv: ...@@ -823,6 +824,7 @@ class SimpleConv:
mask_argsort: tv.Tensor, mask_argsort: tv.Tensor,
indices: tv.Tensor, indices: tv.Tensor,
reverse_mask: bool, reverse_mask: bool,
mask_int_count: int = 1,
mask_filter: int = 0xffffffff, mask_filter: int = 0xffffffff,
mask_width: int = -1, mask_width: int = -1,
mask_output: tv.Tensor = tv.Tensor(), mask_output: tv.Tensor = tv.Tensor(),
...@@ -863,6 +865,8 @@ class SimpleConv: ...@@ -863,6 +865,8 @@ class SimpleConv:
params.indices = indices params.indices = indices
params.mask = mask params.mask = mask
params.mask_output = mask_output params.mask_output = mask_output
params.mask_int_count = mask_int_count
# if op_type == ConvOpType.kBackwardWeight: # if op_type == ConvOpType.kBackwardWeight:
# assert not mask_output.empty() # assert not mask_output.empty()
if op_type == ConvOpType.kBackwardInput: if op_type == ConvOpType.kBackwardInput:
...@@ -940,7 +944,8 @@ class SimpleConv: ...@@ -940,7 +944,8 @@ class SimpleConv:
bias: Optional[tv.Tensor] = None, bias: Optional[tv.Tensor] = None,
act_alpha: float = 0.0, act_alpha: float = 0.0,
act_beta: float = 0.0, act_beta: float = 0.0,
act_type: tv.gemm.Activation = tv.gemm.Activation.None_): act_type: tv.gemm.Activation = tv.gemm.Activation.None_,
mask_int_count: Union[int, None] = None):
channel_k = output.dim(1) channel_k = output.dim(1)
channel_c = inp.dim(1) channel_c = inp.dim(1)
# GemmMainUnitTest.stream_synchronize(stream) # GemmMainUnitTest.stream_synchronize(stream)
...@@ -981,6 +986,7 @@ class SimpleConv: ...@@ -981,6 +986,7 @@ class SimpleConv:
params.mask_filter = mask_filter params.mask_filter = mask_filter
params.mask_output = mask_output params.mask_output = mask_output
params.reverse_mask = reverse_mask params.reverse_mask = reverse_mask
params.mask_int_count = mask_int_count
if bias is not None: if bias is not None:
params.bias = bias params.bias = bias
if timer.enable: if timer.enable:
......
...@@ -144,7 +144,7 @@ class SpconvOps: ...@@ -144,7 +144,7 @@ class SpconvOps:
""" """
... ...
@staticmethod @staticmethod
def generate_conv_inds_mask_stage2(indices: Tensor, hashdata_k: Tensor, hashdata_v: Tensor, indice_pairs_fwd: Tensor, indice_pairs_bwd: Tensor, indice_pairs_uniq: Tensor, indice_pairs_uniq_before_sort: Tensor, out_inds: Tensor, mask_fwd: Tensor, mask_bwd: Tensor, num_out_act: int, batch_size: int, output_dims: List[int], input_dims: List[int], ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], transposed: bool = False, stream_int: int = 0) -> int: def generate_conv_inds_mask_stage2(indices: Tensor, hashdata_k: Tensor, hashdata_v: Tensor, indice_pairs_fwd: Tensor, indice_pairs_bwd: Tensor, indice_pairs_uniq: Tensor, indice_pairs_uniq_before_sort: Tensor, out_inds: Tensor, mask_fwd: Tensor, mask_bwd: Tensor, num_out_act: int, batch_size: int, output_dims: List[int], input_dims: List[int], ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], transposed: bool = False, stream_int: int = 0, mask_int_count: int = 1) -> int:
""" """
Args: Args:
indices: indices:
...@@ -167,10 +167,11 @@ class SpconvOps: ...@@ -167,10 +167,11 @@ class SpconvOps:
dilation: dilation:
transposed: transposed:
stream_int: stream_int:
mask_int_count:
""" """
... ...
@staticmethod @staticmethod
def generate_conv_inds_stage2_mask_direct_table(indices: Tensor, hashdata_k: Tensor, hashdata_v: Tensor, indice_pairs_fwd: Tensor, indice_pairs_bwd: Tensor, indice_pairs_uniq: Tensor, indice_pairs_uniq_before_sort: Tensor, out_inds: Tensor, mask_fwd: Tensor, mask_bwd: Tensor, num_out_act: int, batch_size: int, output_dims: List[int], input_dims: List[int], ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], transposed: bool = False, stream_int: int = 0) -> int: def generate_conv_inds_stage2_mask_direct_table(indices: Tensor, hashdata_k: Tensor, hashdata_v: Tensor, indice_pairs_fwd: Tensor, indice_pairs_bwd: Tensor, indice_pairs_uniq: Tensor, indice_pairs_uniq_before_sort: Tensor, out_inds: Tensor, mask_fwd: Tensor, mask_bwd: Tensor, num_out_act: int, batch_size: int, output_dims: List[int], input_dims: List[int], ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], transposed: bool = False, stream_int: int = 0, mask_int_count: int = 1) -> int:
""" """
Args: Args:
indices: indices:
...@@ -193,10 +194,11 @@ class SpconvOps: ...@@ -193,10 +194,11 @@ class SpconvOps:
dilation: dilation:
transposed: transposed:
stream_int: stream_int:
mask_int_count:
""" """
... ...
@staticmethod @staticmethod
def generate_subm_conv_inds(indices: Tensor, hashdata_k: Tensor, hashdata_v: Tensor, indice_pairs: Tensor, out_inds: Tensor, indice_num_per_loc: Tensor, batch_size: int, input_dims: List[int], ksize: List[int], dilation: List[int], indice_pair_mask: Tensor = Tensor(), backward: bool = False, stream_int: int = 0) -> int: def generate_subm_conv_inds(indices: Tensor, hashdata_k: Tensor, hashdata_v: Tensor, indice_pairs: Tensor, out_inds: Tensor, indice_num_per_loc: Tensor, batch_size: int, input_dims: List[int], ksize: List[int], dilation: List[int], indice_pair_mask: Tensor = Tensor(), backward: bool = False, stream_int: int = 0, mask_int_count: int = 1) -> int:
""" """
Args: Args:
indices: indices:
...@@ -212,6 +214,7 @@ class SpconvOps: ...@@ -212,6 +214,7 @@ class SpconvOps:
indice_pair_mask: indice_pair_mask:
backward: backward:
stream_int: stream_int:
mask_int_count:
""" """
... ...
@staticmethod @staticmethod
...@@ -380,7 +383,7 @@ class SpconvOps: ...@@ -380,7 +383,7 @@ class SpconvOps:
""" """
... ...
@staticmethod @staticmethod
def sort_1d_by_key_allocator(data: Tensor, alloc_func, indices: Tensor = Tensor(), stream: int = 0) -> Tensor: def sort_1d_by_key_allocator_mask32(data: Tensor, alloc_func, indices: Tensor = Tensor(), stream: int = 0) -> Tensor:
""" """
Args: Args:
data: data:
...@@ -390,6 +393,58 @@ class SpconvOps: ...@@ -390,6 +393,58 @@ class SpconvOps:
""" """
... ...
@staticmethod @staticmethod
def sort_1d_by_key_allocator_mask32_v2(data: Tensor, allocator, indices: Tensor = Tensor(), stream: int = 0) -> Tensor:
"""
Args:
data:
allocator:
indices:
stream:
"""
...
@staticmethod
def sort_1d_by_key_allocator_mask128(data: Tensor, alloc_func, indices: Tensor = Tensor(), stream: int = 0) -> Tensor:
"""
Args:
data:
alloc_func:
indices:
stream:
"""
...
@staticmethod
def sort_1d_by_key_allocator_mask128_v2(data: Tensor, allocator, indices: Tensor = Tensor(), stream: int = 0) -> Tensor:
"""
Args:
data:
allocator:
indices:
stream:
"""
...
@staticmethod
def sort_1d_by_key_allocator_mask_auto(data: Tensor, alloc_param, indices: Tensor = Tensor(), stream: int = 0, mask_int_count: int = 1) -> Tensor:
"""
Args:
data:
alloc_param:
indices:
stream:
mask_int_count:
"""
...
@staticmethod
def sort_1d_by_key_allocator_mask_auto_v2(data: Tensor, alloc_param, indices: Tensor = Tensor(), stream: int = 0, mask_int_count: int = 1) -> Tensor:
"""
Args:
data:
alloc_param:
indices:
stream:
mask_int_count:
"""
...
@staticmethod
def sort_1d_by_key_split(data: Tensor, mask: Tensor, indices: Tensor = Tensor(), stream: int = 0, mask_output: bool = False) -> Tensor: def sort_1d_by_key_split(data: Tensor, mask: Tensor, indices: Tensor = Tensor(), stream: int = 0, mask_output: bool = False) -> Tensor:
""" """
Args: Args:
...@@ -543,7 +598,7 @@ class SpconvOps: ...@@ -543,7 +598,7 @@ class SpconvOps:
""" """
... ...
@staticmethod @staticmethod
def get_indice_pairs_implicit_gemm(allocator, indices: Tensor, batch_size: int, input_dims: List[int], algo: int, ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], out_padding: List[int], subm: bool, transposed: bool, is_train: bool, stream_int: int = 0, num_out_act_bound: int = -1, timer: CUDAKernelTimer = CUDAKernelTimer(False), direct_table: bool = False, preallocated: Dict[str, Tensor] = {}) -> Tuple[Tensor, int]: def get_indice_pairs_implicit_gemm(allocator, indices: Tensor, batch_size: int, input_dims: List[int], algo: int, ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], out_padding: List[int], subm: bool, transposed: bool, is_train: bool, stream_int: int = 0, num_out_act_bound: int = -1, timer: CUDAKernelTimer = CUDAKernelTimer(False), direct_table: bool = False, preallocated: Dict[str, Tensor] = {}) -> Tuple[Tensor, int, int]:
""" """
Args: Args:
allocator: allocator:
......
...@@ -48,7 +48,7 @@ class ConvTunerSimple: ...@@ -48,7 +48,7 @@ class ConvTunerSimple:
stream_int: stream_int:
""" """
... ...
def tune_and_cache(self, op_type: int, inp: Tensor, weight: Tensor, output: Tensor, layout_i: int, layout_w: int, layout_o: int, interleave_i: int, interleave_w: int, interleave_o: int, arch: Tuple[int, int], mask: Tensor, mask_argsort: Tensor, indices: Tensor, reverse_mask: bool, mask_filter: int = 0xffffffff, mask_width: int = -1, mask_output: Tensor = Tensor(), alpha: float = 1.0, beta: float = 0.0, stream_int: int = 0, auto_fp32_accum: bool = True, fp32_accum: bool = False, num_run: int = 5, use_tf32: bool = True) -> Tuple[ConvTuneResult, float]: def tune_and_cache(self, op_type: int, inp: Tensor, weight: Tensor, output: Tensor, layout_i: int, layout_w: int, layout_o: int, interleave_i: int, interleave_w: int, interleave_o: int, arch: Tuple[int, int], mask: Tensor, mask_argsort: Tensor, indices: Tensor, reverse_mask: bool, mask_filter: int = 0xffffffff, mask_width: int = -1, mask_output: Tensor = Tensor(), alpha: float = 1.0, beta: float = 0.0, stream_int: int = 0, mask_int_count: int = 1, auto_fp32_accum: bool = True, fp32_accum: bool = False, num_run: int = 5, use_tf32: bool = True) -> Tuple[ConvTuneResult, float]:
""" """
Args: Args:
op_type: op_type:
...@@ -72,6 +72,7 @@ class ConvTunerSimple: ...@@ -72,6 +72,7 @@ class ConvTunerSimple:
alpha: alpha:
beta: beta:
stream_int: stream_int:
mask_int_count:
auto_fp32_accum: auto_fp32_accum:
fp32_accum: fp32_accum:
num_run: num_run:
...@@ -91,7 +92,7 @@ class ConvTunerSimple: ...@@ -91,7 +92,7 @@ class ConvTunerSimple:
mask_width: mask_width:
""" """
... ...
def run_with_tuned_result(self, profile_res, op_type: int, inp: Tensor, weight: Tensor, output: Tensor, mask: Tensor, mask_argsort: Tensor, mask_output: Tensor, indices: Tensor, reverse_mask: bool, mask_filter: int = 0xffffffff, mask_width: int = -1, alpha: float = 1.0, beta: float = 0.0, stream_int: int = 0, workspace: Tensor = Tensor(), verbose: bool = False, timer: CUDAKernelTimer = CUDAKernelTimer(false), force_nvrtc: bool = False, bias: Tensor = Tensor(), act_alpha: float = 0.0, act_beta: float = 0.0, act_type: Activation = Activation.None_) -> None: def run_with_tuned_result(self, profile_res, op_type: int, inp: Tensor, weight: Tensor, output: Tensor, mask: Tensor, mask_argsort: Tensor, mask_output: Tensor, indices: Tensor, reverse_mask: bool, mask_filter: int = 0xffffffff, mask_width: int = -1, alpha: float = 1.0, beta: float = 0.0, stream_int: int = 0, mask_int_count: int = 1, workspace: Tensor = Tensor(), verbose: bool = False, timer: CUDAKernelTimer = CUDAKernelTimer(false), force_nvrtc: bool = False, bias: Tensor = Tensor(), act_alpha: float = 0.0, act_beta: float = 0.0, act_type: Activation = Activation.None_) -> None:
""" """
Args: Args:
profile_res: profile_res:
...@@ -109,6 +110,7 @@ class ConvTunerSimple: ...@@ -109,6 +110,7 @@ class ConvTunerSimple:
alpha: alpha:
beta: beta:
stream_int: stream_int:
mask_int_count:
workspace: workspace:
verbose: verbose:
timer: timer:
......
...@@ -63,7 +63,7 @@ class ConvGemmOps: ...@@ -63,7 +63,7 @@ class ConvGemmOps:
""" """
... ...
@staticmethod @staticmethod
def implicit_gemm(allocator, conv_tuner, features: Tensor, filters: Tensor, pair_fwd: Tensor, pair_mask_fwd_splits: List[Tensor], mask_argsort_fwd_splits: List[Tensor], num_activate_out: int, masks: Tensor, arch: Tuple[int, int], is_train: bool = False, is_subm: bool = False, stream_int: int = 0, timer: CUDAKernelTimer = CUDAKernelTimer(False), auto_fp32_accum: bool = True, fp32_accum: bool = False, bias: Tensor = Tensor(), act_alpha: float = 0.0, act_beta: float = 0.0, act_type: Activation = Activation.None_, use_tf32: bool = True) -> Tuple[int, Any]: def implicit_gemm(allocator, conv_tuner, features: Tensor, filters: Tensor, pair_fwd: Tensor, pair_mask_fwd_splits: List[Tensor], mask_argsort_fwd_splits: List[Tensor], num_activate_out: int, masks: Tensor, mask_int_count: int, arch: Tuple[int, int], is_train: bool = False, is_subm: bool = False, stream_int: int = 0, timer: CUDAKernelTimer = CUDAKernelTimer(False), auto_fp32_accum: bool = True, fp32_accum: bool = False, bias: Tensor = Tensor(), act_alpha: float = 0.0, act_beta: float = 0.0, act_type: Activation = Activation.None_, use_tf32: bool = True) -> Tuple[int, Any]:
""" """
Args: Args:
allocator: allocator:
...@@ -75,6 +75,7 @@ class ConvGemmOps: ...@@ -75,6 +75,7 @@ class ConvGemmOps:
mask_argsort_fwd_splits: mask_argsort_fwd_splits:
num_activate_out: num_activate_out:
masks: masks:
mask_int_count:
arch: arch:
is_train: is_train:
is_subm: is_subm:
...@@ -90,7 +91,7 @@ class ConvGemmOps: ...@@ -90,7 +91,7 @@ class ConvGemmOps:
""" """
... ...
@staticmethod @staticmethod
def implicit_gemm_backward(allocator, conv_tuner, features: Tensor, filters: Tensor, out_bp: Tensor, pair_fwd: Tensor, pair_bwd: Tensor, pair_mask_fwd_splits: List[Tensor], pair_mask_bwd_splits: List[Tensor], mask_argsort_fwd_splits: List[Tensor], mask_argsort_bwd_splits: List[Tensor], mask_output_fwd: Tensor, masks: Tensor, arch: Tuple[int, int], mask_width: int, is_subm: bool, stream_int: int = 0, timer: CUDAKernelTimer = CUDAKernelTimer(False), auto_fp32_accum: bool = True, fp32_accum: bool = False, use_tf32: bool = True) -> None: def implicit_gemm_backward(allocator, conv_tuner, features: Tensor, filters: Tensor, out_bp: Tensor, pair_fwd: Tensor, pair_bwd: Tensor, pair_mask_fwd_splits: List[Tensor], pair_mask_bwd_splits: List[Tensor], mask_argsort_fwd_splits: List[Tensor], mask_argsort_bwd_splits: List[Tensor], mask_output_fwd: Tensor, masks: Tensor, mask_int_count: int, arch: Tuple[int, int], mask_width: int, is_subm: bool, stream_int: int = 0, timer: CUDAKernelTimer = CUDAKernelTimer(False), auto_fp32_accum: bool = True, fp32_accum: bool = False, use_tf32: bool = True) -> None:
""" """
Args: Args:
allocator: allocator:
...@@ -106,6 +107,7 @@ class ConvGemmOps: ...@@ -106,6 +107,7 @@ class ConvGemmOps:
mask_argsort_bwd_splits: mask_argsort_bwd_splits:
mask_output_fwd: mask_output_fwd:
masks: masks:
mask_int_count:
arch: arch:
mask_width: mask_width:
is_subm: is_subm:
......
...@@ -462,6 +462,7 @@ class SpconvOps(pccm.Class): ...@@ -462,6 +462,7 @@ class SpconvOps(pccm.Class):
code.arg("ksize, stride, padding, dilation", f"std::vector<int>") code.arg("ksize, stride, padding, dilation", f"std::vector<int>")
code.arg("transposed", f"bool", "false") code.arg("transposed", f"bool", "false")
code.arg("stream_int", f"std::uintptr_t", "0", pyanno="int") code.arg("stream_int", f"std::uintptr_t", "0", pyanno="int")
code.arg("mask_int_count", "int", "1")
code.raw(f""" code.raw(f"""
int ndim = indices.dim(1) - 1; int ndim = indices.dim(1) - 1;
TV_ASSERT_RT_ERR(output_dims.size() == ndim && input_dims.size() == ndim && TV_ASSERT_RT_ERR(output_dims.size() == ndim && input_dims.size() == ndim &&
...@@ -488,7 +489,7 @@ class SpconvOps(pccm.Class): ...@@ -488,7 +489,7 @@ class SpconvOps(pccm.Class):
indice_pairs_uniq, indice_pairs_uniq_before_sort, indice_pairs_uniq, indice_pairs_uniq_before_sort,
out_inds, mask_fwd, mask_bwd, out_inds, mask_fwd, mask_bwd,
num_out_act, batch_size, output_dims_, input_dims_, num_out_act, batch_size, output_dims_, input_dims_,
ksize_, stride_, padding_, dilation_, transposed, stream_int); ksize_, stride_, padding_, dilation_, transposed, stream_int, mask_int_count);
}} }}
""") """)
code.raw(f"""TV_THROW_RT_ERR("unknown ndim", ndim);""") code.raw(f"""TV_THROW_RT_ERR("unknown ndim", ndim);""")
...@@ -512,6 +513,7 @@ class SpconvOps(pccm.Class): ...@@ -512,6 +513,7 @@ class SpconvOps(pccm.Class):
code.arg("ksize, stride, padding, dilation", f"std::vector<int>") code.arg("ksize, stride, padding, dilation", f"std::vector<int>")
code.arg("transposed", f"bool", "false") code.arg("transposed", f"bool", "false")
code.arg("stream_int", f"std::uintptr_t", "0", pyanno="int") code.arg("stream_int", f"std::uintptr_t", "0", pyanno="int")
code.arg("mask_int_count", "int", "1")
code.raw(f""" code.raw(f"""
int ndim = indices.dim(1) - 1; int ndim = indices.dim(1) - 1;
TV_ASSERT_RT_ERR(output_dims.size() == ndim && input_dims.size() == ndim && TV_ASSERT_RT_ERR(output_dims.size() == ndim && input_dims.size() == ndim &&
...@@ -538,7 +540,7 @@ class SpconvOps(pccm.Class): ...@@ -538,7 +540,7 @@ class SpconvOps(pccm.Class):
indice_pairs_uniq, indice_pairs_uniq_before_sort, indice_pairs_uniq, indice_pairs_uniq_before_sort,
out_inds, mask_fwd, mask_bwd, out_inds, mask_fwd, mask_bwd,
num_out_act, batch_size, output_dims_, input_dims_, num_out_act, batch_size, output_dims_, input_dims_,
ksize_, stride_, padding_, dilation_, transposed, stream_int); ksize_, stride_, padding_, dilation_, transposed, stream_int, mask_int_count);
}} }}
""") """)
code.raw(f"""TV_THROW_RT_ERR("unknown ndim", ndim);""") code.raw(f"""TV_THROW_RT_ERR("unknown ndim", ndim);""")
...@@ -559,6 +561,7 @@ class SpconvOps(pccm.Class): ...@@ -559,6 +561,7 @@ class SpconvOps(pccm.Class):
"cumm.tensorview.Tensor = Tensor()") "cumm.tensorview.Tensor = Tensor()")
code.arg("backward", "bool", "false") code.arg("backward", "bool", "false")
code.arg("stream_int", f"std::uintptr_t", "0", pyanno="int = 0") code.arg("stream_int", f"std::uintptr_t", "0", pyanno="int = 0")
code.arg("mask_int_count", "int", "1")
code.raw(f""" code.raw(f"""
int ndim = indices.dim(1) - 1; int ndim = indices.dim(1) - 1;
TV_ASSERT_RT_ERR(input_dims.size() == ndim && TV_ASSERT_RT_ERR(input_dims.size() == ndim &&
...@@ -579,7 +582,7 @@ class SpconvOps(pccm.Class): ...@@ -579,7 +582,7 @@ class SpconvOps(pccm.Class):
indice_pairs, out_inds, indice_num_per_loc, indice_pairs, out_inds, indice_num_per_loc,
batch_size, input_dims_, batch_size, input_dims_,
ksize_, dilation_, indice_pair_mask, backward, ksize_, dilation_, indice_pair_mask, backward,
stream_int); stream_int, mask_int_count);
}} }}
""") """)
code.raw(f"""TV_THROW_RT_ERR("unknown ndim", ndim);""") code.raw(f"""TV_THROW_RT_ERR("unknown ndim", ndim);""")
...@@ -906,7 +909,7 @@ class SpconvOps(pccm.Class): ...@@ -906,7 +909,7 @@ class SpconvOps(pccm.Class):
""") """)
return code return code
def sort_1d_by_key_allocator_template(self, use_allocator: bool): def sort_1d_by_key_allocator_template(self, use_allocator: bool, int_count: int = 1):
code = pccm.FunctionCode() code = pccm.FunctionCode()
if CUMM_CPU_ONLY_BUILD: if CUMM_CPU_ONLY_BUILD:
return code.make_invalid() return code.make_invalid()
...@@ -942,18 +945,19 @@ class SpconvOps(pccm.Class): ...@@ -942,18 +945,19 @@ class SpconvOps(pccm.Class):
code.raw(f""" code.raw(f"""
cudaStream_t stream_cu = reinterpret_cast<cudaStream_t>(stream); cudaStream_t stream_cu = reinterpret_cast<cudaStream_t>(stream);
if (indices.empty()){{ if (indices.empty()){{
indices = tv::empty({{data.dim(0)}}, tv::int32, 0); indices = tv::empty({{data.dim(0) / {int_count}}}, tv::int32, 0);
}} }}
tv::cuda::Launch launcher(data.dim(0), stream_cu); tv::cuda::Launch launcher(data.dim(0), stream_cu);
launcher(cudakers::arange_kernel<int32_t>, indices.data_ptr<int32_t>(), indices.dim(0)); launcher(cudakers::arange_kernel<int32_t>, indices.data_ptr<int32_t>(), indices.dim(0));
// auto timer = tv::CUDATimer(); // auto timer = tv::CUDATimer();
tv::dispatch<int32_t, uint32_t, int64_t, uint64_t>(data.dtype(), [&](auto I){{ tv::dispatch<int32_t, uint32_t, int64_t, uint64_t>(data.dtype(), [&](auto I){{
using T = TV_DECLTYPE(I); using T_ = TV_DECLTYPE(I);
thrust::device_ptr<T> ptr_tr(data.data_ptr<T>()); using T = {"T_" if int_count == 1 else f"thrust::tuple<{', '.join(['T_'] * int_count) }>"};
thrust::device_ptr<T> ptr_tr(reinterpret_cast<T*>(data.data_ptr<T_>()));
thrust::device_ptr<int32_t> ptr_k(indices.data_ptr<int32_t>()); thrust::device_ptr<int32_t> ptr_k(indices.data_ptr<int32_t>());
auto thrust_ctx = thrust::cuda::par.on(stream_cu); auto thrust_ctx = thrust::cuda::par.on(stream_cu);
auto ctx2 = thrust::cuda::par(allocator).on(stream_cu); auto ctx2 = thrust::cuda::par(allocator).on(stream_cu);
thrust::sort_by_key(ctx2, ptr_tr, ptr_tr + data.dim(0), ptr_k); thrust::sort_by_key(ctx2, ptr_tr, ptr_tr + data.dim(0) / {int_count}, ptr_k);
}}); }});
// tv::ssprint("SORT BY KEY TIME", data.dim(0), timer.report() / 1000.0); // tv::ssprint("SORT BY KEY TIME", data.dim(0), timer.report() / 1000.0);
return indices; return indices;
...@@ -963,16 +967,73 @@ class SpconvOps(pccm.Class): ...@@ -963,16 +967,73 @@ class SpconvOps(pccm.Class):
@pccm.pybind.mark @pccm.pybind.mark
@_STATIC_FUNCTION @_STATIC_FUNCTION
def sort_1d_by_key_allocator(self): def sort_1d_by_key_allocator_mask32(self):
# for python # for python
return self.sort_1d_by_key_allocator_template(False) return self.sort_1d_by_key_allocator_template(False)
@pccm.pybind.mark
@_STATIC_FUNCTION
def sort_1d_by_key_allocator_mask32_v2(self):
# for python
return self.sort_1d_by_key_allocator_template(True)
@pccm.pybind.mark
@_STATIC_FUNCTION
def sort_1d_by_key_allocator_mask128(self):
# for python
return self.sort_1d_by_key_allocator_template(False, 4)
@pccm.pybind.mark
@_STATIC_FUNCTION
def sort_1d_by_key_allocator_mask128_v2(self):
# for python
return self.sort_1d_by_key_allocator_template(True, 4)
def sort_1d_by_key_allocator_mask_auto_template(self, use_allocator: bool):
code = pccm.FunctionCode()
if CUMM_CPU_ONLY_BUILD:
return code.make_invalid()
code.arg("data", "tv::Tensor")
if not use_allocator:
code.arg("alloc_param", "std::function<std::uintptr_t(std::size_t)>")
else:
code.arg("alloc_param", "ThrustAllocator&")
code.arg("indices",
"tv::Tensor",
"tv::Tensor()",
pyanno="cumm.tensorview.Tensor = Tensor()")
code.arg("stream", "std::uintptr_t", "0", pyanno="int")
code.arg("mask_int_count", "int", "1")
code.raw(f"""
switch (mask_int_count){{
case 1:
return sort_1d_by_key_allocator_mask32{"_v2" if use_allocator else ""}(data, alloc_param, indices, stream);
case 4:
return sort_1d_by_key_allocator_mask128{"_v2" if use_allocator else ""}(data, alloc_param, indices, stream);
default:
TV_ASSERT_RT_ERR(false, "Not implement for other mask_int_count");
return tv::Tensor();
}}
""")
return code.ret("tv::Tensor")
@pccm.pybind.mark
@pccm.static_function
def sort_1d_by_key_allocator_mask_auto(self):
return self.sort_1d_by_key_allocator_mask_auto_template(False)
@pccm.pybind.mark
@pccm.static_function
def sort_1d_by_key_allocator_mask_auto_v2(self):
return self.sort_1d_by_key_allocator_mask_auto_template(True)
@_STATIC_FUNCTION @_STATIC_FUNCTION
def sort_1d_by_key_allocator_v2(self): def sort_1d_by_key_allocator_v2(self):
# for cpp only # for cpp only
return self.sort_1d_by_key_allocator_template(True) return self.sort_1d_by_key_allocator_template(True)
@pccm.pybind.mark @pccm.pybind.mark
@_STATIC_FUNCTION @_STATIC_FUNCTION
def sort_1d_by_key_split(self): def sort_1d_by_key_split(self):
...@@ -1659,7 +1720,11 @@ class SpconvOps(pccm.Class): ...@@ -1659,7 +1720,11 @@ class SpconvOps(pccm.Class):
tvctx.set_cuda_stream(reinterpret_cast<cudaStream_t>(stream_int)); tvctx.set_cuda_stream(reinterpret_cast<cudaStream_t>(stream_int));
auto conv_algo = static_cast<tv::gemm::SparseConvAlgo>(algo); auto conv_algo = static_cast<tv::gemm::SparseConvAlgo>(algo);
int kv = std::accumulate(ksize.begin(), ksize.end(), 1, std::multiplies<int>()); int kv = std::accumulate(ksize.begin(), ksize.end(), 1, std::multiplies<int>());
TV_ASSERT_RT_ERR(kv <= 32, "currently only support ksize < 32"); int mask_int_count = (kv + 31) / 32;
if (mask_int_count > 1 && mask_int_count < 4)
mask_int_count = 4;
TV_ASSERT_RT_ERR(mask_int_count == 1 || mask_int_count == 4, "Not Implement too large kernel");
// TV_ASSERT_RT_ERR(kv <= 32, "currently only support ksize < 32");
std::vector<int> out_shape; std::vector<int> out_shape;
if (!subm){{ if (!subm){{
if (transposed){{ if (transposed){{
...@@ -1728,6 +1793,7 @@ class SpconvOps(pccm.Class): ...@@ -1728,6 +1793,7 @@ class SpconvOps(pccm.Class):
auto mask_tensor_ptr = mask_tensor.data_ptr<uint32_t>(); auto mask_tensor_ptr = mask_tensor.data_ptr<uint32_t>();
if (is_mask_split){{ if (is_mask_split){{
TV_ASSERT_RT_ERR(mask_int_count == 1, "not support for kv > 32");
auto kv_div_2 = kv / 2; auto kv_div_2 = kv / 2;
auto remain = kv - kv_div_2; auto remain = kv - kv_div_2;
uint64_t mask_np_1 = 1; uint64_t mask_np_1 = 1;
...@@ -1779,14 +1845,14 @@ class SpconvOps(pccm.Class): ...@@ -1779,14 +1845,14 @@ class SpconvOps(pccm.Class):
pair_mask = preallocated.at({pccm.literal(AllocKeys.PairMask)}); pair_mask = preallocated.at({pccm.literal(AllocKeys.PairMask)});
}}else{{ }}else{{
pair_mask = allocator.empty({pccm.literal(AllocKeys.PairMask)}, pair_mask = allocator.empty({pccm.literal(AllocKeys.PairMask)},
{{mask_split_count, num_act_in}}, tv::uint32, 0, stream_int); {{mask_split_count, num_act_in * mask_int_count}}, tv::uint32, 0, stream_int);
}} }}
generate_subm_conv_inds(indices, hash_k, hash_v, pair, out_inds, indice_num_per_loc, generate_subm_conv_inds(indices, hash_k, hash_v, pair, out_inds, indice_num_per_loc,
batch_size, input_dims, ksize, dilation, pair_mask, is_train, stream_int); batch_size, input_dims, ksize, dilation, pair_mask, is_train, stream_int, mask_int_count);
auto mask_argsort = allocator.empty({pccm.literal(AllocKeys.MaskArgSort)}, auto mask_argsort = allocator.empty({pccm.literal(AllocKeys.MaskArgSort)},
{{mask_split_count, num_act_in}}, tv::int32, 0, stream_int); {{mask_split_count, num_act_in}}, tv::int32, 0, stream_int);
for (int j = 0; j < mask_split_count; ++j){{ for (int j = 0; j < mask_split_count; ++j){{
sort_1d_by_key_allocator_v2(pair_mask[j], thrustalloc, mask_argsort[j], stream_int); sort_1d_by_key_allocator_mask_auto_v2(pair_mask[j], thrustalloc, mask_argsort[j], stream_int, mask_int_count);
}} }}
""") """)
with code.else_(): with code.else_():
...@@ -1892,11 +1958,11 @@ Your Conv Params: )" << "\\n"; ...@@ -1892,11 +1958,11 @@ Your Conv Params: )" << "\\n";
pair_fwd = allocator.full_int({pccm.literal(AllocKeys.PairFwd)}, pair_fwd = allocator.full_int({pccm.literal(AllocKeys.PairFwd)},
{{kv, num_act_out}}, -1, indices.dtype(), indices.device(), stream_int); {{kv, num_act_out}}, -1, indices.dtype(), indices.device(), stream_int);
pair_mask_fwd = allocator.zeros({pccm.literal(AllocKeys.PairMask)}, pair_mask_fwd = allocator.zeros({pccm.literal(AllocKeys.PairMask)},
{{mask_split_count, num_act_out}}, tv::uint32, 0, stream_int); {{mask_split_count, num_act_out * mask_int_count}}, tv::uint32, 0, stream_int);
pair_mask_bwd = tv::Tensor(); pair_mask_bwd = tv::Tensor();
if (is_train){{ if (is_train){{
pair_mask_bwd = allocator.zeros({pccm.literal(AllocKeys.PairMaskBwd)}, pair_mask_bwd = allocator.zeros({pccm.literal(AllocKeys.PairMaskBwd)},
{{mask_split_count, indices.dim(0)}}, tv::uint32, 0, stream_int); {{mask_split_count, indices.dim(0) * mask_int_count}}, tv::uint32, 0, stream_int);
}} }}
}} }}
if (!direct_table){{ if (!direct_table){{
...@@ -1928,13 +1994,13 @@ Your Conv Params: )" << "\\n"; ...@@ -1928,13 +1994,13 @@ Your Conv Params: )" << "\\n";
indice_pairs_uniq, indice_pairs_uniq_bkp, indice_pairs_uniq, indice_pairs_uniq_bkp,
out_inds, pair_mask_fwd, pair_mask_bwd, num_act_out, out_inds, pair_mask_fwd, pair_mask_bwd, num_act_out,
batch_size, out_shape, input_dims, ksize, stride, padding, dilation, batch_size, out_shape, input_dims, ksize, stride, padding, dilation,
transposed, stream_int); transposed, stream_int, mask_int_count);
}}else{{ }}else{{
generate_conv_inds_mask_stage2(indices, hash_k, hash_v, pair_fwd, pair_bwd, generate_conv_inds_mask_stage2(indices, hash_k, hash_v, pair_fwd, pair_bwd,
indice_pairs_uniq, indice_pairs_uniq_bkp, indice_pairs_uniq, indice_pairs_uniq_bkp,
out_inds, pair_mask_fwd, pair_mask_bwd, num_act_out, out_inds, pair_mask_fwd, pair_mask_bwd, num_act_out,
batch_size, out_shape, input_dims, ksize, stride, padding, dilation, batch_size, out_shape, input_dims, ksize, stride, padding, dilation,
transposed, stream_int); transposed, stream_int, mask_int_count);
}} }}
}} }}
""") """)
...@@ -1964,21 +2030,21 @@ Your Conv Params: )" << "\\n"; ...@@ -1964,21 +2030,21 @@ Your Conv Params: )" << "\\n";
}} }}
}}else{{ }}else{{
if (!is_train){{ if (!is_train){{
sort_1d_by_key_allocator_v2(pair_mask_fwd[0], thrustalloc, sort_1d_by_key_allocator_mask_auto_v2(pair_mask_fwd[0], thrustalloc,
mask_argsort_fwd[0], stream_int); mask_argsort_fwd[0], stream_int, mask_int_count);
}}else{{ }}else{{
sort_1d_by_key_allocator_v2(pair_mask_fwd[0], thrustalloc, sort_1d_by_key_allocator_mask_auto_v2(pair_mask_fwd[0], thrustalloc,
mask_argsort_fwd[0], stream_int); mask_argsort_fwd[0], stream_int, mask_int_count);
sort_1d_by_key_allocator_v2(pair_mask_bwd[0], thrustalloc, sort_1d_by_key_allocator_mask_auto_v2(pair_mask_bwd[0], thrustalloc,
mask_argsort_bwd[0], stream_int); mask_argsort_bwd[0], stream_int, mask_int_count);
}} }}
}} }}
}} }}
""") """)
code.raw(f""" code.raw(f"""
return std::make_tuple(mask_tensor, num_act_out); return std::make_tuple(mask_tensor, num_act_out, mask_int_count);
""") """)
return code.ret("std::tuple<tv::Tensor, int>") return code.ret("std::tuple<tv::Tensor, int, int>")
@pccm.pybind.mark @pccm.pybind.mark
@pccm.static_function @pccm.static_function
......
...@@ -1138,6 +1138,7 @@ class ConvTunerSimple(pccm.ParameterizedClass): ...@@ -1138,6 +1138,7 @@ class ConvTunerSimple(pccm.ParameterizedClass):
code.arg("beta", "float", "0.0") code.arg("beta", "float", "0.0")
code.arg("stream_int", f"std::uintptr_t", "0", pyanno="int") code.arg("stream_int", f"std::uintptr_t", "0", pyanno="int")
code.arg("mask_int_count", "int", "1")
code.arg("auto_fp32_accum", "bool", "true") code.arg("auto_fp32_accum", "bool", "true")
code.arg("fp32_accum", "bool", "false") code.arg("fp32_accum", "bool", "false")
code.arg("num_run", "int", "5") code.arg("num_run", "int", "5")
...@@ -1186,6 +1187,8 @@ class ConvTunerSimple(pccm.ParameterizedClass): ...@@ -1186,6 +1187,8 @@ class ConvTunerSimple(pccm.ParameterizedClass):
params.indices = indices; params.indices = indices;
params.mask = mask; params.mask = mask;
params.mask_output = mask_output; params.mask_output = mask_output;
params.mask_int_count = mask_int_count;
// if (op_type_cpp == tv::gemm::ConvOpType::kBackwardWeight){{ // if (op_type_cpp == tv::gemm::ConvOpType::kBackwardWeight){{
// TV_ASSERT_RT_ERR(!mask_output.empty(), "error"); // TV_ASSERT_RT_ERR(!mask_output.empty(), "error");
// }} // }}
...@@ -1335,7 +1338,7 @@ class ConvTunerSimple(pccm.ParameterizedClass): ...@@ -1335,7 +1338,7 @@ class ConvTunerSimple(pccm.ParameterizedClass):
code.arg("beta", "float", "0.0") code.arg("beta", "float", "0.0")
code.arg("stream_int", f"std::uintptr_t", "0") code.arg("stream_int", f"std::uintptr_t", "0")
code.arg("mask_int_count", "int", "1")
code.arg("workspace", "tv::Tensor", "tv::Tensor()", code.arg("workspace", "tv::Tensor", "tv::Tensor()",
"cumm.tensorview.Tensor = Tensor()") "cumm.tensorview.Tensor = Tensor()")
code.arg("verbose", f"bool", "false") code.arg("verbose", f"bool", "false")
...@@ -1347,7 +1350,7 @@ class ConvTunerSimple(pccm.ParameterizedClass): ...@@ -1347,7 +1350,7 @@ class ConvTunerSimple(pccm.ParameterizedClass):
code.arg("act_alpha", f"float", "0.0") code.arg("act_alpha", f"float", "0.0")
code.arg("act_beta", f"float", "0.0") code.arg("act_beta", f"float", "0.0")
code.arg("act_type", f"tv::gemm::Activation", "tv::gemm::Activation::kNone", "cumm.tensorview.gemm.Activation = Activation.None_") code.arg("act_type", f"tv::gemm::Activation", "tv::gemm::Activation::kNone", "cumm.tensorview.gemm.Activation = Activation.None_")
if CUMM_CPU_ONLY_BUILD: if CUMM_CPU_ONLY_BUILD:
code.raw(f"TV_THROW_RT_ERR(\"not implemented for cpu!!!\")") code.raw(f"TV_THROW_RT_ERR(\"not implemented for cpu!!!\")")
return code return code
...@@ -1390,6 +1393,7 @@ class ConvTunerSimple(pccm.ParameterizedClass): ...@@ -1390,6 +1393,7 @@ class ConvTunerSimple(pccm.ParameterizedClass):
params.mask_width = mask_width; params.mask_width = mask_width;
params.mask_output = mask_output; params.mask_output = mask_output;
params.reverse_mask = reverse_mask; params.reverse_mask = reverse_mask;
params.mask_int_count = mask_int_count;
if (timer.enable()){{ if (timer.enable()){{
params.timer = timer; params.timer = timer;
...@@ -2035,6 +2039,7 @@ class ConvGemmOps(pccm.ParameterizedClass): ...@@ -2035,6 +2039,7 @@ class ConvGemmOps(pccm.ParameterizedClass):
"std::vector<tv::Tensor>") "std::vector<tv::Tensor>")
code.arg("num_activate_out", "int") code.arg("num_activate_out", "int")
code.arg("masks", "tv::Tensor") code.arg("masks", "tv::Tensor")
code.arg("mask_int_count", "int")
code.arg("arch", "std::tuple<int, int>") code.arg("arch", "std::tuple<int, int>")
code.arg("is_train, is_subm", "bool", "false") code.arg("is_train, is_subm", "bool", "false")
...@@ -2108,6 +2113,7 @@ class ConvGemmOps(pccm.ParameterizedClass): ...@@ -2108,6 +2113,7 @@ class ConvGemmOps(pccm.ParameterizedClass):
tv::Tensor(), // mask_output tv::Tensor(), // mask_output
1.0, 0.0, 1.0, 0.0,
stream_int, stream_int,
mask_int_count, // mask_int_count is after stream_int
auto_fp32_accum, auto_fp32_accum,
fp32_accum, fp32_accum,
5, // num_run 5, // num_run
...@@ -2120,7 +2126,7 @@ class ConvGemmOps(pccm.ParameterizedClass): ...@@ -2120,7 +2126,7 @@ class ConvGemmOps(pccm.ParameterizedClass):
std::vector<tv::Tensor> mask_output_fwd_splits; std::vector<tv::Tensor> mask_output_fwd_splits;
if (is_train){{ if (is_train){{
mask_output_fwd = allocator.empty({pccm.literal(AllocKeys.MaskOutputFwd)}, mask_output_fwd = allocator.empty({pccm.literal(AllocKeys.MaskOutputFwd)},
{{num_split, tv::div_up(num_activate_out, mask_width)}}, {{num_split, tv::div_up(num_activate_out, mask_width) * mask_int_count}},
tv::uint32, features.device(), stream_int); tv::uint32, features.device(), stream_int);
for (int i = 0; i < num_split; ++i){{ for (int i = 0; i < num_split; ++i){{
mask_output_fwd_splits.push_back(mask_output_fwd[i]); mask_output_fwd_splits.push_back(mask_output_fwd[i]);
...@@ -2154,6 +2160,7 @@ class ConvGemmOps(pccm.ParameterizedClass): ...@@ -2154,6 +2160,7 @@ class ConvGemmOps(pccm.ParameterizedClass):
-1, // mask_width -1, // mask_width
1.0, beta, 1.0, beta,
stream_int, stream_int,
mask_int_count, // mask_int_count is after stream_int
tv::Tensor(), // workspace tv::Tensor(), // workspace
false, // verbose false, // verbose
timer, timer,
...@@ -2186,6 +2193,7 @@ class ConvGemmOps(pccm.ParameterizedClass): ...@@ -2186,6 +2193,7 @@ class ConvGemmOps(pccm.ParameterizedClass):
code.arg("mask_output_fwd", "tv::Tensor") code.arg("mask_output_fwd", "tv::Tensor")
code.arg("masks", "tv::Tensor") code.arg("masks", "tv::Tensor")
code.arg("mask_int_count", "int")
code.arg("arch", "std::tuple<int, int>") code.arg("arch", "std::tuple<int, int>")
code.arg("mask_width", "int") code.arg("mask_width", "int")
...@@ -2278,6 +2286,7 @@ class ConvGemmOps(pccm.ParameterizedClass): ...@@ -2278,6 +2286,7 @@ class ConvGemmOps(pccm.ParameterizedClass):
tv::Tensor(), // mask_output tv::Tensor(), // mask_output
1.0, 0.0, 1.0, 0.0,
stream_int, stream_int,
mask_int_count,
auto_fp32_accum, auto_fp32_accum,
fp32_accum, fp32_accum,
5, // num_run 5, // num_run
...@@ -2302,6 +2311,7 @@ class ConvGemmOps(pccm.ParameterizedClass): ...@@ -2302,6 +2311,7 @@ class ConvGemmOps(pccm.ParameterizedClass):
tv::Tensor(), // mask_output tv::Tensor(), // mask_output
1.0, 0.0, 1.0, 0.0,
stream_int, stream_int,
mask_int_count,
auto_fp32_accum, auto_fp32_accum,
fp32_accum, fp32_accum,
5, // num_run 5, // num_run
...@@ -2344,6 +2354,7 @@ class ConvGemmOps(pccm.ParameterizedClass): ...@@ -2344,6 +2354,7 @@ class ConvGemmOps(pccm.ParameterizedClass):
-1, // mask_width -1, // mask_width
1.0, beta, 1.0, beta,
stream_int, stream_int,
mask_int_count,
tv::Tensor(), // workspace tv::Tensor(), // workspace
false, // verbose false, // verbose
timer); timer);
...@@ -2361,6 +2372,7 @@ class ConvGemmOps(pccm.ParameterizedClass): ...@@ -2361,6 +2372,7 @@ class ConvGemmOps(pccm.ParameterizedClass):
mask_width, mask_width,
1.0, 0.0, 1.0, 0.0,
stream_int, stream_int,
mask_int_count,
workspace, // workspace workspace, // workspace
false, // verbose false, // verbose
timer); timer);
......
...@@ -613,10 +613,12 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -613,10 +613,12 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
code.arg("num_indices_in", "int") code.arg("num_indices_in", "int")
code.arg("num_indices_out", "int") code.arg("num_indices_out", "int")
code.arg("mask_int_count", "int")
code.raw(f""" code.raw(f"""
int filter_offset = blockIdx.y; int filter_offset = blockIdx.y;
uint32_t filter_mask_fwd = (1u << (filter_offset)); int filter_pointer_offset = filter_offset / 32;
uint32_t filter_mask_fwd = (1u << (filter_offset % 32));
// TODO following rule for even kernel size is wrong. // TODO following rule for even kernel size is wrong.
// uint32_t filter_mask_bwd = (1u << (gridDim.y - 1 - filter_offset)); // uint32_t filter_mask_bwd = (1u << (gridDim.y - 1 - filter_offset));
...@@ -633,7 +635,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -633,7 +635,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
auto output_index = table.value_ptr()[table_offset]; auto output_index = table.value_ptr()[table_offset];
bool valid = CheckValueValid ? output_index >= 0 : true; bool valid = CheckValueValid ? output_index >= 0 : true;
if (valid){{ if (valid){{
atomicOr(mask_fwd + output_index, filter_mask_fwd); atomicOr(mask_fwd + output_index * mask_int_count + filter_pointer_offset, filter_mask_fwd);
// atomicOr(mask_bwd + input_index, filter_mask_bwd); // atomicOr(mask_bwd + input_index, filter_mask_bwd);
indice_pairs_fwd_filter[output_index] = input_index; indice_pairs_fwd_filter[output_index] = input_index;
if (indice_pairs_bwd != nullptr){{ if (indice_pairs_bwd != nullptr){{
...@@ -655,15 +657,19 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -655,15 +657,19 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
code.arg("num_indices_in", "int") code.arg("num_indices_in", "int")
code.arg("kv", "int") code.arg("kv", "int")
code.arg("mask_int_count", "int")
code.raw(f""" code.raw(f"""
for (int input_index : tv::KernelLoopX<int>(num_indices_in)) {{ for (int input_index : tv::KernelLoopX<int>(num_indices_in)) {{
uint32_t mask = 0; for (int mask_offset = 0; mask_offset < mask_int_count; ++mask_offset){{
for (int filter_offset = 0; filter_offset < kv; ++filter_offset){{ uint32_t mask = 0;
auto val = indice_pairs_bwd[filter_offset * num_indices_in + input_index]; for (int filter_offset = mask_offset * 32; filter_offset < mask_offset * 32 + 32 && filter_offset < kv; ++filter_offset){{
mask |= (val != -1) << filter_offset; auto val = indice_pairs_bwd[filter_offset * num_indices_in + input_index];
mask |= (val != -1) << (filter_offset % 32);
}}
mask_bwd[input_index * mask_int_count + mask_offset] = mask;
}} }}
mask_bwd[input_index] = mask;
}} }}
""") """)
return code return code
...@@ -685,11 +691,13 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -685,11 +691,13 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
code.arg("mask_fwd", f"uint32_t*") # [kernelProd] code.arg("mask_fwd", f"uint32_t*") # [kernelProd]
code.arg("num_indices_in", "int") code.arg("num_indices_in", "int")
code.arg("num_indices_out", "int") code.arg("num_indices_out", "int")
code.arg("mask_int_count", "int")
# TODO use block instead of filter_offset? # TODO use block instead of filter_offset?
code.raw(f""" code.raw(f"""
int filter_offset = blockIdx.y; int filter_offset = blockIdx.y;
uint32_t filter_mask_fwd = (1u << (filter_offset)); int filter_pointer_offset = filter_offset / 32;
uint32_t filter_mask_fwd = (1u << (filter_offset % 32));
auto indice_pairs_fwd_filter = indice_pairs_fwd + filter_offset * num_indices_out; auto indice_pairs_fwd_filter = indice_pairs_fwd + filter_offset * num_indices_out;
// auto indice_pairs_bwd_filter = indice_pairs_bwd + filter_offset * num_indices_in; // auto indice_pairs_bwd_filter = indice_pairs_bwd + filter_offset * num_indices_in;
...@@ -702,7 +710,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -702,7 +710,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
auto output_index = table.value_ptr()[table_offset]; auto output_index = table.value_ptr()[table_offset];
bool valid = CheckValueValid ? output_index >= 0 : true; bool valid = CheckValueValid ? output_index >= 0 : true;
if (valid){{ if (valid){{
atomicOr(mask_fwd + output_index, filter_mask_fwd); atomicOr(mask_fwd + output_index * mask_int_count + filter_pointer_offset, filter_mask_fwd);
indice_pairs_fwd_filter[output_index] = input_index; indice_pairs_fwd_filter[output_index] = input_index;
}} }}
}} }}
...@@ -812,11 +820,14 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -812,11 +820,14 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
code.arg("RS", "int") code.arg("RS", "int")
code.arg("is_train", "bool") code.arg("is_train", "bool")
code.arg("mask_int_count", "int", "1")
code.raw(f""" code.raw(f"""
int filter_offset = blockIdx.y; int filter_offset = blockIdx.y;
uint32_t filter_mask_out = (1u << (filter_offset)); uint32_t filter_mask_out = (1u << (filter_offset % 32));
uint32_t filter_mask_in = (1u << (RS - 1 - filter_offset)); uint32_t filter_mask_out_offset = filter_offset / 32;
uint32_t filter_mask_in = (1u << ((RS - 1 - filter_offset) % 32));
uint32_t filter_mask_in_offset = (RS - 1 - filter_offset) / 32;
// uint32_t filter_mask_center = (1u << (RS / 2)); // uint32_t filter_mask_center = (1u << (RS / 2));
loc_iter.set_filter_offset(filter_offset); loc_iter.set_filter_offset(filter_offset);
...@@ -843,8 +854,8 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -843,8 +854,8 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
auto table_offset = table.lookup_offset(offset); // performance bound auto table_offset = table.lookup_offset(offset); // performance bound
if (table_offset != -1){{ if (table_offset != -1){{
auto input_index = table.value_ptr()[table_offset]; // we find a input indice idx. auto input_index = table.value_ptr()[table_offset]; // we find a input indice idx.
atomicOr(mask + output_index, filter_mask_out); atomicOr(mask + output_index * mask_int_count + filter_mask_out_offset, filter_mask_out);
atomicOr(mask + input_index, filter_mask_in); atomicOr(mask + input_index * mask_int_count + filter_mask_in_offset, filter_mask_in);
// for this output, we set correct input idx. // for this output, we set correct input idx.
indice_pairs[filter_offset_mul_indices_pair_size + output_index] = input_index; indice_pairs[filter_offset_mul_indices_pair_size + output_index] = input_index;
if (is_train){{ if (is_train){{
...@@ -1244,6 +1255,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -1244,6 +1255,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
f"tv::array<int, {self.ndim}>") f"tv::array<int, {self.ndim}>")
code.arg("transposed", f"bool", "false") code.arg("transposed", f"bool", "false")
code.arg("stream_int", f"std::uintptr_t", "0") code.arg("stream_int", f"std::uintptr_t", "0")
code.arg("mask_int_count", "int", "1")
code.raw(f""" code.raw(f"""
auto custream = reinterpret_cast<cudaStream_t>(stream_int); auto custream = reinterpret_cast<cudaStream_t>(stream_int);
...@@ -1320,11 +1332,13 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -1320,11 +1332,13 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
indice_pairs_fwd.data_ptr<int>(), indice_pairs_bwd.data_ptr<int>(), indice_pairs_fwd.data_ptr<int>(), indice_pairs_bwd.data_ptr<int>(),
indice_pairs_uniq_before_sort.data_ptr<K>(), indice_pairs_uniq_before_sort.data_ptr<K>(),
mask_fwd.data_ptr<uint32_t>(), mask_bwd.data_ptr<uint32_t>(), mask_fwd.data_ptr<uint32_t>(), mask_bwd.data_ptr<uint32_t>(),
num_act_in, indice_pairs_fwd.dim(1)); num_act_in, indice_pairs_fwd.dim(1),
mask_int_count);
launcher_num_act_in_no_y(calc_conv_indices_stage2_mask_output, launcher_num_act_in_no_y(calc_conv_indices_stage2_mask_output,
indice_pairs_bwd.data_ptr<int>(), indice_pairs_bwd.data_ptr<int>(),
mask_bwd.data_ptr<uint32_t>(), mask_bwd.data_ptr<uint32_t>(),
num_act_in, kv); num_act_in, kv,
mask_int_count);
if (mask_fwd.dim(0) == 2){{ if (mask_fwd.dim(0) == 2){{
mask_fwd[1].copy_(mask_fwd[0], ctx); mask_fwd[1].copy_(mask_fwd[0], ctx);
}} }}
...@@ -1336,7 +1350,8 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -1336,7 +1350,8 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
indice_pairs_fwd.data_ptr<int>(), indice_pairs_bwd.data_ptr<int>(), indice_pairs_fwd.data_ptr<int>(), indice_pairs_bwd.data_ptr<int>(),
indice_pairs_uniq_before_sort.data_ptr<K>(), indice_pairs_uniq_before_sort.data_ptr<K>(),
mask_fwd.data_ptr<uint32_t>(), mask_fwd.data_ptr<uint32_t>(),
num_act_in, indice_pairs_fwd.dim(1)); num_act_in, indice_pairs_fwd.dim(1),
mask_int_count);
if (mask_fwd.dim(0) == 2){{ if (mask_fwd.dim(0) == 2){{
mask_fwd[1].copy_(mask_fwd[0], ctx); mask_fwd[1].copy_(mask_fwd[0], ctx);
}} }}
...@@ -1489,6 +1504,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -1489,6 +1504,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
"cumm.tensorview.Tensor = Tensor()") "cumm.tensorview.Tensor = Tensor()")
code.arg("is_train", "bool", "true") code.arg("is_train", "bool", "true")
code.arg("stream_int", f"std::uintptr_t", "0") code.arg("stream_int", f"std::uintptr_t", "0")
code.arg("mask_int_count", "int", "1")
code.raw(f""" code.raw(f"""
int num_act_in_real = indices.dim(0); int num_act_in_real = indices.dim(0);
...@@ -1496,7 +1512,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -1496,7 +1512,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
auto ctx = tv::Context(); auto ctx = tv::Context();
ctx.set_cuda_stream(custream); ctx.set_cuda_stream(custream);
if (!indice_pair_mask.empty()){{ if (!indice_pair_mask.empty()){{
TV_ASSERT_INVALID_ARG(ksize.op<tv::arrayops::prod>() <= 32, "for now only support 32bit mask"); // TV_ASSERT_INVALID_ARG(ksize.op<tv::arrayops::prod>() <= 32, "for now only support 32bit mask");
}} }}
// TODO stream // TODO stream
// TODO handle num input == 0 // TODO handle num input == 0
...@@ -1557,11 +1573,15 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -1557,11 +1573,15 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
}}else{{ }}else{{
// indice_pair_mask: [1, num_act_in] // indice_pair_mask: [1, num_act_in]
tv::cuda::Launch lanucher_fill(num_act_in_real, custream); tv::cuda::Launch lanucher_fill(num_act_in_real, custream);
if (mask_int_count == 1)
lanucher_fill(cudakers::fill_kernel<uint32_t>, indice_pair_mask.data_ptr<uint32_t>(), (1 << (kv / 2)), indices.dim(0)); lanucher_fill(cudakers::fill_kernel<uint32_t>, indice_pair_mask.data_ptr<uint32_t>(), (1 << (kv / 2)), indices.dim(0));
else
lanucher_fill(init_subm_multiple_mask_int_kernel<uint32_t>,
indice_pair_mask.data_ptr<uint32_t>(), kv / 2, indices.dim(0), mask_int_count);
TV_ASSERT_RT_ERR(indice_pair_mask.dim(0) == 1, "error"); TV_ASSERT_RT_ERR(indice_pair_mask.dim(0) == 1, "error");
launcher_num_act_in(calc_subm_conv_indices_mask<table_t, {loc_type}>, loc_iter, hash, launcher_num_act_in(calc_subm_conv_indices_mask<table_t, {loc_type}>, loc_iter, hash,
indices.data_ptr<const int>(), indice_pairs.data_ptr<int>(), indices.data_ptr<const int>(), indice_pairs.data_ptr<int>(),
indice_pair_mask.data_ptr<uint32_t>(), indices.dim(0), indice_pairs.dim(2), kv, is_train); indice_pair_mask.data_ptr<uint32_t>(), indices.dim(0), indice_pairs.dim(2), kv, is_train, mask_int_count);
}} }}
}}else{{ }}else{{
TV_ASSERT_RT_ERR(indice_pairs.ndim() == 3, "error"); TV_ASSERT_RT_ERR(indice_pairs.ndim() == 3, "error");
...@@ -1576,6 +1596,25 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -1576,6 +1596,25 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
return code.ret("int") return code.ret("int")
@pccm.cuda.cuda_global_function
def init_subm_multiple_mask_int_kernel(self):
code = pccm.FunctionCode()
code.targ("T")
code.arg("ptr", "T*")
code.arg("set_bit", "int")
code.arg("length", "int")
code.arg("mask_int_count", "int")
code.raw(f"""
int initial_offset = blockIdx.x * blockDim.x + threadIdx.x;
int bit_offset = set_bit / 32;
int bit_residue = set_bit % 32;
for(int offset : tv::KernelLoopX<int>(length)){{
for (int i=0; i < mask_int_count; ++i)
ptr[offset * mask_int_count + i] = (i == bit_offset) * (1 << bit_residue);
}}
""")
return code
class SparseConvIndicesCPU(pccm.ParameterizedClass): class SparseConvIndicesCPU(pccm.ParameterizedClass):
......
...@@ -111,8 +111,8 @@ class SparseConvolution(SparseModule): ...@@ -111,8 +111,8 @@ class SparseConvolution(SparseModule):
algo = ConvAlgo.MaskImplicitGemm algo = ConvAlgo.MaskImplicitGemm
else: else:
algo = ConvAlgo.Native algo = ConvAlgo.Native
if kv > 32: # if kv > 32:
assert algo == ConvAlgo.Native, "implicit gemm don't support kv >= 32 for now" # assert algo == ConvAlgo.Native, "implicit gemm don't support kv >= 32 for now"
if CPU_ONLY_BUILD: if CPU_ONLY_BUILD:
assert algo == ConvAlgo.Native, "cpu only build only support native algorithm" assert algo == ConvAlgo.Native, "cpu only build only support native algorithm"
self.algo = algo self.algo = algo
...@@ -481,6 +481,7 @@ class SparseConvolution(SparseModule): ...@@ -481,6 +481,7 @@ class SparseConvolution(SparseModule):
mask_argsort_fwd_splits = datas.mask_argsort_fwd_splits mask_argsort_fwd_splits = datas.mask_argsort_fwd_splits
mask_argsort_bwd_splits = datas.mask_argsort_bwd_splits mask_argsort_bwd_splits = datas.mask_argsort_bwd_splits
masks = datas.masks masks = datas.masks
mask_int_count = datas.mask_int_count
assert self.subm, "only support reuse subm indices" assert self.subm, "only support reuse subm indices"
self._check_subm_reuse_valid(input, spatial_shape, self._check_subm_reuse_valid(input, spatial_shape,
datas) datas)
...@@ -523,6 +524,7 @@ class SparseConvolution(SparseModule): ...@@ -523,6 +524,7 @@ class SparseConvolution(SparseModule):
mask_argsort_fwd_splits = res[6] mask_argsort_fwd_splits = res[6]
mask_argsort_bwd_splits = res[7] mask_argsort_bwd_splits = res[7]
masks = res[8] masks = res[8]
mask_int_count = res[9]
if self.indice_key is not None: if self.indice_key is not None:
indice_data = ImplicitGemmIndiceData( indice_data = ImplicitGemmIndiceData(
outids, outids,
...@@ -541,7 +543,8 @@ class SparseConvolution(SparseModule): ...@@ -541,7 +543,8 @@ class SparseConvolution(SparseModule):
ksize=self.kernel_size, ksize=self.kernel_size,
stride=self.stride, stride=self.stride,
padding=self.padding, padding=self.padding,
dilation=self.dilation) dilation=self.dilation,
mask_int_count=mask_int_count)
msg = f"your indice key {self.indice_key} already exists in this sparse tensor." msg = f"your indice key {self.indice_key} already exists in this sparse tensor."
assert self.indice_key not in indice_dict, msg assert self.indice_key not in indice_dict, msg
indice_dict[self.indice_key] = indice_data indice_dict[self.indice_key] = indice_data
...@@ -553,7 +556,7 @@ class SparseConvolution(SparseModule): ...@@ -553,7 +556,7 @@ class SparseConvolution(SparseModule):
features, self.weight, pair_fwd, pair_bwd, features, self.weight, pair_fwd, pair_bwd,
pair_mask_fwd_splits, pair_mask_bwd_splits, pair_mask_fwd_splits, pair_mask_bwd_splits,
mask_argsort_fwd_splits, mask_argsort_bwd_splits, mask_argsort_fwd_splits, mask_argsort_bwd_splits,
num_activate_out, masks, self.training, self.subm, num_activate_out, masks, mask_int_count, self.training, self.subm,
input._timer, self.fp32_accum, input._timer, self.fp32_accum,
bias_for_infer, bias_for_infer,
self.act_alpha, self.act_alpha,
......
...@@ -89,7 +89,8 @@ class ImplicitGemmIndiceData(object): ...@@ -89,7 +89,8 @@ class ImplicitGemmIndiceData(object):
out_spatial_shape, is_subm: bool, algo: ConvAlgo, out_spatial_shape, is_subm: bool, algo: ConvAlgo,
ksize: List[int], stride: List[int], dilation: List[int], padding: List[int], ksize: List[int], stride: List[int], dilation: List[int], padding: List[int],
in_voxel_num: Optional[Any] = None, in_voxel_num: Optional[Any] = None,
out_voxel_num: Optional[Any] = None): out_voxel_num: Optional[Any] = None,
mask_int_count: int=1):
self.out_indices = out_indices self.out_indices = out_indices
self.indices = indices self.indices = indices
self.pair_fwd = pair_fwd self.pair_fwd = pair_fwd
...@@ -110,6 +111,7 @@ class ImplicitGemmIndiceData(object): ...@@ -110,6 +111,7 @@ class ImplicitGemmIndiceData(object):
# in/out voxel_num is only used in tensorrt conversion. # in/out voxel_num is only used in tensorrt conversion.
self.in_voxel_num = in_voxel_num self.in_voxel_num = in_voxel_num
self.out_voxel_num = out_voxel_num self.out_voxel_num = out_voxel_num
self.mask_int_count = mask_int_count
def scatter_nd(indices, updates, shape): def scatter_nd(indices, updates, shape):
......
...@@ -198,6 +198,7 @@ class SparseImplicitGemmFunction(Function): ...@@ -198,6 +198,7 @@ class SparseImplicitGemmFunction(Function):
mask_argsort_bwd_splits: List[torch.Tensor], mask_argsort_bwd_splits: List[torch.Tensor],
num_activate_out: int, num_activate_out: int,
masks: List[np.ndarray], masks: List[np.ndarray],
mask_int_count: int,
is_train: bool, is_train: bool,
is_subm: bool, is_subm: bool,
timer: CUDAKernelTimer = CUDAKernelTimer(False), timer: CUDAKernelTimer = CUDAKernelTimer(False),
...@@ -209,7 +210,7 @@ class SparseImplicitGemmFunction(Function): ...@@ -209,7 +210,7 @@ class SparseImplicitGemmFunction(Function):
try: try:
out, mask_out, mask_width = ops.implicit_gemm( out, mask_out, mask_width = ops.implicit_gemm(
features, filters, pair_fwd, pair_mask_fwd_splits, features, filters, pair_fwd, pair_mask_fwd_splits,
mask_argsort_fwd_splits, num_activate_out, masks, is_train, mask_argsort_fwd_splits, num_activate_out, masks, mask_int_count, is_train,
is_subm, timer, fp32_accum, bias, act_alpha, act_beta, is_subm, timer, fp32_accum, bias, act_alpha, act_beta,
act_type) act_type)
except Exception as e: except Exception as e:
...@@ -235,6 +236,7 @@ class SparseImplicitGemmFunction(Function): ...@@ -235,6 +236,7 @@ class SparseImplicitGemmFunction(Function):
ctx.masks = masks ctx.masks = masks
ctx.is_subm = is_subm ctx.is_subm = is_subm
ctx.fp32_accum = fp32_accum ctx.fp32_accum = fp32_accum
ctx.mask_int_count = mask_int_count
return out return out
@staticmethod @staticmethod
...@@ -253,6 +255,7 @@ class SparseImplicitGemmFunction(Function): ...@@ -253,6 +255,7 @@ class SparseImplicitGemmFunction(Function):
is_subm = ctx.is_subm is_subm = ctx.is_subm
timer = ctx.timer timer = ctx.timer
fp32_accum = ctx.fp32_accum fp32_accum = ctx.fp32_accum
mask_int_count = ctx.mask_int_count
try: try:
input_bp, filters_bp = ops.implicit_gemm_backward( input_bp, filters_bp = ops.implicit_gemm_backward(
...@@ -267,6 +270,7 @@ class SparseImplicitGemmFunction(Function): ...@@ -267,6 +270,7 @@ class SparseImplicitGemmFunction(Function):
mask_argsort_bwd_splits, mask_argsort_bwd_splits,
mask_output_fwd=mask_out, mask_output_fwd=mask_out,
masks=masks, masks=masks,
mask_int_count=mask_int_count,
mask_width=mask_width, mask_width=mask_width,
is_subm=is_subm, is_subm=is_subm,
timer=timer, timer=timer,
...@@ -282,7 +286,7 @@ class SparseImplicitGemmFunction(Function): ...@@ -282,7 +286,7 @@ class SparseImplicitGemmFunction(Function):
mask_argsort_bwd_splits, masks)) mask_argsort_bwd_splits, masks))
raise e raise e
None_9 = [None] * 16 None_9 = [None] * 17
return (input_bp, filters_bp, *None_9) return (input_bp, filters_bp, *None_9)
......
...@@ -365,7 +365,7 @@ def get_indice_pairs_implicit_gemm( ...@@ -365,7 +365,7 @@ def get_indice_pairs_implicit_gemm(
timer_cpp = tv.CUDAKernelTimer(False) timer_cpp = tv.CUDAKernelTimer(False)
if timer._timer is not None: if timer._timer is not None:
timer_cpp = timer._timer timer_cpp = timer._timer
mask_tensor, num_act_out = SpconvOps.get_indice_pairs_implicit_gemm( mask_tensor, num_act_out, mask_int_count = SpconvOps.get_indice_pairs_implicit_gemm(
thalloc, thalloc,
torch_tensor_to_tv(indices), torch_tensor_to_tv(indices),
batch_size, batch_size,
...@@ -408,7 +408,7 @@ def get_indice_pairs_implicit_gemm( ...@@ -408,7 +408,7 @@ def get_indice_pairs_implicit_gemm(
assert pair.shape[0] == 2 assert pair.shape[0] == 2
pair_bwd = pair[1] pair_bwd = pair[1]
return (out_inds, indice_num_per_loc, pair[0], pair_bwd, return (out_inds, indice_num_per_loc, pair[0], pair_bwd,
pair_mask_in_splits, [], mask_argsort_in_splits, [], masks) pair_mask_in_splits, [], mask_argsort_in_splits, [], masks, mask_int_count)
else: else:
pair_bwd = thalloc.allocated.get(AllocKeys.PairBwd, torch.Tensor()) pair_bwd = thalloc.allocated.get(AllocKeys.PairBwd, torch.Tensor())
pair_fwd = thalloc.allocated[AllocKeys.PairFwd] pair_fwd = thalloc.allocated[AllocKeys.PairFwd]
...@@ -437,12 +437,16 @@ def get_indice_pairs_implicit_gemm( ...@@ -437,12 +437,16 @@ def get_indice_pairs_implicit_gemm(
] ]
return (out_inds, indice_num_per_loc, pair_fwd, pair_bwd, return (out_inds, indice_num_per_loc, pair_fwd, pair_bwd,
pair_mask_fwd_splits, pair_mask_bwd_splits, pair_mask_fwd_splits, pair_mask_bwd_splits,
mask_argsort_fwd_splits, mask_argsort_bwd_splits, masks) mask_argsort_fwd_splits, mask_argsort_bwd_splits, masks, mask_int_count)
assert indices.is_cuda, "implicit gemm only support cuda" assert indices.is_cuda, "implicit gemm only support cuda"
ndim = indices.shape[1] - 1 ndim = indices.shape[1] - 1
kv: int = functools.reduce(lambda x, y: x * y, ksize, 1) kv: int = functools.reduce(lambda x, y: x * y, ksize, 1)
# TODO in future we will support up to 128 kernel volume. # TODO in future we will support up to 128 kernel volume.
assert kv <= 32, "currently only support kernel volume <= 32 to use implicit gemm" # assert kv <= 32, "currently only support kernel volume <= 32 to use implicit gemm"
mask_int_count = (kv + 31) // 32
if 1 < mask_int_count < 4:
mask_int_count = 4
assert mask_int_count in [1, 4]
if not subm: if not subm:
if transpose: if transpose:
...@@ -489,6 +493,7 @@ def get_indice_pairs_implicit_gemm( ...@@ -489,6 +493,7 @@ def get_indice_pairs_implicit_gemm(
pair_tv = torch_tensor_to_tv(pair) pair_tv = torch_tensor_to_tv(pair)
indice_num_per_loc_tv = torch_tensor_to_tv(indice_num_per_loc) indice_num_per_loc_tv = torch_tensor_to_tv(indice_num_per_loc)
if is_mask_split: if is_mask_split:
assert mask_int_count == 1, "Not Implemented"
kv_div_2 = kv // 2 kv_div_2 = kv // 2
remain = kv - kv_div_2 remain = kv - kv_div_2
mask_np_1 = np.array([1], dtype=np.uint64) mask_np_1 = np.array([1], dtype=np.uint64)
...@@ -506,7 +511,7 @@ def get_indice_pairs_implicit_gemm( ...@@ -506,7 +511,7 @@ def get_indice_pairs_implicit_gemm(
hashdata = _HashData(out_inds.shape[0], use_int64_hash_k, hashdata = _HashData(out_inds.shape[0], use_int64_hash_k,
indices.device) indices.device)
pair_mask = torch.empty((mask_split_count, indices.shape[0]), pair_mask = torch.empty((mask_split_count, indices.shape[0] * mask_int_count),
dtype=torch.int32, dtype=torch.int32,
device=indices.device) device=indices.device)
...@@ -526,7 +531,8 @@ def get_indice_pairs_implicit_gemm( ...@@ -526,7 +531,8 @@ def get_indice_pairs_implicit_gemm(
dilation=dilation, dilation=dilation,
indice_pair_mask=pair_mask_tv, indice_pair_mask=pair_mask_tv,
backward=is_train, backward=is_train,
stream_int=stream) stream_int=stream,
mask_int_count=mask_int_count)
# torch.cuda.synchronize() # torch.cuda.synchronize()
# print("SUBM0", time.time() - t) # print("SUBM0", time.time() - t)
# CONV.stream_synchronize(stream) # CONV.stream_synchronize(stream)
...@@ -543,9 +549,10 @@ def get_indice_pairs_implicit_gemm( ...@@ -543,9 +549,10 @@ def get_indice_pairs_implicit_gemm(
# so I use this stupid hack to use torch allocator without touch # so I use this stupid hack to use torch allocator without touch
# pytorch binary (c++). # pytorch binary (c++).
# f**k thrust # f**k thrust
SpconvOps.sort_1d_by_key_allocator(pair_mask_tv[j], SpconvOps.sort_1d_by_key_allocator_mask_auto(pair_mask_tv[j],
alloc.alloc, alloc.alloc,
mask_argsort_tv[j], stream) mask_argsort_tv[j], stream,
mask_int_count)
# CONV.stream_synchronize(stream) # CONV.stream_synchronize(stream)
pair_mask_in_splits = [pair_mask[i] for i in range(mask_split_count)] pair_mask_in_splits = [pair_mask[i] for i in range(mask_split_count)]
mask_argsort_in_splits = [ mask_argsort_in_splits = [
...@@ -553,10 +560,10 @@ def get_indice_pairs_implicit_gemm( ...@@ -553,10 +560,10 @@ def get_indice_pairs_implicit_gemm(
] ]
if is_train: if is_train:
return (out_inds, indice_num_per_loc, pair[0], pair[1], return (out_inds, indice_num_per_loc, pair[0], pair[1],
pair_mask_in_splits, [], mask_argsort_in_splits, [], masks) pair_mask_in_splits, [], mask_argsort_in_splits, [], masks, mask_int_count)
else: else:
return (out_inds, indice_num_per_loc, pair[0], torch.Tensor(), return (out_inds, indice_num_per_loc, pair[0], torch.Tensor(),
pair_mask_in_splits, [], mask_argsort_in_splits, [], masks) pair_mask_in_splits, [], mask_argsort_in_splits, [], masks, mask_int_count)
else: else:
max_num_act = SpconvOps.get_handcrafted_max_act_out( max_num_act = SpconvOps.get_handcrafted_max_act_out(
indices.shape[0], ksize, stride, padding, dilation) indices.shape[0], ksize, stride, padding, dilation)
...@@ -648,7 +655,7 @@ def get_indice_pairs_implicit_gemm( ...@@ -648,7 +655,7 @@ def get_indice_pairs_implicit_gemm(
-1, -1,
dtype=indices.dtype, dtype=indices.dtype,
device=indices.device) device=indices.device)
pair_mask_fwd = torch.zeros((mask_split_count, num_act_out), pair_mask_fwd = torch.zeros((mask_split_count, num_act_out * mask_int_count),
dtype=torch.int32, dtype=torch.int32,
device=indices.device) device=indices.device)
pair_fwd_tv = torch_tensor_to_tv(pair_fwd) pair_fwd_tv = torch_tensor_to_tv(pair_fwd)
...@@ -658,7 +665,7 @@ def get_indice_pairs_implicit_gemm( ...@@ -658,7 +665,7 @@ def get_indice_pairs_implicit_gemm(
pair_mask_bwd_tv = tv.Tensor() pair_mask_bwd_tv = tv.Tensor()
if is_train: if is_train:
pair_mask_bwd = torch.zeros( pair_mask_bwd = torch.zeros(
(mask_split_count, indices.shape[0]), (mask_split_count, indices.shape[0] * mask_int_count),
dtype=torch.int32, dtype=torch.int32,
device=indices.device) device=indices.device)
pair_mask_bwd_tv = torch_tensor_to_tv(pair_mask_bwd, pair_mask_bwd_tv = torch_tensor_to_tv(pair_mask_bwd,
...@@ -706,7 +713,8 @@ def get_indice_pairs_implicit_gemm( ...@@ -706,7 +713,8 @@ def get_indice_pairs_implicit_gemm(
padding=padding, padding=padding,
dilation=dilation, dilation=dilation,
transposed=transpose, transposed=transpose,
stream_int=stream) stream_int=stream,
mask_int_count=mask_int_count)
mask_argsort_fwd = torch.empty((mask_split_count, out_inds.shape[0]), mask_argsort_fwd = torch.empty((mask_split_count, out_inds.shape[0]),
dtype=torch.int32, dtype=torch.int32,
device=indices.device) device=indices.device)
...@@ -758,25 +766,26 @@ def get_indice_pairs_implicit_gemm( ...@@ -758,25 +766,26 @@ def get_indice_pairs_implicit_gemm(
else: else:
# if pair_mask_bwd_tv.dim(1) > pair_mask_fwd_tv.dim(1): # if pair_mask_bwd_tv.dim(1) > pair_mask_fwd_tv.dim(1):
if not is_train: if not is_train:
SpconvOps.sort_1d_by_key_allocator(pair_mask_fwd_tv[0], SpconvOps.sort_1d_by_key_allocator_mask_auto(pair_mask_fwd_tv[0],
alloc.alloc, alloc.alloc,
mask_argsort_fwd_tv[0], mask_argsort_fwd_tv[0],
stream) stream,
mask_int_count)
else: else:
if pair_mask_bwd_tv.dim(1) > pair_mask_fwd_tv.dim(1): if pair_mask_bwd_tv.dim(1) > pair_mask_fwd_tv.dim(1):
SpconvOps.sort_1d_by_key_allocator( SpconvOps.sort_1d_by_key_allocator_mask_auto(
pair_mask_bwd_tv[0], alloc.alloc, pair_mask_bwd_tv[0], alloc.alloc,
mask_argsort_bwd_tv[0], stream) mask_argsort_bwd_tv[0], stream, mask_int_count)
SpconvOps.sort_1d_by_key_allocator( SpconvOps.sort_1d_by_key_allocator_mask_auto(
pair_mask_fwd_tv[0], alloc.alloc, pair_mask_fwd_tv[0], alloc.alloc,
mask_argsort_fwd_tv[0], stream) mask_argsort_fwd_tv[0], stream, mask_int_count)
else: else:
SpconvOps.sort_1d_by_key_allocator( SpconvOps.sort_1d_by_key_allocator_mask_auto(
pair_mask_fwd_tv[0], alloc.alloc, pair_mask_fwd_tv[0], alloc.alloc,
mask_argsort_fwd_tv[0], stream) mask_argsort_fwd_tv[0], stream, mask_int_count)
SpconvOps.sort_1d_by_key_allocator( SpconvOps.sort_1d_by_key_allocator_mask_auto(
pair_mask_bwd_tv[0], alloc.alloc, pair_mask_bwd_tv[0], alloc.alloc,
mask_argsort_bwd_tv[0], stream) mask_argsort_bwd_tv[0], stream, mask_int_count)
# CONV.stream_synchronize(stream) # CONV.stream_synchronize(stream)
if not is_train: if not is_train:
...@@ -799,7 +808,7 @@ def get_indice_pairs_implicit_gemm( ...@@ -799,7 +808,7 @@ def get_indice_pairs_implicit_gemm(
return (out_inds, indice_num_per_loc, pair_fwd, pair_bwd, return (out_inds, indice_num_per_loc, pair_fwd, pair_bwd,
pair_mask_fwd_splits, pair_mask_bwd_splits, pair_mask_fwd_splits, pair_mask_bwd_splits,
mask_argsort_fwd_splits, mask_argsort_bwd_splits, masks) mask_argsort_fwd_splits, mask_argsort_bwd_splits, masks, mask_int_count)
def indice_conv(features: torch.Tensor, def indice_conv(features: torch.Tensor,
...@@ -1448,6 +1457,7 @@ def implicit_gemm(features: torch.Tensor, ...@@ -1448,6 +1457,7 @@ def implicit_gemm(features: torch.Tensor,
mask_argsort_fwd_splits: List[torch.Tensor], mask_argsort_fwd_splits: List[torch.Tensor],
num_activate_out: int, num_activate_out: int,
masks: List[np.ndarray], masks: List[np.ndarray],
mask_int_count: int,
is_train: bool, is_train: bool,
is_subm: bool, is_subm: bool,
timer: CUDAKernelTimer = CUDAKernelTimer(False), timer: CUDAKernelTimer = CUDAKernelTimer(False),
...@@ -1491,7 +1501,7 @@ def implicit_gemm(features: torch.Tensor, ...@@ -1491,7 +1501,7 @@ def implicit_gemm(features: torch.Tensor,
mask_width, tune_res_cpp = ConvGemmOps.implicit_gemm( mask_width, tune_res_cpp = ConvGemmOps.implicit_gemm(
alloc, CONV_CPP, features_tv, filters_tv, pair_fwd_tv, alloc, CONV_CPP, features_tv, filters_tv, pair_fwd_tv,
pair_mask_fwd_splits_tv, mask_argsort_fwd_splits_tv, pair_mask_fwd_splits_tv, mask_argsort_fwd_splits_tv,
num_activate_out, mask_tv, arch, is_train, is_subm, stream, num_activate_out, mask_tv, mask_int_count, arch, is_train, is_subm, stream,
timer_cpp, auto_fp32_accum, fp32_accum, bias_tv, act_alpha, act_beta, act_type, timer_cpp, auto_fp32_accum, fp32_accum, bias_tv, act_alpha, act_beta, act_type,
use_tf32=constants.SPCONV_ALLOW_TF32) use_tf32=constants.SPCONV_ALLOW_TF32)
out_features = alloc.allocated[AllocKeys.OutFeatures] out_features = alloc.allocated[AllocKeys.OutFeatures]
...@@ -1557,13 +1567,14 @@ def implicit_gemm(features: torch.Tensor, ...@@ -1557,13 +1567,14 @@ def implicit_gemm(features: torch.Tensor,
mask_filter=masks[0].item(), mask_filter=masks[0].item(),
stream=stream, stream=stream,
fp32_accum=fp32_accum, fp32_accum=fp32_accum,
use_tf32=constants.SPCONV_ALLOW_TF32) use_tf32=constants.SPCONV_ALLOW_TF32,
mask_int_count=mask_int_count)
mask_width = tune_res.algo_desp.tile_shape[0] mask_width = tune_res.algo_desp.tile_shape[0]
if is_train: if is_train:
mask_output_fwd = torch.empty( mask_output_fwd = torch.empty(
[num_split, [num_split,
codeops.div_up(num_activate_out, mask_width)], codeops.div_up(num_activate_out, mask_width) * mask_int_count],
dtype=torch.int32, dtype=torch.int32,
device=features.device) device=features.device)
# pytorch don't support uint32. # pytorch don't support uint32.
...@@ -1611,7 +1622,8 @@ def implicit_gemm(features: torch.Tensor, ...@@ -1611,7 +1622,8 @@ def implicit_gemm(features: torch.Tensor,
bias=bias_tv, bias=bias_tv,
act_type=act_type, act_type=act_type,
act_alpha=act_alpha, act_alpha=act_alpha,
act_beta=act_beta) act_beta=act_beta,
mask_int_count=mask_int_count)
# INT8_TEST = True # INT8_TEST = True
# if INT8_TEST: # if INT8_TEST:
# if features.shape[1] % 32 != 0: # if features.shape[1] % 32 != 0:
...@@ -1710,6 +1722,7 @@ def implicit_gemm_backward(features: torch.Tensor, ...@@ -1710,6 +1722,7 @@ def implicit_gemm_backward(features: torch.Tensor,
mask_argsort_bwd_splits: List[torch.Tensor], mask_argsort_bwd_splits: List[torch.Tensor],
mask_output_fwd: Optional[torch.Tensor], mask_output_fwd: Optional[torch.Tensor],
masks: List[np.ndarray], masks: List[np.ndarray],
mask_int_count: int,
mask_width: int, mask_width: int,
is_subm: bool, is_subm: bool,
timer: CUDAKernelTimer = CUDAKernelTimer(False), timer: CUDAKernelTimer = CUDAKernelTimer(False),
...@@ -1769,7 +1782,7 @@ def implicit_gemm_backward(features: torch.Tensor, ...@@ -1769,7 +1782,7 @@ def implicit_gemm_backward(features: torch.Tensor,
alloc, CONV_CPP, features_tv, filters_tv, out_bp_tv, pair_fwd_tv, alloc, CONV_CPP, features_tv, filters_tv, out_bp_tv, pair_fwd_tv,
pair_bwd_tv, pair_mask_fwd_splits_tv, pair_mask_bwd_splits_tv, pair_bwd_tv, pair_mask_fwd_splits_tv, pair_mask_bwd_splits_tv,
mask_argsort_fwd_splits_tv, mask_argsort_bwd_splits_tv, mask_argsort_fwd_splits_tv, mask_argsort_bwd_splits_tv,
mask_output_fwd_tv, mask_tv, arch, mask_width, is_subm, stream, mask_output_fwd_tv, mask_tv, mask_int_count, arch, mask_width, is_subm, stream,
timer_cpp, auto_fp32_accum, fp32_accum, timer_cpp, auto_fp32_accum, fp32_accum,
use_tf32=constants.SPCONV_ALLOW_TF32) use_tf32=constants.SPCONV_ALLOW_TF32)
din = alloc.allocated[AllocKeys.DIn] din = alloc.allocated[AllocKeys.DIn]
...@@ -1848,7 +1861,8 @@ def implicit_gemm_backward(features: torch.Tensor, ...@@ -1848,7 +1861,8 @@ def implicit_gemm_backward(features: torch.Tensor,
mask_filter=masks[0].item(), mask_filter=masks[0].item(),
stream=stream, stream=stream,
fp32_accum=fp32_accum, fp32_accum=fp32_accum,
use_tf32=constants.SPCONV_ALLOW_TF32) use_tf32=constants.SPCONV_ALLOW_TF32,
mask_int_count=mask_int_count)
if wgrad_tune_res is None: if wgrad_tune_res is None:
wgrad_tune_res, _ = CONV.tune_and_cache( wgrad_tune_res, _ = CONV.tune_and_cache(
ConvOpType.kBackwardWeight, ConvOpType.kBackwardWeight,
...@@ -1867,7 +1881,8 @@ def implicit_gemm_backward(features: torch.Tensor, ...@@ -1867,7 +1881,8 @@ def implicit_gemm_backward(features: torch.Tensor,
mask_output=tv.Tensor(), mask_output=tv.Tensor(),
mask_width=mask_width, mask_width=mask_width,
stream=stream, stream=stream,
use_tf32=constants.SPCONV_ALLOW_TF32) use_tf32=constants.SPCONV_ALLOW_TF32,
mask_int_count=mask_int_count)
workspace_size = CONV.query_workspace_size(wgrad_tune_res.algo_desp, workspace_size = CONV.query_workspace_size(wgrad_tune_res.algo_desp,
wgrad_tune_res.splitk, wgrad_tune_res.splitk,
ConvOpType.kBackwardWeight, ConvOpType.kBackwardWeight,
...@@ -1904,7 +1919,8 @@ def implicit_gemm_backward(features: torch.Tensor, ...@@ -1904,7 +1919,8 @@ def implicit_gemm_backward(features: torch.Tensor,
mask_filter=masks[j].item(), mask_filter=masks[j].item(),
mask_width=-1, mask_width=-1,
beta=beta, beta=beta,
stream=stream) stream=stream,
mask_int_count=mask_int_count)
# for backward weight, beta = 0 because each split # for backward weight, beta = 0 because each split
# handle different kernel locations. # handle different kernel locations.
# TODO remove D iterator in backward weight kernel # TODO remove D iterator in backward weight kernel
...@@ -1923,7 +1939,8 @@ def implicit_gemm_backward(features: torch.Tensor, ...@@ -1923,7 +1939,8 @@ def implicit_gemm_backward(features: torch.Tensor,
mask_width=mask_width, mask_width=mask_width,
beta=0, beta=0,
workspace=workspace_tv, workspace=workspace_tv,
stream=stream) stream=stream,
mask_int_count=mask_int_count)
return (din, dfilters.reshape(filters_shape)) return (din, dfilters.reshape(filters_shape))
......
...@@ -291,6 +291,335 @@ class Net2(nn.Module): ...@@ -291,6 +291,335 @@ class Net2(nn.Module):
return self.net(x) return self.net(x)
class Net_kv75(nn.Module):
def __init__(self, shape, algo):
super().__init__()
pool_algo = algo
# pool_algo = ConvAlgo.Native
self.net = spconv.SparseSequential(
spconv.SubMConv3d(3, 64, (3, 5, 5), bias=False, indice_key="c0",
algo=algo),
# spconv.SubMConv3d(32,
# 32,
# 3,
# bias=False,
# indice_key="c0",
# algo=algo),
# # nn.BatchNorm1d(32),
# # nn.ReLU(),
# # spconv.SparseConv3d(64, 64, 2, 2, bias=False,
# # algo=algo),
# spconv.SubMConv3d(32, 64, 3, bias=False, indice_key="c0",
# algo=algo),
# spconv.SubMConv3d(64, 64, 3, bias=False, indice_key="c0",
# algo=algo),
# spconv.SubMConv3d(32,
# 32,
# 3,
# bias=False,
# indice_key="c0",
# algo=algo),
# # nn.BatchNorm1d(32),
# # nn.ReLU(),
# # spconv.SparseConv3d(64, 64, 2, 2, bias=False,
# # algo=algo),
# spconv.SubMConv3d(32, 64, 3, bias=False, indice_key="c0",
# algo=algo),
spconv.SubMConv3d(64,
64,
(3, 5, 5),
bias=False,
indice_key="c0",
algo=algo),
# nn.BatchNorm1d(32),
# nn.ReLU(),
# spconv.SparseConv3d(64, 64, 2, 2, bias=False, indice_key="m0"),
spconv.SparseMaxPool3d(2, 2, algo=pool_algo),
spconv.SubMConv3d(64,
96,
(3, 5, 5),
bias=False,
indice_key="c1",
algo=algo),
spconv.SubMConv3d(96,
96,
(3, 5, 5),
bias=False,
indice_key="c1",
algo=algo),
# nn.BatchNorm1d(64),
# nn.ReLU(),
# spconv.SparseConv3d(96, 96, 2, 2, bias=False, indice_key="m1"),
spconv.SparseMaxPool3d(2, 2, algo=pool_algo),
spconv.SubMConv3d(96,
128,
(3, 5, 5),
bias=False,
indice_key="c2",
algo=algo),
spconv.SubMConv3d(128,
128,
(3, 5, 5),
bias=False,
indice_key="c2",
algo=algo),
# nn.BatchNorm1d(128),
# nn.ReLU(),
# spconv.SparseConv3d(128, 128, 2, 2, bias=False, indice_key="m2"),
spconv.SparseMaxPool3d(2, 2, algo=pool_algo),
spconv.SubMConv3d(128,
160,
(3, 5, 5),
bias=False,
indice_key="c3",
algo=algo),
spconv.SubMConv3d(160,
160,
(3, 5, 5),
bias=False,
indice_key="c3",
algo=algo),
# nn.BatchNorm1d(128),
# nn.ReLU(),
# spconv.SparseConv3d(160, 160, 2, 2, bias=False, indice_key="m3"),
spconv.SparseMaxPool3d(2, 2, algo=pool_algo),
spconv.SubMConv3d(160,
192,
(3, 5, 5),
bias=False,
indice_key="c4",
algo=algo),
spconv.SubMConv3d(192,
192,
(3, 5, 5),
bias=False,
indice_key="c4",
algo=algo),
# nn.BatchNorm1d(128),
# nn.ReLU(),
spconv.SparseMaxPool3d(2, 2, indice_key="m4", algo=pool_algo),
# spconv.SparseConv3d(192, 192, 2, 2, bias=False, indice_key="m4"),
spconv.SubMConv3d(192,
224,
(3, 5, 5),
bias=False,
indice_key="c5",
algo=algo),
spconv.SubMConv3d(224,
224,
(3, 5, 5),
bias=False,
indice_key="c5",
algo=algo),
# nn.BatchNorm1d(224),
# nn.ReLU(),
# spconv.SparseConv3d(224, 224, 2, 2, bias=False, indice_key="m5"),
spconv.SparseMaxPool3d(2, 2, indice_key="m5", algo=pool_algo),
spconv.SparseConv3d(224,
256,
(3, 5, 5),
padding=(1, 2, 2),
bias=False,
# indice_key="c6",
algo=algo),
spconv.SubMConv3d(256,
256,
(3, 5, 5),
bias=False,
indice_key="c6",
algo=algo),
# nn.BatchNorm1d(256),
# nn.ReLU(),
# spconv.SparseInverseConv3d(256, 128, 2, indice_key="m5", bias=False, algo=algo),
# # # nn.BatchNorm1d(128),
# # # nn.ReLU(),
# spconv.SparseInverseConv3d(128, 64, 2, indice_key="m4", bias=False, algo=algo),
)
max_batch_size = 1
# grid (dense map) is used for indice generation. use pre-allocated grid can run faster.
# self.grid = torch.full([max_batch_size, *shape], -1,
# dtype=torch.int32).cuda()
# self.grid = None
self.shape = shape
def forward(self, features, coors, batch_size, enable_timer: bool = False):
x = spconv.SparseConvTensor(features,
coors,
self.shape,
batch_size,
# self.grid,
enable_timer=enable_timer)
return self.net(x)
class Net_kv125(nn.Module):
def __init__(self, shape, algo):
super().__init__()
pool_algo = algo
# pool_algo = ConvAlgo.Native
self.net = spconv.SparseSequential(
spconv.SubMConv3d(3, 64, 5, bias=False, indice_key="c0",
algo=algo),
# spconv.SubMConv3d(32,
# 32,
# 3,
# bias=False,
# indice_key="c0",
# algo=algo),
# # nn.BatchNorm1d(32),
# # nn.ReLU(),
# # spconv.SparseConv3d(64, 64, 2, 2, bias=False,
# # algo=algo),
# spconv.SubMConv3d(32, 64, 3, bias=False, indice_key="c0",
# algo=algo),
# spconv.SubMConv3d(64, 64, 3, bias=False, indice_key="c0",
# algo=algo),
# spconv.SubMConv3d(32,
# 32,
# 3,
# bias=False,
# indice_key="c0",
# algo=algo),
# # nn.BatchNorm1d(32),
# # nn.ReLU(),
# # spconv.SparseConv3d(64, 64, 2, 2, bias=False,
# # algo=algo),
# spconv.SubMConv3d(32, 64, 3, bias=False, indice_key="c0",
# algo=algo),
spconv.SubMConv3d(64,
64,
5,
bias=False,
indice_key="c0",
algo=algo),
# nn.BatchNorm1d(32),
# nn.ReLU(),
# spconv.SparseConv3d(64, 64, 2, 2, bias=False, indice_key="m0"),
spconv.SparseMaxPool3d(2, 2, algo=pool_algo),
spconv.SubMConv3d(64,
96,
5,
bias=False,
indice_key="c1",
algo=algo),
spconv.SubMConv3d(96,
96,
5,
bias=False,
indice_key="c1",
algo=algo),
# nn.BatchNorm1d(64),
# nn.ReLU(),
# spconv.SparseConv3d(96, 96, 2, 2, bias=False, indice_key="m1"),
spconv.SparseMaxPool3d(2, 2, algo=pool_algo),
spconv.SubMConv3d(96,
128,
5,
bias=False,
indice_key="c2",
algo=algo),
spconv.SubMConv3d(128,
128,
5,
bias=False,
indice_key="c2",
algo=algo),
# nn.BatchNorm1d(128),
# nn.ReLU(),
# spconv.SparseConv3d(128, 128, 2, 2, bias=False, indice_key="m2"),
spconv.SparseMaxPool3d(2, 2, algo=pool_algo),
spconv.SubMConv3d(128,
160,
5,
bias=False,
indice_key="c3",
algo=algo),
spconv.SubMConv3d(160,
160,
5,
bias=False,
indice_key="c3",
algo=algo),
# nn.BatchNorm1d(128),
# nn.ReLU(),
# spconv.SparseConv3d(160, 160, 2, 2, bias=False, indice_key="m3"),
spconv.SparseMaxPool3d(2, 2, algo=pool_algo),
spconv.SubMConv3d(160,
192,
5,
bias=False,
indice_key="c4",
algo=algo),
spconv.SubMConv3d(192,
192,
5,
bias=False,
indice_key="c4",
algo=algo),
# nn.BatchNorm1d(128),
# nn.ReLU(),
spconv.SparseMaxPool3d(2, 2, indice_key="m4", algo=pool_algo),
# spconv.SparseConv3d(192, 192, 2, 2, bias=False, indice_key="m4"),
spconv.SubMConv3d(192,
224,
5,
bias=False,
indice_key="c5",
algo=algo),
spconv.SubMConv3d(224,
224,
5,
bias=False,
indice_key="c5",
algo=algo),
# nn.BatchNorm1d(224),
# nn.ReLU(),
# spconv.SparseConv3d(224, 224, 2, 2, bias=False, indice_key="m5"),
spconv.SparseMaxPool3d(2, 2, indice_key="m5", algo=pool_algo),
spconv.SubMConv3d(224,
256,
5,
bias=False,
indice_key="c6",
algo=algo),
spconv.SubMConv3d(256,
256,
5,
bias=False,
indice_key="c6",
algo=algo),
# nn.BatchNorm1d(256),
# nn.ReLU(),
# spconv.SparseInverseConv3d(256, 128, 2, indice_key="m5", bias=False, algo=algo),
# # # nn.BatchNorm1d(128),
# # # nn.ReLU(),
# spconv.SparseInverseConv3d(128, 64, 2, indice_key="m4", bias=False, algo=algo),
)
max_batch_size = 1
# grid (dense map) is used for indice generation. use pre-allocated grid can run faster.
# self.grid = torch.full([max_batch_size, *shape], -1,
# dtype=torch.int32).cuda()
# self.grid = None
self.shape = shape
def forward(self, features, coors, batch_size, enable_timer: bool = False):
x = spconv.SparseConvTensor(features,
coors,
self.shape,
batch_size,
# self.grid,
enable_timer=enable_timer)
return self.net(x)
class NetSm(nn.Module): class NetSm(nn.Module):
def __init__(self, shape, algo): def __init__(self, shape, algo):
...@@ -395,9 +724,9 @@ def main(): ...@@ -395,9 +724,9 @@ def main():
np.random.seed(50051) np.random.seed(50051)
torch.manual_seed(50051) torch.manual_seed(50051)
# voxels, coors, spatial_shape = waymo_data(num_features=3) # voxels, coors, spatial_shape = waymo_data(num_features=3)
with open(Path(__file__).parent / "data" / "test_spconv.pkl", "rb") as f: # with open(Path(__file__).parent / "data" / "test_spconv.pkl", "rb") as f:
(voxels, coors, spatial_shape) = pickle.load(f) # (voxels, coors, spatial_shape) = pickle.load(f)
# voxels, coors, spatial_shape = waymo_data_large() voxels, coors, spatial_shape = waymo_data_large()
# breakpoint() # breakpoint()
print(spatial_shape) print(spatial_shape)
...@@ -474,17 +803,18 @@ def main(): ...@@ -474,17 +803,18 @@ def main():
# state.pop("net.2.max_num_voxels_during_training") # state.pop("net.2.max_num_voxels_during_training")
# net.load_state_dict(state) # net.load_state_dict(state)
# breakpoint() # breakpoint()
# net = net.train()
print("spconv time", np.mean(times[10:])) print("spconv time", np.mean(times[10:]))
# times = [] # times = []
# for i in range(10): # for i in range(10):
# out = net(voxels_th, coors_th, 1) # out = net(voxels_th, coors_th, 1)
# print("------------") # print("------------")
# # torch.cuda.synchronize() # torch.cuda.synchronize()
# # t = time.time() # t = time.time()
# out.features.backward(dout_t) # out.features.backward(dout_t)
# # torch.cuda.synchronize() # torch.cuda.synchronize()
# # times.append(time.time() - t) # times.append(time.time() - t)
# # # print((net.grid == -1).float().sum(), net.grid.numel()) # # # print((net.grid == -1).float().sum(), net.grid.numel())
# # # print("spconv time", time.time() - t) # # # print("spconv time", time.time() - t)
......
...@@ -135,6 +135,7 @@ class SparseConvTester: ...@@ -135,6 +135,7 @@ class SparseConvTester:
self.mask_argsort_fwd_splits = res[6] self.mask_argsort_fwd_splits = res[6]
self.mask_argsort_bwd_splits = res[7] self.mask_argsort_bwd_splits = res[7]
self.masks = res[8] self.masks = res[8]
self.mask_int_count = res[9]
self.out_inds_scalar = Fsp._indice_to_scalar(self.out_inds.long(), [bs, *out_shape]) self.out_inds_scalar = Fsp._indice_to_scalar(self.out_inds.long(), [bs, *out_shape])
...@@ -293,12 +294,12 @@ def _test_impgemm_conv_cuda(subm: bool): ...@@ -293,12 +294,12 @@ def _test_impgemm_conv_cuda(subm: bool):
multiple_base = 16 multiple_base = 16
if subm: if subm:
ksizes = [3] ksizes = [3, (3, 3, 5), (3, 5, 5), 5]
strides = [1] strides = [1]
paddings = [0] paddings = [0]
dilations = [1] dilations = [1]
else: else:
ksizes = [2, 3] ksizes = [2, 3, (3, 3, 4), 4, (4, 5, 5), 5]
strides = [1, 2, 3] strides = [1, 2, 3]
paddings = [0, 1] paddings = [0, 1]
dilations = [1, 2] dilations = [1, 2]
...@@ -354,8 +355,9 @@ def _test_impgemm_conv_cuda(subm: bool): ...@@ -354,8 +355,9 @@ def _test_impgemm_conv_cuda(subm: bool):
mask_width = desp.tile_shape[0] mask_width = desp.tile_shape[0]
# if mask_width != 32: # if mask_width != 32:
# continue # continue
if mask_width not in mask_width_to_mask_out_fwd: if mask_width not in mask_width_to_mask_out_fwd:
mask_width_to_mask_out_fwd[mask_width] = torch.zeros([2, div_up(tester.out_inds.shape[0], mask_width)], mask_width_to_mask_out_fwd[mask_width] = torch.zeros([2, tester.mask_int_count * div_up(tester.out_inds.shape[0], mask_width)],
dtype=torch.int32, dtype=torch.int32,
device=tester.device) device=tester.device)
mask_output_fwd = mask_width_to_mask_out_fwd[mask_width] mask_output_fwd = mask_width_to_mask_out_fwd[mask_width]
...@@ -413,6 +415,7 @@ def _test_impgemm_conv_cuda(subm: bool): ...@@ -413,6 +415,7 @@ def _test_impgemm_conv_cuda(subm: bool):
force_nvrtc=force_nvrtc, force_nvrtc=force_nvrtc,
bias=bias_cur if is_fwd and bias_cur is not None else tv.Tensor(), bias=bias_cur if is_fwd and bias_cur is not None else tv.Tensor(),
act_type=act, act_type=act,
mask_int_count=tester.mask_int_count,
) )
else: else:
CONV.run_with_tuned_result( CONV.run_with_tuned_result(
...@@ -433,6 +436,7 @@ def _test_impgemm_conv_cuda(subm: bool): ...@@ -433,6 +436,7 @@ def _test_impgemm_conv_cuda(subm: bool):
force_nvrtc=force_nvrtc, force_nvrtc=force_nvrtc,
bias=bias_cur if is_fwd else None, bias=bias_cur if is_fwd else None,
act_type=act, act_type=act,
mask_int_count=tester.mask_int_count
) )
else: else:
...@@ -491,6 +495,7 @@ def _test_impgemm_conv_cuda(subm: bool): ...@@ -491,6 +495,7 @@ def _test_impgemm_conv_cuda(subm: bool):
force_nvrtc=force_nvrtc, force_nvrtc=force_nvrtc,
bias=bias if is_fwd and bias is not None else tv.Tensor(), bias=bias if is_fwd and bias is not None else tv.Tensor(),
act_type=act, act_type=act,
mask_int_count=tester.mask_int_count,
) )
else: else:
CONV.run_with_tuned_result( CONV.run_with_tuned_result(
...@@ -511,6 +516,7 @@ def _test_impgemm_conv_cuda(subm: bool): ...@@ -511,6 +516,7 @@ def _test_impgemm_conv_cuda(subm: bool):
force_nvrtc=force_nvrtc, force_nvrtc=force_nvrtc,
bias=bias if is_fwd else None, bias=bias if is_fwd else None,
act_type=act, act_type=act,
mask_int_count=tester.mask_int_count,
) )
out_ref = tester.out_ref out_ref = tester.out_ref
...@@ -572,6 +578,7 @@ def _test_impgemm_conv_cuda(subm: bool): ...@@ -572,6 +578,7 @@ def _test_impgemm_conv_cuda(subm: bool):
mask_width=mask_width, mask_width=mask_width,
beta=beta, beta=beta,
verbose=False, verbose=False,
mask_int_count=tester.mask_int_count,
) )
else: else:
indice_pairs = tester.pair_fwd # inp -> out indice_pairs = tester.pair_fwd # inp -> out
...@@ -599,6 +606,7 @@ def _test_impgemm_conv_cuda(subm: bool): ...@@ -599,6 +606,7 @@ def _test_impgemm_conv_cuda(subm: bool):
mask_width=mask_width, mask_width=mask_width,
beta=beta, beta=beta,
verbose=False, verbose=False,
mask_int_count=tester.mask_int_count,
) )
dw_ref = tester.dw_ref dw_ref = tester.dw_ref
dw_my = weight_tv.cpu().numpy() dw_my = weight_tv.cpu().numpy()
...@@ -918,8 +926,8 @@ def test_all_algo_unit(): ...@@ -918,8 +926,8 @@ def test_all_algo_unit():
# for i in range(5): # for i in range(5):
_test_impgemm_conv_cuda(True) _test_impgemm_conv_cuda(True)
_test_impgemm_conv_cuda(False) _test_impgemm_conv_cuda(False)
_test_native_conv_cuda(True) # _test_native_conv_cuda(True)
_test_native_conv_cuda(False) # _test_native_conv_cuda(False)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment