import sys from pathlib import Path from typing import Dict, List, Tuple import pickle import sys import time from pathlib import Path from cumm.gemm.algospec.core import GemmAlgo import numpy as np import pccm import torch import torch.nn.functional as F from cumm import dtypes from cumm import tensorview as tv from cumm.constants import PACKAGE_ROOT from cumm.conv.bases import NCHW, NHWC, ConvIterAlgo, ConvOpType from cumm.conv.main import ConvMainUnitTest, gen_gemm_kernels from cumm.conv.params import ConvProblem from cumm.gemm import kernel import os from spconv.core_cc.csrc.sparse.all import SpconvOps from cumm.gemm.codeops import div_up from spconv.constants import PACKAGE_ROOT from spconv.core import ConvAlgo from spconv.pytorch import ops from spconv.algo import CONV, BestConvAlgoByProfile from spconv.pytorch.cppcore import torch_tensor_to_tv def reduce_mask_count(mask: np.ndarray, width: int): mask_length_32 = (div_up(mask.shape[0], width)) * width if mask.shape[0] < mask_length_32: mask_pad = np.zeros((mask_length_32, ), dtype=mask.dtype) mask_pad[:mask.shape[0]] = mask mask = mask_pad mask = mask.reshape(-1, width) maskr = np.bitwise_or.reduce(mask, axis=1) maskr_tv = tv.from_numpy(maskr) return SpconvOps.count_bits(maskr_tv).numpy().sum() * width def reduce_mask_count_x(mask: np.ndarray, width: int): mask_length_32 = (div_up(mask.shape[0], width)) * width if mask.shape[0] < mask_length_32: mask_pad = np.zeros((mask_length_32, ), dtype=mask.dtype) mask_pad[:mask.shape[0]] = mask mask = mask_pad mask = mask.reshape(-1, width) maskr = np.bitwise_or.reduce(mask, axis=1) return maskr def dev_subm_inds_v2(subm: bool = True, run_conv: bool = True): limit_input_n = 16384 limit_input_n = None np.random.seed(484) with (PACKAGE_ROOT.parent / "test/data/test_spconv.pkl").open("rb") as f: voxels_np, indices_np, spatial_shape = pickle.load(f) from spconv.test_utils import generate_sparse_data voxels_np = voxels_np[:limit_input_n] indices_np = indices_np[:limit_input_n] # spatial_shape = [19, 18, 17] # sparse_dict = generate_sparse_data(spatial_shape, [1024], 128) # voxels_np = np.ascontiguousarray(sparse_dict["features"]).astype( # np.float32) # indices_np = np.ascontiguousarray( # sparse_dict["indices"][:, [3, 0, 1, 2]]).astype(np.int32) voxels = tv.from_numpy(voxels_np).cuda() indices = tv.from_numpy(indices_np).cuda() indices_th = torch.from_numpy(indices_np).cuda() print(spatial_shape, indices_np.shape) ndim = 3 if subm: ksize = [3, 3, 3] kv = np.prod(ksize) padding = [1] * ndim stride = [1] * ndim dilation = [1] * ndim out_padding = [0] * ndim else: ksize = [2, 2, 2] kv = np.prod(ksize) padding = [0] * ndim stride = [1] * ndim dilation = [1] * ndim out_padding = [0] * ndim out_inds, pair_ref, indice_num_per_loc = ops.get_indice_pairs( indices_th, 1, spatial_shape, ConvAlgo.Native, ksize, stride, padding, dilation, out_padding, subm) indice_num_per_loc_np = indice_num_per_loc.cpu().numpy() indice_pairs_np = pair_ref.cpu().numpy() algo = ConvAlgo.MaskImplicitGemm if algo == ConvAlgo.MaskImplicitGemm: num_split = 1 else: num_split = 2 for i in range(5): res = ops.get_indice_pairs_implicit_gemm(indices_th, 1, spatial_shape, algo, ksize, stride, padding, dilation, out_padding, subm) out_inds = res[0] num_inds_per_loc = res[1] pair_fwd = res[2] pair_fwd_x = pair_fwd.cpu().numpy().reshape(-1) pair_fwd_x[pair_fwd_x == -1] = 0 loc_num_np = (pair_fwd_x > 0).reshape(kv, -1).sum(1) print(loc_num_np) print(indice_num_per_loc_np) pair_bwd = res[3] pair_mask_fwd_splits = res[4] pair_mask_bwd_splits = res[5] mask_tv = torch_tensor_to_tv(pair_mask_fwd_splits[0], dtype=tv.uint32).cpu().numpy() bench_reduce_mask(mask_tv) return mask_argsort_fwd_splits = res[6] mask_argsort_bwd_splits = res[7] masks = res[8] pair_mask_fwd_splits_tv = [ ops.torch_tensor_to_tv(t, dtype=tv.uint32) for t in pair_mask_fwd_splits ] valid_location_bitcount = [ SpconvOps.count_bits(t) for t in pair_mask_fwd_splits_tv ] valid_location_count = sum( [t.cpu().numpy().sum() for t in valid_location_bitcount]) reduce_length = 32 split_mask_valid_count = sum([ reduce_mask_count(t.cpu().numpy(), reduce_length) for t in pair_mask_fwd_splits_tv ]) if subm: print("SUBM", valid_location_count, split_mask_valid_count, pair_fwd.numel()) else: print("REGULAR", valid_location_count, split_mask_valid_count, pair_fwd.numel()) # return if run_conv: C = 64 K = 64 desps = CONV.desps mask_output_fwd = torch.zeros([2, div_up(out_inds.shape[0], 32)], dtype=torch.int32, device=indices_th.device) mask_output_bwd = torch.zeros([2, div_up(indices.dim(0), 32)], dtype=torch.int32, device=indices_th.device) for desp in desps: if desp.algo != GemmAlgo.Simt.value: continue # if desp.op_type == ConvOpType.kBackwardWeight.value: # continue # if desp.tile_shape ! if desp.dtype_a == dtypes.int8.tv_dtype: inp = np.random.randint(-1, 1, size=[voxels_np.shape[0], C]).astype(np.int8) weight = np.random.randint(-1, 1, size=[K, *ksize, C]).astype(np.int8) output = np.random.randint(-1, 1, size=[ out_inds.shape[0], K ]).astype(dtypes.get_npdtype_from_tvdtype(desp.dtype_output)) else: inp = np.random.uniform(-1, 1, size=[ voxels_np.shape[0], C ]).astype(dtypes.get_npdtype_from_tvdtype(desp.dtype_input)) weight = np.random.uniform(-1, 1, size=[K, *ksize, C]).astype( dtypes.get_npdtype_from_tvdtype(desp.dtype_weight)) output = np.random.uniform(-1, 1, size=[ out_inds.shape[0], K ]).astype(dtypes.get_npdtype_from_tvdtype(desp.dtype_output)) weight_ref = weight.transpose(1, 2, 3, 0, 4) weight_ref = np.ascontiguousarray(weight_ref).reshape(-1, K, C) if desp.op_type == ConvOpType.kBackwardInput.value: inp_tv = tv.zeros(inp.shape, desp.dtype_input, 0) else: inp_tv = tv.from_numpy(inp).cuda() if desp.op_type == ConvOpType.kBackwardWeight.value: weight_tv = tv.zeros(weight.shape, desp.dtype_weight, 0) else: weight_tv = tv.from_numpy(weight).cuda() # _ = tv.zeros([5000, 10], tv.float32, 0) if desp.op_type == ConvOpType.kForward.value: output_tv = tv.zeros(output.shape, desp.dtype_output, 0) else: output_tv = tv.from_numpy(output).cuda() torch.cuda.synchronize() t = time.time() spk = 1 if desp.op_type == ConvOpType.kBackwardWeight.value: # TODO support splitk parallel spk = 32 if subm: if desp.op_type == ConvOpType.kForward.value: indice_pairs = pair_fwd elif desp.op_type == ConvOpType.kBackwardInput.value: indice_pairs = pair_bwd else: indice_pairs = pair_fwd mask_output = mask_output_fwd # print([bin(x.item()) for x in masks]) for j in range(num_split): beta = 1 if j == 1 else 0 mask_filter = 0xffffffff mask_filter = masks[j].item() reverse_mask = False if desp.op_type == ConvOpType.kBackwardWeight.value: mask_op = mask_output[j] else: mask_op = pair_mask_fwd_splits[j] if desp.op_type == ConvOpType.kBackwardInput.value: reverse_mask = True CONV.run_with_tuned_result( BestConvAlgoByProfile(desp, spk), desp.op_type, inp_tv, weight_tv, output_tv, torch_tensor_to_tv(mask_op, dtype=tv.uint32), torch_tensor_to_tv(mask_argsort_fwd_splits[j]), torch_tensor_to_tv(mask_output[j], dtype=tv.uint32), torch_tensor_to_tv(indice_pairs), reverse_mask, mask_filter=mask_filter, mask_width=32, beta=beta, verbose=True, ) else: if desp.op_type == ConvOpType.kForward.value: indice_pairs = pair_fwd # inp -> out mask_ops = pair_mask_fwd_splits mask_argsorts = mask_argsort_fwd_splits mask_output = mask_output_fwd elif desp.op_type == ConvOpType.kBackwardInput.value: indice_pairs = pair_bwd # out -> inp mask_ops = pair_mask_bwd_splits mask_argsorts = mask_argsort_bwd_splits mask_output = mask_output_bwd print([bin(x.item()) for x in masks]) else: indice_pairs = pair_fwd # inp -> out mask_ops = pair_mask_fwd_splits mask_argsorts = mask_argsort_fwd_splits mask_output = mask_output_fwd for j in range(2): beta = 1 if j == 1 else 0 mask_filter = masks[j].item() reverse_mask = False if desp.op_type == ConvOpType.kBackwardWeight.value: mask_op = mask_output[j] else: mask_op = mask_ops[j] CONV.run_with_tuned_result( BestConvAlgoByProfile(desp, spk), desp.op_type, inp_tv, weight_tv, output_tv, torch_tensor_to_tv(mask_op, dtype=tv.uint32), torch_tensor_to_tv(mask_argsorts[j]), torch_tensor_to_tv(mask_output[j], dtype=tv.uint32), torch_tensor_to_tv(indice_pairs), reverse_mask, mask_filter=mask_filter, mask_width=32, beta=beta, verbose=True, ) torch.cuda.synchronize() duration = time.time() - t if desp.op_type == ConvOpType.kForward.value: output_ref = np.zeros_like(output, dtype=np.float32) # ref algorithm for filter_offset in range(kv): if subm and filter_offset > kv // 2: nhot = indice_num_per_loc_np[kv - 1 - filter_offset] elif subm and filter_offset == kv // 2: nhot = voxels.shape[0] else: nhot = indice_num_per_loc_np[filter_offset] a_inds = indice_pairs_np[0][filter_offset][:nhot] c_inds = indice_pairs_np[1][filter_offset][:nhot] # print(a_inds_cpu[:10]) a = inp[a_inds] cc = a.astype( np.float32) @ weight_ref[filter_offset].T.astype( np.float32) output_ref[c_inds] += cc output_cpu = output_tv.cpu().numpy().astype(np.float32) duration = time.time() - t my = output_cpu.reshape(-1) print("ERROR", np.linalg.norm(output_ref.reshape(-1) - my)) elif desp.op_type == ConvOpType.kBackwardInput.value: dinput_ref = np.zeros_like(inp, dtype=np.float32) # ref algorithm for filter_offset in range(kv): if subm and filter_offset > kv // 2: nhot = indice_num_per_loc_np[kv - 1 - filter_offset] elif subm and filter_offset == kv // 2: nhot = voxels.shape[0] else: nhot = indice_num_per_loc_np[filter_offset] a_inds = indice_pairs_np[1][filter_offset][:nhot] c_inds = indice_pairs_np[0][filter_offset][:nhot] # print(a_inds_cpu[:10]) a = output[a_inds] # NK @ KC cc = a.astype( np.float32) @ weight_ref[filter_offset].astype( np.float32) dinput_ref[c_inds] += cc din_cpu = inp_tv.cpu().numpy() print( "ERROR", np.linalg.norm( din_cpu.reshape(-1) - dinput_ref.reshape(-1))) else: dw_ref = np.zeros_like(weight_ref, dtype=np.float32) # KV, K, C for filter_offset in range(kv): if subm and filter_offset > kv // 2: nhot = indice_num_per_loc_np[kv - 1 - filter_offset] elif subm and filter_offset == kv // 2: nhot = voxels.shape[0] else: nhot = indice_num_per_loc_np[filter_offset] o_inds = indice_pairs_np[1][filter_offset][:nhot] i_inds = indice_pairs_np[0][filter_offset][:nhot] # print(a_inds_cpu[:10]) out_gather = output[o_inds] # [N, K] inp_gather = inp[i_inds] # [N, C] # KN @ NC dw_res = out_gather.astype( np.float32).T @ inp_gather.astype(np.float32) dw_ref[filter_offset] = dw_res # print(indice_pairs_np_test[0]) dw_ref_kcrs = dw_ref.transpose(1, 0, 2) dw_cpu = weight_tv.cpu().numpy().reshape(K, np.prod(ksize), C) print( "ERROR", np.linalg.norm( dw_cpu.reshape(-1) - dw_ref_kcrs.reshape(-1))) def reverse_bits(a: np.ndarray): a_unpack = np.unpackbits(a, bitorder="little") return np.packbits(a_unpack) def _count_mask_reduce(masks: np.ndarray): masks_tv_count = SpconvOps.count_bits(tv.from_numpy(masks)) masks_tv_count_sum = masks_tv_count.numpy_view().sum() reduce_count = reduce_mask_count(masks, 64) print(masks_tv_count_sum, reduce_count, reduce_count / masks_tv_count_sum) def bench_reduce_mask(masks: np.ndarray, width: int = 27): # masks = np.random.randint(0, 2000000000, size=[100000], dtype=np.uint32)# & 0xffff width_mask = np.array(0xffffffff, dtype=np.uint32) << (32 - width) >> (32 - width) width_half_mask = np.array(0xffffffff, dtype=np.uint32) >> (32 - width // 2 - 1) width_half_mask_left = width_half_mask << (width // 2 + 1) print(bin(width_half_mask)) masks_sort = masks.copy() masks_sort.sort() _count_mask_reduce(masks_sort) masks_sort = masks.copy() & width_half_mask masks_sort.sort() _count_mask_reduce(masks_sort) # masks.sort() # masks = masks & 0xffff reversed_masks = SpconvOps.reverse_bits(tv.from_numpy(masks)).numpy()# & 0xffff0000 new_masks = np.concatenate([masks, reversed_masks]) np.random.shuffle(new_masks) new_masks.sort() _count_mask_reduce(new_masks) new_masks &= width_half_mask new_masks.sort() _count_mask_reduce(new_masks) if __name__ == "__main__": dev_subm_inds_v2()