Commit 899008fa authored by yan.yan's avatar yan.yan
Browse files

working on c++ only

parent f78575ea
...@@ -248,8 +248,7 @@ class ConvOutLocIter(pccm.ParameterizedClass): ...@@ -248,8 +248,7 @@ class ConvOutLocIter(pccm.ParameterizedClass):
class SparseConvIndicesKernel(pccm.ParameterizedClass): class SparseConvIndicesKernel(pccm.ParameterizedClass):
def __init__(self, problem: ConvProblem, dtype_indices: dtypes.DType): def __init__(self, problem: ConvProblem, dtype_indices: dtypes.DType):
super().__init__() super().__init__()
self.add_dependency(TensorView, TensorViewKernel, TensorViewHashKernel, self.add_dependency(TensorView, TensorViewKernel, TensorViewHashKernel)
ThrustLib)
self.loc_iter = ConvOutLocIter(problem) self.loc_iter = ConvOutLocIter(problem)
self.add_param_class("spinds", self.loc_iter, "ConvLocIter") self.add_param_class("spinds", self.loc_iter, "ConvLocIter")
self.add_param_class("spinds", problem, "ConvProblem") self.add_param_class("spinds", problem, "ConvProblem")
...@@ -271,7 +270,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -271,7 +270,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
code.arg("indice_pairs", code.arg("indice_pairs",
f"{self.dtype_indices}*") # [2, kernelProd, MaxSize] f"{self.dtype_indices}*") # [2, kernelProd, MaxSize]
code.arg("indice_pairs_for_uniq", code.arg("indice_pairs_for_uniq",
f"TIndiceUniq*") # [2, kernelProd, MaxSize] f"TIndiceUniq*") # [kernelProd * MaxSize + 1]
code.arg("indice_num_per_loc", f"int*") # [kernelProd] code.arg("indice_num_per_loc", f"int*") # [kernelProd]
code.arg("num_indices_in", "int") code.arg("num_indices_in", "int")
...@@ -340,7 +339,6 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -340,7 +339,6 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
code.arg("indice_pairs_out_part", f"int*") # [kernelProd, MaxSize] code.arg("indice_pairs_out_part", f"int*") # [kernelProd, MaxSize]
code.arg("num_indices_in", "int") code.arg("num_indices_in", "int")
code.arg("indices_pair_size", "int") code.arg("indices_pair_size", "int")
# TODO use block instead of filter_offset?
code.raw(f""" code.raw(f"""
int filter_offset = blockIdx.y; int filter_offset = blockIdx.y;
auto indice_pairs_out_part_filter = indice_pairs_out_part + filter_offset * indices_pair_size; auto indice_pairs_out_part_filter = indice_pairs_out_part + filter_offset * indices_pair_size;
...@@ -358,6 +356,46 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -358,6 +356,46 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
""") """)
return code return code
@pccm.cuda.cuda_global_function
def calc_conv_indices_stage2_bounded(self):
"""if we bound output indices, some pair may be invalid,
so we need to atomicAdd and assign again.
here we will use indice_pairs_uniq as temp memory of
indice_pairs_in_part.
"""
code = pccm.FunctionCode()
code.targ("TTable")
code.arg("table", f"TTable") # [N, ndim + 1]
code.arg("indice_pairs_uniq_before_sort", f"const typename TTable::key_type*") # [kernelProd, MaxSize]
code.arg("indice_pairs_in_part_temp", f"const int*") # [kernelProd, MaxSize]
code.arg("indice_pairs_in_part", f"int*") # [kernelProd, MaxSize]
code.arg("indice_pairs_out_part", f"int*") # [kernelProd, MaxSize]
code.arg("indice_num_per_loc", f"int*") # [kernelProd]
code.arg("num_indices_in", "int")
code.arg("indices_pair_size", "int")
code.raw(f"""
int filter_offset = blockIdx.y;
auto indice_pairs_in_part_filter = indice_pairs_in_part + filter_offset * indices_pair_size;
auto indice_pairs_out_part_filter = indice_pairs_out_part + filter_offset * indices_pair_size;
auto indice_pairs_in_part_temp_filter = indice_pairs_in_part_temp + filter_offset * indices_pair_size;
auto indice_pairs_uniq_before_sort_filter = indice_pairs_uniq_before_sort + filter_offset * indices_pair_size;
for (int i : tv::KernelLoopX<int>(num_indices_in)) {{
{self.dtype_indices} output_coord_offset = indice_pairs_uniq_before_sort_filter[i];
if (output_coord_offset != std::numeric_limits<typename TTable::key_type>::max()){{
auto table_offset = table.lookup_offset(output_coord_offset);
if (table_offset != -1){{
int old_num = tv::cuda::atomicAggInc(indice_num_per_loc + filter_offset);
indice_pairs_in_part_filter[old_num] = indice_pairs_in_part_temp_filter[i];
indice_pairs_out_part_filter[old_num] = table.value_ptr()[table_offset];
}}
}}
}}
""")
return code
@pccm.cuda.cuda_global_function @pccm.cuda.cuda_global_function
def calc_conv_indices_stage1_mask(self): def calc_conv_indices_stage1_mask(self):
code = pccm.FunctionCode() code = pccm.FunctionCode()
...@@ -369,7 +407,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -369,7 +407,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
code.arg("indice_pairs_bwd", code.arg("indice_pairs_bwd",
f"{self.dtype_indices}*") # [kernelProd, MaxSize] f"{self.dtype_indices}*") # [kernelProd, MaxSize]
code.arg("indice_pairs_for_uniq", code.arg("indice_pairs_for_uniq",
f"TIndiceUniq*") # [2, kernelProd, MaxSize] f"TIndiceUniq*") # [kernelProd * MaxSize + 1]
code.arg("indice_num_per_loc", f"int*") # [kernelProd] code.arg("indice_num_per_loc", f"int*") # [kernelProd]
code.arg("num_indices_in", "int") code.arg("num_indices_in", "int")
...@@ -397,6 +435,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -397,6 +435,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
// indice_pairs[filter_offset_mul_indices_pair_size + old_num] = i; // indice_pairs[filter_offset_mul_indices_pair_size + old_num] = i;
// indice_pairs_bwd[filter_offset_mul_indices_pair_size + input_index] = output_coord_offset; // indice_pairs_bwd[filter_offset_mul_indices_pair_size + input_index] = output_coord_offset;
// indice_pairs_for_uniq[filter_offset_mul_indices_pair_size + old_num] = output_coord_offset; // indice_pairs_for_uniq[filter_offset_mul_indices_pair_size + old_num] = output_coord_offset;
indice_pairs_for_uniq[filter_offset_mul_indices_pair_size + input_index] = output_coord_offset; indice_pairs_for_uniq[filter_offset_mul_indices_pair_size + input_index] = output_coord_offset;
// }} // }}
}} }}
...@@ -420,7 +459,6 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -420,7 +459,6 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
code.arg("num_indices_in", "int") code.arg("num_indices_in", "int")
code.arg("num_indices_out", "int") code.arg("num_indices_out", "int")
# TODO use block instead of filter_offset?
code.raw(f""" code.raw(f"""
int filter_offset = blockIdx.y; int filter_offset = blockIdx.y;
uint32_t filter_mask_fwd = (1u << (filter_offset)); uint32_t filter_mask_fwd = (1u << (filter_offset));
...@@ -458,7 +496,6 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -458,7 +496,6 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
code.arg("num_indices_in", "int") code.arg("num_indices_in", "int")
code.arg("kv", "int") code.arg("kv", "int")
# TODO use block instead of filter_offset?
code.raw(f""" code.raw(f"""
for (int input_index : tv::KernelLoopX<int>(num_indices_in)) {{ for (int input_index : tv::KernelLoopX<int>(num_indices_in)) {{
uint32_t mask = 0; uint32_t mask = 0;
...@@ -749,18 +786,13 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -749,18 +786,13 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
indice_pairs_uniq.data_ptr<T>(), indice_num_per_loc.data_ptr<int>(), indices.dim(0), indice_pairs_uniq.data_ptr<T>(), indice_num_per_loc.data_ptr<int>(), indices.dim(0),
indice_pairs.dim(2), kv, transposed); indice_pairs.dim(2), kv, transposed);
}}); }});
// thrust::device_ptr<{self.dtype_indices}> ptr_tr(indice_pairs_uniq.data_ptr<{self.dtype_indices}>());
// auto thrust_ctx = thrust::cuda::par.on(reinterpret_cast<cudaStream_t>(stream_int));
// thrust::sort(thrust_ctx, ptr_tr, ptr_tr + uniq_size);
// auto new_end = thrust::unique(thrust_ctx, ptr_tr, ptr_tr + uniq_size);
// auto num_out_act = new_end - ptr_tr - 1;
// return num_out_act;
""") """)
return code # .ret("int") return code # .ret("int")
@pccm.cuda.static_function @pccm.cuda.static_function
def generate_conv_inds_stage1_5(self): def generate_conv_inds_stage1_5(self):
code = pccm.FunctionCode() code = pccm.FunctionCode()
code.add_dependency(ThrustLib)
code.arg("indice_pairs_uniq", "tv::Tensor") code.arg("indice_pairs_uniq", "tv::Tensor")
code.arg("uniq_size", "int64_t") code.arg("uniq_size", "int64_t")
code.arg("stream_int", f"std::uintptr_t", "0") code.arg("stream_int", f"std::uintptr_t", "0")
...@@ -783,6 +815,8 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -783,6 +815,8 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
code = pccm.FunctionCode() code = pccm.FunctionCode()
code.arg("indices, hashdata_k, hashdata_v", "tv::Tensor") code.arg("indices, hashdata_k, hashdata_v", "tv::Tensor")
code.arg("indice_pairs, indice_pairs_uniq, indice_pairs_uniq_before_sort, out_inds", "tv::Tensor") code.arg("indice_pairs, indice_pairs_uniq, indice_pairs_uniq_before_sort, out_inds", "tv::Tensor")
code.arg("indice_num_per_loc", "tv::Tensor")
code.arg("num_out_act", "int") code.arg("num_out_act", "int")
code.arg("batch_size", "int") code.arg("batch_size", "int")
code.arg("output_dims, input_dims", f"tv::array<int, {self.ndim}>") code.arg("output_dims, input_dims", f"tv::array<int, {self.ndim}>")
...@@ -790,6 +824,8 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -790,6 +824,8 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
f"tv::array<int, {self.ndim}>") f"tv::array<int, {self.ndim}>")
code.arg("transposed", f"bool", "false") code.arg("transposed", f"bool", "false")
code.arg("stream_int", f"std::uintptr_t", "0") code.arg("stream_int", f"std::uintptr_t", "0")
code.arg("use_bound_algo", "bool", "false")
code.raw(f""" code.raw(f"""
auto custream = reinterpret_cast<cudaStream_t>(stream_int); auto custream = reinterpret_cast<cudaStream_t>(stream_int);
// TODO stream // TODO stream
...@@ -798,6 +834,8 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -798,6 +834,8 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
TV_ASSERT_RT_ERR(kv == indice_pairs.dim(1), "error"); TV_ASSERT_RT_ERR(kv == indice_pairs.dim(1), "error");
TV_ASSERT_RT_ERR(hashdata_k.dtype() == indice_pairs_uniq.dtype(), "error"); TV_ASSERT_RT_ERR(hashdata_k.dtype() == indice_pairs_uniq.dtype(), "error");
TV_ASSERT_RT_ERR(hashdata_v.dtype() == tv::int32, "error"); TV_ASSERT_RT_ERR(hashdata_v.dtype() == tv::int32, "error");
auto ctx = tv::Context();
ctx.set_cuda_stream(custream);
// indice_pairs: [2, kv, indices.dim(0)] // indice_pairs: [2, kv, indices.dim(0)]
// indice_pairs_uniq: [indice_pairs.size() / 2 + 1] // indice_pairs_uniq: [indice_pairs.size() / 2 + 1]
...@@ -805,6 +843,10 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -805,6 +843,10 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
// auto timer = tv::CudaContextTimer<>(); // auto timer = tv::CudaContextTimer<>();
int64_t uniq_size = indice_pairs.size() / 2 + 1; int64_t uniq_size = indice_pairs.size() / 2 + 1;
TV_ASSERT_RT_ERR(indice_pairs_uniq.dim(0) >= num_out_act, "error"); TV_ASSERT_RT_ERR(indice_pairs_uniq.dim(0) >= num_out_act, "error");
// int num_out_act_bounded = num_out_act;
// if (num_out_act_bound > 0){{
// num_out_act_bounded = std::min(num_out_act_bounded, num_out_act);
// }}
TV_ASSERT_RT_ERR(out_inds.dim(0) >= num_out_act && out_inds.dim(1) == {self.ndim + 1}, "error"); TV_ASSERT_RT_ERR(out_inds.dim(0) >= num_out_act && out_inds.dim(1) == {self.ndim + 1}, "error");
tv::cuda::Launch launcher_num_act_in(indices.dim(0), custream); tv::cuda::Launch launcher_num_act_in(indices.dim(0), custream);
launcher_num_act_in.blocks.y = kv; launcher_num_act_in.blocks.y = kv;
...@@ -827,11 +869,29 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -827,11 +869,29 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
lanucher_build_hash(build_conv_hash_table<table_t>, hash, lanucher_build_hash(build_conv_hash_table<table_t>, hash,
out_inds.data_ptr<int>(), indice_pairs_uniq.data_ptr<const K>(), out_inds.data_ptr<int>(), indice_pairs_uniq.data_ptr<const K>(),
loc_iter.layout_npq, num_out_act); loc_iter.layout_npq, num_out_act);
launcher_num_act_in(calc_conv_indices_stage2<table_t>, hash, if (!use_bound_algo){{
indice_pairs_uniq_before_sort.data_ptr<const K>(), launcher_num_act_in(calc_conv_indices_stage2<table_t>, hash,
indice_pairs[1].data_ptr<int>(), indice_pairs_uniq_before_sort.data_ptr<const K>(),
indices.dim(0), indice_pairs[1].data_ptr<int>(),
indice_pairs.dim(2)); indices.dim(0),
indice_pairs.dim(2));
}}else{{
indice_num_per_loc.zero_(ctx);
// copy previous pair in to indice_pairs_uniq
// we need to ensure size of indice_pairs_uniq larger than pair in
TV_ASSERT_RT_ERR({pccm.literal(self.dtype_indices == dtypes.int32)}, "error");
tv::Tensor indice_pairs_in_temp = tv::from_blob(indice_pairs_uniq.raw_data(), {{indice_pairs.dim(1), indice_pairs.dim(2)}},
indice_pairs.dtype(), indice_pairs.device());
indice_pairs_in_temp.copy_(indice_pairs[0].view(-1), ctx);
launcher_num_act_in(calc_conv_indices_stage2_bounded<table_t>, hash,
indice_pairs_uniq_before_sort.data_ptr<const K>(),
indice_pairs_in_temp.data_ptr<const int>(),
indice_pairs[0].data_ptr<int>(),
indice_pairs[1].data_ptr<int>(),
indice_num_per_loc.data_ptr<int>(),
indices.dim(0),
indice_pairs.dim(2));
}}
}}); }});
return num_out_act; return num_out_act;
""") """)
...@@ -899,6 +959,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -899,6 +959,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
f"tv::array<int, {self.ndim}>") f"tv::array<int, {self.ndim}>")
code.arg("transposed", f"bool", "false") code.arg("transposed", f"bool", "false")
code.arg("stream_int", f"std::uintptr_t", "0") code.arg("stream_int", f"std::uintptr_t", "0")
code.raw(f""" code.raw(f"""
auto custream = reinterpret_cast<cudaStream_t>(stream_int); auto custream = reinterpret_cast<cudaStream_t>(stream_int);
// TODO stream // TODO stream
......
import os
import fire
from cumm.common import CompileInfo
from cumm.conv.main import ConvMainUnitTest
from cumm.gemm.main import GemmMainUnitTest
from pccm.builder.pybind import gen_cmake
from spconv.core import (IMPLGEMM_SIMT_PARAMS, IMPLGEMM_TURING_PARAMS,
IMPLGEMM_VOLTA_PARAMS, SHUFFLE_SIMT_PARAMS,
SHUFFLE_TURING_PARAMS, SHUFFLE_VOLTA_PARAMS)
from spconv.csrc.hash.core import HashTable
from spconv.csrc.sparse.all import SpconvOps
from spconv.csrc.sparse.alloc import ExternalAllocator
from spconv.csrc.sparse.convops import (ConvGemmOps, ConvTunerSimple,
ExternalSpconvMatmul, GemmTunerSimple,
SimpleExternalSpconvMatmul)
from spconv.csrc.utils import BoxOps
def main(include: str,
src: str,
libname: str = "spconv",
prefix: str = "spconvlib"):
all_shuffle = SHUFFLE_SIMT_PARAMS + SHUFFLE_VOLTA_PARAMS + SHUFFLE_TURING_PARAMS
all_shuffle = list(filter(lambda x: not x.is_nvrtc, all_shuffle))
cu = GemmMainUnitTest(all_shuffle)
cu.namespace = "cumm.gemm.main"
all_imp = (IMPLGEMM_SIMT_PARAMS + IMPLGEMM_VOLTA_PARAMS +
IMPLGEMM_TURING_PARAMS)
# all_imp = IMPLGEMM_SIMT_PARAMS
all_imp = list(filter(lambda x: not x.is_nvrtc, all_imp))
convcu = ConvMainUnitTest(all_imp)
convcu.namespace = "cumm.conv.main"
gemmtuner = GemmTunerSimple(cu)
gemmtuner.namespace = "csrc.sparse.convops.gemmops"
convtuner = ConvTunerSimple(convcu)
convtuner.namespace = "csrc.sparse.convops.convops"
convops = ConvGemmOps(gemmtuner, convtuner)
convops.namespace = "csrc.sparse.convops.spops"
cus = [
cu,
convcu,
gemmtuner,
convtuner,
convops,
SpconvOps(),
BoxOps(),
HashTable(),
CompileInfo(),
ExternalAllocator(),
ExternalSpconvMatmul(),
SimpleExternalSpconvMatmul(),
]
gen_cmake(libname, cus, include, src, namespace_prefix=prefix)
if __name__ == "__main__":
fire.Fire(main)
...@@ -38,20 +38,6 @@ from torch.nn.init import calculate_gain ...@@ -38,20 +38,6 @@ from torch.nn.init import calculate_gain
FILTER_HWIO = False FILTER_HWIO = False
def expand_nd(val: Union[int, List[int], Tuple[int, ...]], ndim: int) -> List[int]:
if isinstance(val, int):
val = [val] * ndim
elif isinstance(val, list):
assert len(val) == ndim
elif isinstance(val, tuple):
assert len(val) == ndim
return [*val]
else:
raise NotImplementedError
return val
class SparseConvolution(SparseModule): class SparseConvolution(SparseModule):
__constants__ = [ __constants__ = [
'stride', 'padding', 'dilation', 'groups', 'bias', 'subm', 'inverse', 'stride', 'padding', 'dilation', 'groups', 'bias', 'subm', 'inverse',
...@@ -82,6 +68,7 @@ class SparseConvolution(SparseModule): ...@@ -82,6 +68,7 @@ class SparseConvolution(SparseModule):
self.in_channels = in_channels self.in_channels = in_channels
self.out_channels = out_channels self.out_channels = out_channels
self.kernel_size = expand_nd(ndim, kernel_size) self.kernel_size = expand_nd(ndim, kernel_size)
self.stride = expand_nd(ndim, stride) self.stride = expand_nd(ndim, stride)
kv = int(np.prod(self.kernel_size)) kv = int(np.prod(self.kernel_size))
kv_stride = int(np.prod(self.stride)) kv_stride = int(np.prod(self.stride))
...@@ -130,7 +117,6 @@ class SparseConvolution(SparseModule): ...@@ -130,7 +117,6 @@ class SparseConvolution(SparseModule):
# KRSC # KRSC
self.weight = Parameter( self.weight = Parameter(
torch.Tensor(out_channels, *self.kernel_size, in_channels)) torch.Tensor(out_channels, *self.kernel_size, in_channels))
if bias: if bias:
self.bias = Parameter(torch.Tensor(out_channels)) self.bias = Parameter(torch.Tensor(out_channels))
else: else:
......
...@@ -15,9 +15,13 @@ ...@@ -15,9 +15,13 @@
from cumm import tensorview as tv from cumm import tensorview as tv
import torch import torch
from typing import Dict, Optional, List, Union from typing import Dict, Optional, List, Union
from spconv.constants import AllocKeys
from spconv.cppconstants import COMPILED_CUDA_ARCHS from spconv.cppconstants import COMPILED_CUDA_ARCHS
import sys import sys
from spconv.core_cc.csrc.sparse.alloc import ExternalAllocator from spconv.core_cc.csrc.sparse.alloc import ExternalAllocator
from spconv.core_cc.csrc.sparse.convops import ExternalSpconvMatmul
import numpy as np
_TORCH_DTYPE_TO_TV = { _TORCH_DTYPE_TO_TV = {
torch.float32: tv.float32, torch.float32: tv.float32,
...@@ -31,8 +35,16 @@ _TORCH_DTYPE_TO_TV = { ...@@ -31,8 +35,16 @@ _TORCH_DTYPE_TO_TV = {
} }
_TV_DTYPE_TO_TORCH = {v: k for k, v in _TORCH_DTYPE_TO_TV.items()} _TV_DTYPE_TO_TORCH = {v: k for k, v in _TORCH_DTYPE_TO_TV.items()}
_TORCH_UINT_WORKAROUNDS = {tv.uint32: tv.int32, tv.uint16: tv.int16, tv.uint64: tv.int64} _TORCH_UINT_WORKAROUNDS = {
_ALL_INTS = {tv.int32, tv.int16, tv.int8, tv.int64, tv.uint64, tv.uint8, tv.uint32, tv.uint16} tv.uint32: tv.int32,
tv.uint16: tv.int16,
tv.uint64: tv.int64
}
_ALL_INTS = {
tv.int32, tv.int16, tv.int8, tv.int64, tv.uint64, tv.uint8, tv.uint32,
tv.uint16
}
def torch_tensor_to_tv(ten: torch.Tensor, def torch_tensor_to_tv(ten: torch.Tensor,
dtype: Optional[int] = None, dtype: Optional[int] = None,
...@@ -62,6 +74,7 @@ def torch_tensor_to_tv(ten: torch.Tensor, ...@@ -62,6 +74,7 @@ def torch_tensor_to_tv(ten: torch.Tensor,
return tv.from_blob(ptr, shape, dtype, tv_device) return tv.from_blob(ptr, shape, dtype, tv_device)
return tv.from_blob_strided(ptr, shape, stride, dtype, tv_device) return tv.from_blob_strided(ptr, shape, stride, dtype, tv_device)
def torch_tensors_to_tv(*tens: torch.Tensor): def torch_tensors_to_tv(*tens: torch.Tensor):
return (torch_tensor_to_tv(t) for t in tens) return (torch_tensor_to_tv(t) for t in tens)
...@@ -69,28 +82,35 @@ def torch_tensors_to_tv(*tens: torch.Tensor): ...@@ -69,28 +82,35 @@ def torch_tensors_to_tv(*tens: torch.Tensor):
def get_current_stream(): def get_current_stream():
return torch.cuda.current_stream().cuda_stream return torch.cuda.current_stream().cuda_stream
def get_arch(): def get_arch():
arch = torch.cuda.get_device_capability() arch = torch.cuda.get_device_capability()
if arch not in COMPILED_CUDA_ARCHS: if arch not in COMPILED_CUDA_ARCHS:
print(f"[WARNING]your gpu arch {arch} isn't compiled in prebuilt, " print(
f"may cause invalid device function. " f"[WARNING]your gpu arch {arch} isn't compiled in prebuilt, "
f"available: {COMPILED_CUDA_ARCHS}", file=sys.stderr) f"may cause invalid device function. "
f"available: {COMPILED_CUDA_ARCHS}",
file=sys.stderr)
return arch return arch
class TorchAllocator(ExternalAllocator): class TorchAllocator(ExternalAllocator):
def __init__(self, gpudevice: torch.device) -> None: def __init__(self, gpudevice: torch.device) -> None:
super().__init__() super().__init__()
self.gpudevice = gpudevice self.gpudevice = gpudevice
self.cpudevice = torch.device("cpu:0") self.cpudevice = torch.device("cpu")
self.allocated: Dict[Union[str, int], torch.Tensor] = {} self.allocated: Dict[Union[str, int], torch.Tensor] = {}
def zeros(self, name: str, shape: List[int], dtype: int, device: int) -> tv.Tensor: def zeros(self, name: str, shape: List[int], dtype: int,
device: int, is_temp_memory: bool = False, stream: int = 0) -> tv.Tensor:
# TODO free memory by name if its already free by pointer.
# provide a name if you want to access it after c++ function exit. # provide a name if you want to access it after c++ function exit.
torch_uint_workaround = dtype in _TORCH_UINT_WORKAROUNDS torch_uint_workaround = dtype in _TORCH_UINT_WORKAROUNDS
dtype_bkp = dtype dtype_bkp = dtype
if dtype in _TORCH_UINT_WORKAROUNDS: if dtype in _TORCH_UINT_WORKAROUNDS:
assert name == "", "must be temp memory for uint dtypes" # assert name == "", "must be temp memory for uint dtypes"
dtype = _TORCH_UINT_WORKAROUNDS[dtype] dtype = _TORCH_UINT_WORKAROUNDS[dtype]
th_dtype = _TV_DTYPE_TO_TORCH[dtype] th_dtype = _TV_DTYPE_TO_TORCH[dtype]
if device == -1: if device == -1:
dev = self.cpudevice dev = self.cpudevice
...@@ -99,18 +119,19 @@ class TorchAllocator(ExternalAllocator): ...@@ -99,18 +119,19 @@ class TorchAllocator(ExternalAllocator):
ten = torch.zeros(shape, dtype=th_dtype, device=dev) ten = torch.zeros(shape, dtype=th_dtype, device=dev)
ten_tv = torch_tensor_to_tv(ten) ten_tv = torch_tensor_to_tv(ten)
self.allocated[ten.data_ptr()] = ten self.allocated[ten.data_ptr()] = ten
if name: if name and not is_temp_memory:
self.allocated[name] = ten self.allocated[name] = ten
if torch_uint_workaround: if torch_uint_workaround:
return ten_tv.type_view(dtype_bkp) return ten_tv.type_view(dtype_bkp)
return ten_tv return ten_tv
def empty(self, name: str, shape: List[int], dtype: int, device: int) -> tv.Tensor: def empty(self, name: str, shape: List[int], dtype: int,
device: int, is_temp_memory: bool = False, stream: int = 0) -> tv.Tensor:
torch_uint_workaround = dtype in _TORCH_UINT_WORKAROUNDS torch_uint_workaround = dtype in _TORCH_UINT_WORKAROUNDS
dtype_bkp = dtype dtype_bkp = dtype
if dtype in _TORCH_UINT_WORKAROUNDS: if dtype in _TORCH_UINT_WORKAROUNDS:
assert name == "", "must be temp memory for uint dtypes" # assert name == "", "must be temp memory for uint dtypes"
dtype = _TORCH_UINT_WORKAROUNDS[dtype] dtype = _TORCH_UINT_WORKAROUNDS[dtype]
th_dtype = _TV_DTYPE_TO_TORCH[dtype] th_dtype = _TV_DTYPE_TO_TORCH[dtype]
if device == -1: if device == -1:
dev = self.cpudevice dev = self.cpudevice
...@@ -119,20 +140,21 @@ class TorchAllocator(ExternalAllocator): ...@@ -119,20 +140,21 @@ class TorchAllocator(ExternalAllocator):
ten = torch.empty(shape, dtype=th_dtype, device=dev) ten = torch.empty(shape, dtype=th_dtype, device=dev)
ten_tv = torch_tensor_to_tv(ten) ten_tv = torch_tensor_to_tv(ten)
self.allocated[ten.data_ptr()] = ten self.allocated[ten.data_ptr()] = ten
if name: if name and not is_temp_memory:
self.allocated[name] = ten self.allocated[name] = ten
if torch_uint_workaround: if torch_uint_workaround:
return ten_tv.type_view(dtype_bkp) return ten_tv.type_view(dtype_bkp)
return ten_tv return ten_tv
def full_int(self, name: str, shape: List[int], value: int, dtype: int, device: int) -> tv.Tensor: def full_int(self, name: str, shape: List[int], value: int, dtype: int,
device: int, is_temp_memory: bool = False, stream: int = 0) -> tv.Tensor:
if dtype in _TORCH_UINT_WORKAROUNDS and value < 0: if dtype in _TORCH_UINT_WORKAROUNDS and value < 0:
raise NotImplementedError("you can't use full for unsigned dtypes") raise NotImplementedError("you can't use full for unsigned dtypes")
torch_uint_workaround = dtype in _TORCH_UINT_WORKAROUNDS torch_uint_workaround = dtype in _TORCH_UINT_WORKAROUNDS
dtype_bkp = dtype dtype_bkp = dtype
if dtype in _TORCH_UINT_WORKAROUNDS: if dtype in _TORCH_UINT_WORKAROUNDS:
assert name == "", "must be temp memory for uint dtypes" assert name == "", "must be temp memory for uint dtypes"
dtype = _TORCH_UINT_WORKAROUNDS[dtype] dtype = _TORCH_UINT_WORKAROUNDS[dtype]
th_dtype = _TV_DTYPE_TO_TORCH[dtype] th_dtype = _TV_DTYPE_TO_TORCH[dtype]
if device == -1: if device == -1:
...@@ -142,22 +164,21 @@ class TorchAllocator(ExternalAllocator): ...@@ -142,22 +164,21 @@ class TorchAllocator(ExternalAllocator):
ten = torch.full(shape, value, dtype=th_dtype, device=dev) ten = torch.full(shape, value, dtype=th_dtype, device=dev)
ten_tv = torch_tensor_to_tv(ten) ten_tv = torch_tensor_to_tv(ten)
self.allocated[ten.data_ptr()] = ten self.allocated[ten.data_ptr()] = ten
if name: if name and not is_temp_memory:
self.allocated[name] = ten
if name:
self.allocated[name] = ten self.allocated[name] = ten
if torch_uint_workaround: if torch_uint_workaround:
return ten_tv.type_view(dtype_bkp) return ten_tv.type_view(dtype_bkp)
return ten_tv return ten_tv
def full_float(self, name: str, shape: List[int], value: float, dtype: int, device: int) -> tv.Tensor: def full_float(self, name: str, shape: List[int], value: float, dtype: int,
device: int, is_temp_memory: bool = False, stream: int = 0) -> tv.Tensor:
if dtype in _TORCH_UINT_WORKAROUNDS and value < 0: if dtype in _TORCH_UINT_WORKAROUNDS and value < 0:
raise NotImplementedError("you can't use full for unsigned dtypes") raise NotImplementedError("you can't use full for unsigned dtypes")
torch_uint_workaround = dtype in _TORCH_UINT_WORKAROUNDS torch_uint_workaround = dtype in _TORCH_UINT_WORKAROUNDS
dtype_bkp = dtype dtype_bkp = dtype
if dtype in _TORCH_UINT_WORKAROUNDS: if dtype in _TORCH_UINT_WORKAROUNDS:
assert name == "", "must be temp memory for uint dtypes" assert name == "", "must be temp memory for uint dtypes"
dtype = _TORCH_UINT_WORKAROUNDS[dtype] dtype = _TORCH_UINT_WORKAROUNDS[dtype]
th_dtype = _TV_DTYPE_TO_TORCH[dtype] th_dtype = _TV_DTYPE_TO_TORCH[dtype]
if device == -1: if device == -1:
dev = self.cpudevice dev = self.cpudevice
...@@ -166,12 +187,15 @@ class TorchAllocator(ExternalAllocator): ...@@ -166,12 +187,15 @@ class TorchAllocator(ExternalAllocator):
ten = torch.full(shape, value, dtype=th_dtype, device=dev) ten = torch.full(shape, value, dtype=th_dtype, device=dev)
ten_tv = torch_tensor_to_tv(ten) ten_tv = torch_tensor_to_tv(ten)
self.allocated[ten.data_ptr()] = ten self.allocated[ten.data_ptr()] = ten
if name: if name and not is_temp_memory:
self.allocated[name] = ten self.allocated[name] = ten
if torch_uint_workaround: if torch_uint_workaround:
return ten_tv.type_view(dtype_bkp) return ten_tv.type_view(dtype_bkp)
return ten_tv return ten_tv
def get_tensor_by_name(self, name: str):
return torch_tensor_to_tv(self.allocated[name])
def free(self, ten: tv.Tensor): def free(self, ten: tv.Tensor):
if ten.storage_bytesize() != ten.bytesize(): if ten.storage_bytesize() != ten.bytesize():
raise ValueError("you can't free a sliced tensor.") raise ValueError("you can't free a sliced tensor.")
...@@ -189,6 +213,130 @@ class TorchAllocator(ExternalAllocator): ...@@ -189,6 +213,130 @@ class TorchAllocator(ExternalAllocator):
return return
class TorchSpconvMatmul(ExternalSpconvMatmul):
def __init__(self, alloc: TorchAllocator) -> None:
super().__init__()
self.alloc = alloc
def indice_conv_init_gemm(self, features_n: str, filters_n: str,
all_weight_is_krsc: bool, is_kc_not_ck: bool,
kv_center: int, out_channel: int, stream_int: int = 0):
features = self.alloc.allocated[features_n]
filters = self.alloc.allocated[filters_n]
if not all_weight_is_krsc:
filters = filters.reshape(-1, *filters.shape[-2:])
if not is_kc_not_ck:
out_features = torch.mm(features, filters[kv_center])
else:
out_features = torch.mm(features, filters[kv_center].T)
else:
filters = filters.reshape(out_channel, -1, filters.shape[-1])
if features.is_cuda or (features.dtype != torch.float16):
out_features = torch.mm(features, filters[:, kv_center].T)
else:
# pytorch 1.12 don't support cpu half mm, f**k pytorch
# we need cpu fp16 mm for test only.
out_features = torch.empty((features.shape[0], out_channel),
dtype=features.dtype,
device=features.device)
features_np = torch_tensor_to_tv(features).numpy_view()
filters_np = torch_tensor_to_tv(filters).numpy_view()
out_features_np = torch_tensor_to_tv(out_features).numpy_view()
np.matmul(features_np,
filters_np[:, kv_center].T,
out=out_features_np)
self.alloc.allocated[AllocKeys.OutFeatures] = out_features
# print(filters.shape, features.shape, all_weight_is_krsc, out_features.shape, out_features.is_contiguous())
return torch_tensor_to_tv(out_features)
def indice_conv_cpu_gemm(self, inp_buffer_n: str, out_buffer_n: str, filters_n: str,
all_weight_is_krsc: bool,
is_kc_not_ck: bool, nhot: int, index: int):
kv_dim = 1 if all_weight_is_krsc else 0
inp_buffer = self.alloc.allocated[inp_buffer_n]
filters = self.alloc.allocated[filters_n]
if not all_weight_is_krsc:
filters = filters.reshape(-1, *filters.shape[-2:])
else:
filters = filters.reshape(filters.shape[0], -1, filters.shape[-1])
out_buffer = self.alloc.allocated[out_buffer_n]
filters_i = filters.select(kv_dim, index)
filters_cur = filters_i if not is_kc_not_ck else filters_i.T
if inp_buffer.dtype == torch.float16:
inp_buffer_np = torch_tensor_to_tv(inp_buffer).numpy_view()
filters_np = torch_tensor_to_tv(filters).numpy_view()
filters_i_np = filters_np[
index] if not all_weight_is_krsc else filters_np[:, index]
filters_cur_np = filters_i_np if not is_kc_not_ck else filters_i_np.T
out_buffer_np = torch_tensor_to_tv(out_buffer).numpy_view()
np.matmul(inp_buffer_np[:nhot],
filters_cur_np,
out=out_buffer_np[:nhot])
else:
torch.mm(inp_buffer[:nhot], filters_cur, out=out_buffer[:nhot])
def indice_conv_bwd_init_gemm(self, features_n: str, filters_n: str,
out_bp_n: str, dfilters_n: str,
all_weight_is_krsc: bool, is_kc_not_ck: bool,
kv_center: int, stream_int: int = 0):
features = self.alloc.allocated[features_n]
filters = self.alloc.allocated[filters_n]
out_bp = self.alloc.allocated[out_bp_n]
dfilters = self.alloc.allocated[dfilters_n]
if not all_weight_is_krsc:
filters = filters.reshape(-1, *filters.shape[-2:])
dfilters = dfilters.reshape(-1, *filters.shape[-2:])
else:
filters = filters.reshape(filters.shape[0], -1, filters.shape[-1])
dfilters = dfilters.reshape(filters.shape[0], -1, filters.shape[-1])
if not all_weight_is_krsc:
if not is_kc_not_ck:
torch.mm(features.T, out_bp, out=dfilters[kv_center])
din = torch.mm(out_bp, filters[kv_center].T)
else:
torch.mm(out_bp.T, features, out=dfilters[kv_center])
din = torch.mm(out_bp, filters[kv_center])
else:
# KN @ NC
torch.mm(out_bp.T, features, out=dfilters[:, kv_center])
# NK @ KC
din = torch.mm(out_bp, filters[:, kv_center])
self.alloc.allocated[AllocKeys.DIn] = din
return torch_tensor_to_tv(din)
def indice_conv_bwd_cpu_gemm(self, inp_buffer_n: str,
out_buffer_n: str, filters_n: str, dfilters_n: str,all_weight_is_krsc: bool,
is_kc_not_ck: bool, nhot: int, index: int):
kv_dim = 1 if all_weight_is_krsc else 0
inp_buffer = self.alloc.allocated[inp_buffer_n]
out_buffer = self.alloc.allocated[out_buffer_n]
filters = self.alloc.allocated[filters_n]
dfilters = self.alloc.allocated[dfilters_n]
if not all_weight_is_krsc:
filters = filters.reshape(-1, *filters.shape[-2:])
dfilters = dfilters.reshape(-1, *filters.shape[-2:])
else:
filters = filters.reshape(filters.shape[0], -1, filters.shape[-1])
dfilters = dfilters.reshape(filters.shape[0], -1, filters.shape[-1])
filters_i = filters.select(kv_dim, index)
dfilters_i = dfilters.select(kv_dim, index)
filters_KC = filters_i if is_kc_not_ck else filters_i.T
if is_kc_not_ck:
# KN @ NC
torch.mm(out_buffer[:nhot].T, inp_buffer[:nhot], out=dfilters_i)
else:
# CN @ NK
torch.mm(inp_buffer[:nhot].T, out_buffer[:nhot], out=dfilters_i)
# NK @ KC
torch.mm(out_buffer[:nhot], filters_KC, out=inp_buffer[:nhot])
if __name__ == "__main__": if __name__ == "__main__":
a = torch.rand(2, 2) a = torch.rand(2, 2)
atv = torch_tensor_to_tv(a) atv = torch_tensor_to_tv(a)
......
...@@ -23,23 +23,25 @@ import spconv ...@@ -23,23 +23,25 @@ import spconv
from spconv.core import AlgoHint, ConvAlgo from spconv.core import AlgoHint, ConvAlgo
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union
from spconv.pytorch.core import ThrustSortAllocator from spconv.pytorch.core import ThrustSortAllocator
from spconv.pytorch.cppcore import torch_tensor_to_tv, get_current_stream, get_arch from spconv.pytorch.cppcore import TorchAllocator, torch_tensor_to_tv, get_current_stream, get_arch, TorchSpconvMatmul
from spconv.core_cc.csrc.sparse.all import SpconvOps from spconv.core_cc.csrc.sparse.all import SpconvOps
from spconv.core_cc.csrc.sparse.alloc import ExternalAllocator from spconv.core_cc.csrc.sparse.alloc import ExternalAllocator
from spconv.constants import SPCONV_CPP_INDICE_PAIRS, SPCONV_CPP_INDICE_PAIRS_IGEMM, SPCONV_CPP_GEMM
import spconv.core_cc as _ext import spconv.core_cc as _ext
from spconv.core_cc.csrc.sparse.convops.spops import ConvGemmOps
from spconv.utils import nullcontext from spconv.utils import nullcontext
if hasattr(_ext, "cumm"): if hasattr(_ext, "cumm"):
CPU_ONLY_BUILD = False CPU_ONLY_BUILD = False
from spconv.algo import GEMM, CONV # , GATHER, SCATTER from spconv.algo import GEMM, CONV, GEMM_CPP, CONV_CPP
else: else:
CPU_ONLY_BUILD = True CPU_ONLY_BUILD = True
GEMM = None GEMM = None
CONV = None CONV = None
GEMM_CPP = None
CONV_CPP = None
import time import time
from spconv.constants import FILTER_HWIO, ALL_WEIGHT_IS_KRSC from spconv.constants import FILTER_HWIO, ALL_WEIGHT_IS_KRSC, AllocKeys
from cumm.gemm import codeops from cumm.gemm import codeops
from spconv.tools import CUDAKernelTimer from spconv.tools import CUDAKernelTimer
...@@ -103,14 +105,30 @@ def get_indice_pairs(indices: torch.Tensor, ...@@ -103,14 +105,30 @@ def get_indice_pairs(indices: torch.Tensor,
dilation: List[int], dilation: List[int],
out_padding: List[int], out_padding: List[int],
subm: bool = False, subm: bool = False,
transpose: bool = False): transpose: bool = False,
num_out_act_bound: int = -1):
# torch.cuda.synchronize() # torch.cuda.synchronize()
# t = time.time() # t = time.time()
# stream = get_current_stream() # stream = get_current_stream()
# CONV.stream_synchronize(stream) # CONV.stream_synchronize(stream)
# t = time.time() # t = time.time()
if SPCONV_CPP_INDICE_PAIRS:
alloc = TorchAllocator(indices.device)
stream = 0
if indices.is_cuda:
stream = get_current_stream()
num_act_out = SpconvOps.get_indice_pairs(alloc, torch_tensor_to_tv(indices), batch_size, spatial_shape,
algo.value, ksize, stride, padding, dilation, out_padding, subm, transpose, stream)
if subm:
out_inds = indices
else:
out_inds = alloc.allocated[AllocKeys.OutIndices]
pair = alloc.allocated[AllocKeys.Pair]
indice_num_per_loc = alloc.allocated[AllocKeys.IndiceNumPerLoc]
# print(subm, out_inds.shape, pair.shape, indice_num_per_loc.shape, num_act_out)
return out_inds[:num_act_out], pair, indice_num_per_loc
ndim = indices.shape[1] - 1 ndim = indices.shape[1] - 1
kv: int = functools.reduce(lambda x, y: x * y, ksize, 1) kv: int = functools.reduce(lambda x, y: x * y, ksize, 1)
if not subm: if not subm:
...@@ -152,7 +170,6 @@ def get_indice_pairs(indices: torch.Tensor, ...@@ -152,7 +170,6 @@ def get_indice_pairs(indices: torch.Tensor,
# device=indices.device) # device=indices.device)
out_inds_tv = torch_tensor_to_tv(out_inds) out_inds_tv = torch_tensor_to_tv(out_inds)
# hashdata_tv = torch_tensor_to_tv(hashdata, dtype=tv.custom64) # hashdata_tv = torch_tensor_to_tv(hashdata, dtype=tv.custom64)
SpconvOps.generate_subm_conv_inds(inds_tv, SpconvOps.generate_subm_conv_inds(inds_tv,
hashdata.hashdata_k_tv, hashdata.hashdata_k_tv,
hashdata.hashdata_v_tv, hashdata.hashdata_v_tv,
...@@ -200,6 +217,10 @@ def get_indice_pairs(indices: torch.Tensor, ...@@ -200,6 +217,10 @@ def get_indice_pairs(indices: torch.Tensor,
stream_int=stream) stream_int=stream)
uniq_res = indice_pairs_uniq.unique() uniq_res = indice_pairs_uniq.unique()
num_act_out = uniq_res.shape[0] - 1 num_act_out = uniq_res.shape[0] - 1
use_bound_algo = False
if num_out_act_bound > 0 and num_act_out > num_out_act_bound:
num_act_out = num_out_act_bound
use_bound_algo = True
uniq_res_tv = torch_tensor_to_tv(uniq_res) uniq_res_tv = torch_tensor_to_tv(uniq_res)
# num_act_out = SpconvOps.generate_conv_inds_stage1_5( # num_act_out = SpconvOps.generate_conv_inds_stage1_5(
# indice_pairs_uniq_tv, # indice_pairs_uniq_tv,
...@@ -224,6 +245,7 @@ def get_indice_pairs(indices: torch.Tensor, ...@@ -224,6 +245,7 @@ def get_indice_pairs(indices: torch.Tensor,
uniq_res_tv, uniq_res_tv,
indice_pairs_uniq_tv, indice_pairs_uniq_tv,
out_inds_tv, out_inds_tv,
indice_num_per_loc_tv,
num_out_act=num_act_out, num_out_act=num_act_out,
batch_size=batch_size, batch_size=batch_size,
output_dims=out_shape, output_dims=out_shape,
...@@ -233,7 +255,8 @@ def get_indice_pairs(indices: torch.Tensor, ...@@ -233,7 +255,8 @@ def get_indice_pairs(indices: torch.Tensor,
padding=padding, padding=padding,
dilation=dilation, dilation=dilation,
transposed=transpose, transposed=transpose,
stream_int=stream) stream_int=stream,
use_bound_algo=use_bound_algo)
else: else:
out_inds = torch.empty((kv * indices.shape[0], indices.shape[1]), out_inds = torch.empty((kv * indices.shape[0], indices.shape[1]),
dtype=indices.dtype, dtype=indices.dtype,
...@@ -258,7 +281,6 @@ def get_indice_pairs(indices: torch.Tensor, ...@@ -258,7 +281,6 @@ def get_indice_pairs(indices: torch.Tensor,
# print("REGU", time.time() - t) # print("REGU", time.time() - t)
return out_inds, pair, indice_num_per_loc return out_inds, pair, indice_num_per_loc
def get_indice_pairs_implicit_gemm( def get_indice_pairs_implicit_gemm(
indices: torch.Tensor, indices: torch.Tensor,
batch_size: int, batch_size: int,
...@@ -273,7 +295,8 @@ def get_indice_pairs_implicit_gemm( ...@@ -273,7 +295,8 @@ def get_indice_pairs_implicit_gemm(
transpose: bool = False, transpose: bool = False,
is_train: bool = True, is_train: bool = True,
alloc: Optional[ThrustSortAllocator] = None, alloc: Optional[ThrustSortAllocator] = None,
timer: CUDAKernelTimer = CUDAKernelTimer(False)): timer: CUDAKernelTimer = CUDAKernelTimer(False),
num_out_act_bound: int = -1):
""" """
Why return tuple? because pytorch seems don't support custom object in autograd. Why return tuple? because pytorch seems don't support custom object in autograd.
return: ( return: (
...@@ -289,6 +312,62 @@ def get_indice_pairs_implicit_gemm( ...@@ -289,6 +312,62 @@ def get_indice_pairs_implicit_gemm(
) )
""" """
stream = get_current_stream() stream = get_current_stream()
if SPCONV_CPP_INDICE_PAIRS_IGEMM:
thalloc = TorchAllocator(indices.device)
mask_tensor, num_act_out = SpconvOps.get_indice_pairs_implicit_gemm(
thalloc, torch_tensor_to_tv(indices), batch_size, spatial_shape,
algo.value, ksize, stride, padding, dilation, out_padding, subm, transpose, is_train, stream,
num_out_act_bound)
mask_split_count = mask_tensor.dim(0)
masks = [mask_tensor[i:i+1].numpy() for i in range(mask_split_count)]
if subm:
out_inds = indices
else:
out_inds = thalloc.allocated[AllocKeys.OutIndices]
pair = thalloc.allocated[AllocKeys.Pair]
indice_num_per_loc = thalloc.allocated[AllocKeys.IndiceNumPerLoc]
if subm:
pair_mask = thalloc.allocated[AllocKeys.PairMask]
mask_argsort = thalloc.allocated[AllocKeys.MaskArgSort]
pair_mask_in_splits = [pair_mask[i] for i in range(mask_split_count)]
mask_argsort_in_splits = [
mask_argsort[i] for i in range(mask_split_count)
]
return (out_inds, indice_num_per_loc, pair[0], pair[1],
pair_mask_in_splits, [], mask_argsort_in_splits, [], masks)
else:
pair_bwd = pair
pair_fwd = thalloc.allocated[AllocKeys.PairFwd]
pair_mask_fwd = thalloc.allocated[AllocKeys.PairMask]
pair_mask_bwd = torch.Tensor()
mask_argsort_bwd = torch.Tensor()
if is_train:
pair_mask_bwd = thalloc.allocated[AllocKeys.PairMaskBwd]
mask_argsort_bwd = thalloc.allocated[AllocKeys.MaskArgSortBwd]
mask_argsort_fwd = thalloc.allocated[AllocKeys.MaskArgSort]
if not is_train:
pair_bwd = torch.Tensor()
pair_mask_bwd_splits: List[torch.Tensor] = []
mask_argsort_bwd_splits: List[torch.Tensor] = []
else:
pair_mask_bwd_splits = [
pair_mask_bwd[i] for i in range(mask_split_count)
]
mask_argsort_bwd_splits = [
mask_argsort_bwd[i] for i in range(mask_split_count)
]
pair_mask_fwd_splits = [
pair_mask_fwd[i] for i in range(mask_split_count)
]
mask_argsort_fwd_splits = [
mask_argsort_fwd[i] for i in range(mask_split_count)
]
return (out_inds, indice_num_per_loc, pair_fwd, pair_bwd,
pair_mask_fwd_splits, pair_mask_bwd_splits,
mask_argsort_fwd_splits, mask_argsort_bwd_splits, masks)
t = 0 t = 0
if DEBUG: if DEBUG:
CONV.stream_synchronize(stream) CONV.stream_synchronize(stream)
...@@ -443,6 +522,8 @@ def get_indice_pairs_implicit_gemm( ...@@ -443,6 +522,8 @@ def get_indice_pairs_implicit_gemm(
uniq_res = indice_pairs_uniq.unique() uniq_res = indice_pairs_uniq.unique()
num_act_out = uniq_res.shape[0] - 1 num_act_out = uniq_res.shape[0] - 1
if num_out_act_bound > 0 and num_act_out > num_out_act_bound:
num_act_out = num_out_act_bound
if DEBUG: if DEBUG:
CONV.stream_synchronize(stream) CONV.stream_synchronize(stream)
...@@ -627,10 +708,36 @@ def indice_conv(features: torch.Tensor, ...@@ -627,10 +708,36 @@ def indice_conv(features: torch.Tensor,
# t = time.time() # t = time.time()
if not features.is_contiguous(): if not features.is_contiguous():
features = features.contiguous() features = features.contiguous()
if features.dtype == torch.int8 or features.dtype == torch.qint8: if features.dtype == torch.int8 or features.dtype == torch.qint8:
raise NotImplementedError("work in progress") raise NotImplementedError("work in progress")
if SPCONV_CPP_GEMM and GEMM_CPP is not None:
# print("CPPPPPP!!!", features.device)
alloc = TorchAllocator(features.device)
from spconv.core_cc.csrc.sparse.convops import SimpleExternalSpconvMatmul
# ext_mm = TorchSpconvMatmul(alloc)
if features.is_cuda:
ext_mm = SimpleExternalSpconvMatmul(alloc)
else:
ext_mm = TorchSpconvMatmul(alloc)
alloc.allocated[AllocKeys.Features] = features
alloc.allocated[AllocKeys.Filters] = filters
features_tv = torch_tensor_to_tv(features)
indice_pairs_tv = torch_tensor_to_tv(indice_pairs)
indice_pair_num_tv = torch_tensor_to_tv(indice_pair_num)
filters_tv = torch_tensor_to_tv(filters)
stream = 0
if features.is_cuda:
stream = get_current_stream()
ConvGemmOps.indice_conv(alloc, ext_mm, GEMM_CPP, ALL_WEIGHT_IS_KRSC, FILTER_HWIO,
features_tv, filters_tv, indice_pairs_tv, indice_pair_num_tv, num_activate_out,
inverse, subm, algo.value, stream)
out_features = alloc.allocated[AllocKeys.OutFeatures]
return out_features
if not ALL_WEIGHT_IS_KRSC: if not ALL_WEIGHT_IS_KRSC:
kv_dim = 0 kv_dim = 0
is_KC_not_CK = not FILTER_HWIO is_KC_not_CK = not FILTER_HWIO
...@@ -642,7 +749,6 @@ def indice_conv(features: torch.Tensor, ...@@ -642,7 +749,6 @@ def indice_conv(features: torch.Tensor,
filter_shape_per_kv = [out_channel, filters.shape[-1]] filter_shape_per_kv = [out_channel, filters.shape[-1]]
filters = filters.reshape(-1, *filters.shape[-2:]) filters = filters.reshape(-1, *filters.shape[-2:])
kv = filters.shape[0] kv = filters.shape[0]
else: else:
kv_dim = 1 kv_dim = 1
out_channel = filters.shape[0] out_channel = filters.shape[0]
...@@ -651,7 +757,6 @@ def indice_conv(features: torch.Tensor, ...@@ -651,7 +757,6 @@ def indice_conv(features: torch.Tensor,
kv = filters.shape[1] kv = filters.shape[1]
filter_shape_per_kv = [out_channel, filters.shape[-1]] filter_shape_per_kv = [out_channel, filters.shape[-1]]
kv_center = kv // 2 kv_center = kv // 2
if subm: if subm:
# out_features = torch.zeros((num_activate_out, out_channel), # out_features = torch.zeros((num_activate_out, out_channel),
...@@ -663,7 +768,19 @@ def indice_conv(features: torch.Tensor, ...@@ -663,7 +768,19 @@ def indice_conv(features: torch.Tensor,
else: else:
out_features = torch.mm(features, filters[kv_center].T) out_features = torch.mm(features, filters[kv_center].T)
else: else:
out_features = torch.mm(features, filters[:, kv_center].T) if features.is_cuda or (features.dtype != torch.float16):
out_features = torch.mm(features, filters[:, kv_center].T)
else:
# pytorch 1.12 don't support cpu half mm, f**k pytorch
# we need cpu fp16 mm for test only.
out_features = torch.empty((features.shape[0], out_channel),
dtype=features.dtype,
device=features.device)
features_np = torch_tensor_to_tv(features).numpy_view()
filters_np = torch_tensor_to_tv(filters).numpy_view()
out_features_np = torch_tensor_to_tv(out_features).numpy_view()
np.matmul(features_np, filters_np[:, kv_center].T, out=out_features_np)
# out_features = torch.mm(features, filters[:, kv_center].T)
else: else:
out_features = torch.zeros((num_activate_out, out_channel), out_features = torch.zeros((num_activate_out, out_channel),
dtype=features.dtype, dtype=features.dtype,
...@@ -706,7 +823,15 @@ def indice_conv(features: torch.Tensor, ...@@ -706,7 +823,15 @@ def indice_conv(features: torch.Tensor,
SpconvOps.gather_cpu(inp_buffer_tv, a, inp_indices) SpconvOps.gather_cpu(inp_buffer_tv, a, inp_indices)
filters_i = filters.select(kv_dim, i) filters_i = filters.select(kv_dim, i)
filters_cur = filters_i if not is_KC_not_CK else filters_i.T filters_cur = filters_i if not is_KC_not_CK else filters_i.T
torch.mm(inp_buffer[:nhot], filters_cur, out=out_buffer[:nhot]) if features.dtype == torch.float16:
inp_buffer_np = torch_tensor_to_tv(inp_buffer).numpy_view()
filters_np = torch_tensor_to_tv(filters).numpy_view()
filters_i_np = filters_np[i] if not ALL_WEIGHT_IS_KRSC else filters_np[:, i]
filters_cur_np = filters_i_np if not is_KC_not_CK else filters_i_np.T
out_buffer_np = torch_tensor_to_tv(out_buffer).numpy_view()
np.matmul(inp_buffer_np[:nhot], filters_cur_np, out=out_buffer_np[:nhot])
else:
torch.mm(inp_buffer[:nhot], filters_cur, out=out_buffer[:nhot])
SpconvOps.scatter_add_cpu(c, out_buffer_tv, out_indices) SpconvOps.scatter_add_cpu(c, out_buffer_tv, out_indices)
return out_features return out_features
...@@ -750,7 +875,7 @@ def indice_conv(features: torch.Tensor, ...@@ -750,7 +875,7 @@ def indice_conv(features: torch.Tensor,
profile_idx, :nhot_profile] profile_idx, :nhot_profile]
inp_indices = torch_tensor_to_tv(inp_indices_th) inp_indices = torch_tensor_to_tv(inp_indices_th)
out_indices = torch_tensor_to_tv(out_indices_th) out_indices = torch_tensor_to_tv(out_indices_th)
filter_tv = torch_tensor_to_tv(filters)[profile_idx] # filter_tv = torch_tensor_to_tv(filters)[profile_idx]
filter_tv = torch_tensor_to_tv(filters).select(kv_dim, profile_idx) filter_tv = torch_tensor_to_tv(filters).select(kv_dim, profile_idx)
tuned_res, min_time = GEMM.tune_and_cache( tuned_res, min_time = GEMM.tune_and_cache(
...@@ -826,8 +951,40 @@ def indice_conv_backward(features: torch.Tensor, ...@@ -826,8 +951,40 @@ def indice_conv_backward(features: torch.Tensor,
algo: ConvAlgo = ConvAlgo.Native, algo: ConvAlgo = ConvAlgo.Native,
timer: CUDAKernelTimer = CUDAKernelTimer(False)): timer: CUDAKernelTimer = CUDAKernelTimer(False)):
# print(out_bp.mean(), out_bp.max(), out_bp.min()) # print(out_bp.mean(), out_bp.max(), out_bp.min())
if SPCONV_CPP_GEMM and GEMM_CPP is not None:
alloc = TorchAllocator(features.device)
ext_mm = TorchSpconvMatmul(alloc)
alloc.allocated[AllocKeys.Features] = features
alloc.allocated[AllocKeys.Filters] = filters
alloc.allocated[AllocKeys.OutBp] = out_bp
features_tv = torch_tensor_to_tv(features)
out_bp_tv = torch_tensor_to_tv(out_bp)
indice_pairs_tv = torch_tensor_to_tv(indice_pairs)
indice_pair_num_tv = torch_tensor_to_tv(indice_pair_num)
filters_tv = torch_tensor_to_tv(filters)
stream = 0
if features.is_cuda:
stream = get_current_stream()
ConvGemmOps.indice_conv_backward(alloc, ext_mm, GEMM_CPP, ALL_WEIGHT_IS_KRSC, FILTER_HWIO,
features_tv, filters_tv, out_bp_tv, indice_pairs_tv, indice_pair_num_tv,
inverse, subm, algo.value, stream)
din = alloc.allocated[AllocKeys.DIn]
df = alloc.allocated[AllocKeys.DFilters]
return din, df
filters_shape = filters.shape filters_shape = filters.shape
# TODO handle this in nn.Module to make sure features in backward is contiguous
if not features.is_contiguous():
features = features.contiguous()
if not out_bp.is_contiguous():
out_bp = out_bp.contiguous()
assert out_bp.is_contiguous()
assert filters.is_contiguous()
assert features.is_contiguous()
if not ALL_WEIGHT_IS_KRSC: if not ALL_WEIGHT_IS_KRSC:
kv_dim = 0 kv_dim = 0
is_KC_not_CK = not FILTER_HWIO is_KC_not_CK = not FILTER_HWIO
...@@ -849,14 +1006,6 @@ def indice_conv_backward(features: torch.Tensor, ...@@ -849,14 +1006,6 @@ def indice_conv_backward(features: torch.Tensor,
filter_shape_per_kv = [out_channel, filters.shape[-1]] filter_shape_per_kv = [out_channel, filters.shape[-1]]
kv_center = kv // 2 kv_center = kv // 2
# TODO handle this in nn.Module to make sure features in backward is contiguous
if not features.is_contiguous():
features = features.contiguous()
if not out_bp.is_contiguous():
out_bp = out_bp.contiguous()
assert out_bp.is_contiguous()
assert filters.is_contiguous()
assert features.is_contiguous()
if subm: if subm:
dfilters = torch.zeros_like(filters) dfilters = torch.zeros_like(filters)
...@@ -1141,6 +1290,31 @@ def implicit_gemm(features: torch.Tensor, ...@@ -1141,6 +1290,31 @@ def implicit_gemm(features: torch.Tensor,
timer: CUDAKernelTimer = CUDAKernelTimer(False), timer: CUDAKernelTimer = CUDAKernelTimer(False),
fp32_accum: Optional[bool] = None): fp32_accum: Optional[bool] = None):
stream = get_current_stream() stream = get_current_stream()
if SPCONV_CPP_GEMM and CONV_CPP is not None:
alloc = TorchAllocator(features.device)
features_tv = torch_tensor_to_tv(features)
pair_fwd_tv = torch_tensor_to_tv(pair_fwd)
pair_mask_fwd_splits_tv = [torch_tensor_to_tv(t, tv.uint32) for t in pair_mask_fwd_splits]
mask_argsort_fwd_splits_tv = [torch_tensor_to_tv(t) for t in mask_argsort_fwd_splits]
filters_tv = torch_tensor_to_tv(filters)
mask = np.concatenate(masks)
mask_tv = tv.from_numpy(mask).clone()
timer_cpp = tv.CUDAKernelTimer(False)
if timer._timer is not None:
timer_cpp = timer._timer
auto_fp32_accum = fp32_accum is None
if fp32_accum is None:
fp32_accum = False
mask_width = ConvGemmOps.implicit_gemm(alloc, CONV_CPP, features_tv, filters_tv, pair_fwd_tv, pair_mask_fwd_splits_tv,
mask_argsort_fwd_splits_tv, num_activate_out, mask_tv, is_train, is_subm, stream, timer_cpp, auto_fp32_accum,
fp32_accum)
out_features = alloc.allocated[AllocKeys.OutFeatures]
mask_output_fwd = alloc.allocated.get(AllocKeys.MaskOutputFwd, None)
if is_train:
assert mask_output_fwd is not None
return out_features, mask_output_fwd, mask_width
# if DEBUG: # if DEBUG:
# CONV.stream_synchronize(stream) # CONV.stream_synchronize(stream)
...@@ -1225,7 +1399,8 @@ def implicit_gemm(features: torch.Tensor, ...@@ -1225,7 +1399,8 @@ def implicit_gemm(features: torch.Tensor,
# CONV.stream_synchronize(stream) # CONV.stream_synchronize(stream)
# t = time.time() # t = time.time()
print(tune_res.algo_desp) # print(tune_res.algo_desp)
# with tv.measure_and_print("f16 time"):
with timer.record("implicit_gemm", stream): with timer.record("implicit_gemm", stream):
for j in range(num_split): for j in range(num_split):
beta = 0 if j == 0 else 1 beta = 0 if j == 0 else 1
...@@ -1245,6 +1420,81 @@ def implicit_gemm(features: torch.Tensor, ...@@ -1245,6 +1420,81 @@ def implicit_gemm(features: torch.Tensor,
beta=beta, beta=beta,
stream=stream, stream=stream,
verbose=False) verbose=False)
# INT8_TEST = True
# if INT8_TEST:
# if features.shape[1] % 32 != 0:
# return out_features, mask_output_fwd, mask_width
# features = features.to(torch.int8)
# filters = filters.to(torch.int8)
# out_features_i8 = out_features.to(torch.int8)
# features_tv = torch_tensor_to_tv(features)
# filters_tv = torch_tensor_to_tv(filters)
# out_features_i8_tv = torch_tensor_to_tv(out_features_i8)
# tune_res = CONV.get_tuned_algo(ConvOpType.kForward, features_tv.dtype,
# filters_tv.dtype, out_features_i8_tv.dtype,
# out_channel, in_channel, arch)
# if tune_res is None:
# tune_res, _ = CONV.tune_and_cache(
# ConvOpType.kForward,
# features_tv,
# filters_tv,
# out_features_i8_tv,
# NHWC,
# KRSC,
# NHWC,
# arch,
# mask=pair_mask_fwd_split_tvs[0],
# mask_argsort=mask_argsort_fwd_split_tvs[0],
# indices=pair_fwd_tv,
# reverse_mask=False,
# mask_filter=masks[0].item(),
# stream=stream,
# fp32_accum=fp32_accum)
# mask_width = tune_res.algo_desp.tile_shape[0]
# if is_train:
# mask_output_fwd = torch.empty(
# [num_split,
# codeops.div_up(num_activate_out, mask_width)],
# dtype=torch.int32,
# device=features.device)
# # pytorch don't support uint32.
# mask_output_fwd_tv = torch_tensor_to_tv(mask_output_fwd,
# dtype=tv.uint32)
# mask_output_fwd_tvs = [mask_output_fwd_tv[j] for j in range(num_split)]
# else:
# mask_output_fwd = None
# mask_output_fwd_tv = tv.Tensor()
# mask_output_fwd_tvs = [tv.Tensor() for _ in range(num_split)]
# # CONV.stream_synchronize(stream)
# # print("FPREPARE", time.time() - t)
# # # t = time.time()
# # CONV.stream_synchronize(stream)
# # t = time.time()
# # print(tune_res.algo_desp)
# with tv.measure_and_print(f"i8 time {features.shape[0]}-{in_channel}-{out_channel}"):
# with timer.record("implicit_gemm_i8", stream):
# for j in range(num_split):
# beta = 0 if j == 0 else 1
# CONV.run_with_tuned_result(
# tune_res,
# ConvOpType.kForward,
# features_tv,
# filters_tv,
# out_features_i8_tv,
# mask=pair_mask_fwd_split_tvs[j],
# mask_argsort=mask_argsort_fwd_split_tvs[j],
# mask_output=mask_output_fwd_tvs[j],
# indices=pair_fwd_tv,
# reverse_mask=False,
# mask_filter=masks_ints[j],
# mask_width=-1,
# beta=beta,
# stream=stream,
# verbose=False)
# torch.cuda.synchronize() # torch.cuda.synchronize()
# if DEBUG: # if DEBUG:
...@@ -1266,7 +1516,7 @@ def implicit_gemm_backward(features: torch.Tensor, ...@@ -1266,7 +1516,7 @@ def implicit_gemm_backward(features: torch.Tensor,
pair_mask_bwd_splits: List[torch.Tensor], pair_mask_bwd_splits: List[torch.Tensor],
mask_argsort_fwd_splits: List[torch.Tensor], mask_argsort_fwd_splits: List[torch.Tensor],
mask_argsort_bwd_splits: List[torch.Tensor], mask_argsort_bwd_splits: List[torch.Tensor],
mask_output_fwd: torch.Tensor, mask_output_fwd: Optional[torch.Tensor],
masks: List[np.ndarray], masks: List[np.ndarray],
mask_width: int, mask_width: int,
is_subm: bool, is_subm: bool,
...@@ -1275,14 +1525,53 @@ def implicit_gemm_backward(features: torch.Tensor, ...@@ -1275,14 +1525,53 @@ def implicit_gemm_backward(features: torch.Tensor,
# print(out_bp.mean(), out_bp.max(), out_bp.min()) # print(out_bp.mean(), out_bp.max(), out_bp.min())
if features.dtype == torch.int8 or features.dtype == torch.qint8: if features.dtype == torch.int8 or features.dtype == torch.qint8:
raise NotImplementedError("work in progress") raise NotImplementedError("work in progress")
if not out_bp.is_contiguous(): if not out_bp.is_contiguous():
out_bp = out_bp.contiguous() out_bp = out_bp.contiguous()
if not features.is_contiguous(): if not features.is_contiguous():
features = features.contiguous() features = features.contiguous()
if mask_output_fwd is None:
raise ValueError("you must do bwd with net.train()")
assert out_bp.is_contiguous() assert out_bp.is_contiguous()
assert filters.is_contiguous() assert filters.is_contiguous()
assert features.is_contiguous() assert features.is_contiguous()
stream = get_current_stream()
if SPCONV_CPP_GEMM and CONV_CPP is not None:
alloc = TorchAllocator(features.device)
features_tv = torch_tensor_to_tv(features)
pair_fwd_tv = torch_tensor_to_tv(pair_fwd)
pair_bwd_tv = torch_tensor_to_tv(pair_bwd)
pair_mask_fwd_splits_tv = [torch_tensor_to_tv(t) for t in pair_mask_fwd_splits]
pair_mask_bwd_splits_tv = [torch_tensor_to_tv(t) for t in pair_mask_bwd_splits]
mask_argsort_fwd_splits_tv = [torch_tensor_to_tv(t) for t in mask_argsort_fwd_splits]
mask_argsort_bwd_splits_tv = [torch_tensor_to_tv(t) for t in mask_argsort_bwd_splits]
filters_tv = torch_tensor_to_tv(filters)
out_bp_tv = torch_tensor_to_tv(out_bp)
mask_output_fwd_tv = torch_tensor_to_tv(mask_output_fwd)
mask = np.concatenate(masks)
mask_tv = tv.from_numpy(mask).clone()
timer_cpp = tv.CUDAKernelTimer(False)
if timer._timer is not None:
timer_cpp = timer._timer
auto_fp32_accum = fp32_accum is None
if fp32_accum is None:
fp32_accum = False
ConvGemmOps.implicit_gemm_backward(alloc, CONV_CPP, features_tv, filters_tv, out_bp_tv, pair_fwd_tv,
pair_bwd_tv, pair_mask_fwd_splits_tv, pair_mask_bwd_splits_tv, mask_argsort_fwd_splits_tv,
mask_argsort_bwd_splits_tv, mask_output_fwd_tv, mask_tv, mask_width, is_subm, stream, timer_cpp, auto_fp32_accum,
fp32_accum)
din = alloc.allocated[AllocKeys.DIn]
dfilters = alloc.allocated[AllocKeys.DFilters]
return din, dfilters
# here filters is KRSC # here filters is KRSC
filters_shape = filters.shape filters_shape = filters.shape
out_channel = filters.shape[0] out_channel = filters.shape[0]
...@@ -1297,7 +1586,6 @@ def implicit_gemm_backward(features: torch.Tensor, ...@@ -1297,7 +1586,6 @@ def implicit_gemm_backward(features: torch.Tensor,
filters = filters.reshape(out_channel, -1, filters.shape[-1]) filters = filters.reshape(out_channel, -1, filters.shape[-1])
kv = filters.shape[1] kv = filters.shape[1]
stream = get_current_stream()
pair_fwd_tv = torch_tensor_to_tv(pair_fwd) pair_fwd_tv = torch_tensor_to_tv(pair_fwd)
pair_bwd_tv = torch_tensor_to_tv(pair_bwd) pair_bwd_tv = torch_tensor_to_tv(pair_bwd)
...@@ -1366,12 +1654,12 @@ def implicit_gemm_backward(features: torch.Tensor, ...@@ -1366,12 +1654,12 @@ def implicit_gemm_backward(features: torch.Tensor,
KRSC, KRSC,
NHWC, NHWC,
arch, arch,
mask=pair_mask_fwd_split_tvs[0], mask=mask_output_fwd_tv[0],
mask_argsort=mask_argsort_fwd_split_tvs[0], mask_argsort=mask_argsort_fwd_split_tvs[0],
indices=pair_fwd_tv, indices=pair_fwd_tv,
reverse_mask=False, reverse_mask=False,
mask_filter=masks[0].item(), mask_filter=masks[0].item(),
mask_output=mask_output_fwd_tv[0], mask_output=tv.Tensor(),
mask_width=mask_width, mask_width=mask_width,
stream=stream) stream=stream)
workspace_size = CONV.query_workspace_size(wgrad_tune_res.algo_desp, workspace_size = CONV.query_workspace_size(wgrad_tune_res.algo_desp,
......
"""this file can only be used by spconv developer for now.
the "tensorpc" isn't a open source project.
"""
import tensorpc
from tensorpc.apps.flow.flowapp import App
class TestApp(App):
pass
\ No newline at end of file
...@@ -25,7 +25,7 @@ import spconv.pytorch as spconv ...@@ -25,7 +25,7 @@ import spconv.pytorch as spconv
from spconv.utils import Point2VoxelCPU3d from spconv.utils import Point2VoxelCPU3d
# torch.backends.cudnn.enabled = False # torch.backends.cudnn.enabled = False
def waymo_data(batch_size=1): def waymo_data(batch_size=1, num_features=-1):
gen = Point2VoxelCPU3d([0.1, 0.1, 0.1], [-80, -80, -2, 80, 80, 6], 3, gen = Point2VoxelCPU3d([0.1, 0.1, 0.1], [-80, -80, -2, 80, 80, 6], 3,
150000, 1) 150000, 1)
# gen = VoxelGeneratorV2([0.1, 0.1, 0.1], [-80, -80, -2, 80, 80, 6], 1, # gen = VoxelGeneratorV2([0.1, 0.1, 0.1], [-80, -80, -2, 80, 80, 6], 1,
...@@ -35,11 +35,39 @@ def waymo_data(batch_size=1): ...@@ -35,11 +35,39 @@ def waymo_data(batch_size=1):
print(pc.shape) print(pc.shape)
voxels_tv, indices_tv, _ = gen.point_to_voxel(tv.from_numpy(pc)) voxels_tv, indices_tv, _ = gen.point_to_voxel(tv.from_numpy(pc))
voxels = voxels_tv.numpy().reshape(-1, 3) voxels = voxels_tv.numpy().reshape(-1, 3)
if num_features > 0:
voxels = np.zeros((voxels.shape[0], num_features), dtype=voxels.dtype)
coors = indices_tv.numpy() coors = indices_tv.numpy()
N = coors.shape[0] N = coors.shape[0]
coors = np.concatenate([np.full([N, 1], 0, coors.dtype), coors], axis=1) coors = np.concatenate([np.full([N, 1], 0, coors.dtype), coors], axis=1)
return voxels, coors, gen.grid_size return voxels, coors, gen.grid_size
def waymo_data_large(batch_size=1):
gen = Point2VoxelCPU3d([0.1, 0.1, 0.1], [-80, -80, -2, 80, 80, 6], 3,
1200000, 1)
# gen = VoxelGeneratorV2([0.1, 0.1, 0.1], [-80, -80, -2, 80, 80, 6], 1,
# 150000)
data = np.load(Path(__file__).parent / "data" / "benchmark-pc.npz")
pc = np.ascontiguousarray(data["pc"])
pc2 = pc.copy()
pc2[:, 1] += 1
pc3 = pc.copy()
pc3[:, 1] += 2
pc4 = pc.copy()
pc4[:, 1] += 3
pc5 = pc.copy()
pc5[:, 1] += 4
pc = np.concatenate([pc, pc2, pc3, pc4, pc5])
print(pc.shape)
voxels_tv, indices_tv, _ = gen.point_to_voxel(tv.from_numpy(pc))
voxels = voxels_tv.numpy().reshape(-1, 3)
coors = indices_tv.numpy()
N = coors.shape[0]
print("num voxels", N)
coors = np.concatenate([np.full([N, 1], 0, coors.dtype), coors], axis=1)
return voxels, coors, gen.grid_size
class Net(nn.Module): class Net(nn.Module):
def __init__(self, shape, algo): def __init__(self, shape, algo):
...@@ -61,6 +89,21 @@ class Net(nn.Module): ...@@ -61,6 +89,21 @@ class Net(nn.Module):
# # algo=algo), # # algo=algo),
# spconv.SubMConv3d(32, 64, 3, bias=False, indice_key="c0", # spconv.SubMConv3d(32, 64, 3, bias=False, indice_key="c0",
# algo=algo), # algo=algo),
# spconv.SubMConv3d(64, 64, 3, bias=False, indice_key="c0",
# algo=algo),
# spconv.SubMConv3d(32,
# 32,
# 3,
# bias=False,
# indice_key="c0",
# algo=algo),
# # nn.BatchNorm1d(32),
# # nn.ReLU(),
# # spconv.SparseConv3d(64, 64, 2, 2, bias=False,
# # algo=algo),
# spconv.SubMConv3d(32, 64, 3, bias=False, indice_key="c0",
# algo=algo),
spconv.SubMConv3d(64, spconv.SubMConv3d(64,
64, 64,
3, 3,
...@@ -275,7 +318,7 @@ def main(): ...@@ -275,7 +318,7 @@ def main():
import pickle import pickle
np.random.seed(50051) np.random.seed(50051)
torch.manual_seed(50051) torch.manual_seed(50051)
# voxels, coors, spatial_shape = waymo_data() # voxels, coors, spatial_shape = waymo_data(num_features=128)
# with open("/home/yy/test_spconv.pkl", "wb") as f: # with open("/home/yy/test_spconv.pkl", "wb") as f:
# pickle.dump((voxels, coors, spatial_shape), f) # pickle.dump((voxels, coors, spatial_shape), f)
with open(Path(__file__).parent / "data" / "test_spconv.pkl", "rb") as f: with open(Path(__file__).parent / "data" / "test_spconv.pkl", "rb") as f:
...@@ -312,7 +355,7 @@ def main(): ...@@ -312,7 +355,7 @@ def main():
# MaskImpGemm: 51.0ms # MaskImpGemm: 51.0ms
# MaskSplitImpGemm: 41.1ms # MaskSplitImpGemm: 41.1ms
# algo = None # algo = None
net = Net(spatial_shape, algo).to(device).eval().to(dtype)# .train() net = Net(spatial_shape, algo).to(device).eval().to(dtype).train()
# net.load_state_dict(net.state_dict()) # net.load_state_dict(net.state_dict())
spconv.assign_name_for_sparse_modules(net) spconv.assign_name_for_sparse_modules(net)
print(coors_th.shape) print(coors_th.shape)
...@@ -345,18 +388,18 @@ def main(): ...@@ -345,18 +388,18 @@ def main():
print("spconv time", np.mean(times[10:])) print("spconv time", np.mean(times[10:]))
times = [] times = []
# for i in range(10): for i in range(10):
# out = net(voxels_th, coors_th, 1) out = net(voxels_th, coors_th, 1)
# print("------------") print("------------")
# torch.cuda.synchronize() torch.cuda.synchronize()
# t = time.time() t = time.time()
# out.features.backward(dout_t) out.features.backward(dout_t)
# torch.cuda.synchronize() torch.cuda.synchronize()
# times.append(time.time() - t) times.append(time.time() - t)
# # # print((net.grid == -1).float().sum(), net.grid.numel()) # # print((net.grid == -1).float().sum(), net.grid.numel())
# # # print("spconv time", time.time() - t) # # print("spconv time", time.time() - t)
# print("spconv bw time", np.mean(times[5:])) print("spconv bw time", np.mean(times[5:]))
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -30,6 +30,7 @@ import numpy as np ...@@ -30,6 +30,7 @@ import numpy as np
import pccm import pccm
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from spconv.core_cc.csrc.sparse.convops import GemmTuneResult, ConvTuneResult
from spconv.test_utils import TestCase from spconv.test_utils import TestCase
from cumm import tensorview as tv from cumm import tensorview as tv
from cumm.conv.bases import NCHW, NHWC, ConvIterAlgo, ConvOpType from cumm.conv.bases import NCHW, NHWC, ConvIterAlgo, ConvOpType
...@@ -38,11 +39,11 @@ from cumm.gemm.codeops import div_up ...@@ -38,11 +39,11 @@ from cumm.gemm.codeops import div_up
from spconv.core import AlgoHint, ConvAlgo from spconv.core import AlgoHint, ConvAlgo
from spconv.pytorch.conv import expand_nd from spconv.pytorch.conv import expand_nd
from spconv.pytorch import ops from spconv.pytorch import ops
from spconv.algo import CONV, GEMM, BestAlgoByProfile, BestConvAlgoByProfile from spconv.algo import GEMM, CONV, GEMM_CPP, CONV_CPP, BestAlgoByProfile, BestConvAlgoByProfile, GemmTunerSimple
from spconv.pytorch.cppcore import get_current_stream, torch_tensor_to_tv from spconv.pytorch.cppcore import get_current_stream, torch_tensor_to_tv
from spconv.test_utils import generate_sparse_data, params_grid from spconv.test_utils import generate_sparse_data, params_grid
import tqdm import tqdm
from spconv.constants import ALL_WEIGHT_IS_KRSC from spconv.constants import ALL_WEIGHT_IS_KRSC, SPCONV_CPP_GEMM
assert ALL_WEIGHT_IS_KRSC is True, "we only support KRSC in spconv >= 2.2" assert ALL_WEIGHT_IS_KRSC is True, "we only support KRSC in spconv >= 2.2"
...@@ -67,13 +68,13 @@ class SparseConvTester: ...@@ -67,13 +68,13 @@ class SparseConvTester:
self.dtype_th = NUMPY_DTYPE_TO_TORCH[dtype] self.dtype_th = NUMPY_DTYPE_TO_TORCH[dtype]
self.K = K self.K = K
self.C = C self.C = C
self.ksize = expand_nd(ksize, ndim) self.ksize = expand_nd(ndim, ksize)
self.stride = expand_nd(stride, ndim) self.stride = expand_nd(ndim, stride)
self.padding = expand_nd(padding, ndim) self.padding = expand_nd(ndim, padding, )
self.dilation = expand_nd(dilation, ndim) self.dilation = expand_nd(ndim, dilation)
self.N = N self.N = N
self.device = torch.device("cuda:0") self.device = torch.device("cuda:0")
op = expand_nd(0, ndim) op = expand_nd(ndim, 0)
self.kv: int = np.prod(self.ksize) self.kv: int = np.prod(self.ksize)
self.num_split = 1 if algo == ConvAlgo.MaskImplicitGemm else 2 self.num_split = 1 if algo == ConvAlgo.MaskImplicitGemm else 2
...@@ -139,7 +140,9 @@ class SparseConvTester: ...@@ -139,7 +140,9 @@ class SparseConvTester:
self.weight_ref = np.ascontiguousarray(self.weight_ref).reshape(-1, K, C) self.weight_ref = np.ascontiguousarray(self.weight_ref).reshape(-1, K, C)
self.out_ref, self.din_ref, self.dw_ref = self._get_ref_output() self.out_ref, self.din_ref, self.dw_ref = self._get_ref_output()
self.dw_ref = np.ascontiguousarray(self.dw_ref.transpose(1, 0, 2).reshape(K, *self.ksize, C)) self.dw_ref = np.ascontiguousarray(self.dw_ref.transpose(1, 0, 2).reshape(K, *self.ksize, C))
self.arch = tv.get_compute_capability()
def _get_ref_output(self): def _get_ref_output(self):
output_ref = np.zeros_like(self.output, dtype=np.float32) output_ref = np.zeros_like(self.output, dtype=np.float32)
...@@ -174,6 +177,8 @@ class SparseConvTester: ...@@ -174,6 +177,8 @@ class SparseConvTester:
dw_res = out_gather.astype( dw_res = out_gather.astype(
np.float32).T @ inp_gather.astype(np.float32) np.float32).T @ inp_gather.astype(np.float32)
dw_ref[filter_offset] = dw_res dw_ref[filter_offset] = dw_res
if self.dtype == np.int8:
output_ref = np.clip(output_ref, -127, 127)
return output_ref, dinput_ref, dw_ref return output_ref, dinput_ref, dw_ref
def get_operands(self, op_type: ConvOpType): def get_operands(self, op_type: ConvOpType):
...@@ -220,9 +225,15 @@ def _test_impgemm_conv_cuda(subm: bool): ...@@ -220,9 +225,15 @@ def _test_impgemm_conv_cuda(subm: bool):
shapes = [[19, 18, 17]] shapes = [[19, 18, 17]]
batchsizes = [1] batchsizes = [1]
dtypes = [np.float32, np.float16] dtypes = [np.float32, np.float16]
dtypes = [np.int8]
test_case = TestCase() test_case = TestCase()
in_channels = [512] # in_channels = [32]
out_channels = [512] # out_channels = [32, 48, 64]
in_channels = [32, 47]
out_channels = [32, 48, 62]
in_channels = [32]
out_channels = [32]
multiple_base = 16 multiple_base = 16
if subm: if subm:
ksizes = [3] ksizes = [3]
...@@ -239,7 +250,7 @@ def _test_impgemm_conv_cuda(subm: bool): ...@@ -239,7 +250,7 @@ def _test_impgemm_conv_cuda(subm: bool):
ConvAlgo.MaskImplicitGemm, ConvAlgo.MaskImplicitGemm,
] ]
arch = torch.cuda.get_device_capability() arch = torch.cuda.get_device_capability()
force_nvrtc = False
for shape, bs, C, K, k, s, p, d, algo, dtype in tqdm.tqdm(params_grid( for shape, bs, C, K, k, s, p, d, algo, dtype in tqdm.tqdm(params_grid(
shapes, batchsizes, in_channels, out_channels, ksizes, shapes, batchsizes, in_channels, out_channels, ksizes,
strides, paddings, dilations, algos, dtypes)): strides, paddings, dilations, algos, dtypes)):
...@@ -259,15 +270,19 @@ def _test_impgemm_conv_cuda(subm: bool): ...@@ -259,15 +270,19 @@ def _test_impgemm_conv_cuda(subm: bool):
spk = 1 spk = 1
for op_type in op_types: for op_type in op_types:
inp_tv, weight_tv, output_tv = tester.get_operands(op_type) inp_tv, weight_tv, output_tv = tester.get_operands(op_type)
avail_desps = CONV.get_all_available(inp_tv, weight_tv, output_tv, NHWC, NHWC, NHWC, arch, op_type, -1) if SPCONV_CPP_GEMM:
print(avail_desps) avail_desps = CONV_CPP.get_all_available(inp_tv, weight_tv, output_tv,
NHWC.layout_type.value, NHWC.layout_type.value,
NHWC.layout_type.value, NHWC.interleave, NHWC.interleave, NHWC.interleave, arch, op_type.value, -1, True, False)
else:
avail_desps = CONV.get_all_available(inp_tv, weight_tv, output_tv, NHWC, NHWC, NHWC, arch, op_type, -1)
for desp in avail_desps: for desp in avail_desps:
if not subm: if not subm:
if op_type == ConvOpType.kForward: if op_type == ConvOpType.kForward:
output_tv.zero_() output_tv.zero_()
else: else:
inp_tv.zero_() inp_tv.zero_()
# this algo must success # this algo must success
mask_width = desp.tile_shape[0] mask_width = desp.tile_shape[0]
# if mask_width != 32: # if mask_width != 32:
...@@ -279,9 +294,9 @@ def _test_impgemm_conv_cuda(subm: bool): ...@@ -279,9 +294,9 @@ def _test_impgemm_conv_cuda(subm: bool):
mask_output_fwd = mask_width_to_mask_out_fwd[mask_width] mask_output_fwd = mask_width_to_mask_out_fwd[mask_width]
if subm: if subm:
if desp.op_type == ConvOpType.kForward.value: if desp.op_type.value == ConvOpType.kForward.value:
indice_pairs = tester.pair_fwd indice_pairs = tester.pair_fwd
elif desp.op_type == ConvOpType.kBackwardInput.value: elif desp.op_type.value == ConvOpType.kBackwardInput.value:
indice_pairs = tester.pair_bwd indice_pairs = tester.pair_bwd
else: else:
indice_pairs = tester.pair_fwd indice_pairs = tester.pair_fwd
...@@ -292,31 +307,57 @@ def _test_impgemm_conv_cuda(subm: bool): ...@@ -292,31 +307,57 @@ def _test_impgemm_conv_cuda(subm: bool):
mask_filter = tester.masks[j].item() mask_filter = tester.masks[j].item()
reverse_mask = False reverse_mask = False
if desp.op_type == ConvOpType.kBackwardWeight.value: if desp.op_type.value == ConvOpType.kBackwardWeight.value:
mask_op = mask_output[j] mask_op = mask_output[j]
else: else:
mask_op = tester.pair_mask_fwd_splits[j] mask_op = tester.pair_mask_fwd_splits[j]
if desp.op_type == ConvOpType.kBackwardInput.value: if desp.op_type.value == ConvOpType.kBackwardInput.value:
reverse_mask = True reverse_mask = True
mask_output_run = torch_tensor_to_tv(mask_output[j], dtype=tv.uint32) mask_output_run = torch_tensor_to_tv(mask_output[j], dtype=tv.uint32)
if desp.op_type == ConvOpType.kBackwardWeight.value: if desp.op_type.value == ConvOpType.kBackwardWeight.value:
mask_output_run = tv.Tensor() mask_output_run = tv.Tensor()
CONV.run_with_tuned_result( # force_nvrtc = desp.op_type.value == ConvOpType.kBackwardInput.value
BestConvAlgoByProfile(desp, spk), # if force_nvrtc:
desp.op_type, # desp.is_nvrtc = True
inp_tv, # print(force_nvrtc, desp.op_type, op_type)
weight_tv, if SPCONV_CPP_GEMM:
output_tv,
torch_tensor_to_tv(mask_op, dtype=tv.uint32), CONV_CPP.run_with_tuned_result(
torch_tensor_to_tv(tester.mask_argsort_fwd_splits[j]), ConvTuneResult(desp, tester.arch, spk),
mask_output_run, desp.op_type.value,
torch_tensor_to_tv(indice_pairs), inp_tv,
reverse_mask, weight_tv,
mask_filter=mask_filter, output_tv,
mask_width=mask_width, torch_tensor_to_tv(mask_op, dtype=tv.uint32),
beta=beta, torch_tensor_to_tv(tester.mask_argsort_fwd_splits[j]),
verbose=False, mask_output_run,
) torch_tensor_to_tv(indice_pairs),
reverse_mask,
mask_filter=mask_filter,
mask_width=mask_width,
beta=beta,
verbose=False,
force_nvrtc=force_nvrtc,
)
else:
CONV.run_with_tuned_result(
BestConvAlgoByProfile(desp, tester.arch, spk),
desp.op_type.value,
inp_tv,
weight_tv,
output_tv,
torch_tensor_to_tv(mask_op, dtype=tv.uint32),
torch_tensor_to_tv(tester.mask_argsort_fwd_splits[j]),
mask_output_run,
torch_tensor_to_tv(indice_pairs),
reverse_mask,
mask_filter=mask_filter,
mask_width=mask_width,
beta=beta,
verbose=False,
force_nvrtc=force_nvrtc,
)
else: else:
if mask_width not in mask_width_to_mask_out_bwd: if mask_width not in mask_width_to_mask_out_bwd:
mask_width_to_mask_out_bwd[mask_width] = torch.zeros([2, div_up(tester.indices_np.shape[0], mask_width)], mask_width_to_mask_out_bwd[mask_width] = torch.zeros([2, div_up(tester.indices_np.shape[0], mask_width)],
...@@ -324,12 +365,12 @@ def _test_impgemm_conv_cuda(subm: bool): ...@@ -324,12 +365,12 @@ def _test_impgemm_conv_cuda(subm: bool):
device=tester.device) device=tester.device)
mask_output_bwd = mask_width_to_mask_out_bwd[mask_width] mask_output_bwd = mask_width_to_mask_out_bwd[mask_width]
if desp.op_type == ConvOpType.kForward.value: if desp.op_type.value == ConvOpType.kForward.value:
indice_pairs = tester.pair_fwd # inp -> out indice_pairs = tester.pair_fwd # inp -> out
mask_ops = tester.pair_mask_fwd_splits mask_ops = tester.pair_mask_fwd_splits
mask_argsorts = tester.mask_argsort_fwd_splits mask_argsorts = tester.mask_argsort_fwd_splits
mask_output = mask_output_fwd mask_output = mask_output_fwd
elif desp.op_type == ConvOpType.kBackwardInput.value: elif desp.op_type.value == ConvOpType.kBackwardInput.value:
indice_pairs = tester.pair_bwd # out -> inp indice_pairs = tester.pair_bwd # out -> inp
mask_ops = tester.pair_mask_bwd_splits mask_ops = tester.pair_mask_bwd_splits
mask_argsorts = tester.mask_argsort_bwd_splits mask_argsorts = tester.mask_argsort_bwd_splits
...@@ -344,27 +385,46 @@ def _test_impgemm_conv_cuda(subm: bool): ...@@ -344,27 +385,46 @@ def _test_impgemm_conv_cuda(subm: bool):
beta = 1 if j == 1 else 0 beta = 1 if j == 1 else 0
mask_filter = tester.masks[j].item() mask_filter = tester.masks[j].item()
reverse_mask = False reverse_mask = False
if desp.op_type == ConvOpType.kBackwardWeight.value: if desp.op_type.value == ConvOpType.kBackwardWeight.value:
mask_op = mask_output[j] mask_op = mask_output[j]
else: else:
mask_op = mask_ops[j] mask_op = mask_ops[j]
if SPCONV_CPP_GEMM:
CONV_CPP.run_with_tuned_result(
ConvTuneResult(desp, tester.arch, spk),
desp.op_type.value,
inp_tv,
weight_tv,
output_tv,
torch_tensor_to_tv(mask_op, dtype=tv.uint32),
torch_tensor_to_tv(mask_argsorts[j]),
torch_tensor_to_tv(mask_output[j], dtype=tv.uint32),
torch_tensor_to_tv(indice_pairs),
reverse_mask,
mask_filter=mask_filter,
mask_width=mask_width,
beta=beta,
verbose=False,
)
else:
CONV.run_with_tuned_result(
BestConvAlgoByProfile(desp, tester.arch, spk),
desp.op_type.value,
inp_tv,
weight_tv,
output_tv,
torch_tensor_to_tv(mask_op, dtype=tv.uint32),
torch_tensor_to_tv(mask_argsorts[j]),
torch_tensor_to_tv(mask_output[j], dtype=tv.uint32),
torch_tensor_to_tv(indice_pairs),
reverse_mask,
mask_filter=mask_filter,
mask_width=mask_width,
beta=beta,
verbose=False,
)
CONV.run_with_tuned_result(
BestConvAlgoByProfile(desp, spk),
desp.op_type,
inp_tv,
weight_tv,
output_tv,
torch_tensor_to_tv(mask_op, dtype=tv.uint32),
torch_tensor_to_tv(mask_argsorts[j]),
torch_tensor_to_tv(mask_output[j], dtype=tv.uint32),
torch_tensor_to_tv(indice_pairs),
reverse_mask,
mask_filter=mask_filter,
mask_width=mask_width,
beta=beta,
verbose=False,
)
out_ref = tester.out_ref out_ref = tester.out_ref
din_ref = tester.din_ref din_ref = tester.din_ref
dw_ref = tester.dw_ref dw_ref = tester.dw_ref
...@@ -374,8 +434,8 @@ def _test_impgemm_conv_cuda(subm: bool): ...@@ -374,8 +434,8 @@ def _test_impgemm_conv_cuda(subm: bool):
test_case.assertAllClose(out_ref, out_my, atol=atol, rtol=rtol) test_case.assertAllClose(out_ref, out_my, atol=atol, rtol=rtol)
else: else:
error_norm = np.linalg.norm(out_ref.reshape(-1) - out_my.reshape(-1)) error_norm = np.linalg.norm(out_ref.reshape(-1) - out_my.reshape(-1))
# if (error_norm > 5): if (error_norm > 5):
print(f"{desp}, Error={error_norm}") print(f"{desp}, Error={error_norm}")
assert error_norm < 10 * multipler assert error_norm < 10 * multipler
# print(desp, ) # print(desp, )
else: else:
...@@ -389,7 +449,13 @@ def _test_impgemm_conv_cuda(subm: bool): ...@@ -389,7 +449,13 @@ def _test_impgemm_conv_cuda(subm: bool):
for spk in [1, 4, 16, 64]: for spk in [1, 4, 16, 64]:
for mask_width, mask_output in mask_width_to_mask_out_fwd.items(): for mask_width, mask_output in mask_width_to_mask_out_fwd.items():
avail_desps = CONV.get_all_available(inp_tv, weight_tv, output_tv, NHWC, NHWC, NHWC, arch, ConvOpType.kBackwardWeight, mask_width) if SPCONV_CPP_GEMM:
avail_desps = CONV_CPP.get_all_available(inp_tv, weight_tv, output_tv,
NHWC.layout_type.value, NHWC.layout_type.value,
NHWC.layout_type.value, NHWC.interleave, NHWC.interleave, NHWC.interleave, arch,
ConvOpType.kBackwardWeight.value, mask_width, True, False)
else:
avail_desps = CONV.get_all_available(inp_tv, weight_tv, output_tv, NHWC, NHWC, NHWC, arch, ConvOpType.kBackwardWeight, mask_width)
for desp in avail_desps: for desp in avail_desps:
weight_tv.zero_() weight_tv.zero_()
if subm: if subm:
...@@ -403,8 +469,8 @@ def _test_impgemm_conv_cuda(subm: bool): ...@@ -403,8 +469,8 @@ def _test_impgemm_conv_cuda(subm: bool):
# bit_ref = np.bitwise_or.reduce(mask_op_np, axis=0) # bit_ref = np.bitwise_or.reduce(mask_op_np, axis=0)
# bit_my = mask_filter # bit_my = mask_filter
CONV.run_with_tuned_result( CONV.run_with_tuned_result(
BestConvAlgoByProfile(desp, spk), BestConvAlgoByProfile(desp, tester.arch, spk),
desp.op_type, desp.op_type.value,
inp_tv, inp_tv,
weight_tv, weight_tv,
output_tv, output_tv,
...@@ -430,8 +496,8 @@ def _test_impgemm_conv_cuda(subm: bool): ...@@ -430,8 +496,8 @@ def _test_impgemm_conv_cuda(subm: bool):
mask_op = mask_output[j] mask_op = mask_output[j]
CONV.run_with_tuned_result( CONV.run_with_tuned_result(
BestConvAlgoByProfile(desp, spk), BestConvAlgoByProfile(desp, tester.arch, spk),
desp.op_type, desp.op_type.value,
inp_tv, inp_tv,
weight_tv, weight_tv,
output_tv, output_tv,
...@@ -499,6 +565,8 @@ def _test_native_conv_cuda(subm: bool): ...@@ -499,6 +565,8 @@ def _test_native_conv_cuda(subm: bool):
pair_out = torch_tensor_to_tv(tester.pair_native)[1] pair_out = torch_tensor_to_tv(tester.pair_native)[1]
op_types = [ConvOpType.kForward, ConvOpType.kBackwardInput, ConvOpType.kBackwardWeight] op_types = [ConvOpType.kForward, ConvOpType.kBackwardInput, ConvOpType.kBackwardWeight]
# op_types = [ConvOpType.kForward]
indice_pair_num_cpu = tester.indice_num_per_loc_np indice_pair_num_cpu = tester.indice_num_per_loc_np
spk = 1 spk = 1
...@@ -517,9 +585,11 @@ def _test_native_conv_cuda(subm: bool): ...@@ -517,9 +585,11 @@ def _test_native_conv_cuda(subm: bool):
a = inp_tv a = inp_tv
c = output_tv c = output_tv
b = weight_tv.select(1, tester.kv // 2) b = weight_tv.select(1, tester.kv // 2)
if SPCONV_CPP_GEMM:
avail_desps = GEMM_CPP.get_all_available(a, b, c, False, True, False, arch, ShuffleStrideType.ShuffleAC.value)
else:
avail_desps = GEMM.get_all_available(a, b, c, False, True, False, arch, ShuffleStrideType.ShuffleAC)
avail_desps = GEMM.get_all_available(a, b, c, False, True, False, arch, ShuffleStrideType.ShuffleAC)
for desp in avail_desps: for desp in avail_desps:
if subm: if subm:
torch.mm(inp_th, weight_th[:, tester.kv // 2].T, out=output_th) torch.mm(inp_th, weight_th[:, tester.kv // 2].T, out=output_th)
...@@ -538,22 +608,42 @@ def _test_native_conv_cuda(subm: bool): ...@@ -538,22 +608,42 @@ def _test_native_conv_cuda(subm: bool):
b = weight_tv.select(1, i) b = weight_tv.select(1, i)
# inp @ filter.T, NC @ KC # inp @ filter.T, NC @ KC
beta = 1.0 if inited else 0.0 beta = 1.0 if inited else 0.0
GEMM.run_with_tuned_result( if SPCONV_CPP_GEMM:
BestAlgoByProfile(desp, 1), GEMM_CPP.run_with_tuned_result(
a, GemmTuneResult(desp, tester.arch, 1),
b, a,
c, b,
False, c,
True, False,
False, True,
arch=arch, False,
stream=stream, arch=arch,
shuffle_type=ShuffleStrideType.ShuffleAC, stream_int=stream,
a_inds=inp_indices, shuffle_type=ShuffleStrideType.ShuffleAC.value,
c_inds=out_indices, a_inds=inp_indices,
hint=AlgoHint.Fowrard.value, b_inds=tv.Tensor(),
alpha=1.0, c_inds=out_indices,
beta=beta) hint=AlgoHint.Fowrard.value,
alpha=1.0,
beta=beta)
else:
GEMM.run_with_tuned_result(
BestAlgoByProfile(desp, tester.arch, 1),
a,
b,
c,
False,
True,
False,
arch=arch,
stream=stream,
shuffle_type=ShuffleStrideType.ShuffleAC,
a_inds=inp_indices,
c_inds=out_indices,
hint=AlgoHint.Fowrard.value,
alpha=1.0,
beta=beta)
inited = True inited = True
out_my = output_tv.cpu().numpy() out_my = output_tv.cpu().numpy()
if dtype != np.float16: if dtype != np.float16:
...@@ -570,7 +660,11 @@ def _test_native_conv_cuda(subm: bool): ...@@ -570,7 +660,11 @@ def _test_native_conv_cuda(subm: bool):
a = output_tv a = output_tv
b = weight_tv.select(1, tester.kv // 2) b = weight_tv.select(1, tester.kv // 2)
c = inp_tv c = inp_tv
avail_desps = GEMM.get_all_available(a, b, c, False, False, False, arch, ShuffleStrideType.ShuffleAC) if SPCONV_CPP_GEMM:
avail_desps = GEMM_CPP.get_all_available(a, b, c, False, False, False, arch, ShuffleStrideType.ShuffleAC.value)
else:
avail_desps = GEMM.get_all_available(a, b, c, False, False, False, arch, ShuffleStrideType.ShuffleAC)
for desp in avail_desps: for desp in avail_desps:
if subm: if subm:
torch.mm(output_th, weight_th[:, tester.kv // 2], out=inp_th) torch.mm(output_th, weight_th[:, tester.kv // 2], out=inp_th)
...@@ -589,22 +683,42 @@ def _test_native_conv_cuda(subm: bool): ...@@ -589,22 +683,42 @@ def _test_native_conv_cuda(subm: bool):
b = weight_tv.select(1, i) b = weight_tv.select(1, i)
# inp @ filter.T, NC @ KC # inp @ filter.T, NC @ KC
beta = 1.0 if inited else 0.0 beta = 1.0 if inited else 0.0
GEMM.run_with_tuned_result( if SPCONV_CPP_GEMM:
BestAlgoByProfile(desp, 1), GEMM_CPP.run_with_tuned_result(
a, GemmTuneResult(desp, tester.arch, 1),
b, a,
c, b,
False, c,
False, False,
False, False,
arch=arch, False,
stream=stream, arch=arch,
shuffle_type=ShuffleStrideType.ShuffleAC, stream_int=stream,
a_inds=out_indices, shuffle_type=ShuffleStrideType.ShuffleAC.value,
c_inds=inp_indices, a_inds=out_indices,
hint=AlgoHint.Fowrard.value, b_inds=tv.Tensor(),
alpha=1.0, c_inds=inp_indices,
beta=beta) hint=AlgoHint.Fowrard.value,
alpha=1.0,
beta=beta)
else:
GEMM.run_with_tuned_result(
BestAlgoByProfile(desp, tester.arch, 1),
a,
b,
c,
False,
False,
False,
arch=arch,
stream=stream,
shuffle_type=ShuffleStrideType.ShuffleAC,
a_inds=out_indices,
c_inds=inp_indices,
hint=AlgoHint.Fowrard.value,
alpha=1.0,
beta=beta)
inited = True inited = True
din_my = inp_tv.cpu().numpy() din_my = inp_tv.cpu().numpy()
if dtype != np.float16: if dtype != np.float16:
...@@ -616,13 +730,18 @@ def _test_native_conv_cuda(subm: bool): ...@@ -616,13 +730,18 @@ def _test_native_conv_cuda(subm: bool):
else: else:
error_norm = np.linalg.norm(din_ref.reshape(-1) - din_my.reshape(-1)) error_norm = np.linalg.norm(din_ref.reshape(-1) - din_my.reshape(-1))
assert error_norm < 10 * multipler assert error_norm < 10 * multipler
else: else:
a = output_tv a = output_tv
b = inp_tv b = inp_tv
c = weight_tv.select(1, tester.kv // 2) c = weight_tv.select(1, tester.kv // 2)
avail_desps = GEMM.get_all_available(a, b, c, True, False, False, arch, ShuffleStrideType.ShuffleAB) if SPCONV_CPP_GEMM:
avail_desps = GEMM_CPP.get_all_available(a, b, c, True, False, False, arch, ShuffleStrideType.ShuffleAB.value)
else:
avail_desps = GEMM.get_all_available(a, b, c, True, False, False, arch, ShuffleStrideType.ShuffleAB)
for desp in avail_desps: for desp in avail_desps:
# print(desp, C, K, k, s, p, d)
# desp.is_nvrtc = True
inited = subm inited = subm
weight_tv.zero_() weight_tv.zero_()
if subm: if subm:
...@@ -640,42 +759,57 @@ def _test_native_conv_cuda(subm: bool): ...@@ -640,42 +759,57 @@ def _test_native_conv_cuda(subm: bool):
out_indices = pair_out[i].slice_first_axis(0, nhot) out_indices = pair_out[i].slice_first_axis(0, nhot)
a_inds = out_indices a_inds = out_indices
b_inds = inp_indices b_inds = inp_indices
if SPCONV_CPP_GEMM:
GEMM_CPP.run_with_tuned_result(
GemmTuneResult(desp, tester.arch, 32),
a,
b,
weight_tv.select(1, i),
True,
False,
False,
arch=arch,
stream_int=stream,
shuffle_type=ShuffleStrideType.ShuffleAB.value,
a_inds=a_inds,
b_inds=b_inds,
c_inds=tv.Tensor(),
hint=AlgoHint.BackwardWeight.value,
alpha=1.0,
beta=beta)
else:
GEMM.run_with_tuned_result(BestAlgoByProfile(desp, tester.arch, 32),
a,
b,
weight_tv.select(1, i),
True,
False,
False,
arch=arch,
stream=stream,
shuffle_type=ShuffleStrideType.ShuffleAB,
a_inds=a_inds,
b_inds=b_inds,
hint=AlgoHint.BackwardWeight.value,
alpha=1.0,
beta=beta)
GEMM.run_with_tuned_result(BestAlgoByProfile(desp, 32),
a,
b,
weight_tv.select(1, i),
True,
False,
False,
arch=arch,
stream=stream,
shuffle_type=ShuffleStrideType.ShuffleAB,
a_inds=a_inds,
b_inds=b_inds,
hint=AlgoHint.BackwardWeight.value,
alpha=1.0,
beta=beta)
dw_my = weight_tv.cpu().numpy() dw_my = weight_tv.cpu().numpy()
if dtype != np.float16: if dtype != np.float16:
error_norm = np.linalg.norm(dw_ref.reshape(-1) - dw_my.reshape(-1)) error_norm = np.linalg.norm(dw_ref.reshape(-1) - dw_my.reshape(-1))
assert error_norm < 1 * multipler assert error_norm < 1 * multipler, f"{desp}, {error_norm}"
# test_case.assertAllClose(dw_ref, dw_my, atol=atol, rtol=rtol)
# print(desp, error_norm)
else: else:
error_norm = np.linalg.norm(dw_ref.reshape(-1) - dw_my.reshape(-1)) error_norm = np.linalg.norm(dw_ref.reshape(-1) - dw_my.reshape(-1))
# print(desp, error_norm) assert error_norm < 10 * multipler, f"{desp}, {error_norm}"
assert error_norm < 10 * multipler
def test_all_algo_unit(): def test_all_algo_unit():
# for i in range(5): # for i in range(5):
_test_impgemm_conv_cuda(True) _test_impgemm_conv_cuda(True)
# _test_impgemm_conv_cuda(False) _test_impgemm_conv_cuda(False)
# _test_native_conv_cuda(True) _test_native_conv_cuda(True)
# _test_native_conv_cuda(False) _test_native_conv_cuda(False)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -248,7 +248,7 @@ def test_spconv3d(): ...@@ -248,7 +248,7 @@ def test_spconv3d():
ConvAlgo.Native, ConvAlgo.MaskImplicitGemm, ConvAlgo.Native, ConvAlgo.MaskImplicitGemm,
ConvAlgo.MaskSplitImplicitGemm ConvAlgo.MaskSplitImplicitGemm
] ]
# algos = [ConvAlgo.Native] algos = [ConvAlgo.Native]
for dev, shape, bs, IC, OC, k, s, p, d, al in params_grid( for dev, shape, bs, IC, OC, k, s, p, d, al in params_grid(
devices, shapes, batchsizes, in_channels, out_channels, ksizes, devices, shapes, batchsizes, in_channels, out_channels, ksizes,
...@@ -308,7 +308,6 @@ def test_spconv3d(): ...@@ -308,7 +308,6 @@ def test_spconv3d():
filters_t = torch.from_numpy(filters).to(device).to(dtype) filters_t = torch.from_numpy(filters).to(device).to(dtype)
net_ref.net[0].weight.data[:] = filters_t.permute( net_ref.net[0].weight.data[:] = filters_t.permute(
0, 4, 1, 2, 3).contiguous() 0, 4, 1, 2, 3).contiguous()
net.net[0].weight.data[:] = filters_t net.net[0].weight.data[:] = filters_t
out_ref = net_ref(features_dense_t) out_ref = net_ref(features_dense_t)
out = net(features_t, indices_t, bs).dense() out = net(features_t, indices_t, bs).dense()
...@@ -529,4 +528,4 @@ def test_spmaxpool3d(): ...@@ -529,4 +528,4 @@ def test_spmaxpool3d():
if __name__ == "__main__": if __name__ == "__main__":
test_spmaxpool3d() test_spconv3d()
...@@ -222,6 +222,7 @@ class NetLight(nn.Module): ...@@ -222,6 +222,7 @@ class NetLight(nn.Module):
def _test_multi_impl(dtype: torch.dtype): def _test_multi_impl(dtype: torch.dtype):
# TODO pytorch 1.12 don't support cpu half mm, f**k pytorch
# TODO remove or release this when tf32 op is ready # TODO remove or release this when tf32 op is ready
torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False torch.backends.cudnn.allow_tf32 = False
...@@ -239,8 +240,6 @@ def _test_multi_impl(dtype: torch.dtype): ...@@ -239,8 +240,6 @@ def _test_multi_impl(dtype: torch.dtype):
np.float32) np.float32)
coors = np.ascontiguousarray( coors = np.ascontiguousarray(
sparse_dict["indices"][:, [3, 0, 1, 2]]).astype(np.int32) sparse_dict["indices"][:, [3, 0, 1, 2]]).astype(np.int32)
device = torch.device("cuda:0") device = torch.device("cuda:0")
device_cpu = torch.device("cpu:0") device_cpu = torch.device("cpu:0")
...@@ -275,17 +274,21 @@ def _test_multi_impl(dtype: torch.dtype): ...@@ -275,17 +274,21 @@ def _test_multi_impl(dtype: torch.dtype):
dout_t = torch.from_numpy(dout).to(device_cpu).to(dtype) dout_t = torch.from_numpy(dout).to(device_cpu).to(dtype)
dout_t_cu = torch.from_numpy(dout).to(device).to(dtype) dout_t_cu = torch.from_numpy(dout).to(device).to(dtype)
t = time.time()
print(1, time.time() - t)
out_cpu = net_native_cpu(voxels_th, coors_th, 1).dense() out_cpu = net_native_cpu(voxels_th, coors_th, 1).dense()
out_cpu.backward(dout_t) if dtype != torch.float16:
out_cpu.backward(dout_t)
out = net_native_gpu(voxels_th_cuda, coors_th_cuda, 1).dense() out = net_native_gpu(voxels_th_cuda, coors_th_cuda, 1).dense()
print(2, time.time() - t)
out.backward(dout_t_cu) out.backward(dout_t_cu)
out_imp = net_imp_gpu(voxels_th_cuda, coors_th_cuda, 1).dense() out_imp = net_imp_gpu(voxels_th_cuda, coors_th_cuda, 1).dense()
print(3, time.time() - t)
out_imp.backward(dout_t_cu) out_imp.backward(dout_t_cu)
out_simp = net_simp_gpu(voxels_th_cuda, coors_th_cuda, 1).dense() out_simp = net_simp_gpu(voxels_th_cuda, coors_th_cuda, 1).dense()
print(4, time.time() - t)
out_simp.backward(dout_t_cu) out_simp.backward(dout_t_cu)
with torch.no_grad(): with torch.no_grad():
...@@ -297,6 +300,7 @@ def _test_multi_impl(dtype: torch.dtype): ...@@ -297,6 +300,7 @@ def _test_multi_impl(dtype: torch.dtype):
error_native = torch.linalg.norm(dense_cpu - dense_native).cpu().item() error_native = torch.linalg.norm(dense_cpu - dense_native).cpu().item()
error_imp = torch.linalg.norm(dense_cpu - dense_imp).cpu().item() error_imp = torch.linalg.norm(dense_cpu - dense_imp).cpu().item()
error_simp = torch.linalg.norm(dense_cpu - dense_simp).cpu().item() error_simp = torch.linalg.norm(dense_cpu - dense_simp).cpu().item()
print(5, time.time() - t)
print("error_native", error_native) print("error_native", error_native)
print("error_imp", error_imp) print("error_imp", error_imp)
...@@ -320,15 +324,15 @@ def _test_multi_impl(dtype: torch.dtype): ...@@ -320,15 +324,15 @@ def _test_multi_impl(dtype: torch.dtype):
native_w = native_params[k] native_w = native_params[k]
imp_w = imp_params[k] imp_w = imp_params[k]
simp_w = simp_params[k] simp_w = simp_params[k]
cpu_w_grad = cpu_w.grad.detach().cuda()
native_w_grad = native_w.grad.detach() native_w_grad = native_w.grad.detach()
imp_w_grad = imp_w.grad.detach() imp_w_grad = imp_w.grad.detach()
simp_w_grad = simp_w.grad.detach() simp_w_grad = simp_w.grad.detach()
if dtype != torch.float16:
error_native = torch.linalg.norm(native_w_grad - cpu_w_grad).cpu().item() cpu_w_grad = cpu_w.grad.detach().cuda()
error_native = torch.linalg.norm(native_w_grad - cpu_w_grad).cpu().item()
error_imp = torch.linalg.norm(native_w_grad - imp_w_grad).cpu().item() error_imp = torch.linalg.norm(native_w_grad - imp_w_grad).cpu().item()
error_simp = torch.linalg.norm(native_w_grad - simp_w_grad).cpu().item() error_simp = torch.linalg.norm(native_w_grad - simp_w_grad).cpu().item()
print(k, error_native, error_imp, error_simp) print(k, error_imp, error_simp)
assert error_imp < 1 assert error_imp < 1
assert error_simp < 1 assert error_simp < 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment