Commit cd99e7a6 authored by yan.yan's avatar yan.yan
Browse files

fix #575 use a flag to enable large-kernel algo

parent f101f97e
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
## [2.3.5] - 2023-03-24 ## [2.3.5] - 2023-03-24
### Fixed ### Fixed
- pypi project reach size limit, so we need to assign a new version number. - use a flag to enable large kernel algo (need time to compile at runtime)
## [2.3.4] - 2023-03-23 ## [2.3.4] - 2023-03-23
### Added ### Added
......
...@@ -26,4 +26,6 @@ ...@@ -26,4 +26,6 @@
* spconv 2.x in Windows 10 is 1.5x~2x slower than Linux. use Linux if possible. * spconv 2.x in Windows 10 is 1.5x~2x slower than Linux. use Linux if possible.
* If you train with float32 and ampere or later GPUs, you can set ```spconv.constants.SPCONV_ALLOW_TF32``` to enable faster fp32 training. * If you train with float32 and ampere or later GPUs, you can set ```spconv.constants.SPCONV_ALLOW_TF32``` to enable faster fp32 training.
See [benchmark](BENCHMARK.md) for more performance details of different algorithms. See [benchmark](BENCHMARK.md) for more performance details of different algorithms.
* Different CUDA version of spconv may have different performance. Use newest cuda version if possible. For example, spconv-cu117 is faster than spconv-cu114, spconv-cu114 is faster than spconv-cu111. * Different CUDA version of spconv may have different performance. Use newest cuda version if possible. For example, spconv-cu117 is faster than spconv-cu114, spconv-cu114 is faster than spconv-cu111.
\ No newline at end of file * if your kernel size volume larger than 32, spconv will use a slower (and more inaccurate in fp16) algorithm. to use a faster algo for large kernel size (need time to compile at runtime), use ```large_kernel_fast_algo=True```
* use ```SparseGlobalMaxPool``` instead of use large kernel size when you need global pool.
\ No newline at end of file
...@@ -43,7 +43,9 @@ FILTER_HWIO = False ...@@ -43,7 +43,9 @@ FILTER_HWIO = False
_MAX_NUM_VOXELS_DURING_TRAINING = "max_num_voxels_during_training" _MAX_NUM_VOXELS_DURING_TRAINING = "max_num_voxels_during_training"
def _apply_act(x: torch.Tensor, act_type: tv.gemm.Activation, act_alpha: float, act_beta: float):
def _apply_act(x: torch.Tensor, act_type: tv.gemm.Activation, act_alpha: float,
act_beta: float):
if act_type == tv.gemm.Activation.None_: if act_type == tv.gemm.Activation.None_:
return x return x
elif act_type == tv.gemm.Activation.ReLU: elif act_type == tv.gemm.Activation.ReLU:
...@@ -55,7 +57,9 @@ def _apply_act(x: torch.Tensor, act_type: tv.gemm.Activation, act_alpha: float, ...@@ -55,7 +57,9 @@ def _apply_act(x: torch.Tensor, act_type: tv.gemm.Activation, act_alpha: float,
else: else:
raise NotImplementedError raise NotImplementedError
class SparseConvolutionBase: class SparseConvolutionBase:
def __init__(self, def __init__(self,
ndim: int, ndim: int,
in_channels: int, in_channels: int,
...@@ -76,7 +80,8 @@ class SparseConvolutionBase: ...@@ -76,7 +80,8 @@ class SparseConvolutionBase:
record_voxel_count: bool = False, record_voxel_count: bool = False,
act_type: tv.gemm.Activation = tv.gemm.Activation.None_, act_type: tv.gemm.Activation = tv.gemm.Activation.None_,
act_alpha: float = 0, act_alpha: float = 0,
act_beta: float = 0): act_beta: float = 0,
large_kernel_fast_algo: bool = False):
assert groups == 1, "don't support groups for now" assert groups == 1, "don't support groups for now"
self.ndim = ndim self.ndim = ndim
self.in_channels = in_channels self.in_channels = in_channels
...@@ -103,7 +108,10 @@ class SparseConvolutionBase: ...@@ -103,7 +108,10 @@ class SparseConvolutionBase:
self.indice_key = indice_key self.indice_key = indice_key
self.record_voxel_count = record_voxel_count self.record_voxel_count = record_voxel_count
if algo is None: if algo is None:
if kv <= 128 and not CPU_ONLY_BUILD: limit = 32
if large_kernel_fast_algo:
limit = 128
if kv <= limit and not CPU_ONLY_BUILD:
if kv < 8: if kv < 8:
algo = ConvAlgo.MaskImplicitGemm algo = ConvAlgo.MaskImplicitGemm
else: else:
...@@ -117,7 +125,7 @@ class SparseConvolutionBase: ...@@ -117,7 +125,7 @@ class SparseConvolutionBase:
self.algo = algo self.algo = algo
self.fp32_accum = fp32_accum self.fp32_accum = fp32_accum
# self.algo = ConvAlgo.Native # self.algo = ConvAlgo.Native
if self.algo == ConvAlgo.Native and not ALL_WEIGHT_IS_KRSC: if self.algo == ConvAlgo.Native and not ALL_WEIGHT_IS_KRSC:
if FILTER_HWIO: if FILTER_HWIO:
# RSCK # RSCK
...@@ -136,16 +144,23 @@ class SparseConvolutionBase: ...@@ -136,16 +144,23 @@ class SparseConvolutionBase:
self.zero_point = 0 self.zero_point = 0
if self.conv1x1: if self.conv1x1:
assert act_type == tv.gemm.Activation.None_, "conv1x1 don't support fused act" assert act_type == tv.gemm.Activation.None_, "conv1x1 don't support fused act"
def is_inverseable(self): def is_inverseable(self):
return self.indice_key is not None and not self.subm return self.indice_key is not None and not self.subm
def _conv_forward(self, training: bool, input: SparseConvTensor, weight: torch.Tensor, bias: Optional[torch.Tensor], add_input: Optional[SparseConvTensor] = None, def _conv_forward(self,
channel_scale: Optional[torch.Tensor] = None, output_scale: Optional[float] = None, name: Optional[str] = None, training: bool,
sparse_unique_name: str = "", input: SparseConvTensor,
act_type: tv.gemm.Activation = tv.gemm.Activation.None_, weight: torch.Tensor,
act_alpha: float = 0, bias: Optional[torch.Tensor],
act_beta: float = 0): add_input: Optional[SparseConvTensor] = None,
channel_scale: Optional[torch.Tensor] = None,
output_scale: Optional[float] = None,
name: Optional[str] = None,
sparse_unique_name: str = "",
act_type: tv.gemm.Activation = tv.gemm.Activation.None_,
act_alpha: float = 0,
act_beta: float = 0):
# assert isinstance(input, SparseConvTensor) # assert isinstance(input, SparseConvTensor)
is_int8 = input.is_quantized and weight.is_quantized is_int8 = input.is_quantized and weight.is_quantized
if is_int8: if is_int8:
...@@ -165,9 +180,9 @@ class SparseConvolutionBase: ...@@ -165,9 +180,9 @@ class SparseConvolutionBase:
if add_input is not None: if add_input is not None:
output_add_scale = add_input.q_scale() output_add_scale = add_input.q_scale()
if training: if training:
msg = "act don't support backward, only used in inference" msg = "act don't support backward, only used in inference"
assert self.act_type == tv.gemm.Activation.None_, msg assert self.act_type == tv.gemm.Activation.None_, msg
if not self.subm: if not self.subm:
if self.transposed: if self.transposed:
out_spatial_shape = ops.get_deconv_output_size( out_spatial_shape = ops.get_deconv_output_size(
...@@ -252,7 +267,7 @@ class SparseConvolutionBase: ...@@ -252,7 +267,7 @@ class SparseConvolutionBase:
indice_pair_num = datas.indice_pair_num indice_pair_num = datas.indice_pair_num
out_spatial_shape = datas.spatial_shape out_spatial_shape = datas.spatial_shape
self._check_inverse_reuse_valid(input, spatial_shape, self._check_inverse_reuse_valid(input, spatial_shape,
datas) datas)
else: else:
if self.indice_key is not None and datas is not None: if self.indice_key is not None and datas is not None:
outids = datas.out_indices outids = datas.out_indices
...@@ -283,8 +298,8 @@ class SparseConvolutionBase: ...@@ -283,8 +298,8 @@ class SparseConvolutionBase:
if input.benchmark: if input.benchmark:
torch.cuda.synchronize() torch.cuda.synchronize()
interval = time.time() - t interval = time.time() - t
out_tensor.benchmark_record[ out_tensor.benchmark_record[name][
name]["indice_gen_time"].append(interval) "indice_gen_time"].append(interval)
indice_data = IndiceData(outids, indice_data = IndiceData(outids,
indices, indices,
...@@ -310,43 +325,21 @@ class SparseConvolutionBase: ...@@ -310,43 +325,21 @@ class SparseConvolutionBase:
indice_pairs_calc = indice_pairs.to(features.device) indice_pairs_calc = indice_pairs.to(features.device)
if self.subm: if self.subm:
out_features = Fsp.indice_subm_conv( out_features = Fsp.indice_subm_conv(
features, features, weight, indice_pairs_calc, indice_pair_num,
weight, outids.shape[0], algo, input._timer, bias_for_infer,
indice_pairs_calc, act_alpha, act_beta, act_type)
indice_pair_num,
outids.shape[0],
algo,
input._timer,
bias_for_infer,
act_alpha,
act_beta,
act_type)
else: else:
if self.inverse: if self.inverse:
out_features = Fsp.indice_inverse_conv( out_features = Fsp.indice_inverse_conv(
features, features, weight, indice_pairs_calc,
weight, indice_pair_num, outids.shape[0], algo,
indice_pairs_calc, input._timer, bias_for_infer, act_alpha, act_beta,
indice_pair_num,
outids.shape[0],
algo,
input._timer,
bias_for_infer,
act_alpha,
act_beta,
act_type) act_type)
else: else:
out_features = Fsp.indice_conv( out_features = Fsp.indice_conv(
features, features, weight, indice_pairs_calc,
weight, indice_pair_num, outids.shape[0], algo,
indice_pairs_calc, input._timer, bias_for_infer, act_type, act_beta,
indice_pair_num,
outids.shape[0],
algo,
input._timer,
bias_for_infer,
act_type,
act_beta,
act_type) act_type)
else: else:
datas = input.find_indice_pair(self.indice_key) datas = input.find_indice_pair(self.indice_key)
...@@ -367,7 +360,7 @@ class SparseConvolutionBase: ...@@ -367,7 +360,7 @@ class SparseConvolutionBase:
# assert datas.ksize == self.kernel_size, "inverse conv must have same kernel size as its couple conv" # assert datas.ksize == self.kernel_size, "inverse conv must have same kernel size as its couple conv"
self._check_inverse_reuse_valid(input, spatial_shape, self._check_inverse_reuse_valid(input, spatial_shape,
datas) datas)
else: else:
if self.indice_key is not None and datas is not None: if self.indice_key is not None and datas is not None:
outids = datas.out_indices outids = datas.out_indices
...@@ -416,8 +409,8 @@ class SparseConvolutionBase: ...@@ -416,8 +409,8 @@ class SparseConvolutionBase:
if input.benchmark: if input.benchmark:
torch.cuda.synchronize() torch.cuda.synchronize()
interval = time.time() - t interval = time.time() - t
out_tensor.benchmark_record[ out_tensor.benchmark_record[name][
name]["indice_gen_time"].append(interval) "indice_gen_time"].append(interval)
outids = res[0] outids = res[0]
num_inds_per_loc = res[1] num_inds_per_loc = res[1]
pair_fwd = res[2] pair_fwd = res[2]
...@@ -465,28 +458,34 @@ class SparseConvolutionBase: ...@@ -465,28 +458,34 @@ class SparseConvolutionBase:
pair_mask_fwd_splits, pair_mask_bwd_splits, pair_mask_fwd_splits, pair_mask_bwd_splits,
mask_argsort_fwd_splits, mask_argsort_bwd_splits, mask_argsort_fwd_splits, mask_argsort_bwd_splits,
num_activate_out, masks, training, self.subm, num_activate_out, masks, training, self.subm,
input._timer, self.fp32_accum, input._timer, self.fp32_accum, bias_cur, act_alpha,
bias_cur, act_beta, act_type)
act_alpha,
act_beta,
act_type)
else: else:
output_dtype = None output_dtype = None
if output_scale is None: if output_scale is None:
output_dtype = weight.dtype output_dtype = weight.dtype
out_features, _, _ = ops.implicit_gemm( out_features, _, _ = ops.implicit_gemm(
features, weight_cur, pair_fwd, pair_mask_fwd_splits, features,
mask_argsort_fwd_splits, weight_cur,
num_activate_out, masks, training, self.subm, pair_fwd,
input._timer, self.fp32_accum, pair_mask_fwd_splits,
mask_argsort_fwd_splits,
num_activate_out,
masks,
training,
self.subm,
input._timer,
self.fp32_accum,
bias_cur, bias_cur,
act_alpha, act_alpha,
act_beta, act_beta,
act_type, act_type,
# TODO do we really need output scale to scale bias in kernel? # TODO do we really need output scale to scale bias in kernel?
1.0 if output_scale is None else output_scale, # output_scale 1.0 if output_scale is None else
channel_scale, # scale output_scale, # output_scale
output_add=add_input.features if add_input is not None else None, channel_scale, # scale
output_add=add_input.features
if add_input is not None else None,
output_add_scale=output_add_scale, output_add_scale=output_add_scale,
output_dtype=output_dtype) output_dtype=output_dtype)
...@@ -511,8 +510,10 @@ class SparseConvolutionBase: ...@@ -511,8 +510,10 @@ class SparseConvolutionBase:
out_tensor.spatial_shape = out_spatial_shape out_tensor.spatial_shape = out_spatial_shape
if add_input is not None and not is_int8: if add_input is not None and not is_int8:
# in int8, we apply add + act in kernel. # in int8, we apply add + act in kernel.
out_tensor = out_tensor.replace_feature(_apply_act(out_tensor.features + add_input.features, self.act_type, self.act_alpha, self.act_beta)) out_tensor = out_tensor.replace_feature(
_apply_act(out_tensor.features + add_input.features,
self.act_type, self.act_alpha, self.act_beta))
return out_tensor return out_tensor
def _check_subm_reuse_valid(self, inp: SparseConvTensor, def _check_subm_reuse_valid(self, inp: SparseConvTensor,
...@@ -539,9 +540,9 @@ class SparseConvolutionBase: ...@@ -539,9 +540,9 @@ class SparseConvolutionBase:
) )
def _check_inverse_reuse_valid(self, inp: SparseConvTensor, def _check_inverse_reuse_valid(self, inp: SparseConvTensor,
spatial_shape: List[int], spatial_shape: List[int],
datas: Union[ImplicitGemmIndiceData, datas: Union[ImplicitGemmIndiceData,
IndiceData]): IndiceData]):
if self.kernel_size != datas.ksize: if self.kernel_size != datas.ksize:
raise ValueError( raise ValueError(
f"Inverse with same indice_key must have same kernel" f"Inverse with same indice_key must have same kernel"
...@@ -556,8 +557,8 @@ class SparseConvolutionBase: ...@@ -556,8 +557,8 @@ class SparseConvolutionBase:
raise ValueError( raise ValueError(
f"Inverse with same indice_key must have same num of indices" f"Inverse with same indice_key must have same num of indices"
f", expect {datas.indices.shape[0]}, input {inp.indices.shape[0]}, " f", expect {datas.indices.shape[0]}, input {inp.indices.shape[0]}, "
"please check Inverse Convolution in ." "please check Inverse Convolution in .")
)
class SparseConvolution(SparseConvolutionBase, SparseModule): class SparseConvolution(SparseConvolutionBase, SparseModule):
__constants__ = [ __constants__ = [
...@@ -586,11 +587,21 @@ class SparseConvolution(SparseConvolutionBase, SparseModule): ...@@ -586,11 +587,21 @@ class SparseConvolution(SparseConvolutionBase, SparseModule):
act_type: tv.gemm.Activation = tv.gemm.Activation.None_, act_type: tv.gemm.Activation = tv.gemm.Activation.None_,
act_alpha: float = 0, act_alpha: float = 0,
act_beta: float = 0, act_beta: float = 0,
large_kernel_fast_algo: bool = False,
name=None, name=None,
device=None, device=None,
dtype=None): dtype=None):
factory_kwargs = {"device": device, "dtype": dtype} factory_kwargs = {"device": device, "dtype": dtype}
SparseConvolutionBase.__init__(self, ndim, in_channels, out_channels, kernel_size, stride, padding, dilation, groups, SparseConvolutionBase.__init__(
self,
ndim,
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
groups,
bias=False, bias=False,
subm=subm, subm=subm,
output_padding=output_padding, output_padding=output_padding,
...@@ -602,14 +613,17 @@ class SparseConvolution(SparseConvolutionBase, SparseModule): ...@@ -602,14 +613,17 @@ class SparseConvolution(SparseConvolutionBase, SparseModule):
record_voxel_count=record_voxel_count, record_voxel_count=record_voxel_count,
act_type=act_type, act_type=act_type,
act_alpha=act_alpha, act_alpha=act_alpha,
act_beta=act_beta) act_beta=act_beta,
large_kernel_fast_algo=large_kernel_fast_algo)
SparseModule.__init__(self, name=name) SparseModule.__init__(self, name=name)
if record_voxel_count and not self.subm and not self.inverse: if record_voxel_count and not self.subm and not self.inverse:
# we record maximum voxel num in both inference and training if # we record maximum voxel num in both inference and training if
# record_voxel_count flag setting. # record_voxel_count flag setting.
self.register_buffer(_MAX_NUM_VOXELS_DURING_TRAINING, self.register_buffer(
torch.zeros(1, dtype=torch.int32, device=device)) _MAX_NUM_VOXELS_DURING_TRAINING,
self.weight = Parameter(torch.zeros(*self.weight_shape, **factory_kwargs)) torch.zeros(1, dtype=torch.int32, device=device))
self.weight = Parameter(
torch.zeros(*self.weight_shape, **factory_kwargs))
if bias: if bias:
self.bias = Parameter(torch.zeros(out_channels, **factory_kwargs)) self.bias = Parameter(torch.zeros(out_channels, **factory_kwargs))
else: else:
...@@ -636,8 +650,7 @@ class SparseConvolution(SparseConvolutionBase, SparseModule): ...@@ -636,8 +650,7 @@ class SparseConvolution(SparseConvolutionBase, SparseModule):
error_msgs): error_msgs):
name = prefix + _MAX_NUM_VOXELS_DURING_TRAINING name = prefix + _MAX_NUM_VOXELS_DURING_TRAINING
if self.record_voxel_count and not self.subm and not self.inverse and name not in state_dict: if self.record_voxel_count and not self.subm and not self.inverse and name not in state_dict:
state_dict[name] = torch.zeros( state_dict[name] = torch.zeros(1, dtype=torch.int32)
1, dtype=torch.int32)
if not SAVED_WEIGHT_LAYOUT: if not SAVED_WEIGHT_LAYOUT:
return return
key = prefix + "weight" key = prefix + "weight"
...@@ -736,14 +749,23 @@ class SparseConvolution(SparseConvolutionBase, SparseModule): ...@@ -736,14 +749,23 @@ class SparseConvolution(SparseConvolutionBase, SparseModule):
bound = 1 / math.sqrt(fan_in) bound = 1 / math.sqrt(fan_in)
init.uniform_(self.bias, -bound, bound) init.uniform_(self.bias, -bound, bound)
def forward(self,
def forward(self, input: SparseConvTensor, add_input: Optional[SparseConvTensor] = None): input: SparseConvTensor,
return self._conv_forward(self.training, input, self.weight, self.bias, add_input, add_input: Optional[SparseConvTensor] = None):
name=self.name, sparse_unique_name=self._sparse_unique_name, act_type=self.act_type, return self._conv_forward(self.training,
act_alpha=self.act_alpha, act_beta=self.act_beta) input,
self.weight,
self.bias,
add_input,
name=self.name,
sparse_unique_name=self._sparse_unique_name,
act_type=self.act_type,
act_alpha=self.act_alpha,
act_beta=self.act_beta)
class SparseConv1d(SparseConvolution): class SparseConv1d(SparseConvolution):
def __init__(self, def __init__(self,
in_channels, in_channels,
out_channels, out_channels,
...@@ -757,6 +779,7 @@ class SparseConv1d(SparseConvolution): ...@@ -757,6 +779,7 @@ class SparseConv1d(SparseConvolution):
algo: Optional[ConvAlgo] = None, algo: Optional[ConvAlgo] = None,
fp32_accum: Optional[bool] = None, fp32_accum: Optional[bool] = None,
record_voxel_count: bool = False, record_voxel_count: bool = False,
large_kernel_fast_algo: bool = False,
name=None): name=None):
super(SparseConv1d, super(SparseConv1d,
self).__init__(1, self).__init__(1,
...@@ -772,10 +795,12 @@ class SparseConv1d(SparseConvolution): ...@@ -772,10 +795,12 @@ class SparseConv1d(SparseConvolution):
algo=algo, algo=algo,
fp32_accum=fp32_accum, fp32_accum=fp32_accum,
record_voxel_count=record_voxel_count, record_voxel_count=record_voxel_count,
large_kernel_fast_algo=large_kernel_fast_algo,
name=name) name=name)
class SparseConv2d(SparseConvolution): class SparseConv2d(SparseConvolution):
def __init__(self, def __init__(self,
in_channels, in_channels,
out_channels, out_channels,
...@@ -789,6 +814,7 @@ class SparseConv2d(SparseConvolution): ...@@ -789,6 +814,7 @@ class SparseConv2d(SparseConvolution):
algo: Optional[ConvAlgo] = None, algo: Optional[ConvAlgo] = None,
fp32_accum: Optional[bool] = None, fp32_accum: Optional[bool] = None,
record_voxel_count: bool = False, record_voxel_count: bool = False,
large_kernel_fast_algo: bool = False,
name=None): name=None):
super(SparseConv2d, super(SparseConv2d,
self).__init__(2, self).__init__(2,
...@@ -803,11 +829,13 @@ class SparseConv2d(SparseConvolution): ...@@ -803,11 +829,13 @@ class SparseConv2d(SparseConvolution):
indice_key=indice_key, indice_key=indice_key,
algo=algo, algo=algo,
fp32_accum=fp32_accum, fp32_accum=fp32_accum,
large_kernel_fast_algo=large_kernel_fast_algo,
record_voxel_count=record_voxel_count, record_voxel_count=record_voxel_count,
name=name) name=name)
class SparseConv3d(SparseConvolution): class SparseConv3d(SparseConvolution):
def __init__(self, def __init__(self,
in_channels, in_channels,
out_channels, out_channels,
...@@ -821,6 +849,7 @@ class SparseConv3d(SparseConvolution): ...@@ -821,6 +849,7 @@ class SparseConv3d(SparseConvolution):
algo: Optional[ConvAlgo] = None, algo: Optional[ConvAlgo] = None,
fp32_accum: Optional[bool] = None, fp32_accum: Optional[bool] = None,
record_voxel_count: bool = False, record_voxel_count: bool = False,
large_kernel_fast_algo: bool = False,
name=None): name=None):
super(SparseConv3d, super(SparseConv3d,
self).__init__(3, self).__init__(3,
...@@ -835,11 +864,13 @@ class SparseConv3d(SparseConvolution): ...@@ -835,11 +864,13 @@ class SparseConv3d(SparseConvolution):
indice_key=indice_key, indice_key=indice_key,
algo=algo, algo=algo,
fp32_accum=fp32_accum, fp32_accum=fp32_accum,
large_kernel_fast_algo=large_kernel_fast_algo,
record_voxel_count=record_voxel_count, record_voxel_count=record_voxel_count,
name=name) name=name)
class SparseConv4d(SparseConvolution): class SparseConv4d(SparseConvolution):
def __init__(self, def __init__(self,
in_channels, in_channels,
out_channels, out_channels,
...@@ -853,6 +884,7 @@ class SparseConv4d(SparseConvolution): ...@@ -853,6 +884,7 @@ class SparseConv4d(SparseConvolution):
algo: Optional[ConvAlgo] = None, algo: Optional[ConvAlgo] = None,
fp32_accum: Optional[bool] = None, fp32_accum: Optional[bool] = None,
record_voxel_count: bool = False, record_voxel_count: bool = False,
large_kernel_fast_algo: bool = False,
name=None): name=None):
super(SparseConv4d, super(SparseConv4d,
self).__init__(4, self).__init__(4,
...@@ -867,11 +899,13 @@ class SparseConv4d(SparseConvolution): ...@@ -867,11 +899,13 @@ class SparseConv4d(SparseConvolution):
indice_key=indice_key, indice_key=indice_key,
algo=algo, algo=algo,
fp32_accum=fp32_accum, fp32_accum=fp32_accum,
large_kernel_fast_algo=large_kernel_fast_algo,
record_voxel_count=record_voxel_count, record_voxel_count=record_voxel_count,
name=name) name=name)
class SparseConvTranspose1d(SparseConvolution): class SparseConvTranspose1d(SparseConvolution):
def __init__(self, def __init__(self,
in_channels, in_channels,
out_channels, out_channels,
...@@ -885,6 +919,7 @@ class SparseConvTranspose1d(SparseConvolution): ...@@ -885,6 +919,7 @@ class SparseConvTranspose1d(SparseConvolution):
algo: Optional[ConvAlgo] = None, algo: Optional[ConvAlgo] = None,
fp32_accum: Optional[bool] = None, fp32_accum: Optional[bool] = None,
record_voxel_count: bool = False, record_voxel_count: bool = False,
large_kernel_fast_algo: bool = False,
name=None): name=None):
super(SparseConvTranspose1d, super(SparseConvTranspose1d,
self).__init__(1, self).__init__(1,
...@@ -900,11 +935,13 @@ class SparseConvTranspose1d(SparseConvolution): ...@@ -900,11 +935,13 @@ class SparseConvTranspose1d(SparseConvolution):
indice_key=indice_key, indice_key=indice_key,
algo=algo, algo=algo,
fp32_accum=fp32_accum, fp32_accum=fp32_accum,
large_kernel_fast_algo=large_kernel_fast_algo,
record_voxel_count=record_voxel_count, record_voxel_count=record_voxel_count,
name=name) name=name)
class SparseConvTranspose2d(SparseConvolution): class SparseConvTranspose2d(SparseConvolution):
def __init__(self, def __init__(self,
in_channels, in_channels,
out_channels, out_channels,
...@@ -918,6 +955,7 @@ class SparseConvTranspose2d(SparseConvolution): ...@@ -918,6 +955,7 @@ class SparseConvTranspose2d(SparseConvolution):
algo: Optional[ConvAlgo] = None, algo: Optional[ConvAlgo] = None,
fp32_accum: Optional[bool] = None, fp32_accum: Optional[bool] = None,
record_voxel_count: bool = False, record_voxel_count: bool = False,
large_kernel_fast_algo: bool = False,
name=None): name=None):
super(SparseConvTranspose2d, super(SparseConvTranspose2d,
self).__init__(2, self).__init__(2,
...@@ -933,11 +971,13 @@ class SparseConvTranspose2d(SparseConvolution): ...@@ -933,11 +971,13 @@ class SparseConvTranspose2d(SparseConvolution):
indice_key=indice_key, indice_key=indice_key,
algo=algo, algo=algo,
fp32_accum=fp32_accum, fp32_accum=fp32_accum,
large_kernel_fast_algo=large_kernel_fast_algo,
record_voxel_count=record_voxel_count, record_voxel_count=record_voxel_count,
name=name) name=name)
class SparseConvTranspose3d(SparseConvolution): class SparseConvTranspose3d(SparseConvolution):
def __init__(self, def __init__(self,
in_channels, in_channels,
out_channels, out_channels,
...@@ -951,6 +991,7 @@ class SparseConvTranspose3d(SparseConvolution): ...@@ -951,6 +991,7 @@ class SparseConvTranspose3d(SparseConvolution):
algo: Optional[ConvAlgo] = None, algo: Optional[ConvAlgo] = None,
fp32_accum: Optional[bool] = None, fp32_accum: Optional[bool] = None,
record_voxel_count: bool = False, record_voxel_count: bool = False,
large_kernel_fast_algo: bool = False,
name=None): name=None):
super(SparseConvTranspose3d, super(SparseConvTranspose3d,
self).__init__(3, self).__init__(3,
...@@ -966,11 +1007,13 @@ class SparseConvTranspose3d(SparseConvolution): ...@@ -966,11 +1007,13 @@ class SparseConvTranspose3d(SparseConvolution):
indice_key=indice_key, indice_key=indice_key,
algo=algo, algo=algo,
fp32_accum=fp32_accum, fp32_accum=fp32_accum,
large_kernel_fast_algo=large_kernel_fast_algo,
record_voxel_count=record_voxel_count, record_voxel_count=record_voxel_count,
name=name) name=name)
class SparseConvTranspose4d(SparseConvolution): class SparseConvTranspose4d(SparseConvolution):
def __init__(self, def __init__(self,
in_channels, in_channels,
out_channels, out_channels,
...@@ -984,6 +1027,7 @@ class SparseConvTranspose4d(SparseConvolution): ...@@ -984,6 +1027,7 @@ class SparseConvTranspose4d(SparseConvolution):
algo: Optional[ConvAlgo] = None, algo: Optional[ConvAlgo] = None,
fp32_accum: Optional[bool] = None, fp32_accum: Optional[bool] = None,
record_voxel_count: bool = False, record_voxel_count: bool = False,
large_kernel_fast_algo: bool = False,
name=None): name=None):
super(SparseConvTranspose4d, super(SparseConvTranspose4d,
self).__init__(4, self).__init__(4,
...@@ -999,11 +1043,13 @@ class SparseConvTranspose4d(SparseConvolution): ...@@ -999,11 +1043,13 @@ class SparseConvTranspose4d(SparseConvolution):
indice_key=indice_key, indice_key=indice_key,
algo=algo, algo=algo,
fp32_accum=fp32_accum, fp32_accum=fp32_accum,
large_kernel_fast_algo=large_kernel_fast_algo,
record_voxel_count=record_voxel_count, record_voxel_count=record_voxel_count,
name=name) name=name)
class SparseInverseConv1d(SparseConvolution): class SparseInverseConv1d(SparseConvolution):
def __init__(self, def __init__(self,
in_channels, in_channels,
out_channels, out_channels,
...@@ -1012,20 +1058,24 @@ class SparseInverseConv1d(SparseConvolution): ...@@ -1012,20 +1058,24 @@ class SparseInverseConv1d(SparseConvolution):
bias=True, bias=True,
algo: Optional[ConvAlgo] = None, algo: Optional[ConvAlgo] = None,
fp32_accum: Optional[bool] = None, fp32_accum: Optional[bool] = None,
large_kernel_fast_algo: bool = False,
name=None): name=None):
super(SparseInverseConv1d, self).__init__(1, super(SparseInverseConv1d,
in_channels, self).__init__(1,
out_channels, in_channels,
kernel_size, out_channels,
bias=bias, kernel_size,
inverse=True, bias=bias,
indice_key=indice_key, inverse=True,
algo=algo, indice_key=indice_key,
fp32_accum=fp32_accum, algo=algo,
name=name) fp32_accum=fp32_accum,
large_kernel_fast_algo=large_kernel_fast_algo,
name=name)
class SparseInverseConv2d(SparseConvolution): class SparseInverseConv2d(SparseConvolution):
def __init__(self, def __init__(self,
in_channels, in_channels,
out_channels, out_channels,
...@@ -1034,20 +1084,24 @@ class SparseInverseConv2d(SparseConvolution): ...@@ -1034,20 +1084,24 @@ class SparseInverseConv2d(SparseConvolution):
bias=True, bias=True,
algo: Optional[ConvAlgo] = None, algo: Optional[ConvAlgo] = None,
fp32_accum: Optional[bool] = None, fp32_accum: Optional[bool] = None,
large_kernel_fast_algo: bool = False,
name=None): name=None):
super(SparseInverseConv2d, self).__init__(2, super(SparseInverseConv2d,
in_channels, self).__init__(2,
out_channels, in_channels,
kernel_size, out_channels,
bias=bias, kernel_size,
inverse=True, bias=bias,
indice_key=indice_key, inverse=True,
algo=algo, indice_key=indice_key,
fp32_accum=fp32_accum, algo=algo,
name=name) fp32_accum=fp32_accum,
large_kernel_fast_algo=large_kernel_fast_algo,
name=name)
class SparseInverseConv3d(SparseConvolution): class SparseInverseConv3d(SparseConvolution):
def __init__(self, def __init__(self,
in_channels, in_channels,
out_channels, out_channels,
...@@ -1056,20 +1110,24 @@ class SparseInverseConv3d(SparseConvolution): ...@@ -1056,20 +1110,24 @@ class SparseInverseConv3d(SparseConvolution):
bias=True, bias=True,
algo: Optional[ConvAlgo] = None, algo: Optional[ConvAlgo] = None,
fp32_accum: Optional[bool] = None, fp32_accum: Optional[bool] = None,
large_kernel_fast_algo: bool = False,
name=None): name=None):
super(SparseInverseConv3d, self).__init__(3, super(SparseInverseConv3d,
in_channels, self).__init__(3,
out_channels, in_channels,
kernel_size, out_channels,
bias=bias, kernel_size,
inverse=True, bias=bias,
indice_key=indice_key, inverse=True,
algo=algo, indice_key=indice_key,
fp32_accum=fp32_accum, algo=algo,
name=name) fp32_accum=fp32_accum,
large_kernel_fast_algo=large_kernel_fast_algo,
name=name)
class SparseInverseConv4d(SparseConvolution): class SparseInverseConv4d(SparseConvolution):
def __init__(self, def __init__(self,
in_channels, in_channels,
out_channels, out_channels,
...@@ -1078,20 +1136,24 @@ class SparseInverseConv4d(SparseConvolution): ...@@ -1078,20 +1136,24 @@ class SparseInverseConv4d(SparseConvolution):
bias=True, bias=True,
algo: Optional[ConvAlgo] = None, algo: Optional[ConvAlgo] = None,
fp32_accum: Optional[bool] = None, fp32_accum: Optional[bool] = None,
large_kernel_fast_algo: bool = False,
name=None): name=None):
super(SparseInverseConv4d, self).__init__(4, super(SparseInverseConv4d,
in_channels, self).__init__(4,
out_channels, in_channels,
kernel_size, out_channels,
bias=bias, kernel_size,
inverse=True, bias=bias,
indice_key=indice_key, inverse=True,
algo=algo, indice_key=indice_key,
fp32_accum=fp32_accum, algo=algo,
name=name) fp32_accum=fp32_accum,
large_kernel_fast_algo=large_kernel_fast_algo,
name=name)
class SubMConv1d(SparseConvolution): class SubMConv1d(SparseConvolution):
def __init__(self, def __init__(self,
in_channels, in_channels,
out_channels, out_channels,
...@@ -1104,24 +1166,28 @@ class SubMConv1d(SparseConvolution): ...@@ -1104,24 +1166,28 @@ class SubMConv1d(SparseConvolution):
indice_key=None, indice_key=None,
algo: Optional[ConvAlgo] = None, algo: Optional[ConvAlgo] = None,
fp32_accum: Optional[bool] = None, fp32_accum: Optional[bool] = None,
large_kernel_fast_algo: bool = False,
name=None): name=None):
super(SubMConv1d, self).__init__(1, super(SubMConv1d,
in_channels, self).__init__(1,
out_channels, in_channels,
kernel_size, out_channels,
stride, kernel_size,
padding, stride,
dilation, padding,
groups, dilation,
bias, groups,
True, bias,
indice_key=indice_key, True,
algo=algo, indice_key=indice_key,
fp32_accum=fp32_accum, algo=algo,
name=name) fp32_accum=fp32_accum,
large_kernel_fast_algo=large_kernel_fast_algo,
name=name)
class SubMConv2d(SparseConvolution): class SubMConv2d(SparseConvolution):
def __init__(self, def __init__(self,
in_channels, in_channels,
out_channels, out_channels,
...@@ -1134,24 +1200,28 @@ class SubMConv2d(SparseConvolution): ...@@ -1134,24 +1200,28 @@ class SubMConv2d(SparseConvolution):
indice_key=None, indice_key=None,
algo: Optional[ConvAlgo] = None, algo: Optional[ConvAlgo] = None,
fp32_accum: Optional[bool] = None, fp32_accum: Optional[bool] = None,
large_kernel_fast_algo: bool = False,
name=None): name=None):
super(SubMConv2d, self).__init__(2, super(SubMConv2d,
in_channels, self).__init__(2,
out_channels, in_channels,
kernel_size, out_channels,
stride, kernel_size,
padding, stride,
dilation, padding,
groups, dilation,
bias, groups,
True, bias,
indice_key=indice_key, True,
algo=algo, indice_key=indice_key,
fp32_accum=fp32_accum, algo=algo,
name=name) fp32_accum=fp32_accum,
large_kernel_fast_algo=large_kernel_fast_algo,
name=name)
class SubMConv3d(SparseConvolution): class SubMConv3d(SparseConvolution):
def __init__(self, def __init__(self,
in_channels, in_channels,
out_channels, out_channels,
...@@ -1164,24 +1234,28 @@ class SubMConv3d(SparseConvolution): ...@@ -1164,24 +1234,28 @@ class SubMConv3d(SparseConvolution):
indice_key=None, indice_key=None,
algo: Optional[ConvAlgo] = None, algo: Optional[ConvAlgo] = None,
fp32_accum: Optional[bool] = None, fp32_accum: Optional[bool] = None,
large_kernel_fast_algo: bool = False,
name=None): name=None):
super(SubMConv3d, self).__init__(3, super(SubMConv3d,
in_channels, self).__init__(3,
out_channels, in_channels,
kernel_size, out_channels,
stride, kernel_size,
padding, stride,
dilation, padding,
groups, dilation,
bias, groups,
True, bias,
indice_key=indice_key, True,
algo=algo, indice_key=indice_key,
fp32_accum=fp32_accum, algo=algo,
name=name) fp32_accum=fp32_accum,
large_kernel_fast_algo=large_kernel_fast_algo,
name=name)
class SubMConv4d(SparseConvolution): class SubMConv4d(SparseConvolution):
def __init__(self, def __init__(self,
in_channels, in_channels,
out_channels, out_channels,
...@@ -1194,21 +1268,24 @@ class SubMConv4d(SparseConvolution): ...@@ -1194,21 +1268,24 @@ class SubMConv4d(SparseConvolution):
indice_key=None, indice_key=None,
algo: Optional[ConvAlgo] = None, algo: Optional[ConvAlgo] = None,
fp32_accum: Optional[bool] = None, fp32_accum: Optional[bool] = None,
large_kernel_fast_algo: bool = False,
name=None): name=None):
super(SubMConv4d, self).__init__(4, super(SubMConv4d,
in_channels, self).__init__(4,
out_channels, in_channels,
kernel_size, out_channels,
stride, kernel_size,
padding, stride,
dilation, padding,
groups, dilation,
bias, groups,
True, bias,
indice_key=indice_key, True,
algo=algo, indice_key=indice_key,
fp32_accum=fp32_accum, algo=algo,
name=name) fp32_accum=fp32_accum,
large_kernel_fast_algo=large_kernel_fast_algo,
name=name)
DEFAULT_SPARSE_CONV_TYPES = { DEFAULT_SPARSE_CONV_TYPES = {
...@@ -1229,4 +1306,3 @@ DEFAULT_SPARSE_CONV_TYPES = { ...@@ -1229,4 +1306,3 @@ DEFAULT_SPARSE_CONV_TYPES = {
SparseConvTranspose3d, SparseConvTranspose3d,
SparseConvTranspose4d, SparseConvTranspose4d,
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment