Commit 899008fa authored by yan.yan's avatar yan.yan
Browse files

working on c++ only

parent f78575ea
......@@ -248,8 +248,7 @@ class ConvOutLocIter(pccm.ParameterizedClass):
class SparseConvIndicesKernel(pccm.ParameterizedClass):
def __init__(self, problem: ConvProblem, dtype_indices: dtypes.DType):
super().__init__()
self.add_dependency(TensorView, TensorViewKernel, TensorViewHashKernel,
ThrustLib)
self.add_dependency(TensorView, TensorViewKernel, TensorViewHashKernel)
self.loc_iter = ConvOutLocIter(problem)
self.add_param_class("spinds", self.loc_iter, "ConvLocIter")
self.add_param_class("spinds", problem, "ConvProblem")
......@@ -271,7 +270,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
code.arg("indice_pairs",
f"{self.dtype_indices}*") # [2, kernelProd, MaxSize]
code.arg("indice_pairs_for_uniq",
f"TIndiceUniq*") # [2, kernelProd, MaxSize]
f"TIndiceUniq*") # [kernelProd * MaxSize + 1]
code.arg("indice_num_per_loc", f"int*") # [kernelProd]
code.arg("num_indices_in", "int")
......@@ -340,7 +339,6 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
code.arg("indice_pairs_out_part", f"int*") # [kernelProd, MaxSize]
code.arg("num_indices_in", "int")
code.arg("indices_pair_size", "int")
# TODO use block instead of filter_offset?
code.raw(f"""
int filter_offset = blockIdx.y;
auto indice_pairs_out_part_filter = indice_pairs_out_part + filter_offset * indices_pair_size;
......@@ -358,6 +356,46 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
""")
return code
@pccm.cuda.cuda_global_function
def calc_conv_indices_stage2_bounded(self):
"""if we bound output indices, some pair may be invalid,
so we need to atomicAdd and assign again.
here we will use indice_pairs_uniq as temp memory of
indice_pairs_in_part.
"""
code = pccm.FunctionCode()
code.targ("TTable")
code.arg("table", f"TTable") # [N, ndim + 1]
code.arg("indice_pairs_uniq_before_sort", f"const typename TTable::key_type*") # [kernelProd, MaxSize]
code.arg("indice_pairs_in_part_temp", f"const int*") # [kernelProd, MaxSize]
code.arg("indice_pairs_in_part", f"int*") # [kernelProd, MaxSize]
code.arg("indice_pairs_out_part", f"int*") # [kernelProd, MaxSize]
code.arg("indice_num_per_loc", f"int*") # [kernelProd]
code.arg("num_indices_in", "int")
code.arg("indices_pair_size", "int")
code.raw(f"""
int filter_offset = blockIdx.y;
auto indice_pairs_in_part_filter = indice_pairs_in_part + filter_offset * indices_pair_size;
auto indice_pairs_out_part_filter = indice_pairs_out_part + filter_offset * indices_pair_size;
auto indice_pairs_in_part_temp_filter = indice_pairs_in_part_temp + filter_offset * indices_pair_size;
auto indice_pairs_uniq_before_sort_filter = indice_pairs_uniq_before_sort + filter_offset * indices_pair_size;
for (int i : tv::KernelLoopX<int>(num_indices_in)) {{
{self.dtype_indices} output_coord_offset = indice_pairs_uniq_before_sort_filter[i];
if (output_coord_offset != std::numeric_limits<typename TTable::key_type>::max()){{
auto table_offset = table.lookup_offset(output_coord_offset);
if (table_offset != -1){{
int old_num = tv::cuda::atomicAggInc(indice_num_per_loc + filter_offset);
indice_pairs_in_part_filter[old_num] = indice_pairs_in_part_temp_filter[i];
indice_pairs_out_part_filter[old_num] = table.value_ptr()[table_offset];
}}
}}
}}
""")
return code
@pccm.cuda.cuda_global_function
def calc_conv_indices_stage1_mask(self):
code = pccm.FunctionCode()
......@@ -369,7 +407,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
code.arg("indice_pairs_bwd",
f"{self.dtype_indices}*") # [kernelProd, MaxSize]
code.arg("indice_pairs_for_uniq",
f"TIndiceUniq*") # [2, kernelProd, MaxSize]
f"TIndiceUniq*") # [kernelProd * MaxSize + 1]
code.arg("indice_num_per_loc", f"int*") # [kernelProd]
code.arg("num_indices_in", "int")
......@@ -397,6 +435,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
// indice_pairs[filter_offset_mul_indices_pair_size + old_num] = i;
// indice_pairs_bwd[filter_offset_mul_indices_pair_size + input_index] = output_coord_offset;
// indice_pairs_for_uniq[filter_offset_mul_indices_pair_size + old_num] = output_coord_offset;
indice_pairs_for_uniq[filter_offset_mul_indices_pair_size + input_index] = output_coord_offset;
// }}
}}
......@@ -420,7 +459,6 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
code.arg("num_indices_in", "int")
code.arg("num_indices_out", "int")
# TODO use block instead of filter_offset?
code.raw(f"""
int filter_offset = blockIdx.y;
uint32_t filter_mask_fwd = (1u << (filter_offset));
......@@ -458,7 +496,6 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
code.arg("num_indices_in", "int")
code.arg("kv", "int")
# TODO use block instead of filter_offset?
code.raw(f"""
for (int input_index : tv::KernelLoopX<int>(num_indices_in)) {{
uint32_t mask = 0;
......@@ -749,18 +786,13 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
indice_pairs_uniq.data_ptr<T>(), indice_num_per_loc.data_ptr<int>(), indices.dim(0),
indice_pairs.dim(2), kv, transposed);
}});
// thrust::device_ptr<{self.dtype_indices}> ptr_tr(indice_pairs_uniq.data_ptr<{self.dtype_indices}>());
// auto thrust_ctx = thrust::cuda::par.on(reinterpret_cast<cudaStream_t>(stream_int));
// thrust::sort(thrust_ctx, ptr_tr, ptr_tr + uniq_size);
// auto new_end = thrust::unique(thrust_ctx, ptr_tr, ptr_tr + uniq_size);
// auto num_out_act = new_end - ptr_tr - 1;
// return num_out_act;
""")
return code # .ret("int")
@pccm.cuda.static_function
def generate_conv_inds_stage1_5(self):
code = pccm.FunctionCode()
code.add_dependency(ThrustLib)
code.arg("indice_pairs_uniq", "tv::Tensor")
code.arg("uniq_size", "int64_t")
code.arg("stream_int", f"std::uintptr_t", "0")
......@@ -783,6 +815,8 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
code = pccm.FunctionCode()
code.arg("indices, hashdata_k, hashdata_v", "tv::Tensor")
code.arg("indice_pairs, indice_pairs_uniq, indice_pairs_uniq_before_sort, out_inds", "tv::Tensor")
code.arg("indice_num_per_loc", "tv::Tensor")
code.arg("num_out_act", "int")
code.arg("batch_size", "int")
code.arg("output_dims, input_dims", f"tv::array<int, {self.ndim}>")
......@@ -790,6 +824,8 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
f"tv::array<int, {self.ndim}>")
code.arg("transposed", f"bool", "false")
code.arg("stream_int", f"std::uintptr_t", "0")
code.arg("use_bound_algo", "bool", "false")
code.raw(f"""
auto custream = reinterpret_cast<cudaStream_t>(stream_int);
// TODO stream
......@@ -798,6 +834,8 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
TV_ASSERT_RT_ERR(kv == indice_pairs.dim(1), "error");
TV_ASSERT_RT_ERR(hashdata_k.dtype() == indice_pairs_uniq.dtype(), "error");
TV_ASSERT_RT_ERR(hashdata_v.dtype() == tv::int32, "error");
auto ctx = tv::Context();
ctx.set_cuda_stream(custream);
// indice_pairs: [2, kv, indices.dim(0)]
// indice_pairs_uniq: [indice_pairs.size() / 2 + 1]
......@@ -805,6 +843,10 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
// auto timer = tv::CudaContextTimer<>();
int64_t uniq_size = indice_pairs.size() / 2 + 1;
TV_ASSERT_RT_ERR(indice_pairs_uniq.dim(0) >= num_out_act, "error");
// int num_out_act_bounded = num_out_act;
// if (num_out_act_bound > 0){{
// num_out_act_bounded = std::min(num_out_act_bounded, num_out_act);
// }}
TV_ASSERT_RT_ERR(out_inds.dim(0) >= num_out_act && out_inds.dim(1) == {self.ndim + 1}, "error");
tv::cuda::Launch launcher_num_act_in(indices.dim(0), custream);
launcher_num_act_in.blocks.y = kv;
......@@ -827,11 +869,29 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
lanucher_build_hash(build_conv_hash_table<table_t>, hash,
out_inds.data_ptr<int>(), indice_pairs_uniq.data_ptr<const K>(),
loc_iter.layout_npq, num_out_act);
launcher_num_act_in(calc_conv_indices_stage2<table_t>, hash,
indice_pairs_uniq_before_sort.data_ptr<const K>(),
indice_pairs[1].data_ptr<int>(),
indices.dim(0),
indice_pairs.dim(2));
if (!use_bound_algo){{
launcher_num_act_in(calc_conv_indices_stage2<table_t>, hash,
indice_pairs_uniq_before_sort.data_ptr<const K>(),
indice_pairs[1].data_ptr<int>(),
indices.dim(0),
indice_pairs.dim(2));
}}else{{
indice_num_per_loc.zero_(ctx);
// copy previous pair in to indice_pairs_uniq
// we need to ensure size of indice_pairs_uniq larger than pair in
TV_ASSERT_RT_ERR({pccm.literal(self.dtype_indices == dtypes.int32)}, "error");
tv::Tensor indice_pairs_in_temp = tv::from_blob(indice_pairs_uniq.raw_data(), {{indice_pairs.dim(1), indice_pairs.dim(2)}},
indice_pairs.dtype(), indice_pairs.device());
indice_pairs_in_temp.copy_(indice_pairs[0].view(-1), ctx);
launcher_num_act_in(calc_conv_indices_stage2_bounded<table_t>, hash,
indice_pairs_uniq_before_sort.data_ptr<const K>(),
indice_pairs_in_temp.data_ptr<const int>(),
indice_pairs[0].data_ptr<int>(),
indice_pairs[1].data_ptr<int>(),
indice_num_per_loc.data_ptr<int>(),
indices.dim(0),
indice_pairs.dim(2));
}}
}});
return num_out_act;
""")
......@@ -899,6 +959,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
f"tv::array<int, {self.ndim}>")
code.arg("transposed", f"bool", "false")
code.arg("stream_int", f"std::uintptr_t", "0")
code.raw(f"""
auto custream = reinterpret_cast<cudaStream_t>(stream_int);
// TODO stream
......
import os
import fire
from cumm.common import CompileInfo
from cumm.conv.main import ConvMainUnitTest
from cumm.gemm.main import GemmMainUnitTest
from pccm.builder.pybind import gen_cmake
from spconv.core import (IMPLGEMM_SIMT_PARAMS, IMPLGEMM_TURING_PARAMS,
IMPLGEMM_VOLTA_PARAMS, SHUFFLE_SIMT_PARAMS,
SHUFFLE_TURING_PARAMS, SHUFFLE_VOLTA_PARAMS)
from spconv.csrc.hash.core import HashTable
from spconv.csrc.sparse.all import SpconvOps
from spconv.csrc.sparse.alloc import ExternalAllocator
from spconv.csrc.sparse.convops import (ConvGemmOps, ConvTunerSimple,
ExternalSpconvMatmul, GemmTunerSimple,
SimpleExternalSpconvMatmul)
from spconv.csrc.utils import BoxOps
def main(include: str,
src: str,
libname: str = "spconv",
prefix: str = "spconvlib"):
all_shuffle = SHUFFLE_SIMT_PARAMS + SHUFFLE_VOLTA_PARAMS + SHUFFLE_TURING_PARAMS
all_shuffle = list(filter(lambda x: not x.is_nvrtc, all_shuffle))
cu = GemmMainUnitTest(all_shuffle)
cu.namespace = "cumm.gemm.main"
all_imp = (IMPLGEMM_SIMT_PARAMS + IMPLGEMM_VOLTA_PARAMS +
IMPLGEMM_TURING_PARAMS)
# all_imp = IMPLGEMM_SIMT_PARAMS
all_imp = list(filter(lambda x: not x.is_nvrtc, all_imp))
convcu = ConvMainUnitTest(all_imp)
convcu.namespace = "cumm.conv.main"
gemmtuner = GemmTunerSimple(cu)
gemmtuner.namespace = "csrc.sparse.convops.gemmops"
convtuner = ConvTunerSimple(convcu)
convtuner.namespace = "csrc.sparse.convops.convops"
convops = ConvGemmOps(gemmtuner, convtuner)
convops.namespace = "csrc.sparse.convops.spops"
cus = [
cu,
convcu,
gemmtuner,
convtuner,
convops,
SpconvOps(),
BoxOps(),
HashTable(),
CompileInfo(),
ExternalAllocator(),
ExternalSpconvMatmul(),
SimpleExternalSpconvMatmul(),
]
gen_cmake(libname, cus, include, src, namespace_prefix=prefix)
if __name__ == "__main__":
fire.Fire(main)
......@@ -38,20 +38,6 @@ from torch.nn.init import calculate_gain
FILTER_HWIO = False
def expand_nd(val: Union[int, List[int], Tuple[int, ...]], ndim: int) -> List[int]:
if isinstance(val, int):
val = [val] * ndim
elif isinstance(val, list):
assert len(val) == ndim
elif isinstance(val, tuple):
assert len(val) == ndim
return [*val]
else:
raise NotImplementedError
return val
class SparseConvolution(SparseModule):
__constants__ = [
'stride', 'padding', 'dilation', 'groups', 'bias', 'subm', 'inverse',
......@@ -82,6 +68,7 @@ class SparseConvolution(SparseModule):
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = expand_nd(ndim, kernel_size)
self.stride = expand_nd(ndim, stride)
kv = int(np.prod(self.kernel_size))
kv_stride = int(np.prod(self.stride))
......@@ -130,7 +117,6 @@ class SparseConvolution(SparseModule):
# KRSC
self.weight = Parameter(
torch.Tensor(out_channels, *self.kernel_size, in_channels))
if bias:
self.bias = Parameter(torch.Tensor(out_channels))
else:
......
......@@ -15,9 +15,13 @@
from cumm import tensorview as tv
import torch
from typing import Dict, Optional, List, Union
from spconv.constants import AllocKeys
from spconv.cppconstants import COMPILED_CUDA_ARCHS
import sys
import sys
from spconv.core_cc.csrc.sparse.alloc import ExternalAllocator
from spconv.core_cc.csrc.sparse.convops import ExternalSpconvMatmul
import numpy as np
_TORCH_DTYPE_TO_TV = {
torch.float32: tv.float32,
......@@ -31,8 +35,16 @@ _TORCH_DTYPE_TO_TV = {
}
_TV_DTYPE_TO_TORCH = {v: k for k, v in _TORCH_DTYPE_TO_TV.items()}
_TORCH_UINT_WORKAROUNDS = {tv.uint32: tv.int32, tv.uint16: tv.int16, tv.uint64: tv.int64}
_ALL_INTS = {tv.int32, tv.int16, tv.int8, tv.int64, tv.uint64, tv.uint8, tv.uint32, tv.uint16}
_TORCH_UINT_WORKAROUNDS = {
tv.uint32: tv.int32,
tv.uint16: tv.int16,
tv.uint64: tv.int64
}
_ALL_INTS = {
tv.int32, tv.int16, tv.int8, tv.int64, tv.uint64, tv.uint8, tv.uint32,
tv.uint16
}
def torch_tensor_to_tv(ten: torch.Tensor,
dtype: Optional[int] = None,
......@@ -62,6 +74,7 @@ def torch_tensor_to_tv(ten: torch.Tensor,
return tv.from_blob(ptr, shape, dtype, tv_device)
return tv.from_blob_strided(ptr, shape, stride, dtype, tv_device)
def torch_tensors_to_tv(*tens: torch.Tensor):
return (torch_tensor_to_tv(t) for t in tens)
......@@ -69,28 +82,35 @@ def torch_tensors_to_tv(*tens: torch.Tensor):
def get_current_stream():
return torch.cuda.current_stream().cuda_stream
def get_arch():
arch = torch.cuda.get_device_capability()
if arch not in COMPILED_CUDA_ARCHS:
print(f"[WARNING]your gpu arch {arch} isn't compiled in prebuilt, "
f"may cause invalid device function. "
f"available: {COMPILED_CUDA_ARCHS}", file=sys.stderr)
print(
f"[WARNING]your gpu arch {arch} isn't compiled in prebuilt, "
f"may cause invalid device function. "
f"available: {COMPILED_CUDA_ARCHS}",
file=sys.stderr)
return arch
class TorchAllocator(ExternalAllocator):
def __init__(self, gpudevice: torch.device) -> None:
super().__init__()
self.gpudevice = gpudevice
self.cpudevice = torch.device("cpu:0")
self.cpudevice = torch.device("cpu")
self.allocated: Dict[Union[str, int], torch.Tensor] = {}
def zeros(self, name: str, shape: List[int], dtype: int, device: int) -> tv.Tensor:
def zeros(self, name: str, shape: List[int], dtype: int,
device: int, is_temp_memory: bool = False, stream: int = 0) -> tv.Tensor:
# TODO free memory by name if its already free by pointer.
# provide a name if you want to access it after c++ function exit.
torch_uint_workaround = dtype in _TORCH_UINT_WORKAROUNDS
dtype_bkp = dtype
if dtype in _TORCH_UINT_WORKAROUNDS:
assert name == "", "must be temp memory for uint dtypes"
dtype = _TORCH_UINT_WORKAROUNDS[dtype]
# assert name == "", "must be temp memory for uint dtypes"
dtype = _TORCH_UINT_WORKAROUNDS[dtype]
th_dtype = _TV_DTYPE_TO_TORCH[dtype]
if device == -1:
dev = self.cpudevice
......@@ -99,18 +119,19 @@ class TorchAllocator(ExternalAllocator):
ten = torch.zeros(shape, dtype=th_dtype, device=dev)
ten_tv = torch_tensor_to_tv(ten)
self.allocated[ten.data_ptr()] = ten
if name:
if name and not is_temp_memory:
self.allocated[name] = ten
if torch_uint_workaround:
return ten_tv.type_view(dtype_bkp)
return ten_tv
def empty(self, name: str, shape: List[int], dtype: int, device: int) -> tv.Tensor:
def empty(self, name: str, shape: List[int], dtype: int,
device: int, is_temp_memory: bool = False, stream: int = 0) -> tv.Tensor:
torch_uint_workaround = dtype in _TORCH_UINT_WORKAROUNDS
dtype_bkp = dtype
if dtype in _TORCH_UINT_WORKAROUNDS:
assert name == "", "must be temp memory for uint dtypes"
dtype = _TORCH_UINT_WORKAROUNDS[dtype]
# assert name == "", "must be temp memory for uint dtypes"
dtype = _TORCH_UINT_WORKAROUNDS[dtype]
th_dtype = _TV_DTYPE_TO_TORCH[dtype]
if device == -1:
dev = self.cpudevice
......@@ -119,20 +140,21 @@ class TorchAllocator(ExternalAllocator):
ten = torch.empty(shape, dtype=th_dtype, device=dev)
ten_tv = torch_tensor_to_tv(ten)
self.allocated[ten.data_ptr()] = ten
if name:
if name and not is_temp_memory:
self.allocated[name] = ten
if torch_uint_workaround:
return ten_tv.type_view(dtype_bkp)
return ten_tv
def full_int(self, name: str, shape: List[int], value: int, dtype: int, device: int) -> tv.Tensor:
def full_int(self, name: str, shape: List[int], value: int, dtype: int,
device: int, is_temp_memory: bool = False, stream: int = 0) -> tv.Tensor:
if dtype in _TORCH_UINT_WORKAROUNDS and value < 0:
raise NotImplementedError("you can't use full for unsigned dtypes")
torch_uint_workaround = dtype in _TORCH_UINT_WORKAROUNDS
dtype_bkp = dtype
if dtype in _TORCH_UINT_WORKAROUNDS:
assert name == "", "must be temp memory for uint dtypes"
dtype = _TORCH_UINT_WORKAROUNDS[dtype]
dtype = _TORCH_UINT_WORKAROUNDS[dtype]
th_dtype = _TV_DTYPE_TO_TORCH[dtype]
if device == -1:
......@@ -142,22 +164,21 @@ class TorchAllocator(ExternalAllocator):
ten = torch.full(shape, value, dtype=th_dtype, device=dev)
ten_tv = torch_tensor_to_tv(ten)
self.allocated[ten.data_ptr()] = ten
if name:
self.allocated[name] = ten
if name:
if name and not is_temp_memory:
self.allocated[name] = ten
if torch_uint_workaround:
return ten_tv.type_view(dtype_bkp)
return ten_tv
def full_float(self, name: str, shape: List[int], value: float, dtype: int, device: int) -> tv.Tensor:
def full_float(self, name: str, shape: List[int], value: float, dtype: int,
device: int, is_temp_memory: bool = False, stream: int = 0) -> tv.Tensor:
if dtype in _TORCH_UINT_WORKAROUNDS and value < 0:
raise NotImplementedError("you can't use full for unsigned dtypes")
torch_uint_workaround = dtype in _TORCH_UINT_WORKAROUNDS
dtype_bkp = dtype
if dtype in _TORCH_UINT_WORKAROUNDS:
assert name == "", "must be temp memory for uint dtypes"
dtype = _TORCH_UINT_WORKAROUNDS[dtype]
dtype = _TORCH_UINT_WORKAROUNDS[dtype]
th_dtype = _TV_DTYPE_TO_TORCH[dtype]
if device == -1:
dev = self.cpudevice
......@@ -166,12 +187,15 @@ class TorchAllocator(ExternalAllocator):
ten = torch.full(shape, value, dtype=th_dtype, device=dev)
ten_tv = torch_tensor_to_tv(ten)
self.allocated[ten.data_ptr()] = ten
if name:
if name and not is_temp_memory:
self.allocated[name] = ten
if torch_uint_workaround:
return ten_tv.type_view(dtype_bkp)
return ten_tv
def get_tensor_by_name(self, name: str):
return torch_tensor_to_tv(self.allocated[name])
def free(self, ten: tv.Tensor):
if ten.storage_bytesize() != ten.bytesize():
raise ValueError("you can't free a sliced tensor.")
......@@ -189,6 +213,130 @@ class TorchAllocator(ExternalAllocator):
return
class TorchSpconvMatmul(ExternalSpconvMatmul):
def __init__(self, alloc: TorchAllocator) -> None:
super().__init__()
self.alloc = alloc
def indice_conv_init_gemm(self, features_n: str, filters_n: str,
all_weight_is_krsc: bool, is_kc_not_ck: bool,
kv_center: int, out_channel: int, stream_int: int = 0):
features = self.alloc.allocated[features_n]
filters = self.alloc.allocated[filters_n]
if not all_weight_is_krsc:
filters = filters.reshape(-1, *filters.shape[-2:])
if not is_kc_not_ck:
out_features = torch.mm(features, filters[kv_center])
else:
out_features = torch.mm(features, filters[kv_center].T)
else:
filters = filters.reshape(out_channel, -1, filters.shape[-1])
if features.is_cuda or (features.dtype != torch.float16):
out_features = torch.mm(features, filters[:, kv_center].T)
else:
# pytorch 1.12 don't support cpu half mm, f**k pytorch
# we need cpu fp16 mm for test only.
out_features = torch.empty((features.shape[0], out_channel),
dtype=features.dtype,
device=features.device)
features_np = torch_tensor_to_tv(features).numpy_view()
filters_np = torch_tensor_to_tv(filters).numpy_view()
out_features_np = torch_tensor_to_tv(out_features).numpy_view()
np.matmul(features_np,
filters_np[:, kv_center].T,
out=out_features_np)
self.alloc.allocated[AllocKeys.OutFeatures] = out_features
# print(filters.shape, features.shape, all_weight_is_krsc, out_features.shape, out_features.is_contiguous())
return torch_tensor_to_tv(out_features)
def indice_conv_cpu_gemm(self, inp_buffer_n: str, out_buffer_n: str, filters_n: str,
all_weight_is_krsc: bool,
is_kc_not_ck: bool, nhot: int, index: int):
kv_dim = 1 if all_weight_is_krsc else 0
inp_buffer = self.alloc.allocated[inp_buffer_n]
filters = self.alloc.allocated[filters_n]
if not all_weight_is_krsc:
filters = filters.reshape(-1, *filters.shape[-2:])
else:
filters = filters.reshape(filters.shape[0], -1, filters.shape[-1])
out_buffer = self.alloc.allocated[out_buffer_n]
filters_i = filters.select(kv_dim, index)
filters_cur = filters_i if not is_kc_not_ck else filters_i.T
if inp_buffer.dtype == torch.float16:
inp_buffer_np = torch_tensor_to_tv(inp_buffer).numpy_view()
filters_np = torch_tensor_to_tv(filters).numpy_view()
filters_i_np = filters_np[
index] if not all_weight_is_krsc else filters_np[:, index]
filters_cur_np = filters_i_np if not is_kc_not_ck else filters_i_np.T
out_buffer_np = torch_tensor_to_tv(out_buffer).numpy_view()
np.matmul(inp_buffer_np[:nhot],
filters_cur_np,
out=out_buffer_np[:nhot])
else:
torch.mm(inp_buffer[:nhot], filters_cur, out=out_buffer[:nhot])
def indice_conv_bwd_init_gemm(self, features_n: str, filters_n: str,
out_bp_n: str, dfilters_n: str,
all_weight_is_krsc: bool, is_kc_not_ck: bool,
kv_center: int, stream_int: int = 0):
features = self.alloc.allocated[features_n]
filters = self.alloc.allocated[filters_n]
out_bp = self.alloc.allocated[out_bp_n]
dfilters = self.alloc.allocated[dfilters_n]
if not all_weight_is_krsc:
filters = filters.reshape(-1, *filters.shape[-2:])
dfilters = dfilters.reshape(-1, *filters.shape[-2:])
else:
filters = filters.reshape(filters.shape[0], -1, filters.shape[-1])
dfilters = dfilters.reshape(filters.shape[0], -1, filters.shape[-1])
if not all_weight_is_krsc:
if not is_kc_not_ck:
torch.mm(features.T, out_bp, out=dfilters[kv_center])
din = torch.mm(out_bp, filters[kv_center].T)
else:
torch.mm(out_bp.T, features, out=dfilters[kv_center])
din = torch.mm(out_bp, filters[kv_center])
else:
# KN @ NC
torch.mm(out_bp.T, features, out=dfilters[:, kv_center])
# NK @ KC
din = torch.mm(out_bp, filters[:, kv_center])
self.alloc.allocated[AllocKeys.DIn] = din
return torch_tensor_to_tv(din)
def indice_conv_bwd_cpu_gemm(self, inp_buffer_n: str,
out_buffer_n: str, filters_n: str, dfilters_n: str,all_weight_is_krsc: bool,
is_kc_not_ck: bool, nhot: int, index: int):
kv_dim = 1 if all_weight_is_krsc else 0
inp_buffer = self.alloc.allocated[inp_buffer_n]
out_buffer = self.alloc.allocated[out_buffer_n]
filters = self.alloc.allocated[filters_n]
dfilters = self.alloc.allocated[dfilters_n]
if not all_weight_is_krsc:
filters = filters.reshape(-1, *filters.shape[-2:])
dfilters = dfilters.reshape(-1, *filters.shape[-2:])
else:
filters = filters.reshape(filters.shape[0], -1, filters.shape[-1])
dfilters = dfilters.reshape(filters.shape[0], -1, filters.shape[-1])
filters_i = filters.select(kv_dim, index)
dfilters_i = dfilters.select(kv_dim, index)
filters_KC = filters_i if is_kc_not_ck else filters_i.T
if is_kc_not_ck:
# KN @ NC
torch.mm(out_buffer[:nhot].T, inp_buffer[:nhot], out=dfilters_i)
else:
# CN @ NK
torch.mm(inp_buffer[:nhot].T, out_buffer[:nhot], out=dfilters_i)
# NK @ KC
torch.mm(out_buffer[:nhot], filters_KC, out=inp_buffer[:nhot])
if __name__ == "__main__":
a = torch.rand(2, 2)
atv = torch_tensor_to_tv(a)
......
......@@ -23,23 +23,25 @@ import spconv
from spconv.core import AlgoHint, ConvAlgo
from typing import Dict, List, Optional, Union
from spconv.pytorch.core import ThrustSortAllocator
from spconv.pytorch.cppcore import torch_tensor_to_tv, get_current_stream, get_arch
from spconv.pytorch.cppcore import TorchAllocator, torch_tensor_to_tv, get_current_stream, get_arch, TorchSpconvMatmul
from spconv.core_cc.csrc.sparse.all import SpconvOps
from spconv.core_cc.csrc.sparse.alloc import ExternalAllocator
from spconv.constants import SPCONV_CPP_INDICE_PAIRS, SPCONV_CPP_INDICE_PAIRS_IGEMM, SPCONV_CPP_GEMM
import spconv.core_cc as _ext
from spconv.core_cc.csrc.sparse.convops.spops import ConvGemmOps
from spconv.utils import nullcontext
if hasattr(_ext, "cumm"):
CPU_ONLY_BUILD = False
from spconv.algo import GEMM, CONV # , GATHER, SCATTER
from spconv.algo import GEMM, CONV, GEMM_CPP, CONV_CPP
else:
CPU_ONLY_BUILD = True
GEMM = None
CONV = None
GEMM_CPP = None
CONV_CPP = None
import time
from spconv.constants import FILTER_HWIO, ALL_WEIGHT_IS_KRSC
from spconv.constants import FILTER_HWIO, ALL_WEIGHT_IS_KRSC, AllocKeys
from cumm.gemm import codeops
from spconv.tools import CUDAKernelTimer
......@@ -103,14 +105,30 @@ def get_indice_pairs(indices: torch.Tensor,
dilation: List[int],
out_padding: List[int],
subm: bool = False,
transpose: bool = False):
transpose: bool = False,
num_out_act_bound: int = -1):
# torch.cuda.synchronize()
# t = time.time()
# stream = get_current_stream()
# CONV.stream_synchronize(stream)
# t = time.time()
if SPCONV_CPP_INDICE_PAIRS:
alloc = TorchAllocator(indices.device)
stream = 0
if indices.is_cuda:
stream = get_current_stream()
num_act_out = SpconvOps.get_indice_pairs(alloc, torch_tensor_to_tv(indices), batch_size, spatial_shape,
algo.value, ksize, stride, padding, dilation, out_padding, subm, transpose, stream)
if subm:
out_inds = indices
else:
out_inds = alloc.allocated[AllocKeys.OutIndices]
pair = alloc.allocated[AllocKeys.Pair]
indice_num_per_loc = alloc.allocated[AllocKeys.IndiceNumPerLoc]
# print(subm, out_inds.shape, pair.shape, indice_num_per_loc.shape, num_act_out)
return out_inds[:num_act_out], pair, indice_num_per_loc
ndim = indices.shape[1] - 1
kv: int = functools.reduce(lambda x, y: x * y, ksize, 1)
if not subm:
......@@ -152,7 +170,6 @@ def get_indice_pairs(indices: torch.Tensor,
# device=indices.device)
out_inds_tv = torch_tensor_to_tv(out_inds)
# hashdata_tv = torch_tensor_to_tv(hashdata, dtype=tv.custom64)
SpconvOps.generate_subm_conv_inds(inds_tv,
hashdata.hashdata_k_tv,
hashdata.hashdata_v_tv,
......@@ -200,6 +217,10 @@ def get_indice_pairs(indices: torch.Tensor,
stream_int=stream)
uniq_res = indice_pairs_uniq.unique()
num_act_out = uniq_res.shape[0] - 1
use_bound_algo = False
if num_out_act_bound > 0 and num_act_out > num_out_act_bound:
num_act_out = num_out_act_bound
use_bound_algo = True
uniq_res_tv = torch_tensor_to_tv(uniq_res)
# num_act_out = SpconvOps.generate_conv_inds_stage1_5(
# indice_pairs_uniq_tv,
......@@ -224,6 +245,7 @@ def get_indice_pairs(indices: torch.Tensor,
uniq_res_tv,
indice_pairs_uniq_tv,
out_inds_tv,
indice_num_per_loc_tv,
num_out_act=num_act_out,
batch_size=batch_size,
output_dims=out_shape,
......@@ -233,7 +255,8 @@ def get_indice_pairs(indices: torch.Tensor,
padding=padding,
dilation=dilation,
transposed=transpose,
stream_int=stream)
stream_int=stream,
use_bound_algo=use_bound_algo)
else:
out_inds = torch.empty((kv * indices.shape[0], indices.shape[1]),
dtype=indices.dtype,
......@@ -258,7 +281,6 @@ def get_indice_pairs(indices: torch.Tensor,
# print("REGU", time.time() - t)
return out_inds, pair, indice_num_per_loc
def get_indice_pairs_implicit_gemm(
indices: torch.Tensor,
batch_size: int,
......@@ -273,7 +295,8 @@ def get_indice_pairs_implicit_gemm(
transpose: bool = False,
is_train: bool = True,
alloc: Optional[ThrustSortAllocator] = None,
timer: CUDAKernelTimer = CUDAKernelTimer(False)):
timer: CUDAKernelTimer = CUDAKernelTimer(False),
num_out_act_bound: int = -1):
"""
Why return tuple? because pytorch seems don't support custom object in autograd.
return: (
......@@ -289,6 +312,62 @@ def get_indice_pairs_implicit_gemm(
)
"""
stream = get_current_stream()
if SPCONV_CPP_INDICE_PAIRS_IGEMM:
thalloc = TorchAllocator(indices.device)
mask_tensor, num_act_out = SpconvOps.get_indice_pairs_implicit_gemm(
thalloc, torch_tensor_to_tv(indices), batch_size, spatial_shape,
algo.value, ksize, stride, padding, dilation, out_padding, subm, transpose, is_train, stream,
num_out_act_bound)
mask_split_count = mask_tensor.dim(0)
masks = [mask_tensor[i:i+1].numpy() for i in range(mask_split_count)]
if subm:
out_inds = indices
else:
out_inds = thalloc.allocated[AllocKeys.OutIndices]
pair = thalloc.allocated[AllocKeys.Pair]
indice_num_per_loc = thalloc.allocated[AllocKeys.IndiceNumPerLoc]
if subm:
pair_mask = thalloc.allocated[AllocKeys.PairMask]
mask_argsort = thalloc.allocated[AllocKeys.MaskArgSort]
pair_mask_in_splits = [pair_mask[i] for i in range(mask_split_count)]
mask_argsort_in_splits = [
mask_argsort[i] for i in range(mask_split_count)
]
return (out_inds, indice_num_per_loc, pair[0], pair[1],
pair_mask_in_splits, [], mask_argsort_in_splits, [], masks)
else:
pair_bwd = pair
pair_fwd = thalloc.allocated[AllocKeys.PairFwd]
pair_mask_fwd = thalloc.allocated[AllocKeys.PairMask]
pair_mask_bwd = torch.Tensor()
mask_argsort_bwd = torch.Tensor()
if is_train:
pair_mask_bwd = thalloc.allocated[AllocKeys.PairMaskBwd]
mask_argsort_bwd = thalloc.allocated[AllocKeys.MaskArgSortBwd]
mask_argsort_fwd = thalloc.allocated[AllocKeys.MaskArgSort]
if not is_train:
pair_bwd = torch.Tensor()
pair_mask_bwd_splits: List[torch.Tensor] = []
mask_argsort_bwd_splits: List[torch.Tensor] = []
else:
pair_mask_bwd_splits = [
pair_mask_bwd[i] for i in range(mask_split_count)
]
mask_argsort_bwd_splits = [
mask_argsort_bwd[i] for i in range(mask_split_count)
]
pair_mask_fwd_splits = [
pair_mask_fwd[i] for i in range(mask_split_count)
]
mask_argsort_fwd_splits = [
mask_argsort_fwd[i] for i in range(mask_split_count)
]
return (out_inds, indice_num_per_loc, pair_fwd, pair_bwd,
pair_mask_fwd_splits, pair_mask_bwd_splits,
mask_argsort_fwd_splits, mask_argsort_bwd_splits, masks)
t = 0
if DEBUG:
CONV.stream_synchronize(stream)
......@@ -443,6 +522,8 @@ def get_indice_pairs_implicit_gemm(
uniq_res = indice_pairs_uniq.unique()
num_act_out = uniq_res.shape[0] - 1
if num_out_act_bound > 0 and num_act_out > num_out_act_bound:
num_act_out = num_out_act_bound
if DEBUG:
CONV.stream_synchronize(stream)
......@@ -627,10 +708,36 @@ def indice_conv(features: torch.Tensor,
# t = time.time()
if not features.is_contiguous():
features = features.contiguous()
if features.dtype == torch.int8 or features.dtype == torch.qint8:
raise NotImplementedError("work in progress")
if SPCONV_CPP_GEMM and GEMM_CPP is not None:
# print("CPPPPPP!!!", features.device)
alloc = TorchAllocator(features.device)
from spconv.core_cc.csrc.sparse.convops import SimpleExternalSpconvMatmul
# ext_mm = TorchSpconvMatmul(alloc)
if features.is_cuda:
ext_mm = SimpleExternalSpconvMatmul(alloc)
else:
ext_mm = TorchSpconvMatmul(alloc)
alloc.allocated[AllocKeys.Features] = features
alloc.allocated[AllocKeys.Filters] = filters
features_tv = torch_tensor_to_tv(features)
indice_pairs_tv = torch_tensor_to_tv(indice_pairs)
indice_pair_num_tv = torch_tensor_to_tv(indice_pair_num)
filters_tv = torch_tensor_to_tv(filters)
stream = 0
if features.is_cuda:
stream = get_current_stream()
ConvGemmOps.indice_conv(alloc, ext_mm, GEMM_CPP, ALL_WEIGHT_IS_KRSC, FILTER_HWIO,
features_tv, filters_tv, indice_pairs_tv, indice_pair_num_tv, num_activate_out,
inverse, subm, algo.value, stream)
out_features = alloc.allocated[AllocKeys.OutFeatures]
return out_features
if not ALL_WEIGHT_IS_KRSC:
kv_dim = 0
is_KC_not_CK = not FILTER_HWIO
......@@ -642,7 +749,6 @@ def indice_conv(features: torch.Tensor,
filter_shape_per_kv = [out_channel, filters.shape[-1]]
filters = filters.reshape(-1, *filters.shape[-2:])
kv = filters.shape[0]
else:
kv_dim = 1
out_channel = filters.shape[0]
......@@ -651,7 +757,6 @@ def indice_conv(features: torch.Tensor,
kv = filters.shape[1]
filter_shape_per_kv = [out_channel, filters.shape[-1]]
kv_center = kv // 2
if subm:
# out_features = torch.zeros((num_activate_out, out_channel),
......@@ -663,7 +768,19 @@ def indice_conv(features: torch.Tensor,
else:
out_features = torch.mm(features, filters[kv_center].T)
else:
out_features = torch.mm(features, filters[:, kv_center].T)
if features.is_cuda or (features.dtype != torch.float16):
out_features = torch.mm(features, filters[:, kv_center].T)
else:
# pytorch 1.12 don't support cpu half mm, f**k pytorch
# we need cpu fp16 mm for test only.
out_features = torch.empty((features.shape[0], out_channel),
dtype=features.dtype,
device=features.device)
features_np = torch_tensor_to_tv(features).numpy_view()
filters_np = torch_tensor_to_tv(filters).numpy_view()
out_features_np = torch_tensor_to_tv(out_features).numpy_view()
np.matmul(features_np, filters_np[:, kv_center].T, out=out_features_np)
# out_features = torch.mm(features, filters[:, kv_center].T)
else:
out_features = torch.zeros((num_activate_out, out_channel),
dtype=features.dtype,
......@@ -706,7 +823,15 @@ def indice_conv(features: torch.Tensor,
SpconvOps.gather_cpu(inp_buffer_tv, a, inp_indices)
filters_i = filters.select(kv_dim, i)
filters_cur = filters_i if not is_KC_not_CK else filters_i.T
torch.mm(inp_buffer[:nhot], filters_cur, out=out_buffer[:nhot])
if features.dtype == torch.float16:
inp_buffer_np = torch_tensor_to_tv(inp_buffer).numpy_view()
filters_np = torch_tensor_to_tv(filters).numpy_view()
filters_i_np = filters_np[i] if not ALL_WEIGHT_IS_KRSC else filters_np[:, i]
filters_cur_np = filters_i_np if not is_KC_not_CK else filters_i_np.T
out_buffer_np = torch_tensor_to_tv(out_buffer).numpy_view()
np.matmul(inp_buffer_np[:nhot], filters_cur_np, out=out_buffer_np[:nhot])
else:
torch.mm(inp_buffer[:nhot], filters_cur, out=out_buffer[:nhot])
SpconvOps.scatter_add_cpu(c, out_buffer_tv, out_indices)
return out_features
......@@ -750,7 +875,7 @@ def indice_conv(features: torch.Tensor,
profile_idx, :nhot_profile]
inp_indices = torch_tensor_to_tv(inp_indices_th)
out_indices = torch_tensor_to_tv(out_indices_th)
filter_tv = torch_tensor_to_tv(filters)[profile_idx]
# filter_tv = torch_tensor_to_tv(filters)[profile_idx]
filter_tv = torch_tensor_to_tv(filters).select(kv_dim, profile_idx)
tuned_res, min_time = GEMM.tune_and_cache(
......@@ -826,8 +951,40 @@ def indice_conv_backward(features: torch.Tensor,
algo: ConvAlgo = ConvAlgo.Native,
timer: CUDAKernelTimer = CUDAKernelTimer(False)):
# print(out_bp.mean(), out_bp.max(), out_bp.min())
if SPCONV_CPP_GEMM and GEMM_CPP is not None:
alloc = TorchAllocator(features.device)
ext_mm = TorchSpconvMatmul(alloc)
alloc.allocated[AllocKeys.Features] = features
alloc.allocated[AllocKeys.Filters] = filters
alloc.allocated[AllocKeys.OutBp] = out_bp
features_tv = torch_tensor_to_tv(features)
out_bp_tv = torch_tensor_to_tv(out_bp)
indice_pairs_tv = torch_tensor_to_tv(indice_pairs)
indice_pair_num_tv = torch_tensor_to_tv(indice_pair_num)
filters_tv = torch_tensor_to_tv(filters)
stream = 0
if features.is_cuda:
stream = get_current_stream()
ConvGemmOps.indice_conv_backward(alloc, ext_mm, GEMM_CPP, ALL_WEIGHT_IS_KRSC, FILTER_HWIO,
features_tv, filters_tv, out_bp_tv, indice_pairs_tv, indice_pair_num_tv,
inverse, subm, algo.value, stream)
din = alloc.allocated[AllocKeys.DIn]
df = alloc.allocated[AllocKeys.DFilters]
return din, df
filters_shape = filters.shape
# TODO handle this in nn.Module to make sure features in backward is contiguous
if not features.is_contiguous():
features = features.contiguous()
if not out_bp.is_contiguous():
out_bp = out_bp.contiguous()
assert out_bp.is_contiguous()
assert filters.is_contiguous()
assert features.is_contiguous()
if not ALL_WEIGHT_IS_KRSC:
kv_dim = 0
is_KC_not_CK = not FILTER_HWIO
......@@ -849,14 +1006,6 @@ def indice_conv_backward(features: torch.Tensor,
filter_shape_per_kv = [out_channel, filters.shape[-1]]
kv_center = kv // 2
# TODO handle this in nn.Module to make sure features in backward is contiguous
if not features.is_contiguous():
features = features.contiguous()
if not out_bp.is_contiguous():
out_bp = out_bp.contiguous()
assert out_bp.is_contiguous()
assert filters.is_contiguous()
assert features.is_contiguous()
if subm:
dfilters = torch.zeros_like(filters)
......@@ -1141,6 +1290,31 @@ def implicit_gemm(features: torch.Tensor,
timer: CUDAKernelTimer = CUDAKernelTimer(False),
fp32_accum: Optional[bool] = None):
stream = get_current_stream()
if SPCONV_CPP_GEMM and CONV_CPP is not None:
alloc = TorchAllocator(features.device)
features_tv = torch_tensor_to_tv(features)
pair_fwd_tv = torch_tensor_to_tv(pair_fwd)
pair_mask_fwd_splits_tv = [torch_tensor_to_tv(t, tv.uint32) for t in pair_mask_fwd_splits]
mask_argsort_fwd_splits_tv = [torch_tensor_to_tv(t) for t in mask_argsort_fwd_splits]
filters_tv = torch_tensor_to_tv(filters)
mask = np.concatenate(masks)
mask_tv = tv.from_numpy(mask).clone()
timer_cpp = tv.CUDAKernelTimer(False)
if timer._timer is not None:
timer_cpp = timer._timer
auto_fp32_accum = fp32_accum is None
if fp32_accum is None:
fp32_accum = False
mask_width = ConvGemmOps.implicit_gemm(alloc, CONV_CPP, features_tv, filters_tv, pair_fwd_tv, pair_mask_fwd_splits_tv,
mask_argsort_fwd_splits_tv, num_activate_out, mask_tv, is_train, is_subm, stream, timer_cpp, auto_fp32_accum,
fp32_accum)
out_features = alloc.allocated[AllocKeys.OutFeatures]
mask_output_fwd = alloc.allocated.get(AllocKeys.MaskOutputFwd, None)
if is_train:
assert mask_output_fwd is not None
return out_features, mask_output_fwd, mask_width
# if DEBUG:
# CONV.stream_synchronize(stream)
......@@ -1225,7 +1399,8 @@ def implicit_gemm(features: torch.Tensor,
# CONV.stream_synchronize(stream)
# t = time.time()
print(tune_res.algo_desp)
# print(tune_res.algo_desp)
# with tv.measure_and_print("f16 time"):
with timer.record("implicit_gemm", stream):
for j in range(num_split):
beta = 0 if j == 0 else 1
......@@ -1245,6 +1420,81 @@ def implicit_gemm(features: torch.Tensor,
beta=beta,
stream=stream,
verbose=False)
# INT8_TEST = True
# if INT8_TEST:
# if features.shape[1] % 32 != 0:
# return out_features, mask_output_fwd, mask_width
# features = features.to(torch.int8)
# filters = filters.to(torch.int8)
# out_features_i8 = out_features.to(torch.int8)
# features_tv = torch_tensor_to_tv(features)
# filters_tv = torch_tensor_to_tv(filters)
# out_features_i8_tv = torch_tensor_to_tv(out_features_i8)
# tune_res = CONV.get_tuned_algo(ConvOpType.kForward, features_tv.dtype,
# filters_tv.dtype, out_features_i8_tv.dtype,
# out_channel, in_channel, arch)
# if tune_res is None:
# tune_res, _ = CONV.tune_and_cache(
# ConvOpType.kForward,
# features_tv,
# filters_tv,
# out_features_i8_tv,
# NHWC,
# KRSC,
# NHWC,
# arch,
# mask=pair_mask_fwd_split_tvs[0],
# mask_argsort=mask_argsort_fwd_split_tvs[0],
# indices=pair_fwd_tv,
# reverse_mask=False,
# mask_filter=masks[0].item(),
# stream=stream,
# fp32_accum=fp32_accum)
# mask_width = tune_res.algo_desp.tile_shape[0]
# if is_train:
# mask_output_fwd = torch.empty(
# [num_split,
# codeops.div_up(num_activate_out, mask_width)],
# dtype=torch.int32,
# device=features.device)
# # pytorch don't support uint32.
# mask_output_fwd_tv = torch_tensor_to_tv(mask_output_fwd,
# dtype=tv.uint32)
# mask_output_fwd_tvs = [mask_output_fwd_tv[j] for j in range(num_split)]
# else:
# mask_output_fwd = None
# mask_output_fwd_tv = tv.Tensor()
# mask_output_fwd_tvs = [tv.Tensor() for _ in range(num_split)]
# # CONV.stream_synchronize(stream)
# # print("FPREPARE", time.time() - t)
# # # t = time.time()
# # CONV.stream_synchronize(stream)
# # t = time.time()
# # print(tune_res.algo_desp)
# with tv.measure_and_print(f"i8 time {features.shape[0]}-{in_channel}-{out_channel}"):
# with timer.record("implicit_gemm_i8", stream):
# for j in range(num_split):
# beta = 0 if j == 0 else 1
# CONV.run_with_tuned_result(
# tune_res,
# ConvOpType.kForward,
# features_tv,
# filters_tv,
# out_features_i8_tv,
# mask=pair_mask_fwd_split_tvs[j],
# mask_argsort=mask_argsort_fwd_split_tvs[j],
# mask_output=mask_output_fwd_tvs[j],
# indices=pair_fwd_tv,
# reverse_mask=False,
# mask_filter=masks_ints[j],
# mask_width=-1,
# beta=beta,
# stream=stream,
# verbose=False)
# torch.cuda.synchronize()
# if DEBUG:
......@@ -1266,7 +1516,7 @@ def implicit_gemm_backward(features: torch.Tensor,
pair_mask_bwd_splits: List[torch.Tensor],
mask_argsort_fwd_splits: List[torch.Tensor],
mask_argsort_bwd_splits: List[torch.Tensor],
mask_output_fwd: torch.Tensor,
mask_output_fwd: Optional[torch.Tensor],
masks: List[np.ndarray],
mask_width: int,
is_subm: bool,
......@@ -1275,14 +1525,53 @@ def implicit_gemm_backward(features: torch.Tensor,
# print(out_bp.mean(), out_bp.max(), out_bp.min())
if features.dtype == torch.int8 or features.dtype == torch.qint8:
raise NotImplementedError("work in progress")
if not out_bp.is_contiguous():
out_bp = out_bp.contiguous()
if not features.is_contiguous():
features = features.contiguous()
if mask_output_fwd is None:
raise ValueError("you must do bwd with net.train()")
assert out_bp.is_contiguous()
assert filters.is_contiguous()
assert features.is_contiguous()
stream = get_current_stream()
if SPCONV_CPP_GEMM and CONV_CPP is not None:
alloc = TorchAllocator(features.device)
features_tv = torch_tensor_to_tv(features)
pair_fwd_tv = torch_tensor_to_tv(pair_fwd)
pair_bwd_tv = torch_tensor_to_tv(pair_bwd)
pair_mask_fwd_splits_tv = [torch_tensor_to_tv(t) for t in pair_mask_fwd_splits]
pair_mask_bwd_splits_tv = [torch_tensor_to_tv(t) for t in pair_mask_bwd_splits]
mask_argsort_fwd_splits_tv = [torch_tensor_to_tv(t) for t in mask_argsort_fwd_splits]
mask_argsort_bwd_splits_tv = [torch_tensor_to_tv(t) for t in mask_argsort_bwd_splits]
filters_tv = torch_tensor_to_tv(filters)
out_bp_tv = torch_tensor_to_tv(out_bp)
mask_output_fwd_tv = torch_tensor_to_tv(mask_output_fwd)
mask = np.concatenate(masks)
mask_tv = tv.from_numpy(mask).clone()
timer_cpp = tv.CUDAKernelTimer(False)
if timer._timer is not None:
timer_cpp = timer._timer
auto_fp32_accum = fp32_accum is None
if fp32_accum is None:
fp32_accum = False
ConvGemmOps.implicit_gemm_backward(alloc, CONV_CPP, features_tv, filters_tv, out_bp_tv, pair_fwd_tv,
pair_bwd_tv, pair_mask_fwd_splits_tv, pair_mask_bwd_splits_tv, mask_argsort_fwd_splits_tv,
mask_argsort_bwd_splits_tv, mask_output_fwd_tv, mask_tv, mask_width, is_subm, stream, timer_cpp, auto_fp32_accum,
fp32_accum)
din = alloc.allocated[AllocKeys.DIn]
dfilters = alloc.allocated[AllocKeys.DFilters]
return din, dfilters
# here filters is KRSC
filters_shape = filters.shape
out_channel = filters.shape[0]
......@@ -1297,7 +1586,6 @@ def implicit_gemm_backward(features: torch.Tensor,
filters = filters.reshape(out_channel, -1, filters.shape[-1])
kv = filters.shape[1]
stream = get_current_stream()
pair_fwd_tv = torch_tensor_to_tv(pair_fwd)
pair_bwd_tv = torch_tensor_to_tv(pair_bwd)
......@@ -1366,12 +1654,12 @@ def implicit_gemm_backward(features: torch.Tensor,
KRSC,
NHWC,
arch,
mask=pair_mask_fwd_split_tvs[0],
mask=mask_output_fwd_tv[0],
mask_argsort=mask_argsort_fwd_split_tvs[0],
indices=pair_fwd_tv,
reverse_mask=False,
mask_filter=masks[0].item(),
mask_output=mask_output_fwd_tv[0],
mask_output=tv.Tensor(),
mask_width=mask_width,
stream=stream)
workspace_size = CONV.query_workspace_size(wgrad_tune_res.algo_desp,
......
"""this file can only be used by spconv developer for now.
the "tensorpc" isn't a open source project.
"""
import tensorpc
from tensorpc.apps.flow.flowapp import App
class TestApp(App):
pass
\ No newline at end of file
......@@ -25,7 +25,7 @@ import spconv.pytorch as spconv
from spconv.utils import Point2VoxelCPU3d
# torch.backends.cudnn.enabled = False
def waymo_data(batch_size=1):
def waymo_data(batch_size=1, num_features=-1):
gen = Point2VoxelCPU3d([0.1, 0.1, 0.1], [-80, -80, -2, 80, 80, 6], 3,
150000, 1)
# gen = VoxelGeneratorV2([0.1, 0.1, 0.1], [-80, -80, -2, 80, 80, 6], 1,
......@@ -35,11 +35,39 @@ def waymo_data(batch_size=1):
print(pc.shape)
voxels_tv, indices_tv, _ = gen.point_to_voxel(tv.from_numpy(pc))
voxels = voxels_tv.numpy().reshape(-1, 3)
if num_features > 0:
voxels = np.zeros((voxels.shape[0], num_features), dtype=voxels.dtype)
coors = indices_tv.numpy()
N = coors.shape[0]
coors = np.concatenate([np.full([N, 1], 0, coors.dtype), coors], axis=1)
return voxels, coors, gen.grid_size
def waymo_data_large(batch_size=1):
gen = Point2VoxelCPU3d([0.1, 0.1, 0.1], [-80, -80, -2, 80, 80, 6], 3,
1200000, 1)
# gen = VoxelGeneratorV2([0.1, 0.1, 0.1], [-80, -80, -2, 80, 80, 6], 1,
# 150000)
data = np.load(Path(__file__).parent / "data" / "benchmark-pc.npz")
pc = np.ascontiguousarray(data["pc"])
pc2 = pc.copy()
pc2[:, 1] += 1
pc3 = pc.copy()
pc3[:, 1] += 2
pc4 = pc.copy()
pc4[:, 1] += 3
pc5 = pc.copy()
pc5[:, 1] += 4
pc = np.concatenate([pc, pc2, pc3, pc4, pc5])
print(pc.shape)
voxels_tv, indices_tv, _ = gen.point_to_voxel(tv.from_numpy(pc))
voxels = voxels_tv.numpy().reshape(-1, 3)
coors = indices_tv.numpy()
N = coors.shape[0]
print("num voxels", N)
coors = np.concatenate([np.full([N, 1], 0, coors.dtype), coors], axis=1)
return voxels, coors, gen.grid_size
class Net(nn.Module):
def __init__(self, shape, algo):
......@@ -61,6 +89,21 @@ class Net(nn.Module):
# # algo=algo),
# spconv.SubMConv3d(32, 64, 3, bias=False, indice_key="c0",
# algo=algo),
# spconv.SubMConv3d(64, 64, 3, bias=False, indice_key="c0",
# algo=algo),
# spconv.SubMConv3d(32,
# 32,
# 3,
# bias=False,
# indice_key="c0",
# algo=algo),
# # nn.BatchNorm1d(32),
# # nn.ReLU(),
# # spconv.SparseConv3d(64, 64, 2, 2, bias=False,
# # algo=algo),
# spconv.SubMConv3d(32, 64, 3, bias=False, indice_key="c0",
# algo=algo),
spconv.SubMConv3d(64,
64,
3,
......@@ -275,7 +318,7 @@ def main():
import pickle
np.random.seed(50051)
torch.manual_seed(50051)
# voxels, coors, spatial_shape = waymo_data()
# voxels, coors, spatial_shape = waymo_data(num_features=128)
# with open("/home/yy/test_spconv.pkl", "wb") as f:
# pickle.dump((voxels, coors, spatial_shape), f)
with open(Path(__file__).parent / "data" / "test_spconv.pkl", "rb") as f:
......@@ -312,7 +355,7 @@ def main():
# MaskImpGemm: 51.0ms
# MaskSplitImpGemm: 41.1ms
# algo = None
net = Net(spatial_shape, algo).to(device).eval().to(dtype)# .train()
net = Net(spatial_shape, algo).to(device).eval().to(dtype).train()
# net.load_state_dict(net.state_dict())
spconv.assign_name_for_sparse_modules(net)
print(coors_th.shape)
......@@ -345,18 +388,18 @@ def main():
print("spconv time", np.mean(times[10:]))
times = []
# for i in range(10):
# out = net(voxels_th, coors_th, 1)
# print("------------")
# torch.cuda.synchronize()
# t = time.time()
# out.features.backward(dout_t)
# torch.cuda.synchronize()
# times.append(time.time() - t)
# # # print((net.grid == -1).float().sum(), net.grid.numel())
# # # print("spconv time", time.time() - t)
# print("spconv bw time", np.mean(times[5:]))
for i in range(10):
out = net(voxels_th, coors_th, 1)
print("------------")
torch.cuda.synchronize()
t = time.time()
out.features.backward(dout_t)
torch.cuda.synchronize()
times.append(time.time() - t)
# # print((net.grid == -1).float().sum(), net.grid.numel())
# # print("spconv time", time.time() - t)
print("spconv bw time", np.mean(times[5:]))
if __name__ == "__main__":
......
......@@ -30,6 +30,7 @@ import numpy as np
import pccm
import torch
import torch.nn.functional as F
from spconv.core_cc.csrc.sparse.convops import GemmTuneResult, ConvTuneResult
from spconv.test_utils import TestCase
from cumm import tensorview as tv
from cumm.conv.bases import NCHW, NHWC, ConvIterAlgo, ConvOpType
......@@ -38,11 +39,11 @@ from cumm.gemm.codeops import div_up
from spconv.core import AlgoHint, ConvAlgo
from spconv.pytorch.conv import expand_nd
from spconv.pytorch import ops
from spconv.algo import CONV, GEMM, BestAlgoByProfile, BestConvAlgoByProfile
from spconv.algo import GEMM, CONV, GEMM_CPP, CONV_CPP, BestAlgoByProfile, BestConvAlgoByProfile, GemmTunerSimple
from spconv.pytorch.cppcore import get_current_stream, torch_tensor_to_tv
from spconv.test_utils import generate_sparse_data, params_grid
import tqdm
from spconv.constants import ALL_WEIGHT_IS_KRSC
from spconv.constants import ALL_WEIGHT_IS_KRSC, SPCONV_CPP_GEMM
assert ALL_WEIGHT_IS_KRSC is True, "we only support KRSC in spconv >= 2.2"
......@@ -67,13 +68,13 @@ class SparseConvTester:
self.dtype_th = NUMPY_DTYPE_TO_TORCH[dtype]
self.K = K
self.C = C
self.ksize = expand_nd(ksize, ndim)
self.stride = expand_nd(stride, ndim)
self.padding = expand_nd(padding, ndim)
self.dilation = expand_nd(dilation, ndim)
self.ksize = expand_nd(ndim, ksize)
self.stride = expand_nd(ndim, stride)
self.padding = expand_nd(ndim, padding, )
self.dilation = expand_nd(ndim, dilation)
self.N = N
self.device = torch.device("cuda:0")
op = expand_nd(0, ndim)
op = expand_nd(ndim, 0)
self.kv: int = np.prod(self.ksize)
self.num_split = 1 if algo == ConvAlgo.MaskImplicitGemm else 2
......@@ -139,7 +140,9 @@ class SparseConvTester:
self.weight_ref = np.ascontiguousarray(self.weight_ref).reshape(-1, K, C)
self.out_ref, self.din_ref, self.dw_ref = self._get_ref_output()
self.dw_ref = np.ascontiguousarray(self.dw_ref.transpose(1, 0, 2).reshape(K, *self.ksize, C))
self.arch = tv.get_compute_capability()
def _get_ref_output(self):
output_ref = np.zeros_like(self.output, dtype=np.float32)
......@@ -174,6 +177,8 @@ class SparseConvTester:
dw_res = out_gather.astype(
np.float32).T @ inp_gather.astype(np.float32)
dw_ref[filter_offset] = dw_res
if self.dtype == np.int8:
output_ref = np.clip(output_ref, -127, 127)
return output_ref, dinput_ref, dw_ref
def get_operands(self, op_type: ConvOpType):
......@@ -220,9 +225,15 @@ def _test_impgemm_conv_cuda(subm: bool):
shapes = [[19, 18, 17]]
batchsizes = [1]
dtypes = [np.float32, np.float16]
dtypes = [np.int8]
test_case = TestCase()
in_channels = [512]
out_channels = [512]
# in_channels = [32]
# out_channels = [32, 48, 64]
in_channels = [32, 47]
out_channels = [32, 48, 62]
in_channels = [32]
out_channels = [32]
multiple_base = 16
if subm:
ksizes = [3]
......@@ -239,7 +250,7 @@ def _test_impgemm_conv_cuda(subm: bool):
ConvAlgo.MaskImplicitGemm,
]
arch = torch.cuda.get_device_capability()
force_nvrtc = False
for shape, bs, C, K, k, s, p, d, algo, dtype in tqdm.tqdm(params_grid(
shapes, batchsizes, in_channels, out_channels, ksizes,
strides, paddings, dilations, algos, dtypes)):
......@@ -259,15 +270,19 @@ def _test_impgemm_conv_cuda(subm: bool):
spk = 1
for op_type in op_types:
inp_tv, weight_tv, output_tv = tester.get_operands(op_type)
avail_desps = CONV.get_all_available(inp_tv, weight_tv, output_tv, NHWC, NHWC, NHWC, arch, op_type, -1)
print(avail_desps)
if SPCONV_CPP_GEMM:
avail_desps = CONV_CPP.get_all_available(inp_tv, weight_tv, output_tv,
NHWC.layout_type.value, NHWC.layout_type.value,
NHWC.layout_type.value, NHWC.interleave, NHWC.interleave, NHWC.interleave, arch, op_type.value, -1, True, False)
else:
avail_desps = CONV.get_all_available(inp_tv, weight_tv, output_tv, NHWC, NHWC, NHWC, arch, op_type, -1)
for desp in avail_desps:
if not subm:
if op_type == ConvOpType.kForward:
output_tv.zero_()
else:
inp_tv.zero_()
# this algo must success
mask_width = desp.tile_shape[0]
# if mask_width != 32:
......@@ -279,9 +294,9 @@ def _test_impgemm_conv_cuda(subm: bool):
mask_output_fwd = mask_width_to_mask_out_fwd[mask_width]
if subm:
if desp.op_type == ConvOpType.kForward.value:
if desp.op_type.value == ConvOpType.kForward.value:
indice_pairs = tester.pair_fwd
elif desp.op_type == ConvOpType.kBackwardInput.value:
elif desp.op_type.value == ConvOpType.kBackwardInput.value:
indice_pairs = tester.pair_bwd
else:
indice_pairs = tester.pair_fwd
......@@ -292,31 +307,57 @@ def _test_impgemm_conv_cuda(subm: bool):
mask_filter = tester.masks[j].item()
reverse_mask = False
if desp.op_type == ConvOpType.kBackwardWeight.value:
if desp.op_type.value == ConvOpType.kBackwardWeight.value:
mask_op = mask_output[j]
else:
mask_op = tester.pair_mask_fwd_splits[j]
if desp.op_type == ConvOpType.kBackwardInput.value:
if desp.op_type.value == ConvOpType.kBackwardInput.value:
reverse_mask = True
mask_output_run = torch_tensor_to_tv(mask_output[j], dtype=tv.uint32)
if desp.op_type == ConvOpType.kBackwardWeight.value:
if desp.op_type.value == ConvOpType.kBackwardWeight.value:
mask_output_run = tv.Tensor()
CONV.run_with_tuned_result(
BestConvAlgoByProfile(desp, spk),
desp.op_type,
inp_tv,
weight_tv,
output_tv,
torch_tensor_to_tv(mask_op, dtype=tv.uint32),
torch_tensor_to_tv(tester.mask_argsort_fwd_splits[j]),
mask_output_run,
torch_tensor_to_tv(indice_pairs),
reverse_mask,
mask_filter=mask_filter,
mask_width=mask_width,
beta=beta,
verbose=False,
)
# force_nvrtc = desp.op_type.value == ConvOpType.kBackwardInput.value
# if force_nvrtc:
# desp.is_nvrtc = True
# print(force_nvrtc, desp.op_type, op_type)
if SPCONV_CPP_GEMM:
CONV_CPP.run_with_tuned_result(
ConvTuneResult(desp, tester.arch, spk),
desp.op_type.value,
inp_tv,
weight_tv,
output_tv,
torch_tensor_to_tv(mask_op, dtype=tv.uint32),
torch_tensor_to_tv(tester.mask_argsort_fwd_splits[j]),
mask_output_run,
torch_tensor_to_tv(indice_pairs),
reverse_mask,
mask_filter=mask_filter,
mask_width=mask_width,
beta=beta,
verbose=False,
force_nvrtc=force_nvrtc,
)
else:
CONV.run_with_tuned_result(
BestConvAlgoByProfile(desp, tester.arch, spk),
desp.op_type.value,
inp_tv,
weight_tv,
output_tv,
torch_tensor_to_tv(mask_op, dtype=tv.uint32),
torch_tensor_to_tv(tester.mask_argsort_fwd_splits[j]),
mask_output_run,
torch_tensor_to_tv(indice_pairs),
reverse_mask,
mask_filter=mask_filter,
mask_width=mask_width,
beta=beta,
verbose=False,
force_nvrtc=force_nvrtc,
)
else:
if mask_width not in mask_width_to_mask_out_bwd:
mask_width_to_mask_out_bwd[mask_width] = torch.zeros([2, div_up(tester.indices_np.shape[0], mask_width)],
......@@ -324,12 +365,12 @@ def _test_impgemm_conv_cuda(subm: bool):
device=tester.device)
mask_output_bwd = mask_width_to_mask_out_bwd[mask_width]
if desp.op_type == ConvOpType.kForward.value:
if desp.op_type.value == ConvOpType.kForward.value:
indice_pairs = tester.pair_fwd # inp -> out
mask_ops = tester.pair_mask_fwd_splits
mask_argsorts = tester.mask_argsort_fwd_splits
mask_output = mask_output_fwd
elif desp.op_type == ConvOpType.kBackwardInput.value:
elif desp.op_type.value == ConvOpType.kBackwardInput.value:
indice_pairs = tester.pair_bwd # out -> inp
mask_ops = tester.pair_mask_bwd_splits
mask_argsorts = tester.mask_argsort_bwd_splits
......@@ -344,27 +385,46 @@ def _test_impgemm_conv_cuda(subm: bool):
beta = 1 if j == 1 else 0
mask_filter = tester.masks[j].item()
reverse_mask = False
if desp.op_type == ConvOpType.kBackwardWeight.value:
if desp.op_type.value == ConvOpType.kBackwardWeight.value:
mask_op = mask_output[j]
else:
mask_op = mask_ops[j]
if SPCONV_CPP_GEMM:
CONV_CPP.run_with_tuned_result(
ConvTuneResult(desp, tester.arch, spk),
desp.op_type.value,
inp_tv,
weight_tv,
output_tv,
torch_tensor_to_tv(mask_op, dtype=tv.uint32),
torch_tensor_to_tv(mask_argsorts[j]),
torch_tensor_to_tv(mask_output[j], dtype=tv.uint32),
torch_tensor_to_tv(indice_pairs),
reverse_mask,
mask_filter=mask_filter,
mask_width=mask_width,
beta=beta,
verbose=False,
)
else:
CONV.run_with_tuned_result(
BestConvAlgoByProfile(desp, tester.arch, spk),
desp.op_type.value,
inp_tv,
weight_tv,
output_tv,
torch_tensor_to_tv(mask_op, dtype=tv.uint32),
torch_tensor_to_tv(mask_argsorts[j]),
torch_tensor_to_tv(mask_output[j], dtype=tv.uint32),
torch_tensor_to_tv(indice_pairs),
reverse_mask,
mask_filter=mask_filter,
mask_width=mask_width,
beta=beta,
verbose=False,
)
CONV.run_with_tuned_result(
BestConvAlgoByProfile(desp, spk),
desp.op_type,
inp_tv,
weight_tv,
output_tv,
torch_tensor_to_tv(mask_op, dtype=tv.uint32),
torch_tensor_to_tv(mask_argsorts[j]),
torch_tensor_to_tv(mask_output[j], dtype=tv.uint32),
torch_tensor_to_tv(indice_pairs),
reverse_mask,
mask_filter=mask_filter,
mask_width=mask_width,
beta=beta,
verbose=False,
)
out_ref = tester.out_ref
din_ref = tester.din_ref
dw_ref = tester.dw_ref
......@@ -374,8 +434,8 @@ def _test_impgemm_conv_cuda(subm: bool):
test_case.assertAllClose(out_ref, out_my, atol=atol, rtol=rtol)
else:
error_norm = np.linalg.norm(out_ref.reshape(-1) - out_my.reshape(-1))
# if (error_norm > 5):
print(f"{desp}, Error={error_norm}")
if (error_norm > 5):
print(f"{desp}, Error={error_norm}")
assert error_norm < 10 * multipler
# print(desp, )
else:
......@@ -389,7 +449,13 @@ def _test_impgemm_conv_cuda(subm: bool):
for spk in [1, 4, 16, 64]:
for mask_width, mask_output in mask_width_to_mask_out_fwd.items():
avail_desps = CONV.get_all_available(inp_tv, weight_tv, output_tv, NHWC, NHWC, NHWC, arch, ConvOpType.kBackwardWeight, mask_width)
if SPCONV_CPP_GEMM:
avail_desps = CONV_CPP.get_all_available(inp_tv, weight_tv, output_tv,
NHWC.layout_type.value, NHWC.layout_type.value,
NHWC.layout_type.value, NHWC.interleave, NHWC.interleave, NHWC.interleave, arch,
ConvOpType.kBackwardWeight.value, mask_width, True, False)
else:
avail_desps = CONV.get_all_available(inp_tv, weight_tv, output_tv, NHWC, NHWC, NHWC, arch, ConvOpType.kBackwardWeight, mask_width)
for desp in avail_desps:
weight_tv.zero_()
if subm:
......@@ -403,8 +469,8 @@ def _test_impgemm_conv_cuda(subm: bool):
# bit_ref = np.bitwise_or.reduce(mask_op_np, axis=0)
# bit_my = mask_filter
CONV.run_with_tuned_result(
BestConvAlgoByProfile(desp, spk),
desp.op_type,
BestConvAlgoByProfile(desp, tester.arch, spk),
desp.op_type.value,
inp_tv,
weight_tv,
output_tv,
......@@ -430,8 +496,8 @@ def _test_impgemm_conv_cuda(subm: bool):
mask_op = mask_output[j]
CONV.run_with_tuned_result(
BestConvAlgoByProfile(desp, spk),
desp.op_type,
BestConvAlgoByProfile(desp, tester.arch, spk),
desp.op_type.value,
inp_tv,
weight_tv,
output_tv,
......@@ -499,6 +565,8 @@ def _test_native_conv_cuda(subm: bool):
pair_out = torch_tensor_to_tv(tester.pair_native)[1]
op_types = [ConvOpType.kForward, ConvOpType.kBackwardInput, ConvOpType.kBackwardWeight]
# op_types = [ConvOpType.kForward]
indice_pair_num_cpu = tester.indice_num_per_loc_np
spk = 1
......@@ -517,9 +585,11 @@ def _test_native_conv_cuda(subm: bool):
a = inp_tv
c = output_tv
b = weight_tv.select(1, tester.kv // 2)
if SPCONV_CPP_GEMM:
avail_desps = GEMM_CPP.get_all_available(a, b, c, False, True, False, arch, ShuffleStrideType.ShuffleAC.value)
else:
avail_desps = GEMM.get_all_available(a, b, c, False, True, False, arch, ShuffleStrideType.ShuffleAC)
avail_desps = GEMM.get_all_available(a, b, c, False, True, False, arch, ShuffleStrideType.ShuffleAC)
for desp in avail_desps:
if subm:
torch.mm(inp_th, weight_th[:, tester.kv // 2].T, out=output_th)
......@@ -538,22 +608,42 @@ def _test_native_conv_cuda(subm: bool):
b = weight_tv.select(1, i)
# inp @ filter.T, NC @ KC
beta = 1.0 if inited else 0.0
GEMM.run_with_tuned_result(
BestAlgoByProfile(desp, 1),
a,
b,
c,
False,
True,
False,
arch=arch,
stream=stream,
shuffle_type=ShuffleStrideType.ShuffleAC,
a_inds=inp_indices,
c_inds=out_indices,
hint=AlgoHint.Fowrard.value,
alpha=1.0,
beta=beta)
if SPCONV_CPP_GEMM:
GEMM_CPP.run_with_tuned_result(
GemmTuneResult(desp, tester.arch, 1),
a,
b,
c,
False,
True,
False,
arch=arch,
stream_int=stream,
shuffle_type=ShuffleStrideType.ShuffleAC.value,
a_inds=inp_indices,
b_inds=tv.Tensor(),
c_inds=out_indices,
hint=AlgoHint.Fowrard.value,
alpha=1.0,
beta=beta)
else:
GEMM.run_with_tuned_result(
BestAlgoByProfile(desp, tester.arch, 1),
a,
b,
c,
False,
True,
False,
arch=arch,
stream=stream,
shuffle_type=ShuffleStrideType.ShuffleAC,
a_inds=inp_indices,
c_inds=out_indices,
hint=AlgoHint.Fowrard.value,
alpha=1.0,
beta=beta)
inited = True
out_my = output_tv.cpu().numpy()
if dtype != np.float16:
......@@ -570,7 +660,11 @@ def _test_native_conv_cuda(subm: bool):
a = output_tv
b = weight_tv.select(1, tester.kv // 2)
c = inp_tv
avail_desps = GEMM.get_all_available(a, b, c, False, False, False, arch, ShuffleStrideType.ShuffleAC)
if SPCONV_CPP_GEMM:
avail_desps = GEMM_CPP.get_all_available(a, b, c, False, False, False, arch, ShuffleStrideType.ShuffleAC.value)
else:
avail_desps = GEMM.get_all_available(a, b, c, False, False, False, arch, ShuffleStrideType.ShuffleAC)
for desp in avail_desps:
if subm:
torch.mm(output_th, weight_th[:, tester.kv // 2], out=inp_th)
......@@ -589,22 +683,42 @@ def _test_native_conv_cuda(subm: bool):
b = weight_tv.select(1, i)
# inp @ filter.T, NC @ KC
beta = 1.0 if inited else 0.0
GEMM.run_with_tuned_result(
BestAlgoByProfile(desp, 1),
a,
b,
c,
False,
False,
False,
arch=arch,
stream=stream,
shuffle_type=ShuffleStrideType.ShuffleAC,
a_inds=out_indices,
c_inds=inp_indices,
hint=AlgoHint.Fowrard.value,
alpha=1.0,
beta=beta)
if SPCONV_CPP_GEMM:
GEMM_CPP.run_with_tuned_result(
GemmTuneResult(desp, tester.arch, 1),
a,
b,
c,
False,
False,
False,
arch=arch,
stream_int=stream,
shuffle_type=ShuffleStrideType.ShuffleAC.value,
a_inds=out_indices,
b_inds=tv.Tensor(),
c_inds=inp_indices,
hint=AlgoHint.Fowrard.value,
alpha=1.0,
beta=beta)
else:
GEMM.run_with_tuned_result(
BestAlgoByProfile(desp, tester.arch, 1),
a,
b,
c,
False,
False,
False,
arch=arch,
stream=stream,
shuffle_type=ShuffleStrideType.ShuffleAC,
a_inds=out_indices,
c_inds=inp_indices,
hint=AlgoHint.Fowrard.value,
alpha=1.0,
beta=beta)
inited = True
din_my = inp_tv.cpu().numpy()
if dtype != np.float16:
......@@ -616,13 +730,18 @@ def _test_native_conv_cuda(subm: bool):
else:
error_norm = np.linalg.norm(din_ref.reshape(-1) - din_my.reshape(-1))
assert error_norm < 10 * multipler
else:
a = output_tv
b = inp_tv
c = weight_tv.select(1, tester.kv // 2)
avail_desps = GEMM.get_all_available(a, b, c, True, False, False, arch, ShuffleStrideType.ShuffleAB)
if SPCONV_CPP_GEMM:
avail_desps = GEMM_CPP.get_all_available(a, b, c, True, False, False, arch, ShuffleStrideType.ShuffleAB.value)
else:
avail_desps = GEMM.get_all_available(a, b, c, True, False, False, arch, ShuffleStrideType.ShuffleAB)
for desp in avail_desps:
# print(desp, C, K, k, s, p, d)
# desp.is_nvrtc = True
inited = subm
weight_tv.zero_()
if subm:
......@@ -640,42 +759,57 @@ def _test_native_conv_cuda(subm: bool):
out_indices = pair_out[i].slice_first_axis(0, nhot)
a_inds = out_indices
b_inds = inp_indices
if SPCONV_CPP_GEMM:
GEMM_CPP.run_with_tuned_result(
GemmTuneResult(desp, tester.arch, 32),
a,
b,
weight_tv.select(1, i),
True,
False,
False,
arch=arch,
stream_int=stream,
shuffle_type=ShuffleStrideType.ShuffleAB.value,
a_inds=a_inds,
b_inds=b_inds,
c_inds=tv.Tensor(),
hint=AlgoHint.BackwardWeight.value,
alpha=1.0,
beta=beta)
else:
GEMM.run_with_tuned_result(BestAlgoByProfile(desp, tester.arch, 32),
a,
b,
weight_tv.select(1, i),
True,
False,
False,
arch=arch,
stream=stream,
shuffle_type=ShuffleStrideType.ShuffleAB,
a_inds=a_inds,
b_inds=b_inds,
hint=AlgoHint.BackwardWeight.value,
alpha=1.0,
beta=beta)
GEMM.run_with_tuned_result(BestAlgoByProfile(desp, 32),
a,
b,
weight_tv.select(1, i),
True,
False,
False,
arch=arch,
stream=stream,
shuffle_type=ShuffleStrideType.ShuffleAB,
a_inds=a_inds,
b_inds=b_inds,
hint=AlgoHint.BackwardWeight.value,
alpha=1.0,
beta=beta)
dw_my = weight_tv.cpu().numpy()
if dtype != np.float16:
error_norm = np.linalg.norm(dw_ref.reshape(-1) - dw_my.reshape(-1))
assert error_norm < 1 * multipler
# test_case.assertAllClose(dw_ref, dw_my, atol=atol, rtol=rtol)
# print(desp, error_norm)
assert error_norm < 1 * multipler, f"{desp}, {error_norm}"
else:
error_norm = np.linalg.norm(dw_ref.reshape(-1) - dw_my.reshape(-1))
# print(desp, error_norm)
assert error_norm < 10 * multipler
assert error_norm < 10 * multipler, f"{desp}, {error_norm}"
def test_all_algo_unit():
# for i in range(5):
_test_impgemm_conv_cuda(True)
# _test_impgemm_conv_cuda(False)
# _test_native_conv_cuda(True)
# _test_native_conv_cuda(False)
_test_impgemm_conv_cuda(False)
_test_native_conv_cuda(True)
_test_native_conv_cuda(False)
if __name__ == "__main__":
......
......@@ -248,7 +248,7 @@ def test_spconv3d():
ConvAlgo.Native, ConvAlgo.MaskImplicitGemm,
ConvAlgo.MaskSplitImplicitGemm
]
# algos = [ConvAlgo.Native]
algos = [ConvAlgo.Native]
for dev, shape, bs, IC, OC, k, s, p, d, al in params_grid(
devices, shapes, batchsizes, in_channels, out_channels, ksizes,
......@@ -308,7 +308,6 @@ def test_spconv3d():
filters_t = torch.from_numpy(filters).to(device).to(dtype)
net_ref.net[0].weight.data[:] = filters_t.permute(
0, 4, 1, 2, 3).contiguous()
net.net[0].weight.data[:] = filters_t
out_ref = net_ref(features_dense_t)
out = net(features_t, indices_t, bs).dense()
......@@ -529,4 +528,4 @@ def test_spmaxpool3d():
if __name__ == "__main__":
test_spmaxpool3d()
test_spconv3d()
......@@ -222,6 +222,7 @@ class NetLight(nn.Module):
def _test_multi_impl(dtype: torch.dtype):
# TODO pytorch 1.12 don't support cpu half mm, f**k pytorch
# TODO remove or release this when tf32 op is ready
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
......@@ -239,8 +240,6 @@ def _test_multi_impl(dtype: torch.dtype):
np.float32)
coors = np.ascontiguousarray(
sparse_dict["indices"][:, [3, 0, 1, 2]]).astype(np.int32)
device = torch.device("cuda:0")
device_cpu = torch.device("cpu:0")
......@@ -275,17 +274,21 @@ def _test_multi_impl(dtype: torch.dtype):
dout_t = torch.from_numpy(dout).to(device_cpu).to(dtype)
dout_t_cu = torch.from_numpy(dout).to(device).to(dtype)
t = time.time()
print(1, time.time() - t)
out_cpu = net_native_cpu(voxels_th, coors_th, 1).dense()
out_cpu.backward(dout_t)
if dtype != torch.float16:
out_cpu.backward(dout_t)
out = net_native_gpu(voxels_th_cuda, coors_th_cuda, 1).dense()
print(2, time.time() - t)
out.backward(dout_t_cu)
out_imp = net_imp_gpu(voxels_th_cuda, coors_th_cuda, 1).dense()
print(3, time.time() - t)
out_imp.backward(dout_t_cu)
out_simp = net_simp_gpu(voxels_th_cuda, coors_th_cuda, 1).dense()
print(4, time.time() - t)
out_simp.backward(dout_t_cu)
with torch.no_grad():
......@@ -297,6 +300,7 @@ def _test_multi_impl(dtype: torch.dtype):
error_native = torch.linalg.norm(dense_cpu - dense_native).cpu().item()
error_imp = torch.linalg.norm(dense_cpu - dense_imp).cpu().item()
error_simp = torch.linalg.norm(dense_cpu - dense_simp).cpu().item()
print(5, time.time() - t)
print("error_native", error_native)
print("error_imp", error_imp)
......@@ -320,15 +324,15 @@ def _test_multi_impl(dtype: torch.dtype):
native_w = native_params[k]
imp_w = imp_params[k]
simp_w = simp_params[k]
cpu_w_grad = cpu_w.grad.detach().cuda()
native_w_grad = native_w.grad.detach()
imp_w_grad = imp_w.grad.detach()
simp_w_grad = simp_w.grad.detach()
error_native = torch.linalg.norm(native_w_grad - cpu_w_grad).cpu().item()
if dtype != torch.float16:
cpu_w_grad = cpu_w.grad.detach().cuda()
error_native = torch.linalg.norm(native_w_grad - cpu_w_grad).cpu().item()
error_imp = torch.linalg.norm(native_w_grad - imp_w_grad).cpu().item()
error_simp = torch.linalg.norm(native_w_grad - simp_w_grad).cpu().item()
print(k, error_native, error_imp, error_simp)
print(k, error_imp, error_simp)
assert error_imp < 1
assert error_simp < 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment