Commit f8c25027 authored by yan.yan's avatar yan.yan
Browse files

add act

parent 99c8a0bd
...@@ -35,6 +35,7 @@ from spconv.pytorch.modules import SparseModule ...@@ -35,6 +35,7 @@ from spconv.pytorch.modules import SparseModule
from spconv.constants import SAVED_WEIGHT_LAYOUT, ALL_WEIGHT_IS_KRSC, SPCONV_DEBUG_WEIGHT from spconv.constants import SAVED_WEIGHT_LAYOUT, ALL_WEIGHT_IS_KRSC, SPCONV_DEBUG_WEIGHT
from spconv.utils import nullcontext from spconv.utils import nullcontext
from torch.nn.init import calculate_gain from torch.nn.init import calculate_gain
from cumm import tensorview as tv
FILTER_HWIO = False FILTER_HWIO = False
...@@ -65,6 +66,9 @@ class SparseConvolution(SparseModule): ...@@ -65,6 +66,9 @@ class SparseConvolution(SparseModule):
algo: Optional[ConvAlgo] = None, algo: Optional[ConvAlgo] = None,
fp32_accum: Optional[bool] = None, fp32_accum: Optional[bool] = None,
record_voxel_count: bool = False, record_voxel_count: bool = False,
act_type: tv.gemm.Activation = tv.gemm.Activation.None_,
act_alpha: float = 0,
act_beta: float = 0,
name=None): name=None):
super(SparseConvolution, self).__init__(name=name) super(SparseConvolution, self).__init__(name=name)
assert groups == 1, "don't support groups for now" assert groups == 1, "don't support groups for now"
...@@ -131,6 +135,12 @@ class SparseConvolution(SparseModule): ...@@ -131,6 +135,12 @@ class SparseConvolution(SparseModule):
self.bias = Parameter(torch.Tensor(out_channels)) self.bias = Parameter(torch.Tensor(out_channels))
else: else:
self.register_parameter('bias', None) self.register_parameter('bias', None)
self.act_type = act_type
self.act_alpha = act_alpha
self.act_beta = act_beta
if self.conv1x1:
assert act_type == tv.gemm.Activation.None_, "conv1x1 don't support fused act"
self.reset_parameters() self.reset_parameters()
if hasattr(self, "_register_load_state_dict_pre_hook"): if hasattr(self, "_register_load_state_dict_pre_hook"):
self._register_load_state_dict_pre_hook( self._register_load_state_dict_pre_hook(
...@@ -141,7 +151,6 @@ class SparseConvolution(SparseModule): ...@@ -141,7 +151,6 @@ class SparseConvolution(SparseModule):
return getattr(self, _MAX_NUM_VOXELS_DURING_TRAINING) return getattr(self, _MAX_NUM_VOXELS_DURING_TRAINING)
return None return None
def _load_weight_different_layout(self, state_dict, prefix, local_metadata, def _load_weight_different_layout(self, state_dict, prefix, local_metadata,
strict, missing_keys, unexpected_keys, strict, missing_keys, unexpected_keys,
error_msgs): error_msgs):
...@@ -255,6 +264,12 @@ class SparseConvolution(SparseModule): ...@@ -255,6 +264,12 @@ class SparseConvolution(SparseModule):
indices = input.indices indices = input.indices
spatial_shape = input.spatial_shape spatial_shape = input.spatial_shape
batch_size = input.batch_size batch_size = input.batch_size
bias_for_training = self.bias if self.training else None
bias_for_infer = self.bias if not self.training else None
if self.training:
msg = "act don't support backward, only used in inference"
assert self.act_type == tv.gemm.Activation.None_, msg
if not self.subm: if not self.subm:
if self.transposed: if self.transposed:
out_spatial_shape = ops.get_deconv_output_size( out_spatial_shape = ops.get_deconv_output_size(
...@@ -393,19 +408,43 @@ class SparseConvolution(SparseModule): ...@@ -393,19 +408,43 @@ class SparseConvolution(SparseModule):
indice_pairs_calc = indice_pairs.to(features.device) indice_pairs_calc = indice_pairs.to(features.device)
if self.subm: if self.subm:
out_features = Fsp.indice_subm_conv( out_features = Fsp.indice_subm_conv(
features, self.weight, indice_pairs_calc, features,
indice_pair_num, outids.shape[0], algo, input._timer) self.weight,
indice_pairs_calc,
indice_pair_num,
outids.shape[0],
algo,
input._timer,
bias_for_infer,
self.act_alpha,
self.act_beta,
self.act_type)
else: else:
if self.inverse: if self.inverse:
out_features = Fsp.indice_inverse_conv( out_features = Fsp.indice_inverse_conv(
features, self.weight, indice_pairs_calc, features,
indice_pair_num, outids.shape[0], algo) self.weight,
indice_pairs_calc,
indice_pair_num,
outids.shape[0],
algo,
bias_for_infer,
self.act_alpha,
self.act_beta,
self.act_type)
else: else:
out_features = Fsp.indice_conv(features, self.weight, out_features = Fsp.indice_conv(
features,
self.weight,
indice_pairs_calc, indice_pairs_calc,
indice_pair_num, indice_pair_num,
outids.shape[0], algo, outids.shape[0],
input._timer) algo,
input._timer,
bias_for_infer,
self.act_alpha,
self.act_beta,
self.act_type)
else: else:
datas = input.find_indice_pair(self.indice_key) datas = input.find_indice_pair(self.indice_key)
...@@ -507,9 +546,14 @@ class SparseConvolution(SparseModule): ...@@ -507,9 +546,14 @@ class SparseConvolution(SparseModule):
pair_mask_fwd_splits, pair_mask_bwd_splits, pair_mask_fwd_splits, pair_mask_bwd_splits,
mask_argsort_fwd_splits, mask_argsort_bwd_splits, mask_argsort_fwd_splits, mask_argsort_bwd_splits,
num_activate_out, masks, self.training, self.subm, num_activate_out, masks, self.training, self.subm,
input._timer, self.fp32_accum) input._timer, self.fp32_accum,
if self.bias is not None: bias_for_infer,
out_features += self.bias self.act_alpha,
self.act_beta,
self.act_type)
if bias_for_training is not None:
out_features += bias_for_training
if input.benchmark: if input.benchmark:
torch.cuda.synchronize() torch.cuda.synchronize()
interval = time.time() - t interval = time.time() - t
...@@ -519,12 +563,9 @@ class SparseConvolution(SparseModule): ...@@ -519,12 +563,9 @@ class SparseConvolution(SparseModule):
out_tensor.benchmark_record[self.name]["num_out_points"].append( out_tensor.benchmark_record[self.name]["num_out_points"].append(
out_features.shape[0]) out_features.shape[0])
if not self.subm and not self.inverse and self.record_voxel_count: if not self.subm and not self.inverse and self.record_voxel_count:
if hasattr(self, if hasattr(self, _MAX_NUM_VOXELS_DURING_TRAINING):
_MAX_NUM_VOXELS_DURING_TRAINING):
ops.maximum_value_int_( ops.maximum_value_int_(
getattr( getattr(self, _MAX_NUM_VOXELS_DURING_TRAINING),
self,
_MAX_NUM_VOXELS_DURING_TRAINING),
outids.shape[0]) outids.shape[0])
out_tensor = out_tensor.replace_feature(out_features) out_tensor = out_tensor.replace_feature(out_features)
out_tensor.indices = outids out_tensor.indices = outids
......
...@@ -31,14 +31,17 @@ from spconv.pytorch.hash import HashTable ...@@ -31,14 +31,17 @@ from spconv.pytorch.hash import HashTable
from cumm.gemm.layout import to_stride from cumm.gemm.layout import to_stride
from typing import List from typing import List
from functools import reduce from functools import reduce
from cumm import tensorview as tv
_MAX_INT32 = 2147483647 _MAX_INT32 = 2147483647
_T = TypeVar("_T") _T = TypeVar("_T")
def identity_decorator(func: _T) -> _T: def identity_decorator(func: _T) -> _T:
return func return func
if PYTORCH_VERSION >= [1, 6, 0]: if PYTORCH_VERSION >= [1, 6, 0]:
import torch.cuda.amp as amp import torch.cuda.amp as amp
_TORCH_CUSTOM_FWD = amp.custom_fwd(cast_inputs=torch.float16) _TORCH_CUSTOM_FWD = amp.custom_fwd(cast_inputs=torch.float16)
...@@ -48,6 +51,7 @@ else: ...@@ -48,6 +51,7 @@ else:
_TORCH_CUSTOM_FWD = identity_decorator _TORCH_CUSTOM_FWD = identity_decorator
_TORCH_CUSTOM_BWD = identity_decorator _TORCH_CUSTOM_BWD = identity_decorator
class SparseConvFunction(Function): class SparseConvFunction(Function):
@staticmethod @staticmethod
@_TORCH_CUSTOM_FWD @_TORCH_CUSTOM_FWD
...@@ -58,7 +62,11 @@ class SparseConvFunction(Function): ...@@ -58,7 +62,11 @@ class SparseConvFunction(Function):
indice_pair_num, indice_pair_num,
num_activate_out, num_activate_out,
algo, algo,
timer: CUDAKernelTimer = CUDAKernelTimer(False)): timer: CUDAKernelTimer = CUDAKernelTimer(False),
bias: Optional[torch.Tensor] = None,
act_alpha: float = 0.0,
act_beta: float = 0.0,
act_type: tv.gemm.Activation = tv.gemm.Activation.None_):
ctx.save_for_backward(indice_pairs, indice_pair_num, features, filters) ctx.save_for_backward(indice_pairs, indice_pair_num, features, filters)
ctx.algo = algo ctx.algo = algo
ctx.timer = timer ctx.timer = timer
...@@ -70,7 +78,11 @@ class SparseConvFunction(Function): ...@@ -70,7 +78,11 @@ class SparseConvFunction(Function):
num_activate_out, num_activate_out,
False, False,
algo=algo, algo=algo,
timer=timer) timer=timer,
bias=bias,
act_alpha=act_alpha,
act_beta=act_beta,
act_type=act_type)
except Exception as e: except Exception as e:
msg = "[Exception|indice_conv]" msg = "[Exception|indice_conv]"
msg += f"feat={features.shape},w={filters.shape},pair={indice_pairs.shape}," msg += f"feat={features.shape},w={filters.shape},pair={indice_pairs.shape},"
...@@ -102,7 +114,7 @@ class SparseConvFunction(Function): ...@@ -102,7 +114,7 @@ class SparseConvFunction(Function):
spconv_save_debug_data((indice_pairs, indice_pair_num)) spconv_save_debug_data((indice_pairs, indice_pair_num))
raise e raise e
return input_bp, filters_bp, None, None, None, None, None return input_bp, filters_bp, None, None, None, None, None, None, None, None, None
class SparseInverseConvFunction(Function): class SparseInverseConvFunction(Function):
...@@ -115,7 +127,11 @@ class SparseInverseConvFunction(Function): ...@@ -115,7 +127,11 @@ class SparseInverseConvFunction(Function):
indice_pair_num, indice_pair_num,
num_activate_out, num_activate_out,
algo, algo,
timer: CUDAKernelTimer = CUDAKernelTimer(False)): timer: CUDAKernelTimer = CUDAKernelTimer(False),
bias: Optional[torch.Tensor] = None,
act_alpha: float = 0.0,
act_beta: float = 0.0,
act_type: tv.gemm.Activation = tv.gemm.Activation.None_):
ctx.save_for_backward(indice_pairs, indice_pair_num, features, filters) ctx.save_for_backward(indice_pairs, indice_pair_num, features, filters)
ctx.algo = algo ctx.algo = algo
ctx.timer = timer ctx.timer = timer
...@@ -128,7 +144,11 @@ class SparseInverseConvFunction(Function): ...@@ -128,7 +144,11 @@ class SparseInverseConvFunction(Function):
True, True,
False, False,
algo=algo, algo=algo,
timer=timer) timer=timer,
bias=bias,
act_alpha=act_alpha,
act_beta=act_beta,
act_type=act_type)
except Exception as e: except Exception as e:
msg = "[Exception|indice_conv|inverse]" msg = "[Exception|indice_conv|inverse]"
msg += f"feat={features.shape},w={filters.shape},pair={indice_pairs.shape}," msg += f"feat={features.shape},w={filters.shape},pair={indice_pairs.shape},"
...@@ -161,7 +181,7 @@ class SparseInverseConvFunction(Function): ...@@ -161,7 +181,7 @@ class SparseInverseConvFunction(Function):
spconv_save_debug_data((indice_pairs, indice_pair_num)) spconv_save_debug_data((indice_pairs, indice_pair_num))
raise e raise e
return input_bp, filters_bp, None, None, None, None, None return input_bp, filters_bp, None, None, None, None, None, None, None, None, None
class SparseImplicitGemmFunction(Function): class SparseImplicitGemmFunction(Function):
...@@ -181,23 +201,26 @@ class SparseImplicitGemmFunction(Function): ...@@ -181,23 +201,26 @@ class SparseImplicitGemmFunction(Function):
is_train: bool, is_train: bool,
is_subm: bool, is_subm: bool,
timer: CUDAKernelTimer = CUDAKernelTimer(False), timer: CUDAKernelTimer = CUDAKernelTimer(False),
fp32_accum: Optional[bool] = None): fp32_accum: Optional[bool] = None,
bias: Optional[torch.Tensor] = None,
act_alpha: float = 0.0,
act_beta: float = 0.0,
act_type: tv.gemm.Activation = tv.gemm.Activation.None_):
try: try:
out, mask_out, mask_width = ops.implicit_gemm(features, filters, out, mask_out, mask_width = ops.implicit_gemm(
pair_fwd, features, filters, pair_fwd, pair_mask_fwd_splits,
pair_mask_fwd_splits, mask_argsort_fwd_splits, num_activate_out, masks, is_train,
mask_argsort_fwd_splits, is_subm, timer, fp32_accum, bias, act_alpha, act_beta,
num_activate_out, masks, act_type)
is_train, is_subm, timer,
fp32_accum)
except Exception as e: except Exception as e:
msg = "[Exception|implicit_gemm]" msg = "[Exception|implicit_gemm]"
msg += f"feat={features.shape},w={filters.shape},pair={pair_fwd.shape}," msg += f"feat={features.shape},w={filters.shape},pair={pair_fwd.shape},"
msg += f"act={num_activate_out},issubm={is_subm},istrain={is_train}" msg += f"act={num_activate_out},issubm={is_subm},istrain={is_train}"
print(msg, file=sys.stderr) print(msg, file=sys.stderr)
spconv_save_debug_data((pair_fwd, pair_bwd, pair_mask_fwd_splits, spconv_save_debug_data(
pair_mask_bwd_splits, mask_argsort_fwd_splits, mask_argsort_bwd_splits, (pair_fwd, pair_bwd, pair_mask_fwd_splits,
masks)) pair_mask_bwd_splits, mask_argsort_fwd_splits,
mask_argsort_bwd_splits, masks))
raise e raise e
ctx.save_for_backward(features, filters, pair_fwd, pair_bwd) ctx.save_for_backward(features, filters, pair_fwd, pair_bwd)
...@@ -253,12 +276,13 @@ class SparseImplicitGemmFunction(Function): ...@@ -253,12 +276,13 @@ class SparseImplicitGemmFunction(Function):
msg += f"feat={features.shape},w={filters.shape},pair={pair_fwd.shape}," msg += f"feat={features.shape},w={filters.shape},pair={pair_fwd.shape},"
msg += f"issubm={is_subm},do={grad_output.shape}" msg += f"issubm={is_subm},do={grad_output.shape}"
print(msg, file=sys.stderr) print(msg, file=sys.stderr)
spconv_save_debug_data((pair_fwd, pair_bwd, pair_mask_fwd_splits, spconv_save_debug_data(
pair_mask_bwd_splits, mask_argsort_fwd_splits, mask_argsort_bwd_splits, (pair_fwd, pair_bwd, pair_mask_fwd_splits,
masks)) pair_mask_bwd_splits, mask_argsort_fwd_splits,
mask_argsort_bwd_splits, masks))
raise e raise e
None_9 = [None] * 12 None_9 = [None] * 16
return (input_bp, filters_bp, *None_9) return (input_bp, filters_bp, *None_9)
...@@ -272,7 +296,11 @@ class SubMConvFunction(Function): ...@@ -272,7 +296,11 @@ class SubMConvFunction(Function):
indice_pair_num, indice_pair_num,
num_activate_out, num_activate_out,
algo, algo,
timer: CUDAKernelTimer = CUDAKernelTimer(False)): timer: CUDAKernelTimer = CUDAKernelTimer(False),
bias: Optional[torch.Tensor] = None,
act_alpha: float = 0.0,
act_beta: float = 0.0,
act_type: tv.gemm.Activation = tv.gemm.Activation.None_):
ctx.save_for_backward(indice_pairs, indice_pair_num, features, filters) ctx.save_for_backward(indice_pairs, indice_pair_num, features, filters)
ctx.algo = algo ctx.algo = algo
ctx.timer = timer ctx.timer = timer
...@@ -285,7 +313,11 @@ class SubMConvFunction(Function): ...@@ -285,7 +313,11 @@ class SubMConvFunction(Function):
False, False,
True, True,
algo=algo, algo=algo,
timer=timer) timer=timer,
bias=bias,
act_alpha=act_alpha,
act_beta=act_beta,
act_type=act_type)
except Exception as e: except Exception as e:
msg = "[Exception|indice_conv|subm]" msg = "[Exception|indice_conv|subm]"
msg += f"feat={features.shape},w={filters.shape},pair={indice_pairs.shape}," msg += f"feat={features.shape},w={filters.shape},pair={indice_pairs.shape},"
...@@ -318,8 +350,7 @@ class SubMConvFunction(Function): ...@@ -318,8 +350,7 @@ class SubMConvFunction(Function):
spconv_save_debug_data((indice_pairs, indice_pair_num)) spconv_save_debug_data((indice_pairs, indice_pair_num))
raise e raise e
return input_bp, filters_bp, None, None, None, None, None, None, None, None, None
return input_bp, filters_bp, None, None, None, None, None
class SparseMaxPoolFunction(Function): class SparseMaxPoolFunction(Function):
...@@ -361,13 +392,17 @@ class SparseMaxPoolImplicitGemmFunction(Function): ...@@ -361,13 +392,17 @@ class SparseMaxPoolImplicitGemmFunction(Function):
features, out, grad_output, indice_pairs_bwd) features, out, grad_output, indice_pairs_bwd)
return input_bp, None, None, None return input_bp, None, None, None
class SparseAvgPoolImplicitGemmFunction(Function): class SparseAvgPoolImplicitGemmFunction(Function):
@staticmethod @staticmethod
@_TORCH_CUSTOM_FWD @_TORCH_CUSTOM_FWD
def forward(ctx, features: torch.Tensor, indice_pairs_fwd: torch.Tensor, def forward(ctx, features: torch.Tensor, indice_pairs_fwd: torch.Tensor,
indice_pairs_bwd: torch.Tensor, num_activate_out: int, calc_count): indice_pairs_bwd: torch.Tensor, num_activate_out: int,
out, count = ops.indice_avgpool_implicit_gemm(features, indice_pairs_fwd, calc_count):
num_activate_out, calc_count) out, count = ops.indice_avgpool_implicit_gemm(features,
indice_pairs_fwd,
num_activate_out,
calc_count)
ctx.save_for_backward(indice_pairs_bwd, features, out, count) ctx.save_for_backward(indice_pairs_bwd, features, out, count)
return out return out
...@@ -398,6 +433,7 @@ def _indice_to_scalar(indices: torch.Tensor, shape: List[int]): ...@@ -398,6 +433,7 @@ def _indice_to_scalar(indices: torch.Tensor, shape: List[int]):
scalar_inds += stride[i] * indices[:, i] scalar_inds += stride[i] * indices[:, i]
return scalar_inds.contiguous() return scalar_inds.contiguous()
def sparse_add_hash_based(*tens: SparseConvTensor): def sparse_add_hash_based(*tens: SparseConvTensor):
""" sparse add with misaligned indices. """ sparse add with misaligned indices.
if you use sparse add, the indice_dict will be dropped and impossible if you use sparse add, the indice_dict will be dropped and impossible
...@@ -438,14 +474,21 @@ def sparse_add_hash_based(*tens: SparseConvTensor): ...@@ -438,14 +474,21 @@ def sparse_add_hash_based(*tens: SparseConvTensor):
# assign arange to values of hash table # assign arange to values of hash table
count = table.assign_arange_() count = table.assign_arange_()
count_val = count.item() count_val = count.item()
out_features = torch.zeros([int(count_val), feat.shape[1]], dtype=feat.dtype, device=feat.device) out_features = torch.zeros([int(count_val), feat.shape[1]],
out_indices = torch.zeros([int(count_val), first.indices.shape[1]], dtype=first.indices.dtype, device=first.indices.device) dtype=feat.dtype,
device=feat.device)
out_indices = torch.zeros([int(count_val), first.indices.shape[1]],
dtype=first.indices.dtype,
device=first.indices.device)
for ten, scalar in zip(tens, scalars): for ten, scalar in zip(tens, scalars):
out_inds, _ = table.query(scalar) out_inds, _ = table.query(scalar)
out_inds = out_inds.long() out_inds = out_inds.long()
out_features[out_inds] += ten.features out_features[out_inds] += ten.features
out_indices[out_inds] = ten.indices out_indices[out_inds] = ten.indices
res = SparseConvTensor(out_features, out_indices, first.spatial_shape, first.batch_size, res = SparseConvTensor(out_features,
out_indices,
first.spatial_shape,
first.batch_size,
benchmark=first.benchmark) benchmark=first.benchmark)
if count_val == max_num_indices: if count_val == max_num_indices:
res.indice_dict = tens[max_num_indices_idx].indice_dict res.indice_dict = tens[max_num_indices_idx].indice_dict
...@@ -454,6 +497,7 @@ def sparse_add_hash_based(*tens: SparseConvTensor): ...@@ -454,6 +497,7 @@ def sparse_add_hash_based(*tens: SparseConvTensor):
res.thrust_allocator = first.thrust_allocator res.thrust_allocator = first.thrust_allocator
return res return res
def sparse_add(*tens: SparseConvTensor): def sparse_add(*tens: SparseConvTensor):
"""reuse torch.sparse. the internal is sort + unique """reuse torch.sparse. the internal is sort + unique
""" """
...@@ -461,7 +505,9 @@ def sparse_add(*tens: SparseConvTensor): ...@@ -461,7 +505,9 @@ def sparse_add(*tens: SparseConvTensor):
max_num_indices_idx = 0 max_num_indices_idx = 0
ten_ths: List[torch.Tensor] = [] ten_ths: List[torch.Tensor] = []
first = tens[0] first = tens[0]
res_shape = [first.batch_size, *first.spatial_shape, first.features.shape[1]] res_shape = [
first.batch_size, *first.spatial_shape, first.features.shape[1]
]
for i, ten in enumerate(tens): for i, ten in enumerate(tens):
assert ten.spatial_shape == tens[0].spatial_shape assert ten.spatial_shape == tens[0].spatial_shape
...@@ -470,14 +516,21 @@ def sparse_add(*tens: SparseConvTensor): ...@@ -470,14 +516,21 @@ def sparse_add(*tens: SparseConvTensor):
if max_num_indices < ten.features.shape[0]: if max_num_indices < ten.features.shape[0]:
max_num_indices_idx = i max_num_indices_idx = i
max_num_indices = ten.features.shape[0] max_num_indices = ten.features.shape[0]
ten_ths.append(torch.sparse_coo_tensor(ten.indices.T, ten.features, res_shape, requires_grad=True)) ten_ths.append(
torch.sparse_coo_tensor(ten.indices.T,
ten.features,
res_shape,
requires_grad=True))
c_th = reduce(lambda x, y: x + y, ten_ths).coalesce() c_th = reduce(lambda x, y: x + y, ten_ths).coalesce()
c_th_inds = c_th.indices().T.contiguous().int() c_th_inds = c_th.indices().T.contiguous().int()
c_th_values = c_th.values() c_th_values = c_th.values()
assert c_th_values.is_contiguous() assert c_th_values.is_contiguous()
res = SparseConvTensor(c_th_values, c_th_inds, first.spatial_shape, first.batch_size, res = SparseConvTensor(c_th_values,
c_th_inds,
first.spatial_shape,
first.batch_size,
benchmark=first.benchmark) benchmark=first.benchmark)
if c_th_values.shape[0] == max_num_indices: if c_th_values.shape[0] == max_num_indices:
res.indice_dict = tens[max_num_indices_idx].indice_dict res.indice_dict = tens[max_num_indices_idx].indice_dict
......
...@@ -29,6 +29,8 @@ from spconv.core_cc.csrc.sparse.alloc import ExternalAllocator ...@@ -29,6 +29,8 @@ from spconv.core_cc.csrc.sparse.alloc import ExternalAllocator
from spconv.constants import SPCONV_CPP_INDICE_PAIRS, SPCONV_CPP_INDICE_PAIRS_IGEMM, SPCONV_CPP_GEMM, SPCONV_DIRECT_TABLE_HASH_SIZE_SCALE from spconv.constants import SPCONV_CPP_INDICE_PAIRS, SPCONV_CPP_INDICE_PAIRS_IGEMM, SPCONV_CPP_GEMM, SPCONV_DIRECT_TABLE_HASH_SIZE_SCALE
import spconv.core_cc as _ext import spconv.core_cc as _ext
from spconv.core_cc.csrc.sparse.convops.spops import ConvGemmOps from spconv.core_cc.csrc.sparse.convops.spops import ConvGemmOps
from spconv.core_cc.csrc.sparse.inference import InferenceOps
from spconv.utils import nullcontext from spconv.utils import nullcontext
if hasattr(_ext, "cumm"): if hasattr(_ext, "cumm"):
...@@ -784,7 +786,11 @@ def indice_conv(features: torch.Tensor, ...@@ -784,7 +786,11 @@ def indice_conv(features: torch.Tensor,
inverse: bool = False, inverse: bool = False,
subm: bool = False, subm: bool = False,
algo: ConvAlgo = ConvAlgo.Native, algo: ConvAlgo = ConvAlgo.Native,
timer: CUDAKernelTimer = CUDAKernelTimer(False)): timer: CUDAKernelTimer = CUDAKernelTimer(False),
bias: Optional[torch.Tensor] = None,
act_alpha: float = 0.0,
act_beta: float = 0.0,
act_type: tv.gemm.Activation = tv.gemm.Activation.None_):
# filters: RSKC # filters: RSKC
# stream = get_current_stream() # stream = get_current_stream()
# CONV.stream_synchronize(stream) # CONV.stream_synchronize(stream)
...@@ -793,6 +799,9 @@ def indice_conv(features: torch.Tensor, ...@@ -793,6 +799,9 @@ def indice_conv(features: torch.Tensor,
features = features.contiguous() features = features.contiguous()
if features.dtype == torch.int8 or features.dtype == torch.qint8: if features.dtype == torch.int8 or features.dtype == torch.qint8:
raise NotImplementedError("work in progress") raise NotImplementedError("work in progress")
bias_tv = tv.Tensor()
if bias is not None:
bias_tv = torch_tensor_to_tv(bias)
if SPCONV_CPP_GEMM and GEMM_CPP is not None: if SPCONV_CPP_GEMM and GEMM_CPP is not None:
# print("CPPPPPP!!!", features.device) # print("CPPPPPP!!!", features.device)
...@@ -822,10 +831,18 @@ def indice_conv(features: torch.Tensor, ...@@ -822,10 +831,18 @@ def indice_conv(features: torch.Tensor,
FILTER_HWIO, features_tv, filters_tv, FILTER_HWIO, features_tv, filters_tv,
indice_pairs_tv, indice_pair_num_tv, arch, indice_pairs_tv, indice_pair_num_tv, arch,
num_activate_out, inverse, subm, algo.value, num_activate_out, inverse, subm, algo.value,
stream) stream, bias_tv, act_alpha, act_beta, act_type)
out_features = alloc.allocated[AllocKeys.OutFeatures] out_features = alloc.allocated[AllocKeys.OutFeatures]
return out_features return out_features
if not features.is_cuda:
stream = 0
else:
stream = get_current_stream()
has_bias = bias is not None
has_act = act_type != tv.gemm.Activation.None_
if has_bias or has_act:
assert features.is_cuda, "cpu don't support act and bias"
if not ALL_WEIGHT_IS_KRSC: if not ALL_WEIGHT_IS_KRSC:
kv_dim = 0 kv_dim = 0
is_KC_not_CK = not FILTER_HWIO is_KC_not_CK = not FILTER_HWIO
...@@ -875,7 +892,17 @@ def indice_conv(features: torch.Tensor, ...@@ -875,7 +892,17 @@ def indice_conv(features: torch.Tensor,
out_features = torch.zeros((num_activate_out, out_channel), out_features = torch.zeros((num_activate_out, out_channel),
dtype=features.dtype, dtype=features.dtype,
device=features.device) device=features.device)
c = torch_tensor_to_tv(out_features)
if kv == 1 and subm: if kv == 1 and subm:
if (has_act and has_bias):
InferenceOps.bias_add_act_inplace(c, bias_tv, act_type, act_alpha, act_beta, stream)
else:
if has_act:
InferenceOps.activation_inplace(c, act_type, act_alpha, act_beta, stream)
if has_bias:
InferenceOps.bias_add_inplace(c, bias_tv, stream)
return out_features return out_features
indice_pair_num_cpu = indice_pair_num.cpu().tolist() indice_pair_num_cpu = indice_pair_num.cpu().tolist()
...@@ -928,7 +955,6 @@ def indice_conv(features: torch.Tensor, ...@@ -928,7 +955,6 @@ def indice_conv(features: torch.Tensor,
SpconvOps.scatter_add_cpu(c, out_buffer_tv, out_indices) SpconvOps.scatter_add_cpu(c, out_buffer_tv, out_indices)
return out_features return out_features
stream = get_current_stream()
profile_idx = kv_center profile_idx = kv_center
if subm: if subm:
...@@ -1020,6 +1046,14 @@ def indice_conv(features: torch.Tensor, ...@@ -1020,6 +1046,14 @@ def indice_conv(features: torch.Tensor,
# gather_times += gather_time # gather_times += gather_time
inited = True inited = True
if (has_act and has_bias):
InferenceOps.bias_add_act_inplace(c, bias_tv, act_type, act_alpha, act_beta, stream)
else:
if has_act:
InferenceOps.activation_inplace(c, act_type, act_alpha, act_beta, stream)
if has_bias:
InferenceOps.bias_add_inplace(c, bias_tv, stream)
# CONV.stream_synchronize(stream) # CONV.stream_synchronize(stream)
# print(out_features.mean(), out_features.max(), out_features.min()) # print(out_features.mean(), out_features.max(), out_features.min())
...@@ -1391,8 +1425,16 @@ def implicit_gemm(features: torch.Tensor, ...@@ -1391,8 +1425,16 @@ def implicit_gemm(features: torch.Tensor,
is_train: bool, is_train: bool,
is_subm: bool, is_subm: bool,
timer: CUDAKernelTimer = CUDAKernelTimer(False), timer: CUDAKernelTimer = CUDAKernelTimer(False),
fp32_accum: Optional[bool] = None): fp32_accum: Optional[bool] = None,
bias: Optional[torch.Tensor] = None,
act_alpha: float = 0.0,
act_beta: float = 0.0,
act_type: tv.gemm.Activation = tv.gemm.Activation.None_):
stream = get_current_stream() stream = get_current_stream()
bias_tv = tv.Tensor()
if bias is not None:
bias_tv = torch_tensor_to_tv(bias)
if SPCONV_CPP_GEMM and CONV_CPP is not None: if SPCONV_CPP_GEMM and CONV_CPP is not None:
alloc = TorchAllocator(features.device) alloc = TorchAllocator(features.device)
...@@ -1420,7 +1462,7 @@ def implicit_gemm(features: torch.Tensor, ...@@ -1420,7 +1462,7 @@ def implicit_gemm(features: torch.Tensor,
alloc, CONV_CPP, features_tv, filters_tv, pair_fwd_tv, alloc, CONV_CPP, features_tv, filters_tv, pair_fwd_tv,
pair_mask_fwd_splits_tv, mask_argsort_fwd_splits_tv, pair_mask_fwd_splits_tv, mask_argsort_fwd_splits_tv,
num_activate_out, mask_tv, arch, is_train, is_subm, stream, num_activate_out, mask_tv, arch, is_train, is_subm, stream,
timer_cpp, auto_fp32_accum, fp32_accum) timer_cpp, auto_fp32_accum, fp32_accum, bias_tv, act_alpha, act_beta, act_type)
out_features = alloc.allocated[AllocKeys.OutFeatures] out_features = alloc.allocated[AllocKeys.OutFeatures]
mask_output_fwd = alloc.allocated.get(AllocKeys.MaskOutputFwd, None) mask_output_fwd = alloc.allocated.get(AllocKeys.MaskOutputFwd, None)
if is_train: if is_train:
...@@ -1512,6 +1554,10 @@ def implicit_gemm(features: torch.Tensor, ...@@ -1512,6 +1554,10 @@ def implicit_gemm(features: torch.Tensor,
# t = time.time() # t = time.time()
# print(tune_res.algo_desp, "REF", features_tv.shape, filters.shape) # print(tune_res.algo_desp, "REF", features_tv.shape, filters.shape)
# with tv.measure_and_print("f16 time"): # with tv.measure_and_print("f16 time"):
bias_tv = tv.Tensor()
if bias is not None:
bias_tv = torch_tensor_to_tv(bias)
with timer.record("implicit_gemm", stream): with timer.record("implicit_gemm", stream):
for j in range(num_split): for j in range(num_split):
beta = 0 if j == 0 else 1 beta = 0 if j == 0 else 1
...@@ -1530,7 +1576,11 @@ def implicit_gemm(features: torch.Tensor, ...@@ -1530,7 +1576,11 @@ def implicit_gemm(features: torch.Tensor,
mask_width=-1, mask_width=-1,
beta=beta, beta=beta,
stream=stream, stream=stream,
verbose=False) verbose=False,
bias=bias_tv,
act_type=act_type,
act_alpha=act_alpha,
act_beta=act_beta)
# INT8_TEST = True # INT8_TEST = True
# if INT8_TEST: # if INT8_TEST:
# if features.shape[1] % 32 != 0: # if features.shape[1] % 32 != 0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment