Commit 21bb00ae authored by Yan Yan's avatar Yan Yan
Browse files

still working on c++ only

parent 899008fa
......@@ -75,26 +75,31 @@ def get_deconv_output_size(input_size, kernel_size, stride, padding, dilation,
output_size.append(size)
return output_size
class _HashData:
def __init__(self, num: int, use_i64: bool, device: torch.device) -> None:
if use_i64:
self.hashdata_k = torch.empty((num * 2, ),
dtype=torch.int64,
device=device)
self.hashdata_v = torch.empty((num* 2, ),
self.hashdata_v = torch.empty((num * 2, ),
dtype=torch.int32,
device=device)
self.hashdata_k_tv = torch_tensor_to_tv(self.hashdata_k)
self.hashdata_v_tv = torch_tensor_to_tv(self.hashdata_v)
else:
self.hashdata = torch.empty((2, num * 2, ),
self.hashdata = torch.empty((
2,
num * 2,
),
dtype=torch.int32,
device=device)
hashdata_tv = torch_tensor_to_tv(self.hashdata)
self.hashdata_k_tv = hashdata_tv[0]
self.hashdata_v_tv = hashdata_tv[1]
def get_indice_pairs(indices: torch.Tensor,
batch_size: int,
spatial_shape: List[int],
......@@ -119,13 +124,18 @@ def get_indice_pairs(indices: torch.Tensor,
if indices.is_cuda:
stream = get_current_stream()
num_act_out = SpconvOps.get_indice_pairs(alloc, torch_tensor_to_tv(indices), batch_size, spatial_shape,
algo.value, ksize, stride, padding, dilation, out_padding, subm, transpose, stream)
num_act_out = SpconvOps.get_indice_pairs(alloc,
torch_tensor_to_tv(indices),
batch_size, spatial_shape,
algo.value, ksize, stride,
padding, dilation,
out_padding, subm, transpose,
stream)
if subm:
out_inds = indices
else:
out_inds = alloc.allocated[AllocKeys.OutIndices]
pair = alloc.allocated[AllocKeys.Pair]
pair = alloc.allocated[AllocKeys.PairFwd]
indice_num_per_loc = alloc.allocated[AllocKeys.IndiceNumPerLoc]
# print(subm, out_inds.shape, pair.shape, indice_num_per_loc.shape, num_act_out)
return out_inds[:num_act_out], pair, indice_num_per_loc
......@@ -146,7 +156,7 @@ def get_indice_pairs(indices: torch.Tensor,
)
assert algo == ConvAlgo.Native, "TODO"
# indices = indices.cpu()
spatial_volume = functools.reduce(lambda x, y: x * y, spatial_shape, 1)
spatial_volume = functools.reduce(lambda x, y: x * y, out_shape, 1)
use_int64_hash_k = spatial_volume >= INT32_MAX or DEBUG_INT64_HASH_K
indice_dtype = torch.int64 if use_int64_hash_k else indices.dtype
pair = torch.full((2, kv, indices.shape[0]),
......@@ -164,7 +174,8 @@ def get_indice_pairs(indices: torch.Tensor,
out_inds = indices
if indices.is_cuda:
stream = get_current_stream()
hashdata = _HashData(out_inds.shape[0], use_int64_hash_k, indices.device)
hashdata = _HashData(out_inds.shape[0], use_int64_hash_k,
indices.device)
# hashdata = torch.empty((out_inds.shape[0] * 2, ),
# dtype=torch.int64,
# device=indices.device)
......@@ -234,7 +245,8 @@ def get_indice_pairs(indices: torch.Tensor,
# hashdata = torch.empty((out_inds.shape[0] * 2, ),
# dtype=torch.int64,
# device=indices.device)
hashdata = _HashData(out_inds.shape[0], use_int64_hash_k, indices.device)
hashdata = _HashData(out_inds.shape[0], use_int64_hash_k,
indices.device)
out_inds_tv = torch_tensor_to_tv(out_inds)
# hashdata_tv = torch_tensor_to_tv(hashdata, dtype=tv.custom64)
......@@ -281,6 +293,7 @@ def get_indice_pairs(indices: torch.Tensor,
# print("REGU", time.time() - t)
return out_inds, pair, indice_num_per_loc
def get_indice_pairs_implicit_gemm(
indices: torch.Tensor,
batch_size: int,
......@@ -303,11 +316,11 @@ def get_indice_pairs_implicit_gemm(
out_inds,
num_inds_per_loc,
pair_fwd,
pair_bwd, # None if subm or inference mode
pair_bwd, # torch.Tensor() if subm or inference mode
pair_mask_fwd_splits,
pair_mask_bwd_splits, # None if subm or inference mode
pair_mask_bwd_splits, # torch.Tensor() if subm or inference mode
mask_argsort_fwd_splits,
mask_argsort_bwd_splits, # None if subm or inference mode
mask_argsort_bwd_splits, # torch.Tensor() if subm or inference mode
masks,
)
"""
......@@ -316,39 +329,47 @@ def get_indice_pairs_implicit_gemm(
thalloc = TorchAllocator(indices.device)
mask_tensor, num_act_out = SpconvOps.get_indice_pairs_implicit_gemm(
thalloc, torch_tensor_to_tv(indices), batch_size, spatial_shape,
algo.value, ksize, stride, padding, dilation, out_padding, subm, transpose, is_train, stream,
num_out_act_bound)
algo.value, ksize, stride, padding, dilation, out_padding, subm,
transpose, is_train, stream, num_out_act_bound)
mask_split_count = mask_tensor.dim(0)
masks = [mask_tensor[i:i+1].numpy() for i in range(mask_split_count)]
masks = [mask_tensor[i:i + 1].numpy() for i in range(mask_split_count)]
if subm:
out_inds = indices
else:
out_inds = thalloc.allocated[AllocKeys.OutIndices]
pair = thalloc.allocated[AllocKeys.Pair]
indice_num_per_loc = thalloc.allocated[AllocKeys.IndiceNumPerLoc]
if subm:
# for subm, if training, pair shape is [2, kv, ...]
# if not training, pair is [1, kv, ...]
pair = thalloc.allocated[AllocKeys.PairFwd]
pair_mask = thalloc.allocated[AllocKeys.PairMask]
mask_argsort = thalloc.allocated[AllocKeys.MaskArgSort]
pair_mask_in_splits = [pair_mask[i] for i in range(mask_split_count)]
pair_mask_in_splits = [
pair_mask[i] for i in range(mask_split_count)
]
mask_argsort_in_splits = [
mask_argsort[i] for i in range(mask_split_count)
]
return (out_inds, indice_num_per_loc, pair[0], pair[1],
pair_bwd = torch.Tensor()
pair_fwd = pair[0]
if is_train:
assert pair.shape[0] == 2
pair_bwd = pair[1]
return (out_inds, indice_num_per_loc, pair[0], pair_bwd,
pair_mask_in_splits, [], mask_argsort_in_splits, [], masks)
else:
pair_bwd = pair
pair_bwd = thalloc.allocated.get(AllocKeys.PairBwd, torch.Tensor())
pair_fwd = thalloc.allocated[AllocKeys.PairFwd]
pair_mask_fwd = thalloc.allocated[AllocKeys.PairMask]
pair_mask_bwd = torch.Tensor()
mask_argsort_bwd = torch.Tensor()
if is_train:
pair_mask_bwd = thalloc.allocated[AllocKeys.PairMaskBwd]
mask_argsort_bwd = thalloc.allocated[AllocKeys.MaskArgSortBwd]
mask_argsort_fwd = thalloc.allocated[AllocKeys.MaskArgSort]
if not is_train:
pair_bwd = torch.Tensor()
pair_mask_bwd_splits: List[torch.Tensor] = []
mask_argsort_bwd_splits: List[torch.Tensor] = []
else:
......@@ -377,9 +398,6 @@ def get_indice_pairs_implicit_gemm(
kv: int = functools.reduce(lambda x, y: x * y, ksize, 1)
# TODO in future we will support up to 128 kernel volume.
assert kv <= 32, "currently only support kernel volume <= 32 to use implicit gemm"
spatial_volume = functools.reduce(lambda x, y: x * y, spatial_shape, 1)
use_int64_hash_k = spatial_volume >= INT32_MAX or DEBUG_INT64_HASH_K
indice_dtype = torch.int64 if use_int64_hash_k else indices.dtype
if not subm:
if transpose:
......@@ -394,6 +412,9 @@ def get_indice_pairs_implicit_gemm(
raise ValueError(
f"your out spatial shape {out_shape} reach zero!!! input shape: {spatial_shape}"
)
spatial_volume = functools.reduce(lambda x, y: x * y, spatial_shape, 1)
use_int64_hash_k = spatial_volume >= INT32_MAX or DEBUG_INT64_HASH_K
indice_dtype = torch.int64 if use_int64_hash_k else indices.dtype
assert algo == ConvAlgo.MaskImplicitGemm or algo == ConvAlgo.MaskSplitImplicitGemm, "TODO"
is_mask_split = algo == ConvAlgo.MaskSplitImplicitGemm
mask_split_count = 2 if is_mask_split else 1
......@@ -433,7 +454,8 @@ def get_indice_pairs_implicit_gemm(
# hashdata = torch.empty((out_inds.shape[0] * 2, ),
# dtype=torch.int64,
# device=indices.device)
hashdata = _HashData(out_inds.shape[0], use_int64_hash_k, indices.device)
hashdata = _HashData(out_inds.shape[0], use_int64_hash_k,
indices.device)
pair_mask = torch.empty((mask_split_count, indices.shape[0]),
dtype=torch.int32,
......@@ -552,7 +574,8 @@ def get_indice_pairs_implicit_gemm(
device=indices.device)
pair_mask_bwd_tv = torch_tensor_to_tv(pair_mask_bwd,
dtype=tv.uint32)
hashdata = _HashData(out_inds.shape[0], use_int64_hash_k, indices.device)
hashdata = _HashData(out_inds.shape[0], use_int64_hash_k,
indices.device)
# hashdata = torch.empty((out_inds.shape[0] * 2, ),
# dtype=torch.int64,
......@@ -714,13 +737,14 @@ def indice_conv(features: torch.Tensor,
if SPCONV_CPP_GEMM and GEMM_CPP is not None:
# print("CPPPPPP!!!", features.device)
alloc = TorchAllocator(features.device)
from spconv.core_cc.csrc.sparse.convops import SimpleExternalSpconvMatmul
# ext_mm = TorchSpconvMatmul(alloc)
if features.is_cuda:
ext_mm = SimpleExternalSpconvMatmul(alloc)
else:
ext_mm = TorchSpconvMatmul(alloc)
# from spconv.core_cc.csrc.sparse.convops import SimpleExternalSpconvMatmul
# if features.is_cuda:
# ext_mm = SimpleExternalSpconvMatmul(alloc)
# else:
# ext_mm = TorchSpconvMatmul(alloc)
alloc.allocated[AllocKeys.Features] = features
alloc.allocated[AllocKeys.Filters] = filters
......@@ -731,13 +755,14 @@ def indice_conv(features: torch.Tensor,
stream = 0
if features.is_cuda:
stream = get_current_stream()
ConvGemmOps.indice_conv(alloc, ext_mm, GEMM_CPP, ALL_WEIGHT_IS_KRSC, FILTER_HWIO,
features_tv, filters_tv, indice_pairs_tv, indice_pair_num_tv, num_activate_out,
inverse, subm, algo.value, stream)
ConvGemmOps.indice_conv(alloc, ext_mm, GEMM_CPP, ALL_WEIGHT_IS_KRSC,
FILTER_HWIO, features_tv, filters_tv,
indice_pairs_tv, indice_pair_num_tv,
num_activate_out, inverse, subm, algo.value,
stream)
out_features = alloc.allocated[AllocKeys.OutFeatures]
return out_features
if not ALL_WEIGHT_IS_KRSC:
kv_dim = 0
is_KC_not_CK = not FILTER_HWIO
......@@ -779,7 +804,9 @@ def indice_conv(features: torch.Tensor,
features_np = torch_tensor_to_tv(features).numpy_view()
filters_np = torch_tensor_to_tv(filters).numpy_view()
out_features_np = torch_tensor_to_tv(out_features).numpy_view()
np.matmul(features_np, filters_np[:, kv_center].T, out=out_features_np)
np.matmul(features_np,
filters_np[:, kv_center].T,
out=out_features_np)
# out_features = torch.mm(features, filters[:, kv_center].T)
else:
out_features = torch.zeros((num_activate_out, out_channel),
......@@ -826,10 +853,13 @@ def indice_conv(features: torch.Tensor,
if features.dtype == torch.float16:
inp_buffer_np = torch_tensor_to_tv(inp_buffer).numpy_view()
filters_np = torch_tensor_to_tv(filters).numpy_view()
filters_i_np = filters_np[i] if not ALL_WEIGHT_IS_KRSC else filters_np[:, i]
filters_i_np = filters_np[
i] if not ALL_WEIGHT_IS_KRSC else filters_np[:, i]
filters_cur_np = filters_i_np if not is_KC_not_CK else filters_i_np.T
out_buffer_np = torch_tensor_to_tv(out_buffer).numpy_view()
np.matmul(inp_buffer_np[:nhot], filters_cur_np, out=out_buffer_np[:nhot])
np.matmul(inp_buffer_np[:nhot],
filters_cur_np,
out=out_buffer_np[:nhot])
else:
torch.mm(inp_buffer[:nhot], filters_cur, out=out_buffer[:nhot])
SpconvOps.scatter_add_cpu(c, out_buffer_tv, out_indices)
......@@ -968,8 +998,10 @@ def indice_conv_backward(features: torch.Tensor,
stream = 0
if features.is_cuda:
stream = get_current_stream()
ConvGemmOps.indice_conv_backward(alloc, ext_mm, GEMM_CPP, ALL_WEIGHT_IS_KRSC, FILTER_HWIO,
features_tv, filters_tv, out_bp_tv, indice_pairs_tv, indice_pair_num_tv,
ConvGemmOps.indice_conv_backward(alloc, ext_mm, GEMM_CPP,
ALL_WEIGHT_IS_KRSC, FILTER_HWIO,
features_tv, filters_tv, out_bp_tv,
indice_pairs_tv, indice_pair_num_tv,
inverse, subm, algo.value, stream)
din = alloc.allocated[AllocKeys.DIn]
df = alloc.allocated[AllocKeys.DFilters]
......@@ -1076,10 +1108,14 @@ def indice_conv_backward(features: torch.Tensor,
filters_KC = filters_i if is_KC_not_CK else filters_i.T
if is_KC_not_CK:
# KN @ NC
torch.mm(out_buffer[:nhot].T, inp_buffer[:nhot], out=dfilters_i)
torch.mm(out_buffer[:nhot].T,
inp_buffer[:nhot],
out=dfilters_i)
else:
# CN @ NK
torch.mm(inp_buffer[:nhot].T, out_buffer[:nhot], out=dfilters_i)
torch.mm(inp_buffer[:nhot].T,
out_buffer[:nhot],
out=dfilters_i)
# NK @ KC
torch.mm(out_buffer[:nhot], filters_KC, out=inp_buffer[:nhot])
SpconvOps.scatter_add_cpu(din_tv, inp_buffer_tv, inp_indices)
......@@ -1295,8 +1331,12 @@ def implicit_gemm(features: torch.Tensor,
alloc = TorchAllocator(features.device)
features_tv = torch_tensor_to_tv(features)
pair_fwd_tv = torch_tensor_to_tv(pair_fwd)
pair_mask_fwd_splits_tv = [torch_tensor_to_tv(t, tv.uint32) for t in pair_mask_fwd_splits]
mask_argsort_fwd_splits_tv = [torch_tensor_to_tv(t) for t in mask_argsort_fwd_splits]
pair_mask_fwd_splits_tv = [
torch_tensor_to_tv(t, tv.uint32) for t in pair_mask_fwd_splits
]
mask_argsort_fwd_splits_tv = [
torch_tensor_to_tv(t) for t in mask_argsort_fwd_splits
]
filters_tv = torch_tensor_to_tv(filters)
mask = np.concatenate(masks)
......@@ -1307,9 +1347,11 @@ def implicit_gemm(features: torch.Tensor,
auto_fp32_accum = fp32_accum is None
if fp32_accum is None:
fp32_accum = False
mask_width = ConvGemmOps.implicit_gemm(alloc, CONV_CPP, features_tv, filters_tv, pair_fwd_tv, pair_mask_fwd_splits_tv,
mask_argsort_fwd_splits_tv, num_activate_out, mask_tv, is_train, is_subm, stream, timer_cpp, auto_fp32_accum,
fp32_accum)
mask_width = ConvGemmOps.implicit_gemm(
alloc, CONV_CPP, features_tv, filters_tv, pair_fwd_tv,
pair_mask_fwd_splits_tv, mask_argsort_fwd_splits_tv,
num_activate_out, mask_tv, is_train, is_subm, stream, timer_cpp,
auto_fp32_accum, fp32_accum)
out_features = alloc.allocated[AllocKeys.OutFeatures]
mask_output_fwd = alloc.allocated.get(AllocKeys.MaskOutputFwd, None)
if is_train:
......@@ -1543,12 +1585,19 @@ def implicit_gemm_backward(features: torch.Tensor,
pair_fwd_tv = torch_tensor_to_tv(pair_fwd)
pair_bwd_tv = torch_tensor_to_tv(pair_bwd)
pair_mask_fwd_splits_tv = [torch_tensor_to_tv(t) for t in pair_mask_fwd_splits]
pair_mask_bwd_splits_tv = [torch_tensor_to_tv(t) for t in pair_mask_bwd_splits]
mask_argsort_fwd_splits_tv = [torch_tensor_to_tv(t) for t in mask_argsort_fwd_splits]
mask_argsort_bwd_splits_tv = [torch_tensor_to_tv(t) for t in mask_argsort_bwd_splits]
pair_mask_fwd_splits_tv = [
torch_tensor_to_tv(t) for t in pair_mask_fwd_splits
]
pair_mask_bwd_splits_tv = [
torch_tensor_to_tv(t) for t in pair_mask_bwd_splits
]
mask_argsort_fwd_splits_tv = [
torch_tensor_to_tv(t) for t in mask_argsort_fwd_splits
]
mask_argsort_bwd_splits_tv = [
torch_tensor_to_tv(t) for t in mask_argsort_bwd_splits
]
filters_tv = torch_tensor_to_tv(filters)
out_bp_tv = torch_tensor_to_tv(out_bp)
......@@ -1564,10 +1613,12 @@ def implicit_gemm_backward(features: torch.Tensor,
auto_fp32_accum = fp32_accum is None
if fp32_accum is None:
fp32_accum = False
ConvGemmOps.implicit_gemm_backward(alloc, CONV_CPP, features_tv, filters_tv, out_bp_tv, pair_fwd_tv,
pair_bwd_tv, pair_mask_fwd_splits_tv, pair_mask_bwd_splits_tv, mask_argsort_fwd_splits_tv,
mask_argsort_bwd_splits_tv, mask_output_fwd_tv, mask_tv, mask_width, is_subm, stream, timer_cpp, auto_fp32_accum,
fp32_accum)
ConvGemmOps.implicit_gemm_backward(
alloc, CONV_CPP, features_tv, filters_tv, out_bp_tv, pair_fwd_tv,
pair_bwd_tv, pair_mask_fwd_splits_tv, pair_mask_bwd_splits_tv,
mask_argsort_fwd_splits_tv, mask_argsort_bwd_splits_tv,
mask_output_fwd_tv, mask_tv, mask_width, is_subm, stream,
timer_cpp, auto_fp32_accum, fp32_accum)
din = alloc.allocated[AllocKeys.DIn]
dfilters = alloc.allocated[AllocKeys.DFilters]
return din, dfilters
......@@ -1849,3 +1900,65 @@ def indice_maxpool_implicit_gemm_backward(features, out_features, out_bp,
indice_pairs_tv, stream)
return din
def indice_avgpool_implicit_gemm(features: torch.Tensor,
indice_pairs: torch.Tensor, num_activate_out, calc_count: bool):
# torch.cuda.synchronize()
# t = time.time()
stream = get_current_stream()
# CONV.stream_synchronize(stream)
# t = time.time()
if not features.is_contiguous():
features = features.contiguous()
out_channel = features.shape[-1]
out_features = torch.empty((num_activate_out, out_channel),
dtype=features.dtype,
device=features.device)
assert features.is_cuda
stream = get_current_stream()
out_features_tv = torch_tensor_to_tv(out_features)
features_tv = torch_tensor_to_tv(features)
indice_pairs_tv = torch_tensor_to_tv(indice_pairs)
count_out = torch.Tensor()
count_out_tv = tv.Tensor()
if calc_count:
count_out = torch.zeros((num_activate_out,),
dtype=torch.int32,
device=features.device)
count_out_tv = torch_tensor_to_tv(count_out)
SpconvOps.avgpool_implicit_gemm_forward(out_features_tv, features_tv,
indice_pairs_tv, count_out_tv, stream)
# CONV.stream_synchronize(stream)
# print("M", time.time() - t)
return out_features, count_out
def indice_avgpool_implicit_gemm_backward(out_bp,
indice_pairs, count_out):
# torch.cuda.synchronize()
# t = time.time()
out_channel = out_bp.shape[-1]
din = torch.zeros((indice_pairs.shape[1], out_bp.shape[1]), dtype=out_bp.dtype, device=out_bp.device)
assert out_bp.is_cuda
if not out_bp.is_contiguous():
out_bp = out_bp.contiguous()
stream = get_current_stream()
count_out_tv = torch_tensor_to_tv(count_out)
out_bp_tv = torch_tensor_to_tv(out_bp)
din_tv = torch_tensor_to_tv(din)
indice_pairs_tv = torch_tensor_to_tv(indice_pairs)
SpconvOps.avgpool_implicit_gemm_backward(out_bp_tv, din_tv,
indice_pairs_tv, count_out_tv, stream)
return din
def maximum_value_int_(ten: torch.Tensor, value: int):
stream = 0
if not CPU_ONLY_BUILD:
stream = get_current_stream()
else:
assert not ten.is_cuda
SpconvOps.maximum_value_int(torch_tensor_to_tv(ten), value, stream)
......@@ -30,6 +30,7 @@ from spconv.pytorch.core import IndiceData, ImplicitGemmIndiceData, expand_nd
from spconv.pytorch.modules import SparseModule
from spconv.cppconstants import CPU_ONLY_BUILD
from spconv.utils import nullcontext
from .conv import _MAX_NUM_VOXELS_DURING_TRAINING
class SparseMaxPool(SparseModule):
......@@ -42,6 +43,7 @@ class SparseMaxPool(SparseModule):
indice_key: Optional[str] = None,
subm: bool = False,
algo: Optional[ConvAlgo] = None,
record_voxel_count: bool = False,
name=None):
super(SparseMaxPool, self).__init__(name=name)
self.ndim = ndim
......@@ -52,6 +54,12 @@ class SparseMaxPool(SparseModule):
self.stride = expand_nd(ndim, stride)
self.padding = expand_nd(ndim, padding)
self.subm = subm
if record_voxel_count and not self.subm:
# we record maximum voxel num in both inference and training if
# record_voxel_count flag setting.
self.register_buffer(_MAX_NUM_VOXELS_DURING_TRAINING,
torch.zeros(1, dtype=torch.int32))
self.record_voxel_count = record_voxel_count
self.dilation = expand_nd(ndim, dilation)
self.indice_key = indice_key
kv = int(np.prod(kernel_size))
......@@ -220,6 +228,136 @@ class SparseMaxPool(SparseModule):
features.shape[0])
out_tensor.benchmark_record[self.name]["num_out_points"].append(
out_features.shape[0])
if not self.subm and self.record_voxel_count:
if hasattr(self, _MAX_NUM_VOXELS_DURING_TRAINING):
ops.maximum_value_int_(
getattr(self, _MAX_NUM_VOXELS_DURING_TRAINING),
outids.shape[0])
out_tensor = out_tensor.replace_feature(out_features)
out_tensor.indices = outids
out_tensor.indice_dict = indice_dict
out_tensor.spatial_shape = out_spatial_shape
return out_tensor
class SparseAvgPool(SparseModule):
def __init__(self,
ndim,
kernel_size: Union[int, List[int], Tuple[int, ...]] = 3,
stride: Optional[Union[int, List[int], Tuple[int, ...]]] = 1,
padding: Union[int, List[int], Tuple[int, ...]] = 0,
dilation: Union[int, List[int], Tuple[int, ...]] = 1,
indice_key: Optional[str] = None,
subm: bool = False,
algo: Optional[ConvAlgo] = None,
record_voxel_count: bool = False,
name=None):
super(SparseAvgPool, self).__init__(name=name)
self.ndim = ndim
self.kernel_size = expand_nd(ndim, kernel_size)
if stride is None:
self.stride = self.kernel_size.copy()
else:
self.stride = expand_nd(ndim, stride)
self.padding = expand_nd(ndim, padding)
self.subm = subm
if record_voxel_count and not self.subm:
# we record maximum voxel num in both inference and training if
# record_voxel_count flag setting.
self.register_buffer(_MAX_NUM_VOXELS_DURING_TRAINING,
torch.zeros(1, dtype=torch.int32))
self.record_voxel_count = record_voxel_count
self.dilation = expand_nd(ndim, dilation)
self.indice_key = indice_key
kv = int(np.prod(kernel_size))
assert kv <= 32, "avg pool only support implicit-gemm style indice gen with kv <= 32 limit"
self.algo = ConvAlgo.MaskImplicitGemm
def extra_repr(self):
s = ('kernel_size={kernel_size}' ', stride={stride}')
if self.padding != (0, ) * len(self.padding):
s += ', padding={padding}'
if self.dilation != (1, ) * len(self.dilation):
s += ', dilation={dilation}'
if self.algo is not None:
s += f', algo={self.algo}'
return s.format(**self.__dict__)
def forward(self, input):
assert isinstance(input, spconv.SparseConvTensor)
features = input.features
device = features.device
indices = input.indices
spatial_shape = input.spatial_shape
batch_size = input.batch_size
if not self.subm:
out_spatial_shape = ops.get_conv_output_size(
spatial_shape, self.kernel_size, self.stride, self.padding,
self.dilation)
else:
out_spatial_shape = spatial_shape
out_tensor = input.shadow_copy()
out_padding = [0] * self.ndim
indice_dict = input.indice_dict.copy()
profile_ctx = nullcontext()
if input._timer is not None and self._sparse_unique_name:
profile_ctx = input._timer.namespace(self._sparse_unique_name)
with profile_ctx:
with input._timer.namespace("gen_pairs"):
res = ops.get_indice_pairs_implicit_gemm(
indices,
batch_size,
spatial_shape,
self.algo,
ksize=self.kernel_size,
stride=self.stride,
padding=self.padding,
dilation=self.dilation,
out_padding=out_padding,
subm=self.subm,
is_train=(not self.subm) or self.training,
alloc=input.thrust_allocator,
timer=input._timer)
outids = res[0]
num_inds_per_loc = res[1]
pair_fwd = res[2]
pair_bwd = res[3]
pair_mask_fwd_splits = res[4]
pair_mask_bwd_splits = res[5]
mask_argsort_fwd_splits = res[6]
mask_argsort_bwd_splits = res[7]
masks = res[8]
if self.indice_key is not None:
indice_data = ImplicitGemmIndiceData(
outids,
indices,
pair_fwd,
pair_bwd,
pair_mask_fwd_splits=pair_mask_fwd_splits,
pair_mask_bwd_splits=pair_mask_bwd_splits,
mask_argsort_fwd_splits=mask_argsort_fwd_splits,
mask_argsort_bwd_splits=mask_argsort_bwd_splits,
masks=masks,
is_subm=self.subm,
spatial_shape=spatial_shape,
out_spatial_shape=out_spatial_shape,
algo=self.algo,
ksize=self.kernel_size,
stride=self.stride,
padding=self.padding,
dilation=self.dilation)
msg = f"your indice key {self.indice_key} already exists in this sparse tensor."
assert self.indice_key not in indice_dict, msg
indice_dict[self.indice_key] = indice_data
out_features = Fsp.indice_avgpool_implicit_gemm(
features, pair_fwd, pair_bwd, outids.shape[0], self.training)
if not self.subm and self.record_voxel_count:
if hasattr(self, _MAX_NUM_VOXELS_DURING_TRAINING):
ops.maximum_value_int_(
getattr(self, _MAX_NUM_VOXELS_DURING_TRAINING),
outids.shape[0])
out_tensor = out_tensor.replace_feature(out_features)
out_tensor.indices = outids
out_tensor.indice_dict = indice_dict
......@@ -235,14 +373,17 @@ class SparseMaxPool1d(SparseMaxPool):
dilation=1,
indice_key=None,
algo: Optional[ConvAlgo] = None,
record_voxel_count: bool = False,
name=None):
super(SparseMaxPool1d, self).__init__(1,
super(SparseMaxPool1d,
self).__init__(1,
kernel_size,
stride,
padding,
dilation,
indice_key=indice_key,
algo=algo,
record_voxel_count=record_voxel_count,
name=name)
......@@ -254,14 +395,17 @@ class SparseMaxPool2d(SparseMaxPool):
dilation=1,
indice_key=None,
algo: Optional[ConvAlgo] = None,
record_voxel_count: bool = False,
name=None):
super(SparseMaxPool2d, self).__init__(2,
super(SparseMaxPool2d,
self).__init__(2,
kernel_size,
stride,
padding,
dilation,
indice_key=indice_key,
algo=algo,
record_voxel_count=record_voxel_count,
name=name)
......@@ -273,14 +417,17 @@ class SparseMaxPool3d(SparseMaxPool):
dilation=1,
indice_key=None,
algo: Optional[ConvAlgo] = None,
record_voxel_count: bool = False,
name=None):
super(SparseMaxPool3d, self).__init__(3,
super(SparseMaxPool3d,
self).__init__(3,
kernel_size,
stride,
padding,
dilation,
indice_key=indice_key,
algo=algo,
record_voxel_count=record_voxel_count,
name=name)
......@@ -292,12 +439,87 @@ class SparseMaxPool4d(SparseMaxPool):
dilation=1,
indice_key=None,
algo: Optional[ConvAlgo] = None,
record_voxel_count: bool = False,
name=None):
super(SparseMaxPool4d,
self).__init__(4,
kernel_size,
stride,
padding,
dilation,
indice_key=indice_key,
algo=algo,
record_voxel_count=record_voxel_count,
name=name)
class SparseAvgPool1d(SparseAvgPool):
"""avg pool that use real point count instead of kernel size.
"""
def __init__(self,
kernel_size,
stride=None,
padding=0,
dilation=1,
indice_key=None,
algo: Optional[ConvAlgo] = None,
record_voxel_count: bool = False,
name=None):
super(SparseAvgPool1d,
self).__init__(1,
kernel_size,
stride,
padding,
dilation,
indice_key=indice_key,
algo=algo,
record_voxel_count=record_voxel_count,
name=name)
class SparseAvgPool2d(SparseAvgPool):
"""avg pool that use real point count instead of kernel size.
"""
def __init__(self,
kernel_size,
stride=None,
padding=0,
dilation=1,
indice_key=None,
algo: Optional[ConvAlgo] = None,
record_voxel_count: bool = False,
name=None):
super(SparseAvgPool2d,
self).__init__(2,
kernel_size,
stride,
padding,
dilation,
indice_key=indice_key,
algo=algo,
record_voxel_count=record_voxel_count,
name=name)
class SparseAvgPool3d(SparseAvgPool):
"""avg pool that use real point count instead of kernel size.
"""
def __init__(self,
kernel_size,
stride=None,
padding=0,
dilation=1,
indice_key=None,
algo: Optional[ConvAlgo] = None,
record_voxel_count: bool = False,
name=None):
super(SparseMaxPool4d, self).__init__(4,
super(SparseAvgPool3d,
self).__init__(3,
kernel_size,
stride,
padding,
dilation,
indice_key=indice_key,
algo=algo,
record_voxel_count=record_voxel_count,
name=name)
set(CATCH_HEADER ${PROJECT_SOURCE_DIR}/third_party/catch2)
add_library(catch_main OBJECT src/catch_main.cpp)
# target_compile_features(catch_main PUBLIC cxx_std_2a)
set_property(TARGET catch_main PROPERTY CXX_STANDARD 14)
target_include_directories(catch_main PRIVATE ${CATCH_HEADER})
file(GLOB files "src/test_*.cpp")
foreach(file ${files})
get_filename_component(file_basename ${file} NAME_WE)
string(REGEX REPLACE "test_([^$]+)" "test-\\1" testcase ${file_basename})
add_executable(${testcase} ${file} $<TARGET_OBJECTS:catch_main>)
set_property(TARGET ${testcase} PROPERTY CXX_STANDARD 14)
# set_target_properties(${testcase} PROPERTIES CUDA_SEPARABLE_COMPILATION ON)
# set_property(TARGET ${testcase} PROPERTY CUDA_STANDARD 14)
target_compile_definitions(${testcase} PRIVATE
CATCH_CONFIG_FAST_COMPILE
)
target_include_directories(${testcase} PRIVATE
${CATCH_HEADER} ${ALL_INCLUDE}
)
target_link_libraries(${testcase} ${ALL_LIBS} pybind11::embed -Wl,--no-as-needed spconv)
add_test(NAME "${testcase}"
COMMAND ${testcase}
WORKING_DIRECTORY ${CMAKE_SOURCE_DIR})
endforeach()
\ No newline at end of file
......@@ -113,7 +113,7 @@ class Net(nn.Module):
# nn.BatchNorm1d(32),
# nn.ReLU(),
# spconv.SparseConv3d(64, 64, 2, 2, bias=False, indice_key="m0"),
spconv.SparseMaxPool3d(2, 2, algo=pool_algo),
spconv.SparseMaxPool3d(2, 2, algo=pool_algo, record_voxel_count=True),
spconv.SubMConv3d(64,
96,
3,
......@@ -332,7 +332,7 @@ def main():
voxels_th = torch.from_numpy(voxels).to(device).to(dtype)
coors_th = torch.from_numpy(coors).to(device).int()
voxels_th.requires_grad = True
algo = spconv.ConvAlgo.MaskImplicitGemm
algo = spconv.ConvAlgo.Native
# 3080 Laptop
# MaskImpGemm: 11.2ms
# MaskSplitImpGemm: 12.2ms
......@@ -385,21 +385,25 @@ def main():
torch.cuda.synchronize()
# sort_bench()
times.append(time.time() - t)
# state = net.state_dict()
# state.pop("net.2.max_num_voxels_during_training")
# net.load_state_dict(state)
# breakpoint()
print("spconv time", np.mean(times[10:]))
times = []
for i in range(10):
out = net(voxels_th, coors_th, 1)
print("------------")
torch.cuda.synchronize()
t = time.time()
out.features.backward(dout_t)
torch.cuda.synchronize()
times.append(time.time() - t)
# # print((net.grid == -1).float().sum(), net.grid.numel())
# # print("spconv time", time.time() - t)
print("spconv bw time", np.mean(times[5:]))
# times = []
# for i in range(10):
# out = net(voxels_th, coors_th, 1)
# print("------------")
# torch.cuda.synchronize()
# t = time.time()
# out.features.backward(dout_t)
# torch.cuda.synchronize()
# times.append(time.time() - t)
# # # print((net.grid == -1).float().sum(), net.grid.numel())
# # # print("spconv time", time.time() - t)
# print("spconv bw time", np.mean(times[5:]))
if __name__ == "__main__":
......
......@@ -248,7 +248,7 @@ def test_spconv3d():
ConvAlgo.Native, ConvAlgo.MaskImplicitGemm,
ConvAlgo.MaskSplitImplicitGemm
]
algos = [ConvAlgo.Native]
algos = [ConvAlgo.Native, ConvAlgo.MaskImplicitGemm, ConvAlgo.MaskSplitImplicitGemm]
for dev, shape, bs, IC, OC, k, s, p, d, al in params_grid(
devices, shapes, batchsizes, in_channels, out_channels, ksizes,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment