Commit cd99e7a6 authored by yan.yan's avatar yan.yan
Browse files

fix #575 use a flag to enable large-kernel algo

parent f101f97e
......@@ -2,7 +2,7 @@
## [2.3.5] - 2023-03-24
### Fixed
- pypi project reach size limit, so we need to assign a new version number.
- use a flag to enable large kernel algo (need time to compile at runtime)
## [2.3.4] - 2023-03-23
### Added
......
......@@ -26,4 +26,6 @@
* spconv 2.x in Windows 10 is 1.5x~2x slower than Linux. use Linux if possible.
* If you train with float32 and ampere or later GPUs, you can set ```spconv.constants.SPCONV_ALLOW_TF32``` to enable faster fp32 training.
See [benchmark](BENCHMARK.md) for more performance details of different algorithms.
* Different CUDA version of spconv may have different performance. Use newest cuda version if possible. For example, spconv-cu117 is faster than spconv-cu114, spconv-cu114 is faster than spconv-cu111.
\ No newline at end of file
* Different CUDA version of spconv may have different performance. Use newest cuda version if possible. For example, spconv-cu117 is faster than spconv-cu114, spconv-cu114 is faster than spconv-cu111.
* if your kernel size volume larger than 32, spconv will use a slower (and more inaccurate in fp16) algorithm. to use a faster algo for large kernel size (need time to compile at runtime), use ```large_kernel_fast_algo=True```
* use ```SparseGlobalMaxPool``` instead of use large kernel size when you need global pool.
\ No newline at end of file
......@@ -43,7 +43,9 @@ FILTER_HWIO = False
_MAX_NUM_VOXELS_DURING_TRAINING = "max_num_voxels_during_training"
def _apply_act(x: torch.Tensor, act_type: tv.gemm.Activation, act_alpha: float, act_beta: float):
def _apply_act(x: torch.Tensor, act_type: tv.gemm.Activation, act_alpha: float,
act_beta: float):
if act_type == tv.gemm.Activation.None_:
return x
elif act_type == tv.gemm.Activation.ReLU:
......@@ -55,7 +57,9 @@ def _apply_act(x: torch.Tensor, act_type: tv.gemm.Activation, act_alpha: float,
else:
raise NotImplementedError
class SparseConvolutionBase:
def __init__(self,
ndim: int,
in_channels: int,
......@@ -76,7 +80,8 @@ class SparseConvolutionBase:
record_voxel_count: bool = False,
act_type: tv.gemm.Activation = tv.gemm.Activation.None_,
act_alpha: float = 0,
act_beta: float = 0):
act_beta: float = 0,
large_kernel_fast_algo: bool = False):
assert groups == 1, "don't support groups for now"
self.ndim = ndim
self.in_channels = in_channels
......@@ -103,7 +108,10 @@ class SparseConvolutionBase:
self.indice_key = indice_key
self.record_voxel_count = record_voxel_count
if algo is None:
if kv <= 128 and not CPU_ONLY_BUILD:
limit = 32
if large_kernel_fast_algo:
limit = 128
if kv <= limit and not CPU_ONLY_BUILD:
if kv < 8:
algo = ConvAlgo.MaskImplicitGemm
else:
......@@ -117,7 +125,7 @@ class SparseConvolutionBase:
self.algo = algo
self.fp32_accum = fp32_accum
# self.algo = ConvAlgo.Native
if self.algo == ConvAlgo.Native and not ALL_WEIGHT_IS_KRSC:
if FILTER_HWIO:
# RSCK
......@@ -136,16 +144,23 @@ class SparseConvolutionBase:
self.zero_point = 0
if self.conv1x1:
assert act_type == tv.gemm.Activation.None_, "conv1x1 don't support fused act"
def is_inverseable(self):
return self.indice_key is not None and not self.subm
def _conv_forward(self, training: bool, input: SparseConvTensor, weight: torch.Tensor, bias: Optional[torch.Tensor], add_input: Optional[SparseConvTensor] = None,
channel_scale: Optional[torch.Tensor] = None, output_scale: Optional[float] = None, name: Optional[str] = None,
sparse_unique_name: str = "",
act_type: tv.gemm.Activation = tv.gemm.Activation.None_,
act_alpha: float = 0,
act_beta: float = 0):
def _conv_forward(self,
training: bool,
input: SparseConvTensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor],
add_input: Optional[SparseConvTensor] = None,
channel_scale: Optional[torch.Tensor] = None,
output_scale: Optional[float] = None,
name: Optional[str] = None,
sparse_unique_name: str = "",
act_type: tv.gemm.Activation = tv.gemm.Activation.None_,
act_alpha: float = 0,
act_beta: float = 0):
# assert isinstance(input, SparseConvTensor)
is_int8 = input.is_quantized and weight.is_quantized
if is_int8:
......@@ -165,9 +180,9 @@ class SparseConvolutionBase:
if add_input is not None:
output_add_scale = add_input.q_scale()
if training:
msg = "act don't support backward, only used in inference"
msg = "act don't support backward, only used in inference"
assert self.act_type == tv.gemm.Activation.None_, msg
if not self.subm:
if self.transposed:
out_spatial_shape = ops.get_deconv_output_size(
......@@ -252,7 +267,7 @@ class SparseConvolutionBase:
indice_pair_num = datas.indice_pair_num
out_spatial_shape = datas.spatial_shape
self._check_inverse_reuse_valid(input, spatial_shape,
datas)
datas)
else:
if self.indice_key is not None and datas is not None:
outids = datas.out_indices
......@@ -283,8 +298,8 @@ class SparseConvolutionBase:
if input.benchmark:
torch.cuda.synchronize()
interval = time.time() - t
out_tensor.benchmark_record[
name]["indice_gen_time"].append(interval)
out_tensor.benchmark_record[name][
"indice_gen_time"].append(interval)
indice_data = IndiceData(outids,
indices,
......@@ -310,43 +325,21 @@ class SparseConvolutionBase:
indice_pairs_calc = indice_pairs.to(features.device)
if self.subm:
out_features = Fsp.indice_subm_conv(
features,
weight,
indice_pairs_calc,
indice_pair_num,
outids.shape[0],
algo,
input._timer,
bias_for_infer,
act_alpha,
act_beta,
act_type)
features, weight, indice_pairs_calc, indice_pair_num,
outids.shape[0], algo, input._timer, bias_for_infer,
act_alpha, act_beta, act_type)
else:
if self.inverse:
out_features = Fsp.indice_inverse_conv(
features,
weight,
indice_pairs_calc,
indice_pair_num,
outids.shape[0],
algo,
input._timer,
bias_for_infer,
act_alpha,
act_beta,
features, weight, indice_pairs_calc,
indice_pair_num, outids.shape[0], algo,
input._timer, bias_for_infer, act_alpha, act_beta,
act_type)
else:
out_features = Fsp.indice_conv(
features,
weight,
indice_pairs_calc,
indice_pair_num,
outids.shape[0],
algo,
input._timer,
bias_for_infer,
act_type,
act_beta,
features, weight, indice_pairs_calc,
indice_pair_num, outids.shape[0], algo,
input._timer, bias_for_infer, act_type, act_beta,
act_type)
else:
datas = input.find_indice_pair(self.indice_key)
......@@ -367,7 +360,7 @@ class SparseConvolutionBase:
# assert datas.ksize == self.kernel_size, "inverse conv must have same kernel size as its couple conv"
self._check_inverse_reuse_valid(input, spatial_shape,
datas)
datas)
else:
if self.indice_key is not None and datas is not None:
outids = datas.out_indices
......@@ -416,8 +409,8 @@ class SparseConvolutionBase:
if input.benchmark:
torch.cuda.synchronize()
interval = time.time() - t
out_tensor.benchmark_record[
name]["indice_gen_time"].append(interval)
out_tensor.benchmark_record[name][
"indice_gen_time"].append(interval)
outids = res[0]
num_inds_per_loc = res[1]
pair_fwd = res[2]
......@@ -465,28 +458,34 @@ class SparseConvolutionBase:
pair_mask_fwd_splits, pair_mask_bwd_splits,
mask_argsort_fwd_splits, mask_argsort_bwd_splits,
num_activate_out, masks, training, self.subm,
input._timer, self.fp32_accum,
bias_cur,
act_alpha,
act_beta,
act_type)
input._timer, self.fp32_accum, bias_cur, act_alpha,
act_beta, act_type)
else:
output_dtype = None
output_dtype = None
if output_scale is None:
output_dtype = weight.dtype
out_features, _, _ = ops.implicit_gemm(
features, weight_cur, pair_fwd, pair_mask_fwd_splits,
mask_argsort_fwd_splits,
num_activate_out, masks, training, self.subm,
input._timer, self.fp32_accum,
features,
weight_cur,
pair_fwd,
pair_mask_fwd_splits,
mask_argsort_fwd_splits,
num_activate_out,
masks,
training,
self.subm,
input._timer,
self.fp32_accum,
bias_cur,
act_alpha,
act_beta,
act_type,
# TODO do we really need output scale to scale bias in kernel?
1.0 if output_scale is None else output_scale, # output_scale
channel_scale, # scale
output_add=add_input.features if add_input is not None else None,
1.0 if output_scale is None else
output_scale, # output_scale
channel_scale, # scale
output_add=add_input.features
if add_input is not None else None,
output_add_scale=output_add_scale,
output_dtype=output_dtype)
......@@ -511,8 +510,10 @@ class SparseConvolutionBase:
out_tensor.spatial_shape = out_spatial_shape
if add_input is not None and not is_int8:
# in int8, we apply add + act in kernel.
out_tensor = out_tensor.replace_feature(_apply_act(out_tensor.features + add_input.features, self.act_type, self.act_alpha, self.act_beta))
out_tensor = out_tensor.replace_feature(
_apply_act(out_tensor.features + add_input.features,
self.act_type, self.act_alpha, self.act_beta))
return out_tensor
def _check_subm_reuse_valid(self, inp: SparseConvTensor,
......@@ -539,9 +540,9 @@ class SparseConvolutionBase:
)
def _check_inverse_reuse_valid(self, inp: SparseConvTensor,
spatial_shape: List[int],
datas: Union[ImplicitGemmIndiceData,
IndiceData]):
spatial_shape: List[int],
datas: Union[ImplicitGemmIndiceData,
IndiceData]):
if self.kernel_size != datas.ksize:
raise ValueError(
f"Inverse with same indice_key must have same kernel"
......@@ -556,8 +557,8 @@ class SparseConvolutionBase:
raise ValueError(
f"Inverse with same indice_key must have same num of indices"
f", expect {datas.indices.shape[0]}, input {inp.indices.shape[0]}, "
"please check Inverse Convolution in ."
)
"please check Inverse Convolution in .")
class SparseConvolution(SparseConvolutionBase, SparseModule):
__constants__ = [
......@@ -586,11 +587,21 @@ class SparseConvolution(SparseConvolutionBase, SparseModule):
act_type: tv.gemm.Activation = tv.gemm.Activation.None_,
act_alpha: float = 0,
act_beta: float = 0,
large_kernel_fast_algo: bool = False,
name=None,
device=None,
device=None,
dtype=None):
factory_kwargs = {"device": device, "dtype": dtype}
SparseConvolutionBase.__init__(self, ndim, in_channels, out_channels, kernel_size, stride, padding, dilation, groups,
SparseConvolutionBase.__init__(
self,
ndim,
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
groups,
bias=False,
subm=subm,
output_padding=output_padding,
......@@ -602,14 +613,17 @@ class SparseConvolution(SparseConvolutionBase, SparseModule):
record_voxel_count=record_voxel_count,
act_type=act_type,
act_alpha=act_alpha,
act_beta=act_beta)
act_beta=act_beta,
large_kernel_fast_algo=large_kernel_fast_algo)
SparseModule.__init__(self, name=name)
if record_voxel_count and not self.subm and not self.inverse:
# we record maximum voxel num in both inference and training if
# record_voxel_count flag setting.
self.register_buffer(_MAX_NUM_VOXELS_DURING_TRAINING,
torch.zeros(1, dtype=torch.int32, device=device))
self.weight = Parameter(torch.zeros(*self.weight_shape, **factory_kwargs))
self.register_buffer(
_MAX_NUM_VOXELS_DURING_TRAINING,
torch.zeros(1, dtype=torch.int32, device=device))
self.weight = Parameter(
torch.zeros(*self.weight_shape, **factory_kwargs))
if bias:
self.bias = Parameter(torch.zeros(out_channels, **factory_kwargs))
else:
......@@ -636,8 +650,7 @@ class SparseConvolution(SparseConvolutionBase, SparseModule):
error_msgs):
name = prefix + _MAX_NUM_VOXELS_DURING_TRAINING
if self.record_voxel_count and not self.subm and not self.inverse and name not in state_dict:
state_dict[name] = torch.zeros(
1, dtype=torch.int32)
state_dict[name] = torch.zeros(1, dtype=torch.int32)
if not SAVED_WEIGHT_LAYOUT:
return
key = prefix + "weight"
......@@ -736,14 +749,23 @@ class SparseConvolution(SparseConvolutionBase, SparseModule):
bound = 1 / math.sqrt(fan_in)
init.uniform_(self.bias, -bound, bound)
def forward(self, input: SparseConvTensor, add_input: Optional[SparseConvTensor] = None):
return self._conv_forward(self.training, input, self.weight, self.bias, add_input,
name=self.name, sparse_unique_name=self._sparse_unique_name, act_type=self.act_type,
act_alpha=self.act_alpha, act_beta=self.act_beta)
def forward(self,
input: SparseConvTensor,
add_input: Optional[SparseConvTensor] = None):
return self._conv_forward(self.training,
input,
self.weight,
self.bias,
add_input,
name=self.name,
sparse_unique_name=self._sparse_unique_name,
act_type=self.act_type,
act_alpha=self.act_alpha,
act_beta=self.act_beta)
class SparseConv1d(SparseConvolution):
def __init__(self,
in_channels,
out_channels,
......@@ -757,6 +779,7 @@ class SparseConv1d(SparseConvolution):
algo: Optional[ConvAlgo] = None,
fp32_accum: Optional[bool] = None,
record_voxel_count: bool = False,
large_kernel_fast_algo: bool = False,
name=None):
super(SparseConv1d,
self).__init__(1,
......@@ -772,10 +795,12 @@ class SparseConv1d(SparseConvolution):
algo=algo,
fp32_accum=fp32_accum,
record_voxel_count=record_voxel_count,
large_kernel_fast_algo=large_kernel_fast_algo,
name=name)
class SparseConv2d(SparseConvolution):
def __init__(self,
in_channels,
out_channels,
......@@ -789,6 +814,7 @@ class SparseConv2d(SparseConvolution):
algo: Optional[ConvAlgo] = None,
fp32_accum: Optional[bool] = None,
record_voxel_count: bool = False,
large_kernel_fast_algo: bool = False,
name=None):
super(SparseConv2d,
self).__init__(2,
......@@ -803,11 +829,13 @@ class SparseConv2d(SparseConvolution):
indice_key=indice_key,
algo=algo,
fp32_accum=fp32_accum,
large_kernel_fast_algo=large_kernel_fast_algo,
record_voxel_count=record_voxel_count,
name=name)
class SparseConv3d(SparseConvolution):
def __init__(self,
in_channels,
out_channels,
......@@ -821,6 +849,7 @@ class SparseConv3d(SparseConvolution):
algo: Optional[ConvAlgo] = None,
fp32_accum: Optional[bool] = None,
record_voxel_count: bool = False,
large_kernel_fast_algo: bool = False,
name=None):
super(SparseConv3d,
self).__init__(3,
......@@ -835,11 +864,13 @@ class SparseConv3d(SparseConvolution):
indice_key=indice_key,
algo=algo,
fp32_accum=fp32_accum,
large_kernel_fast_algo=large_kernel_fast_algo,
record_voxel_count=record_voxel_count,
name=name)
class SparseConv4d(SparseConvolution):
def __init__(self,
in_channels,
out_channels,
......@@ -853,6 +884,7 @@ class SparseConv4d(SparseConvolution):
algo: Optional[ConvAlgo] = None,
fp32_accum: Optional[bool] = None,
record_voxel_count: bool = False,
large_kernel_fast_algo: bool = False,
name=None):
super(SparseConv4d,
self).__init__(4,
......@@ -867,11 +899,13 @@ class SparseConv4d(SparseConvolution):
indice_key=indice_key,
algo=algo,
fp32_accum=fp32_accum,
large_kernel_fast_algo=large_kernel_fast_algo,
record_voxel_count=record_voxel_count,
name=name)
class SparseConvTranspose1d(SparseConvolution):
def __init__(self,
in_channels,
out_channels,
......@@ -885,6 +919,7 @@ class SparseConvTranspose1d(SparseConvolution):
algo: Optional[ConvAlgo] = None,
fp32_accum: Optional[bool] = None,
record_voxel_count: bool = False,
large_kernel_fast_algo: bool = False,
name=None):
super(SparseConvTranspose1d,
self).__init__(1,
......@@ -900,11 +935,13 @@ class SparseConvTranspose1d(SparseConvolution):
indice_key=indice_key,
algo=algo,
fp32_accum=fp32_accum,
large_kernel_fast_algo=large_kernel_fast_algo,
record_voxel_count=record_voxel_count,
name=name)
class SparseConvTranspose2d(SparseConvolution):
def __init__(self,
in_channels,
out_channels,
......@@ -918,6 +955,7 @@ class SparseConvTranspose2d(SparseConvolution):
algo: Optional[ConvAlgo] = None,
fp32_accum: Optional[bool] = None,
record_voxel_count: bool = False,
large_kernel_fast_algo: bool = False,
name=None):
super(SparseConvTranspose2d,
self).__init__(2,
......@@ -933,11 +971,13 @@ class SparseConvTranspose2d(SparseConvolution):
indice_key=indice_key,
algo=algo,
fp32_accum=fp32_accum,
large_kernel_fast_algo=large_kernel_fast_algo,
record_voxel_count=record_voxel_count,
name=name)
class SparseConvTranspose3d(SparseConvolution):
def __init__(self,
in_channels,
out_channels,
......@@ -951,6 +991,7 @@ class SparseConvTranspose3d(SparseConvolution):
algo: Optional[ConvAlgo] = None,
fp32_accum: Optional[bool] = None,
record_voxel_count: bool = False,
large_kernel_fast_algo: bool = False,
name=None):
super(SparseConvTranspose3d,
self).__init__(3,
......@@ -966,11 +1007,13 @@ class SparseConvTranspose3d(SparseConvolution):
indice_key=indice_key,
algo=algo,
fp32_accum=fp32_accum,
large_kernel_fast_algo=large_kernel_fast_algo,
record_voxel_count=record_voxel_count,
name=name)
class SparseConvTranspose4d(SparseConvolution):
def __init__(self,
in_channels,
out_channels,
......@@ -984,6 +1027,7 @@ class SparseConvTranspose4d(SparseConvolution):
algo: Optional[ConvAlgo] = None,
fp32_accum: Optional[bool] = None,
record_voxel_count: bool = False,
large_kernel_fast_algo: bool = False,
name=None):
super(SparseConvTranspose4d,
self).__init__(4,
......@@ -999,11 +1043,13 @@ class SparseConvTranspose4d(SparseConvolution):
indice_key=indice_key,
algo=algo,
fp32_accum=fp32_accum,
large_kernel_fast_algo=large_kernel_fast_algo,
record_voxel_count=record_voxel_count,
name=name)
class SparseInverseConv1d(SparseConvolution):
def __init__(self,
in_channels,
out_channels,
......@@ -1012,20 +1058,24 @@ class SparseInverseConv1d(SparseConvolution):
bias=True,
algo: Optional[ConvAlgo] = None,
fp32_accum: Optional[bool] = None,
large_kernel_fast_algo: bool = False,
name=None):
super(SparseInverseConv1d, self).__init__(1,
in_channels,
out_channels,
kernel_size,
bias=bias,
inverse=True,
indice_key=indice_key,
algo=algo,
fp32_accum=fp32_accum,
name=name)
super(SparseInverseConv1d,
self).__init__(1,
in_channels,
out_channels,
kernel_size,
bias=bias,
inverse=True,
indice_key=indice_key,
algo=algo,
fp32_accum=fp32_accum,
large_kernel_fast_algo=large_kernel_fast_algo,
name=name)
class SparseInverseConv2d(SparseConvolution):
def __init__(self,
in_channels,
out_channels,
......@@ -1034,20 +1084,24 @@ class SparseInverseConv2d(SparseConvolution):
bias=True,
algo: Optional[ConvAlgo] = None,
fp32_accum: Optional[bool] = None,
large_kernel_fast_algo: bool = False,
name=None):
super(SparseInverseConv2d, self).__init__(2,
in_channels,
out_channels,
kernel_size,
bias=bias,
inverse=True,
indice_key=indice_key,
algo=algo,
fp32_accum=fp32_accum,
name=name)
super(SparseInverseConv2d,
self).__init__(2,
in_channels,
out_channels,
kernel_size,
bias=bias,
inverse=True,
indice_key=indice_key,
algo=algo,
fp32_accum=fp32_accum,
large_kernel_fast_algo=large_kernel_fast_algo,
name=name)
class SparseInverseConv3d(SparseConvolution):
def __init__(self,
in_channels,
out_channels,
......@@ -1056,20 +1110,24 @@ class SparseInverseConv3d(SparseConvolution):
bias=True,
algo: Optional[ConvAlgo] = None,
fp32_accum: Optional[bool] = None,
large_kernel_fast_algo: bool = False,
name=None):
super(SparseInverseConv3d, self).__init__(3,
in_channels,
out_channels,
kernel_size,
bias=bias,
inverse=True,
indice_key=indice_key,
algo=algo,
fp32_accum=fp32_accum,
name=name)
super(SparseInverseConv3d,
self).__init__(3,
in_channels,
out_channels,
kernel_size,
bias=bias,
inverse=True,
indice_key=indice_key,
algo=algo,
fp32_accum=fp32_accum,
large_kernel_fast_algo=large_kernel_fast_algo,
name=name)
class SparseInverseConv4d(SparseConvolution):
def __init__(self,
in_channels,
out_channels,
......@@ -1078,20 +1136,24 @@ class SparseInverseConv4d(SparseConvolution):
bias=True,
algo: Optional[ConvAlgo] = None,
fp32_accum: Optional[bool] = None,
large_kernel_fast_algo: bool = False,
name=None):
super(SparseInverseConv4d, self).__init__(4,
in_channels,
out_channels,
kernel_size,
bias=bias,
inverse=True,
indice_key=indice_key,
algo=algo,
fp32_accum=fp32_accum,
name=name)
super(SparseInverseConv4d,
self).__init__(4,
in_channels,
out_channels,
kernel_size,
bias=bias,
inverse=True,
indice_key=indice_key,
algo=algo,
fp32_accum=fp32_accum,
large_kernel_fast_algo=large_kernel_fast_algo,
name=name)
class SubMConv1d(SparseConvolution):
def __init__(self,
in_channels,
out_channels,
......@@ -1104,24 +1166,28 @@ class SubMConv1d(SparseConvolution):
indice_key=None,
algo: Optional[ConvAlgo] = None,
fp32_accum: Optional[bool] = None,
large_kernel_fast_algo: bool = False,
name=None):
super(SubMConv1d, self).__init__(1,
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
groups,
bias,
True,
indice_key=indice_key,
algo=algo,
fp32_accum=fp32_accum,
name=name)
super(SubMConv1d,
self).__init__(1,
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
groups,
bias,
True,
indice_key=indice_key,
algo=algo,
fp32_accum=fp32_accum,
large_kernel_fast_algo=large_kernel_fast_algo,
name=name)
class SubMConv2d(SparseConvolution):
def __init__(self,
in_channels,
out_channels,
......@@ -1134,24 +1200,28 @@ class SubMConv2d(SparseConvolution):
indice_key=None,
algo: Optional[ConvAlgo] = None,
fp32_accum: Optional[bool] = None,
large_kernel_fast_algo: bool = False,
name=None):
super(SubMConv2d, self).__init__(2,
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
groups,
bias,
True,
indice_key=indice_key,
algo=algo,
fp32_accum=fp32_accum,
name=name)
super(SubMConv2d,
self).__init__(2,
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
groups,
bias,
True,
indice_key=indice_key,
algo=algo,
fp32_accum=fp32_accum,
large_kernel_fast_algo=large_kernel_fast_algo,
name=name)
class SubMConv3d(SparseConvolution):
def __init__(self,
in_channels,
out_channels,
......@@ -1164,24 +1234,28 @@ class SubMConv3d(SparseConvolution):
indice_key=None,
algo: Optional[ConvAlgo] = None,
fp32_accum: Optional[bool] = None,
large_kernel_fast_algo: bool = False,
name=None):
super(SubMConv3d, self).__init__(3,
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
groups,
bias,
True,
indice_key=indice_key,
algo=algo,
fp32_accum=fp32_accum,
name=name)
super(SubMConv3d,
self).__init__(3,
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
groups,
bias,
True,
indice_key=indice_key,
algo=algo,
fp32_accum=fp32_accum,
large_kernel_fast_algo=large_kernel_fast_algo,
name=name)
class SubMConv4d(SparseConvolution):
def __init__(self,
in_channels,
out_channels,
......@@ -1194,21 +1268,24 @@ class SubMConv4d(SparseConvolution):
indice_key=None,
algo: Optional[ConvAlgo] = None,
fp32_accum: Optional[bool] = None,
large_kernel_fast_algo: bool = False,
name=None):
super(SubMConv4d, self).__init__(4,
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
groups,
bias,
True,
indice_key=indice_key,
algo=algo,
fp32_accum=fp32_accum,
name=name)
super(SubMConv4d,
self).__init__(4,
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
groups,
bias,
True,
indice_key=indice_key,
algo=algo,
fp32_accum=fp32_accum,
large_kernel_fast_algo=large_kernel_fast_algo,
name=name)
DEFAULT_SPARSE_CONV_TYPES = {
......@@ -1229,4 +1306,3 @@ DEFAULT_SPARSE_CONV_TYPES = {
SparseConvTranspose3d,
SparseConvTranspose4d,
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment