Commit 4791f582 authored by yan.yan's avatar yan.yan
Browse files

working on spconv 2.2

parent ab5e2e8e
...@@ -24,4 +24,4 @@ ...@@ -24,4 +24,4 @@
* VoxelGenerator has been replaced by ```spconv.pytorch.utils.PointToVoxel``` (torch API) or Point2VoxelGPU[1-4]d/Point2VoxelCPU[1-4]d (tv.Tensor API). * VoxelGenerator has been replaced by ```spconv.pytorch.utils.PointToVoxel``` (torch API) or Point2VoxelGPU[1-4]d/Point2VoxelCPU[1-4]d (tv.Tensor API).
* spconv < 2.1 don't support CPU. spconv 2.1+ support cpu for debug usage. * spconv < 2.1 don't support CPU. spconv 2.1+ support cpu for debug usage.
* test spconv 1.x model in spconv 2.x: Firstly set environment variable before run program, Then set all ```algo``` in conv/pool to ```ConvAlgo.Native```. Linux: ```export SPCONV_FILTER_HWIO="1"```, Windows powershell: ```$Env:SPCONV_FILTER_HWIO = "1"```. **WARNING** test spconv 1.x model don't support implicit gemm algorithm. * test spconv 1.x model in spconv 2.x: Linux: ```export SPCONV_SAVED_WEIGHT_LAYOUT="RSCK"```, Windows powershell: ```$Env:SPCONV_SAVED_WEIGHT_LAYOUT = "RSCK"```.
\ No newline at end of file
...@@ -53,7 +53,7 @@ def reduce_mask_count_x(mask: np.ndarray, width: int): ...@@ -53,7 +53,7 @@ def reduce_mask_count_x(mask: np.ndarray, width: int):
return maskr return maskr
def dev_subm_inds_v2(subm: bool = False, run_conv: bool = True): def dev_subm_inds_v2(subm: bool = True, run_conv: bool = True):
limit_input_n = 16384 limit_input_n = 16384
limit_input_n = None limit_input_n = None
np.random.seed(484) np.random.seed(484)
...@@ -64,13 +64,13 @@ def dev_subm_inds_v2(subm: bool = False, run_conv: bool = True): ...@@ -64,13 +64,13 @@ def dev_subm_inds_v2(subm: bool = False, run_conv: bool = True):
voxels_np = voxels_np[:limit_input_n] voxels_np = voxels_np[:limit_input_n]
indices_np = indices_np[:limit_input_n] indices_np = indices_np[:limit_input_n]
spatial_shape = [19, 18, 17] # spatial_shape = [19, 18, 17]
sparse_dict = generate_sparse_data(spatial_shape, [1024], 128) # sparse_dict = generate_sparse_data(spatial_shape, [1024], 128)
voxels_np = np.ascontiguousarray(sparse_dict["features"]).astype( # voxels_np = np.ascontiguousarray(sparse_dict["features"]).astype(
np.float32) # np.float32)
indices_np = np.ascontiguousarray( # indices_np = np.ascontiguousarray(
sparse_dict["indices"][:, [3, 0, 1, 2]]).astype(np.int32) # sparse_dict["indices"][:, [3, 0, 1, 2]]).astype(np.int32)
voxels = tv.from_numpy(voxels_np).cuda() voxels = tv.from_numpy(voxels_np).cuda()
indices = tv.from_numpy(indices_np).cuda() indices = tv.from_numpy(indices_np).cuda()
...@@ -96,7 +96,7 @@ def dev_subm_inds_v2(subm: bool = False, run_conv: bool = True): ...@@ -96,7 +96,7 @@ def dev_subm_inds_v2(subm: bool = False, run_conv: bool = True):
dilation, out_padding, subm) dilation, out_padding, subm)
indice_num_per_loc_np = indice_num_per_loc.cpu().numpy() indice_num_per_loc_np = indice_num_per_loc.cpu().numpy()
indice_pairs_np = pair_ref.cpu().numpy() indice_pairs_np = pair_ref.cpu().numpy()
algo = ConvAlgo.MaskSplitImplicitGemm algo = ConvAlgo.MaskImplicitGemm
if algo == ConvAlgo.MaskImplicitGemm: if algo == ConvAlgo.MaskImplicitGemm:
num_split = 1 num_split = 1
else: else:
...@@ -116,7 +116,12 @@ def dev_subm_inds_v2(subm: bool = False, run_conv: bool = True): ...@@ -116,7 +116,12 @@ def dev_subm_inds_v2(subm: bool = False, run_conv: bool = True):
pair_bwd = res[3] pair_bwd = res[3]
pair_mask_fwd_splits = res[4] pair_mask_fwd_splits = res[4]
pair_mask_bwd_splits = res[5] pair_mask_bwd_splits = res[5]
mask_tv = torch_tensor_to_tv(pair_mask_fwd_splits[0], dtype=tv.uint32).cpu().numpy()
bench_reduce_mask(mask_tv)
return
mask_argsort_fwd_splits = res[6] mask_argsort_fwd_splits = res[6]
mask_argsort_bwd_splits = res[7] mask_argsort_bwd_splits = res[7]
masks = res[8] masks = res[8]
...@@ -358,6 +363,47 @@ def dev_subm_inds_v2(subm: bool = False, run_conv: bool = True): ...@@ -358,6 +363,47 @@ def dev_subm_inds_v2(subm: bool = False, run_conv: bool = True):
np.linalg.norm( np.linalg.norm(
dw_cpu.reshape(-1) - dw_ref_kcrs.reshape(-1))) dw_cpu.reshape(-1) - dw_ref_kcrs.reshape(-1)))
def reverse_bits(a: np.ndarray):
a_unpack = np.unpackbits(a, bitorder="little")
return np.packbits(a_unpack)
def _count_mask_reduce(masks: np.ndarray):
masks_tv_count = SpconvOps.count_bits(tv.from_numpy(masks))
masks_tv_count_sum = masks_tv_count.numpy_view().sum()
reduce_count = reduce_mask_count(masks, 64)
print(masks_tv_count_sum, reduce_count, reduce_count / masks_tv_count_sum)
def bench_reduce_mask(masks: np.ndarray, width: int = 27):
# masks = np.random.randint(0, 2000000000, size=[100000], dtype=np.uint32)# & 0xffff
width_mask = np.array(0xffffffff, dtype=np.uint32) << (32 - width) >> (32 - width)
width_half_mask = np.array(0xffffffff, dtype=np.uint32) >> (32 - width // 2 - 1)
width_half_mask_left = width_half_mask << (width // 2 + 1)
print(bin(width_half_mask))
masks_sort = masks.copy()
masks_sort.sort()
_count_mask_reduce(masks_sort)
masks_sort = masks.copy() & width_half_mask
masks_sort.sort()
_count_mask_reduce(masks_sort)
# masks.sort()
# masks = masks & 0xffff
reversed_masks = SpconvOps.reverse_bits(tv.from_numpy(masks)).numpy()# & 0xffff0000
new_masks = np.concatenate([masks, reversed_masks])
np.random.shuffle(new_masks)
new_masks.sort()
_count_mask_reduce(new_masks)
new_masks &= width_half_mask
new_masks.sort()
_count_mask_reduce(new_masks)
if __name__ == "__main__": if __name__ == "__main__":
dev_subm_inds_v2() dev_subm_inds_v2()
...@@ -17,3 +17,5 @@ from . import build as _build ...@@ -17,3 +17,5 @@ from . import build as _build
from .core import ConvAlgo, AlgoHint from .core import ConvAlgo, AlgoHint
from . import constants from . import constants
from .__version__ import __version__ from .__version__ import __version__
SPCONV_VERSION_NUMBERS = list(map(int, __version__.split(".")))
\ No newline at end of file
...@@ -320,15 +320,16 @@ class SimpleGemm: ...@@ -320,15 +320,16 @@ class SimpleGemm:
c_inds.shape) c_inds.shape)
avail = self.get_all_available(a, b, c, trans_a, trans_b, trans_c, avail = self.get_all_available(a, b, c, trans_a, trans_b, trans_c,
arch, shuffle_type) arch, shuffle_type)
# c may be weight, may non-contiguous.
c_ = c.clone() # cumm.tensorview.Tensor don't support non-contiguous clone
c_ = c.clone_whole_storage()
times: List[float] = [] times: List[float] = []
best_gather_params = (-1, -1, -1, -1) best_gather_params = (-1, -1, -1, -1)
best_scatter_params = (-1, -1, -1, -1) best_scatter_params = (-1, -1, -1, -1)
all_profile_res: List[BestAlgoByProfile] = [] all_profile_res: List[BestAlgoByProfile] = []
for desp in avail: for desp in avail:
c_.zero_() c_.zero_whole_storage_()
split_k_slices = 1 split_k_slices = 1
# TODO better splitk selection # TODO better splitk selection
if desp.split_k_serial and hint & AlgoHint.BackwardWeight.value: if desp.split_k_serial and hint & AlgoHint.BackwardWeight.value:
......
...@@ -23,7 +23,17 @@ PACKAGE_ROOT = Path(__file__).parent.resolve() ...@@ -23,7 +23,17 @@ PACKAGE_ROOT = Path(__file__).parent.resolve()
EDITABLE_INSTALLED = project_is_installed( EDITABLE_INSTALLED = project_is_installed(
PACKAGE_NAME) and project_is_editable(PACKAGE_NAME) PACKAGE_NAME) and project_is_editable(PACKAGE_NAME)
_filter_hwio_env = os.getenv("SPCONV_FILTER_HWIO", "0") _filter_hwio_env = os.getenv("SPCONV_FILTER_HWIO", None)
FILTER_HWIO = _filter_hwio_env == "1" if _filter_hwio_env is not None:
raise NotImplementedError("SPCONV_FILTER_HWIO is deprecated. use SPCONV_SAVED_WEIGHT_LAYOUT instead.")
DISABLE_JIT = os.getenv("SPCONV_DISABLE_JIT", "0") == "1" DISABLE_JIT = os.getenv("SPCONV_DISABLE_JIT", "0") == "1"
NDIM_DONT_CARE = 3 NDIM_DONT_CARE = 3
FILTER_HWIO = False
SAVED_WEIGHT_LAYOUT = os.getenv("SPCONV_SAVED_WEIGHT_LAYOUT", "")
if SAVED_WEIGHT_LAYOUT != "":
assert SAVED_WEIGHT_LAYOUT in ["KRSC", "RSKC", "RSCK"], "please set SAVED_WEIGHT_LAYOUT to KRSC, RSKC or RSCK"
ALL_WEIGHT_IS_KRSC = True
\ No newline at end of file
...@@ -83,6 +83,9 @@ SHUFFLE_SIMT_PARAMS: List[GemmAlgoParams] = [ ...@@ -83,6 +83,9 @@ SHUFFLE_SIMT_PARAMS: List[GemmAlgoParams] = [
"f32,f32,f32,f32,f32", 2, kernel.GemmAlgo.Simt, None), "f32,f32,f32,f32,f32", 2, kernel.GemmAlgo.Simt, None),
*gen_shuffle_params((32, 32, 32), (32, 32, 8), ["f32,f32,f32,f32,f32"], *gen_shuffle_params((32, 32, 32), (32, 32, 8), ["f32,f32,f32,f32,f32"],
"f32,f32,f32,f32,f32", 2, kernel.GemmAlgo.Simt, None), "f32,f32,f32,f32,f32", 2, kernel.GemmAlgo.Simt, None),
*gen_shuffle_params((16, 32, 8), (16, 16, 8), ["f32,f32,f32,f32,f32"],
"f32,f32,f32,f32,f32", 2, kernel.GemmAlgo.Simt, None),
# fall back kernels if mat is misaligned for half # fall back kernels if mat is misaligned for half
# TODO use access-per-vector kernel instead of simt kernel for fallback # TODO use access-per-vector kernel instead of simt kernel for fallback
*gen_shuffle_params((128, 128, 8), (32, 64, 8), ["f16,f16,f16,f16,f16"], *gen_shuffle_params((128, 128, 8), (32, 64, 8), ["f16,f16,f16,f16,f16"],
......
...@@ -283,6 +283,13 @@ class SpconvOps: ...@@ -283,6 +283,13 @@ class SpconvOps:
""" """
... ...
@staticmethod @staticmethod
def reverse_bits(a: Tensor) -> Tensor:
"""
Args:
a:
"""
...
@staticmethod
def calc_point2voxel_meta_data(vsize_xyz: List[float], coors_range_xyz: List[float]) -> Tuple[List[float], List[int], List[int], List[float]]: def calc_point2voxel_meta_data(vsize_xyz: List[float], coors_range_xyz: List[float]) -> Tuple[List[float], List[int], List[int], List[float]]:
""" """
Args: Args:
......
...@@ -815,6 +815,66 @@ class SpconvOps(pccm.Class): ...@@ -815,6 +815,66 @@ class SpconvOps(pccm.Class):
""") """)
return code.ret("tv::Tensor") return code.ret("tv::Tensor")
@pccm.pybind.mark
@pccm.cuda.static_function
def reverse_bits(self):
code = pccm.FunctionCode()
if CUMM_CPU_ONLY_BUILD:
return code.make_invalid()
code.add_dependency(TensorViewKernel)
code.arg("a", "tv::Tensor")
code.code_after_include = f"""
__global__ void reverse_bits_kernel_64(const uint64_t* data, uint64_t* out, int size){{
for (int i : tv::KernelLoopX<int>(size)){{
out[i] = __brevll(reinterpret_cast<const unsigned long long*>(data)[i]);
}}
}}
__global__ void reverse_bits_kernel(const uint32_t* data, uint32_t* out, int size){{
for (int i : tv::KernelLoopX<int>(size)){{
out[i] = __brev(data[i]);
}}
}}
uint32_t reverse(uint32_t x)
{{
x = ((x >> 1) & 0x55555555u) | ((x & 0x55555555u) << 1);
x = ((x >> 2) & 0x33333333u) | ((x & 0x33333333u) << 2);
x = ((x >> 4) & 0x0f0f0f0fu) | ((x & 0x0f0f0f0fu) << 4);
x = ((x >> 8) & 0x00ff00ffu) | ((x & 0x00ff00ffu) << 8);
x = ((x >> 16) & 0xffffu) | ((x & 0xffffu) << 16);
return x;
}}
int reverse(uint64_t i)
{{
return (reverse(uint32_t(i)) << 32) | reverse(uint32_t(i >> 32));
}}
"""
code.raw(f"""
tv::Tensor res(a.shape(), a.dtype(), a.device());
tv::dispatch<uint32_t, uint64_t>(a.dtype(), [&](auto I){{
using T = TV_DECLTYPE(I);
auto res_ptr = res.data_ptr<T>();
auto a_ptr = a.data_ptr<const T>();
if (a.device() == -1){{
for (int i = 0; i < a.size(); ++i){{
res_ptr[i] = reverse(a_ptr[i]);
}}
}}else{{
tv::cuda::Launch launcher(a.size());
tv::if_constexpr<std::is_same<T, uint64_t>::value>([=](auto _)mutable{{
launcher(_(reverse_bits_kernel_64), a_ptr, res_ptr, int(a.size()));
}}, [=](auto _)mutable{{
launcher(_(reverse_bits_kernel), a_ptr, res_ptr, int(a.size()));
}});
}}
}});
return res;
""")
return code.ret("tv::Tensor")
@pccm.pybind.mark @pccm.pybind.mark
@pccm.static_function @pccm.static_function
def calc_point2voxel_meta_data(self): def calc_point2voxel_meta_data(self):
......
...@@ -16,7 +16,6 @@ import pccm ...@@ -16,7 +16,6 @@ import pccm
from ccimport import compat from ccimport import compat
from cumm.common import TensorView from cumm.common import TensorView
class OMPLib(pccm.Class): class OMPLib(pccm.Class):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
import pccm import pccm
from cumm.common import TensorView from cumm.common import TensorView, GemmDTypes
from cumm.constants import CUMM_CPU_ONLY_BUILD from cumm.constants import CUMM_CPU_ONLY_BUILD
from spconv.csrc.sparse.cpu_core import OMPLib from spconv.csrc.sparse.cpu_core import OMPLib
from typing import List from typing import List
...@@ -24,7 +24,7 @@ class GatherCPU(pccm.Class): ...@@ -24,7 +24,7 @@ class GatherCPU(pccm.Class):
super().__init__() super().__init__()
if CUMM_CPU_ONLY_BUILD: if CUMM_CPU_ONLY_BUILD:
self.add_dependency(OMPLib) self.add_dependency(OMPLib)
self.add_dependency(TensorView) self.add_dependency(TensorView, GemmDTypes)
self.add_include("tensorview/parallel/all.h") self.add_include("tensorview/parallel/all.h")
@pccm.static_function @pccm.static_function
...@@ -39,7 +39,7 @@ class GatherCPU(pccm.Class): ...@@ -39,7 +39,7 @@ class GatherCPU(pccm.Class):
auto nhot = inds.dim(0); auto nhot = inds.dim(0);
int channel = in.dim(1); int channel = in.dim(1);
tv::dispatch<float, double>(out.dtype(), [&](auto I){{ tv::dispatch<float, double, tv::bfloat16_t, tv::half_t>(out.dtype(), [&](auto I){{
auto indices_data = inds.data_ptr<const int>(); auto indices_data = inds.data_ptr<const int>();
using T = TV_DECLTYPE(I); using T = TV_DECLTYPE(I);
T *buffer_data = out.data_ptr<T>(); T *buffer_data = out.data_ptr<T>();
...@@ -65,7 +65,7 @@ class GatherCPU(pccm.Class): ...@@ -65,7 +65,7 @@ class GatherCPU(pccm.Class):
// tv::check_shape(inds, {{in.dim(0)}}); // tv::check_shape(inds, {{in.dim(0)}});
auto nhot = inds.dim(0); auto nhot = inds.dim(0);
int channel = in.dim(1); int channel = in.dim(1);
tv::dispatch<float, double>(out.dtype(), [&](auto I){{ tv::dispatch<float, double, tv::bfloat16_t, tv::half_t>(out.dtype(), [&](auto I){{
using T = TV_DECLTYPE(I); using T = TV_DECLTYPE(I);
auto indices_data = inds.data_ptr<const int>(); auto indices_data = inds.data_ptr<const int>();
const T *buffer_data = in.data_ptr<const T>(); const T *buffer_data = in.data_ptr<const T>();
......
...@@ -18,7 +18,7 @@ from cumm.gemm.core.metaarray import MetaArray, seq ...@@ -18,7 +18,7 @@ from cumm.gemm.core.metaarray import MetaArray, seq
from cumm import dtypes from cumm import dtypes
import pccm import pccm
from cumm.gemm.layout import TensorGeneric, to_stride from cumm.gemm.layout import TensorGeneric, to_stride
from cumm.common import TensorView, TensorViewHashKernel, TensorViewKernel, ThrustLib, GemmBasic from cumm.common import TensorView, GemmDTypes, TensorViewKernel, ThrustLib, GemmBasic
from cumm.gemm import codeops from cumm.gemm import codeops
from typing import List from typing import List
from cumm.conv.params import ConvProblem from cumm.conv.params import ConvProblem
...@@ -353,7 +353,7 @@ class IndiceMaxPool(pccm.Class): ...@@ -353,7 +353,7 @@ class IndiceMaxPool(pccm.Class):
class IndiceMaxPoolCPU(pccm.Class): class IndiceMaxPoolCPU(pccm.Class):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.add_dependency(TensorView) self.add_dependency(TensorView, GemmDTypes)
if CUMM_CPU_ONLY_BUILD: if CUMM_CPU_ONLY_BUILD:
self.add_dependency(OMPLib) self.add_dependency(OMPLib)
self.add_include("tensorview/parallel/all.h") self.add_include("tensorview/parallel/all.h")
...@@ -370,7 +370,7 @@ class IndiceMaxPoolCPU(pccm.Class): ...@@ -370,7 +370,7 @@ class IndiceMaxPoolCPU(pccm.Class):
code.raw(f""" code.raw(f"""
int nhot = out_inds.dim(0); int nhot = out_inds.dim(0);
int num_features = in.dim(1); int num_features = in.dim(1);
tv::dispatch<float, double>(out.dtype(), [&](auto I){{ tv::dispatch<float, double, tv::half_t, tv::bfloat16_t>(out.dtype(), [&](auto I){{
using T = TV_DECLTYPE(I); using T = TV_DECLTYPE(I);
auto out_features = out.data_ptr<T>(); auto out_features = out.data_ptr<T>();
auto in_features = in.data_ptr<const T>(); auto in_features = in.data_ptr<const T>();
...@@ -410,7 +410,7 @@ class IndiceMaxPoolCPU(pccm.Class): ...@@ -410,7 +410,7 @@ class IndiceMaxPoolCPU(pccm.Class):
code.raw(f""" code.raw(f"""
int nhot = out_inds.dim(0); int nhot = out_inds.dim(0);
int num_features = in.dim(1); int num_features = in.dim(1);
tv::dispatch<float, double>(out.dtype(), [&](auto I){{ tv::dispatch<float, double, tv::half_t, tv::bfloat16_t>(out.dtype(), [&](auto I){{
using T = TV_DECLTYPE(I); using T = TV_DECLTYPE(I);
auto out_features = out.data_ptr<const T>(); auto out_features = out.data_ptr<const T>();
auto in_features = in.data_ptr<const T>(); auto in_features = in.data_ptr<const T>();
......
...@@ -23,15 +23,17 @@ from torch.nn import init ...@@ -23,15 +23,17 @@ from torch.nn import init
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from spconv import pytorch as spconv from spconv import pytorch as spconv
from spconv import SPCONV_VERSION_NUMBERS
from spconv.core import ConvAlgo from spconv.core import ConvAlgo
from spconv.pytorch import functional as Fsp from spconv.pytorch import functional as Fsp
from spconv.pytorch import ops from spconv.pytorch import ops
from spconv.cppconstants import CPU_ONLY_BUILD from spconv.cppconstants import CPU_ONLY_BUILD
from spconv.pytorch.core import IndiceData, SparseConvTensor, ImplicitGemmIndiceData from spconv.pytorch.core import IndiceData, SparseConvTensor, ImplicitGemmIndiceData
from spconv.pytorch.modules import SparseModule from spconv.pytorch.modules import SparseModule
from spconv.constants import FILTER_HWIO from spconv.constants import SAVED_WEIGHT_LAYOUT, ALL_WEIGHT_IS_KRSC
from spconv.utils import nullcontext from spconv.utils import nullcontext
FILTER_HWIO = False
def _calculate_fan_in_and_fan_out_hwio(tensor, algo: ConvAlgo): def _calculate_fan_in_and_fan_out_hwio(tensor, algo: ConvAlgo):
dimensions = tensor.ndimension() dimensions = tensor.ndimension()
...@@ -132,7 +134,7 @@ class SparseConvolution(SparseModule): ...@@ -132,7 +134,7 @@ class SparseConvolution(SparseModule):
assert algo == ConvAlgo.Native, "cpu only build only support native algorithm" assert algo == ConvAlgo.Native, "cpu only build only support native algorithm"
self.algo = algo self.algo = algo
# self.algo = ConvAlgo.Native # self.algo = ConvAlgo.Native
if self.algo == ConvAlgo.Native: if self.algo == ConvAlgo.Native and not ALL_WEIGHT_IS_KRSC:
if FILTER_HWIO: if FILTER_HWIO:
# RSCK # RSCK
self.weight = Parameter( self.weight = Parameter(
...@@ -152,6 +154,37 @@ class SparseConvolution(SparseModule): ...@@ -152,6 +154,37 @@ class SparseConvolution(SparseModule):
self.register_parameter('bias', None) self.register_parameter('bias', None)
self.reset_parameters() self.reset_parameters()
self._register_load_state_dict_pre_hook(self._load_weight_different_layout)
def _load_weight_different_layout(
self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
if not SAVED_WEIGHT_LAYOUT:
return
key = prefix + "weight"
assert key in state_dict
ndim = self.ndim
if SAVED_WEIGHT_LAYOUT == "RSKC":
state_dict[key] = state_dict[key].permute(ndim, *range(ndim), ndim + 1).contiguous()
elif SAVED_WEIGHT_LAYOUT == "RSCK":
state_dict[key] = state_dict[key].permute(ndim + 1, *range(ndim), ndim).contiguous()
if ALL_WEIGHT_IS_KRSC or self.algo != ConvAlgo.Native:
# in spconv 2.2, we only support KRSC layout.
if SAVED_WEIGHT_LAYOUT == "RSKC":
state_dict[key] = state_dict[key].permute(ndim, *range(ndim), ndim + 1).contiguous()
elif SAVED_WEIGHT_LAYOUT == "RSCK":
state_dict[key] = state_dict[key].permute(ndim + 1, *range(ndim), ndim).contiguous()
else:
if self.algo == ConvAlgo.Native:
# to RSCK
if SAVED_WEIGHT_LAYOUT == "RSKC":
state_dict[key] = state_dict[key].permute(*range(ndim), ndim + 1, ndim).contiguous()
elif SAVED_WEIGHT_LAYOUT == "KRSC":
state_dict[key] = state_dict[key].permute(*range(1, ndim + 1), 0, ndim + 1).contiguous()
def extra_repr(self): def extra_repr(self):
s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}' s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}'
', stride={stride}') ', stride={stride}')
......
...@@ -31,7 +31,7 @@ _TORCH_DTYPE_TO_TV = { ...@@ -31,7 +31,7 @@ _TORCH_DTYPE_TO_TV = {
def torch_tensor_to_tv(ten: torch.Tensor, def torch_tensor_to_tv(ten: torch.Tensor,
dtype: Optional[int] = None, dtype: Optional[int] = None,
shape: Optional[List[int]] = None): shape: Optional[List[int]] = None):
assert ten.is_contiguous(), "must be contiguous tensor" # assert ten.is_contiguous(), "must be contiguous tensor"
ptr = ten.data_ptr() ptr = ten.data_ptr()
device = ten.device device = ten.device
if device.type == "cpu": if device.type == "cpu":
...@@ -44,7 +44,7 @@ def torch_tensor_to_tv(ten: torch.Tensor, ...@@ -44,7 +44,7 @@ def torch_tensor_to_tv(ten: torch.Tensor,
shape = list(ten.shape) shape = list(ten.shape)
if dtype is None: if dtype is None:
dtype = _TORCH_DTYPE_TO_TV[ten.dtype] dtype = _TORCH_DTYPE_TO_TV[ten.dtype]
return tv.from_blob(ptr, shape, dtype, tv_device) return tv.from_blob(ptr, shape, list(ten.stride()), dtype, tv_device)
def get_current_stream(): def get_current_stream():
......
...@@ -36,7 +36,7 @@ else: ...@@ -36,7 +36,7 @@ else:
GEMM = None GEMM = None
CONV = None CONV = None
import time import time
from spconv.constants import FILTER_HWIO from spconv.constants import FILTER_HWIO, ALL_WEIGHT_IS_KRSC
from cumm.gemm import codeops from cumm.gemm import codeops
from spconv.tools import CUDAKernelTimer from spconv.tools import CUDAKernelTimer
...@@ -606,21 +606,40 @@ def indice_conv(features: torch.Tensor, ...@@ -606,21 +606,40 @@ def indice_conv(features: torch.Tensor,
if features.dtype == torch.int8 or features.dtype == torch.qint8: if features.dtype == torch.int8 or features.dtype == torch.qint8:
raise NotImplementedError("work in progress") raise NotImplementedError("work in progress")
if FILTER_HWIO:
out_channel = filters.shape[-1] if not ALL_WEIGHT_IS_KRSC:
kv_dim = 0
is_KC_not_CK = not FILTER_HWIO
if FILTER_HWIO:
out_channel = filters.shape[-1]
filter_shape_per_kv = [filters.shape[-2], out_channel]
else:
out_channel = filters.shape[-2]
filter_shape_per_kv = [out_channel, filters.shape[-1]]
filters = filters.reshape(-1, *filters.shape[-2:])
kv = filters.shape[0]
else: else:
out_channel = filters.shape[-2] kv_dim = 1
filters = filters.reshape(-1, *filters.shape[-2:]) out_channel = filters.shape[0]
kv = filters.shape[0] filters = filters.reshape(out_channel, -1, filters.shape[-1])
is_KC_not_CK = True
kv = filters.shape[1]
filter_shape_per_kv = [out_channel, filters.shape[-1]]
kv_center = kv // 2 kv_center = kv // 2
if subm: if subm:
# out_features = torch.zeros((num_activate_out, out_channel), # out_features = torch.zeros((num_activate_out, out_channel),
# dtype=features.dtype, # dtype=features.dtype,
# device=features.device) # device=features.device)
if FILTER_HWIO: if not ALL_WEIGHT_IS_KRSC:
out_features = torch.mm(features, filters[kv_center]) if not is_KC_not_CK:
out_features = torch.mm(features, filters[kv_center])
else:
out_features = torch.mm(features, filters[kv_center].T)
else: else:
out_features = torch.mm(features, filters[kv_center].T) out_features = torch.mm(features, filters[:, kv_center].T)
else: else:
out_features = torch.zeros((num_activate_out, out_channel), out_features = torch.zeros((num_activate_out, out_channel),
dtype=features.dtype, dtype=features.dtype,
...@@ -640,7 +659,6 @@ def indice_conv(features: torch.Tensor, ...@@ -640,7 +659,6 @@ def indice_conv(features: torch.Tensor,
pair_in = indice_pairs_tv[int(inverse)] pair_in = indice_pairs_tv[int(inverse)]
pair_out = indice_pairs_tv[int(not inverse)] pair_out = indice_pairs_tv[int(not inverse)]
filters_tv = torch_tensor_to_tv(filters) filters_tv = torch_tensor_to_tv(filters)
if not features.is_cuda: if not features.is_cuda:
# perform gather-mm-scatter_add for cpu data # perform gather-mm-scatter_add for cpu data
assert not filters.is_cuda assert not filters.is_cuda
...@@ -662,7 +680,8 @@ def indice_conv(features: torch.Tensor, ...@@ -662,7 +680,8 @@ def indice_conv(features: torch.Tensor,
inp_indices = pair_in[i].slice_first_axis(0, nhot) inp_indices = pair_in[i].slice_first_axis(0, nhot)
out_indices = pair_out[i].slice_first_axis(0, nhot) out_indices = pair_out[i].slice_first_axis(0, nhot)
SpconvOps.gather_cpu(inp_buffer_tv, a, inp_indices) SpconvOps.gather_cpu(inp_buffer_tv, a, inp_indices)
filters_cur = filters[i] if FILTER_HWIO else filters[i].T filters_i = filters.select(kv_dim, i)
filters_cur = filters_i if not is_KC_not_CK else filters_i.T
torch.mm(inp_buffer[:nhot], filters_cur, out=out_buffer[:nhot]) torch.mm(inp_buffer[:nhot], filters_cur, out=out_buffer[:nhot])
SpconvOps.scatter_add_cpu(c, out_buffer_tv, out_indices) SpconvOps.scatter_add_cpu(c, out_buffer_tv, out_indices)
...@@ -689,10 +708,10 @@ def indice_conv(features: torch.Tensor, ...@@ -689,10 +708,10 @@ def indice_conv(features: torch.Tensor,
filters_tv.dtype, filters_tv.dtype,
c.dtype, c.dtype,
a.shape, a.shape,
filters.shape[-2:], filter_shape_per_kv,
c.shape, c.shape,
False, False,
False if FILTER_HWIO else True, is_KC_not_CK,
False, False,
arch=arch, arch=arch,
shuffle_type=ShuffleStrideType.ShuffleAC, shuffle_type=ShuffleStrideType.ShuffleAC,
...@@ -708,13 +727,14 @@ def indice_conv(features: torch.Tensor, ...@@ -708,13 +727,14 @@ def indice_conv(features: torch.Tensor,
inp_indices = torch_tensor_to_tv(inp_indices_th) inp_indices = torch_tensor_to_tv(inp_indices_th)
out_indices = torch_tensor_to_tv(out_indices_th) out_indices = torch_tensor_to_tv(out_indices_th)
filter_tv = torch_tensor_to_tv(filters)[profile_idx] filter_tv = torch_tensor_to_tv(filters)[profile_idx]
filter_tv = torch_tensor_to_tv(filters).select(kv_dim, profile_idx)
tuned_res, min_time = GEMM.tune_and_cache( tuned_res, min_time = GEMM.tune_and_cache(
a, a,
filter_tv, filter_tv,
c, c,
False, False,
False if FILTER_HWIO else True, is_KC_not_CK,
False, False,
arch=arch, arch=arch,
shuffle_type=ShuffleStrideType.ShuffleAC, shuffle_type=ShuffleStrideType.ShuffleAC,
...@@ -736,7 +756,7 @@ def indice_conv(features: torch.Tensor, ...@@ -736,7 +756,7 @@ def indice_conv(features: torch.Tensor,
continue continue
inp_indices = pair_in[i].slice_first_axis(0, nhot) inp_indices = pair_in[i].slice_first_axis(0, nhot)
out_indices = pair_out[i].slice_first_axis(0, nhot) out_indices = pair_out[i].slice_first_axis(0, nhot)
b = filters_tv[i] b = filters_tv.select(kv_dim, i)
# inp @ filter.T, NC @ KC # inp @ filter.T, NC @ KC
beta = 1.0 if inited else 0.0 beta = 1.0 if inited else 0.0
algo_desp = GEMM.run_with_tuned_result( algo_desp = GEMM.run_with_tuned_result(
...@@ -745,7 +765,7 @@ def indice_conv(features: torch.Tensor, ...@@ -745,7 +765,7 @@ def indice_conv(features: torch.Tensor,
b, b,
c, c,
False, False,
False if FILTER_HWIO else True, is_KC_not_CK,
False, False,
arch=arch, arch=arch,
stream=stream, stream=stream,
...@@ -783,11 +803,27 @@ def indice_conv_backward(features: torch.Tensor, ...@@ -783,11 +803,27 @@ def indice_conv_backward(features: torch.Tensor,
timer: CUDAKernelTimer = CUDAKernelTimer(False)): timer: CUDAKernelTimer = CUDAKernelTimer(False)):
# print(out_bp.mean(), out_bp.max(), out_bp.min()) # print(out_bp.mean(), out_bp.max(), out_bp.min())
num_activate_out = out_bp.shape[0]
out_channel = out_bp.shape[-1]
filters_shape = filters.shape filters_shape = filters.shape
filters = filters.reshape(-1, *filters.shape[-2:]) if not ALL_WEIGHT_IS_KRSC:
kv = filters.shape[0] kv_dim = 0
is_KC_not_CK = not FILTER_HWIO
if FILTER_HWIO:
out_channel = filters.shape[-1]
filter_shape_per_kv = [filters.shape[-2], out_channel]
else:
out_channel = filters.shape[-2]
filter_shape_per_kv = [out_channel, filters.shape[-1]]
filters = filters.reshape(-1, *filters.shape[-2:])
kv = filters.shape[0]
else:
kv_dim = 1
out_channel = filters.shape[0]
filters = filters.reshape(out_channel, -1, filters.shape[-1])
is_KC_not_CK = True
kv = filters.shape[1]
filter_shape_per_kv = [out_channel, filters.shape[-1]]
kv_center = kv // 2 kv_center = kv // 2
if not out_bp.is_contiguous(): if not out_bp.is_contiguous():
out_bp = out_bp.contiguous() out_bp = out_bp.contiguous()
...@@ -797,20 +833,24 @@ def indice_conv_backward(features: torch.Tensor, ...@@ -797,20 +833,24 @@ def indice_conv_backward(features: torch.Tensor,
if subm: if subm:
dfilters = torch.zeros_like(filters) dfilters = torch.zeros_like(filters)
if FILTER_HWIO: if not ALL_WEIGHT_IS_KRSC:
torch.mm(features.T, out_bp, out=dfilters[kv_center]) if not is_KC_not_CK:
# TODO can we use torch mm for f16 backward weight? torch.mm(features.T, out_bp, out=dfilters[kv_center])
din = torch.mm(out_bp, filters[kv_center].T) din = torch.mm(out_bp, filters[kv_center].T)
else:
torch.mm(out_bp.T, features, out=dfilters[kv_center])
din = torch.mm(out_bp, filters[kv_center])
else: else:
torch.mm(out_bp.T, features, out=dfilters[kv_center]) # KN @ NC
# TODO can we use torch mm for f16 backward weight? torch.mm(out_bp.T, features, out=dfilters[:, kv_center])
din = torch.mm(out_bp, filters[kv_center]) # NK @ KC
din = torch.mm(out_bp, filters[:, kv_center])
else: else:
dfilters = torch.zeros_like(filters) dfilters = torch.zeros_like(filters)
din = torch.zeros_like(features) din = torch.zeros_like(features)
if kv == 1 and subm: if kv == 1 and subm:
return (din, dfilters.reshape(filters_shape)) return (din, dfilters.reshape(filters_shape))
inited: bool = subm inited: bool = subm
indice_pairs_tv = torch_tensor_to_tv(indice_pairs) indice_pairs_tv = torch_tensor_to_tv(indice_pairs)
# torch slice (a_th[x]) is very slow, so we need to use tv.Tensor earlier. # torch slice (a_th[x]) is very slow, so we need to use tv.Tensor earlier.
...@@ -854,12 +894,18 @@ def indice_conv_backward(features: torch.Tensor, ...@@ -854,12 +894,18 @@ def indice_conv_backward(features: torch.Tensor,
out_indices = pair_out[i].slice_first_axis(0, nhot) out_indices = pair_out[i].slice_first_axis(0, nhot)
SpconvOps.gather_cpu(inp_buffer_tv, features_tv, inp_indices) SpconvOps.gather_cpu(inp_buffer_tv, features_tv, inp_indices)
SpconvOps.gather_cpu(out_buffer_tv, out_bp_tv, out_indices) SpconvOps.gather_cpu(out_buffer_tv, out_bp_tv, out_indices)
filters_T_cur = filters[i].T if FILTER_HWIO else filters[i] filters_i = filters.select(kv_dim, i)
dfilters_cur = dfilters[i] if FILTER_HWIO else dfilters[i].T dfilters_i = dfilters.select(kv_dim, i)
torch.mm(inp_buffer[:nhot].T, out_buffer[:nhot], out=dfilters_cur)
torch.mm(out_buffer[:nhot], filters_T_cur, out=inp_buffer[:nhot])
filters_KC = filters_i if is_KC_not_CK else filters_i.T
if is_KC_not_CK:
# KN @ NC
torch.mm(out_buffer[:nhot].T, inp_buffer[:nhot], out=dfilters_i)
else:
# CN @ NK
torch.mm(inp_buffer[:nhot].T, out_buffer[:nhot], out=dfilters_i)
# NK @ KC
torch.mm(out_buffer[:nhot], filters_KC, out=inp_buffer[:nhot])
SpconvOps.scatter_add_cpu(din_tv, inp_buffer_tv, inp_indices) SpconvOps.scatter_add_cpu(din_tv, inp_buffer_tv, inp_indices)
return (din, dfilters.reshape(filters_shape)) return (din, dfilters.reshape(filters_shape))
...@@ -883,10 +929,10 @@ def indice_conv_backward(features: torch.Tensor, ...@@ -883,10 +929,10 @@ def indice_conv_backward(features: torch.Tensor,
filters_tv.dtype, filters_tv.dtype,
din_tv.dtype, din_tv.dtype,
out_bp_tv.shape, out_bp_tv.shape,
filters.shape[-2:], filter_shape_per_kv,
din_tv.shape, din_tv.shape,
False, False,
True if FILTER_HWIO else False, not is_KC_not_CK,
False, False,
arch=arch, arch=arch,
shuffle_type=ShuffleStrideType.ShuffleAC, shuffle_type=ShuffleStrideType.ShuffleAC,
...@@ -896,13 +942,13 @@ def indice_conv_backward(features: torch.Tensor, ...@@ -896,13 +942,13 @@ def indice_conv_backward(features: torch.Tensor,
if tuned_res_dgrad is None: if tuned_res_dgrad is None:
inp_indices = pair_in[profile_idx].slice_first_axis(0, nhot_profile) inp_indices = pair_in[profile_idx].slice_first_axis(0, nhot_profile)
out_indices = pair_out[profile_idx].slice_first_axis(0, nhot_profile) out_indices = pair_out[profile_idx].slice_first_axis(0, nhot_profile)
filter_tv = filters_tv[profile_idx] filter_tv = filters_tv.select(kv_dim, profile_idx)
tuned_res_dgrad, min_time = GEMM.tune_and_cache( tuned_res_dgrad, min_time = GEMM.tune_and_cache(
out_bp_tv, out_bp_tv,
filter_tv, filter_tv,
din_tv, din_tv,
False, False,
True if FILTER_HWIO else False, not is_KC_not_CK,
False, False,
arch=arch, arch=arch,
shuffle_type=ShuffleStrideType.ShuffleAC, shuffle_type=ShuffleStrideType.ShuffleAC,
...@@ -912,7 +958,7 @@ def indice_conv_backward(features: torch.Tensor, ...@@ -912,7 +958,7 @@ def indice_conv_backward(features: torch.Tensor,
beta=0.0, beta=0.0,
hint=AlgoHint.BackwardInput.value, hint=AlgoHint.BackwardInput.value,
stream=stream) stream=stream)
if not FILTER_HWIO: if is_KC_not_CK:
a_wgrad = out_bp_tv a_wgrad = out_bp_tv
b_wgrad = features_tv b_wgrad = features_tv
else: else:
...@@ -924,7 +970,7 @@ def indice_conv_backward(features: torch.Tensor, ...@@ -924,7 +970,7 @@ def indice_conv_backward(features: torch.Tensor,
filters_tv.dtype, filters_tv.dtype,
a_wgrad.shape, a_wgrad.shape,
b_wgrad.shape, b_wgrad.shape,
filters.shape[-2:], filter_shape_per_kv,
True, True,
False, False,
False, False,
...@@ -937,8 +983,8 @@ def indice_conv_backward(features: torch.Tensor, ...@@ -937,8 +983,8 @@ def indice_conv_backward(features: torch.Tensor,
if tuned_res_wgrad is None: if tuned_res_wgrad is None:
inp_indices = pair_in[profile_idx].slice_first_axis(0, nhot_profile) inp_indices = pair_in[profile_idx].slice_first_axis(0, nhot_profile)
out_indices = pair_out[profile_idx].slice_first_axis(0, nhot_profile) out_indices = pair_out[profile_idx].slice_first_axis(0, nhot_profile)
dfilter_tv = dfilters_tv[profile_idx] dfilter_tv = dfilters_tv.select(kv_dim, profile_idx)
if not FILTER_HWIO: if is_KC_not_CK:
a_inds_wgrad = out_indices a_inds_wgrad = out_indices
b_inds_wgrad = inp_indices b_inds_wgrad = inp_indices
else: else:
...@@ -961,7 +1007,7 @@ def indice_conv_backward(features: torch.Tensor, ...@@ -961,7 +1007,7 @@ def indice_conv_backward(features: torch.Tensor,
stream=stream) stream=stream)
# print(tuned_res_wgrad.algo_desp, tuned_res_wgrad.splitk, min_time) # print(tuned_res_wgrad.algo_desp, tuned_res_wgrad.splitk, min_time)
# get workspace size for wgrad # get workspace size for wgrad
if not FILTER_HWIO: if is_KC_not_CK:
a_shape = [maxnhot, out_bp_tv.dim(1)] a_shape = [maxnhot, out_bp_tv.dim(1)]
b_shape = [maxnhot, features_tv.dim(1)] b_shape = [maxnhot, features_tv.dim(1)]
else: else:
...@@ -1003,13 +1049,13 @@ def indice_conv_backward(features: torch.Tensor, ...@@ -1003,13 +1049,13 @@ def indice_conv_backward(features: torch.Tensor,
inp_indices = pair_in[i].slice_first_axis(0, nhot) inp_indices = pair_in[i].slice_first_axis(0, nhot)
out_indices = pair_out[i].slice_first_axis(0, nhot) out_indices = pair_out[i].slice_first_axis(0, nhot)
# out.T @ inp, NK @ NC # out.T @ inp, NK @ NC
# print(features_tv.shape, out_bp_tv.shape) filter_i_tv = filters_tv.select(kv_dim, i)
GEMM.run_with_tuned_result(tuned_res_dgrad, GEMM.run_with_tuned_result(tuned_res_dgrad,
out_bp_tv, out_bp_tv,
filters_tv[i], filter_i_tv,
din_tv, din_tv,
False, False,
True if FILTER_HWIO else False, not is_KC_not_CK,
False, False,
arch=arch, arch=arch,
stream=stream, stream=stream,
...@@ -1033,7 +1079,7 @@ def indice_conv_backward(features: torch.Tensor, ...@@ -1033,7 +1079,7 @@ def indice_conv_backward(features: torch.Tensor,
GEMM.run_with_tuned_result(tuned_res_wgrad, GEMM.run_with_tuned_result(tuned_res_wgrad,
a, a,
b, b,
dfilters_tv[i], dfilters_tv.select(kv_dim, i),
True, True,
False, False,
False, False,
......
...@@ -168,8 +168,8 @@ class Net(nn.Module): ...@@ -168,8 +168,8 @@ class Net(nn.Module):
# nn.ReLU(), # nn.ReLU(),
# spconv.SparseInverseConv3d(256, 128, 2, indice_key="m5", bias=False, algo=algo), # spconv.SparseInverseConv3d(256, 128, 2, indice_key="m5", bias=False, algo=algo),
# # nn.BatchNorm1d(128), # # # nn.BatchNorm1d(128),
# # nn.ReLU(), # # # nn.ReLU(),
# spconv.SparseInverseConv3d(128, 64, 2, indice_key="m4", bias=False, algo=algo), # spconv.SparseInverseConv3d(128, 64, 2, indice_key="m4", bias=False, algo=algo),
) )
...@@ -312,7 +312,8 @@ def main(): ...@@ -312,7 +312,8 @@ def main():
# MaskImpGemm: 51.0ms # MaskImpGemm: 51.0ms
# MaskSplitImpGemm: 41.1ms # MaskSplitImpGemm: 41.1ms
# algo = None # algo = None
net = Net(spatial_shape, algo).to(device).eval().to(dtype).train() net = Net(spatial_shape, algo).to(device).eval().to(dtype)# .train()
# net.load_state_dict(net.state_dict())
spconv.assign_name_for_sparse_modules(net) spconv.assign_name_for_sparse_modules(net)
print(coors_th.shape) print(coors_th.shape)
out = net(voxels_th, coors_th, 1) out = net(voxels_th, coors_th, 1)
...@@ -323,25 +324,25 @@ def main(): ...@@ -323,25 +324,25 @@ def main():
print(out.spatial_shape, out.features.mean(), out.features.max(), print(out.spatial_shape, out.features.mean(), out.features.max(),
out.features.min()) out.features.min())
times = [] # times = []
with torch.no_grad(): # with torch.no_grad():
for i in range(20): # for i in range(20):
print("------------") # print("------------")
torch.cuda.synchronize() # torch.cuda.synchronize()
t = time.time() # t = time.time()
out_nograd = net(voxels_th, coors_th, 1, True) # out_nograd = net(voxels_th, coors_th, 1, False)
timer = out_nograd._timer # timer = out_nograd._timer
res = timer.collect_by_name("forward", timer.get_all_pair_time()) # # res = timer.collect_by_name("forward", timer.get_all_pair_time())
res2 = timer.collect_by_name("forward0", timer.get_all_pair_time()) # # res2 = timer.collect_by_name("forward0", timer.get_all_pair_time())
print(sum(res.values()) + sum(res2.values())) # # print(sum(res.values()) + sum(res2.values()))
# print(timer.get_all_pair_time()) # # print(timer.get_all_pair_time())
# print(sum(timer.get_all_pair_time().values())) # # print(sum(timer.get_all_pair_time().values()))
torch.cuda.synchronize() # torch.cuda.synchronize()
# sort_bench() # # sort_bench()
times.append(time.time() - t) # times.append(time.time() - t)
print("spconv time", np.mean(times[10:])) # print("spconv time", np.mean(times[10:]))
# times = [] # times = []
# for i in range(10): # for i in range(10):
......
...@@ -23,7 +23,7 @@ from spconv.core import ConvAlgo ...@@ -23,7 +23,7 @@ from spconv.core import ConvAlgo
import spconv.pytorch as spconv import spconv.pytorch as spconv
from spconv.test_utils import TestCase, generate_sparse_data, params_grid from spconv.test_utils import TestCase, generate_sparse_data, params_grid
from spconv.constants import FILTER_HWIO from spconv.constants import ALL_WEIGHT_IS_KRSC, FILTER_HWIO
# import sparseconvnet as scn # import sparseconvnet as scn
# we must disable tf32 to increase reference precision. # we must disable tf32 to increase reference precision.
...@@ -368,14 +368,14 @@ class TestSpConv(TestCase): ...@@ -368,14 +368,14 @@ class TestSpConv(TestCase):
ConvAlgo.Native, ConvAlgo.MaskImplicitGemm, ConvAlgo.Native, ConvAlgo.MaskImplicitGemm,
ConvAlgo.MaskSplitImplicitGemm ConvAlgo.MaskSplitImplicitGemm
] ]
algos = [ConvAlgo.MaskSplitImplicitGemm] # algos = [ConvAlgo.Native]
for dev, shape, bs, IC, OC, k, s, p, d, al in params_grid( for dev, shape, bs, IC, OC, k, s, p, d, al in params_grid(
devices, shapes, batchsizes, in_channels, out_channels, ksizes, devices, shapes, batchsizes, in_channels, out_channels, ksizes,
strides, paddings, dilations, algos): strides, paddings, dilations, algos):
if all([s > 1, d > 1]): if all([s > 1, d > 1]):
continue # don't support this. continue # don't support this.
print(k, s, p, d) # print(dev, shape, bs, IC, OC, k, s, p, d)
device = torch.device(dev) device = torch.device(dev)
num_points = [1000] * bs num_points = [1000] * bs
dtype = torch.float32 dtype = torch.float32
...@@ -405,7 +405,7 @@ class TestSpConv(TestCase): ...@@ -405,7 +405,7 @@ class TestSpConv(TestCase):
features_dense_t = torch.from_numpy(features_dense).to(device).to( features_dense_t = torch.from_numpy(features_dense).to(device).to(
dtype) dtype)
features_dense_t.requires_grad = True features_dense_t.requires_grad = True
if net.algo == ConvAlgo.Native: if net.algo == ConvAlgo.Native and not ALL_WEIGHT_IS_KRSC:
if FILTER_HWIO: if FILTER_HWIO:
filters = np.random.uniform(-1, 1, filters = np.random.uniform(-1, 1,
size=[k, k, k, IC, size=[k, k, k, IC,
...@@ -451,7 +451,7 @@ class TestSpConv(TestCase): ...@@ -451,7 +451,7 @@ class TestSpConv(TestCase):
for layer, layer_ref in zip(net.net, net_ref.net): for layer, layer_ref in zip(net.net, net_ref.net):
dw = layer.weight.grad.detach().cpu().numpy() dw = layer.weight.grad.detach().cpu().numpy()
dw_ref = layer_ref.weight.grad.detach().cpu().numpy() dw_ref = layer_ref.weight.grad.detach().cpu().numpy()
if net.algo == ConvAlgo.Native: if net.algo == ConvAlgo.Native and not ALL_WEIGHT_IS_KRSC:
if FILTER_HWIO: if FILTER_HWIO:
dw = dw.transpose(4, 3, 0, 1, 2) dw = dw.transpose(4, 3, 0, 1, 2)
else: else:
...@@ -829,4 +829,4 @@ if __name__ == '__main__': ...@@ -829,4 +829,4 @@ if __name__ == '__main__':
# main(algo=spconv.ConvAlgo.SparseConvNet, dtype=torch.float32) # main(algo=spconv.ConvAlgo.SparseConvNet, dtype=torch.float32)
# TestCase().assertAllClose(out_my, out_ref) # TestCase().assertAllClose(out_my, out_ref)
# unittest.main() # unittest.main()
TestSpConv().testSpMaxPool3d() TestSpConv().testSpConv3d()
# Copyright 2021 Yan Yan
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Compare results between different algo:
CPU: gather-mm-scatter
Native: Fused gather-mm-scatter
ImplicitGemm
"""
...@@ -28,12 +28,12 @@ if (($CUDA_VERSION_FULL -eq "10.2") -or ($CUDA_VERSION_FULL -eq "11.0") -or ($CU ...@@ -28,12 +28,12 @@ if (($CUDA_VERSION_FULL -eq "10.2") -or ($CUDA_VERSION_FULL -eq "11.0") -or ($CU
) )
} elseif ($CUDA_VERSION_FULL -eq "11.3"){ } elseif ($CUDA_VERSION_FULL -eq "11.3"){
$CUDA_PACKAGES_IN = @( $CUDA_PACKAGES_IN = @(
"cuda_nvcc"; "nvcc";
"visual_studio_integration"; "visual_studio_integration";
"cuda_nvrtc"; "nvrtc_dev";
"cuda_cudart"; "cudart";
"cuda_thrust"; "thrust";
"libcurand"; "curand_dev";
) )
} else { } else {
# after cuda 11.4 # after cuda 11.4
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment