Commit 4791f582 authored by yan.yan's avatar yan.yan
Browse files

working on spconv 2.2

parent ab5e2e8e
......@@ -24,4 +24,4 @@
* VoxelGenerator has been replaced by ```spconv.pytorch.utils.PointToVoxel``` (torch API) or Point2VoxelGPU[1-4]d/Point2VoxelCPU[1-4]d (tv.Tensor API).
* spconv < 2.1 don't support CPU. spconv 2.1+ support cpu for debug usage.
* test spconv 1.x model in spconv 2.x: Firstly set environment variable before run program, Then set all ```algo``` in conv/pool to ```ConvAlgo.Native```. Linux: ```export SPCONV_FILTER_HWIO="1"```, Windows powershell: ```$Env:SPCONV_FILTER_HWIO = "1"```. **WARNING** test spconv 1.x model don't support implicit gemm algorithm.
* test spconv 1.x model in spconv 2.x: Linux: ```export SPCONV_SAVED_WEIGHT_LAYOUT="RSCK"```, Windows powershell: ```$Env:SPCONV_SAVED_WEIGHT_LAYOUT = "RSCK"```.
\ No newline at end of file
......@@ -53,7 +53,7 @@ def reduce_mask_count_x(mask: np.ndarray, width: int):
return maskr
def dev_subm_inds_v2(subm: bool = False, run_conv: bool = True):
def dev_subm_inds_v2(subm: bool = True, run_conv: bool = True):
limit_input_n = 16384
limit_input_n = None
np.random.seed(484)
......@@ -64,13 +64,13 @@ def dev_subm_inds_v2(subm: bool = False, run_conv: bool = True):
voxels_np = voxels_np[:limit_input_n]
indices_np = indices_np[:limit_input_n]
spatial_shape = [19, 18, 17]
sparse_dict = generate_sparse_data(spatial_shape, [1024], 128)
# spatial_shape = [19, 18, 17]
# sparse_dict = generate_sparse_data(spatial_shape, [1024], 128)
voxels_np = np.ascontiguousarray(sparse_dict["features"]).astype(
np.float32)
indices_np = np.ascontiguousarray(
sparse_dict["indices"][:, [3, 0, 1, 2]]).astype(np.int32)
# voxels_np = np.ascontiguousarray(sparse_dict["features"]).astype(
# np.float32)
# indices_np = np.ascontiguousarray(
# sparse_dict["indices"][:, [3, 0, 1, 2]]).astype(np.int32)
voxels = tv.from_numpy(voxels_np).cuda()
indices = tv.from_numpy(indices_np).cuda()
......@@ -96,7 +96,7 @@ def dev_subm_inds_v2(subm: bool = False, run_conv: bool = True):
dilation, out_padding, subm)
indice_num_per_loc_np = indice_num_per_loc.cpu().numpy()
indice_pairs_np = pair_ref.cpu().numpy()
algo = ConvAlgo.MaskSplitImplicitGemm
algo = ConvAlgo.MaskImplicitGemm
if algo == ConvAlgo.MaskImplicitGemm:
num_split = 1
else:
......@@ -116,7 +116,12 @@ def dev_subm_inds_v2(subm: bool = False, run_conv: bool = True):
pair_bwd = res[3]
pair_mask_fwd_splits = res[4]
pair_mask_bwd_splits = res[5]
mask_tv = torch_tensor_to_tv(pair_mask_fwd_splits[0], dtype=tv.uint32).cpu().numpy()
bench_reduce_mask(mask_tv)
return
mask_argsort_fwd_splits = res[6]
mask_argsort_bwd_splits = res[7]
masks = res[8]
......@@ -358,6 +363,47 @@ def dev_subm_inds_v2(subm: bool = False, run_conv: bool = True):
np.linalg.norm(
dw_cpu.reshape(-1) - dw_ref_kcrs.reshape(-1)))
def reverse_bits(a: np.ndarray):
a_unpack = np.unpackbits(a, bitorder="little")
return np.packbits(a_unpack)
def _count_mask_reduce(masks: np.ndarray):
masks_tv_count = SpconvOps.count_bits(tv.from_numpy(masks))
masks_tv_count_sum = masks_tv_count.numpy_view().sum()
reduce_count = reduce_mask_count(masks, 64)
print(masks_tv_count_sum, reduce_count, reduce_count / masks_tv_count_sum)
def bench_reduce_mask(masks: np.ndarray, width: int = 27):
# masks = np.random.randint(0, 2000000000, size=[100000], dtype=np.uint32)# & 0xffff
width_mask = np.array(0xffffffff, dtype=np.uint32) << (32 - width) >> (32 - width)
width_half_mask = np.array(0xffffffff, dtype=np.uint32) >> (32 - width // 2 - 1)
width_half_mask_left = width_half_mask << (width // 2 + 1)
print(bin(width_half_mask))
masks_sort = masks.copy()
masks_sort.sort()
_count_mask_reduce(masks_sort)
masks_sort = masks.copy() & width_half_mask
masks_sort.sort()
_count_mask_reduce(masks_sort)
# masks.sort()
# masks = masks & 0xffff
reversed_masks = SpconvOps.reverse_bits(tv.from_numpy(masks)).numpy()# & 0xffff0000
new_masks = np.concatenate([masks, reversed_masks])
np.random.shuffle(new_masks)
new_masks.sort()
_count_mask_reduce(new_masks)
new_masks &= width_half_mask
new_masks.sort()
_count_mask_reduce(new_masks)
if __name__ == "__main__":
dev_subm_inds_v2()
......@@ -17,3 +17,5 @@ from . import build as _build
from .core import ConvAlgo, AlgoHint
from . import constants
from .__version__ import __version__
SPCONV_VERSION_NUMBERS = list(map(int, __version__.split(".")))
\ No newline at end of file
......@@ -320,15 +320,16 @@ class SimpleGemm:
c_inds.shape)
avail = self.get_all_available(a, b, c, trans_a, trans_b, trans_c,
arch, shuffle_type)
c_ = c.clone()
# c may be weight, may non-contiguous.
# cumm.tensorview.Tensor don't support non-contiguous clone
c_ = c.clone_whole_storage()
times: List[float] = []
best_gather_params = (-1, -1, -1, -1)
best_scatter_params = (-1, -1, -1, -1)
all_profile_res: List[BestAlgoByProfile] = []
for desp in avail:
c_.zero_()
c_.zero_whole_storage_()
split_k_slices = 1
# TODO better splitk selection
if desp.split_k_serial and hint & AlgoHint.BackwardWeight.value:
......
......@@ -23,7 +23,17 @@ PACKAGE_ROOT = Path(__file__).parent.resolve()
EDITABLE_INSTALLED = project_is_installed(
PACKAGE_NAME) and project_is_editable(PACKAGE_NAME)
_filter_hwio_env = os.getenv("SPCONV_FILTER_HWIO", "0")
FILTER_HWIO = _filter_hwio_env == "1"
_filter_hwio_env = os.getenv("SPCONV_FILTER_HWIO", None)
if _filter_hwio_env is not None:
raise NotImplementedError("SPCONV_FILTER_HWIO is deprecated. use SPCONV_SAVED_WEIGHT_LAYOUT instead.")
DISABLE_JIT = os.getenv("SPCONV_DISABLE_JIT", "0") == "1"
NDIM_DONT_CARE = 3
FILTER_HWIO = False
SAVED_WEIGHT_LAYOUT = os.getenv("SPCONV_SAVED_WEIGHT_LAYOUT", "")
if SAVED_WEIGHT_LAYOUT != "":
assert SAVED_WEIGHT_LAYOUT in ["KRSC", "RSKC", "RSCK"], "please set SAVED_WEIGHT_LAYOUT to KRSC, RSKC or RSCK"
ALL_WEIGHT_IS_KRSC = True
\ No newline at end of file
......@@ -83,6 +83,9 @@ SHUFFLE_SIMT_PARAMS: List[GemmAlgoParams] = [
"f32,f32,f32,f32,f32", 2, kernel.GemmAlgo.Simt, None),
*gen_shuffle_params((32, 32, 32), (32, 32, 8), ["f32,f32,f32,f32,f32"],
"f32,f32,f32,f32,f32", 2, kernel.GemmAlgo.Simt, None),
*gen_shuffle_params((16, 32, 8), (16, 16, 8), ["f32,f32,f32,f32,f32"],
"f32,f32,f32,f32,f32", 2, kernel.GemmAlgo.Simt, None),
# fall back kernels if mat is misaligned for half
# TODO use access-per-vector kernel instead of simt kernel for fallback
*gen_shuffle_params((128, 128, 8), (32, 64, 8), ["f16,f16,f16,f16,f16"],
......
......@@ -283,6 +283,13 @@ class SpconvOps:
"""
...
@staticmethod
def reverse_bits(a: Tensor) -> Tensor:
"""
Args:
a:
"""
...
@staticmethod
def calc_point2voxel_meta_data(vsize_xyz: List[float], coors_range_xyz: List[float]) -> Tuple[List[float], List[int], List[int], List[float]]:
"""
Args:
......
......@@ -815,6 +815,66 @@ class SpconvOps(pccm.Class):
""")
return code.ret("tv::Tensor")
@pccm.pybind.mark
@pccm.cuda.static_function
def reverse_bits(self):
code = pccm.FunctionCode()
if CUMM_CPU_ONLY_BUILD:
return code.make_invalid()
code.add_dependency(TensorViewKernel)
code.arg("a", "tv::Tensor")
code.code_after_include = f"""
__global__ void reverse_bits_kernel_64(const uint64_t* data, uint64_t* out, int size){{
for (int i : tv::KernelLoopX<int>(size)){{
out[i] = __brevll(reinterpret_cast<const unsigned long long*>(data)[i]);
}}
}}
__global__ void reverse_bits_kernel(const uint32_t* data, uint32_t* out, int size){{
for (int i : tv::KernelLoopX<int>(size)){{
out[i] = __brev(data[i]);
}}
}}
uint32_t reverse(uint32_t x)
{{
x = ((x >> 1) & 0x55555555u) | ((x & 0x55555555u) << 1);
x = ((x >> 2) & 0x33333333u) | ((x & 0x33333333u) << 2);
x = ((x >> 4) & 0x0f0f0f0fu) | ((x & 0x0f0f0f0fu) << 4);
x = ((x >> 8) & 0x00ff00ffu) | ((x & 0x00ff00ffu) << 8);
x = ((x >> 16) & 0xffffu) | ((x & 0xffffu) << 16);
return x;
}}
int reverse(uint64_t i)
{{
return (reverse(uint32_t(i)) << 32) | reverse(uint32_t(i >> 32));
}}
"""
code.raw(f"""
tv::Tensor res(a.shape(), a.dtype(), a.device());
tv::dispatch<uint32_t, uint64_t>(a.dtype(), [&](auto I){{
using T = TV_DECLTYPE(I);
auto res_ptr = res.data_ptr<T>();
auto a_ptr = a.data_ptr<const T>();
if (a.device() == -1){{
for (int i = 0; i < a.size(); ++i){{
res_ptr[i] = reverse(a_ptr[i]);
}}
}}else{{
tv::cuda::Launch launcher(a.size());
tv::if_constexpr<std::is_same<T, uint64_t>::value>([=](auto _)mutable{{
launcher(_(reverse_bits_kernel_64), a_ptr, res_ptr, int(a.size()));
}}, [=](auto _)mutable{{
launcher(_(reverse_bits_kernel), a_ptr, res_ptr, int(a.size()));
}});
}}
}});
return res;
""")
return code.ret("tv::Tensor")
@pccm.pybind.mark
@pccm.static_function
def calc_point2voxel_meta_data(self):
......
......@@ -16,7 +16,6 @@ import pccm
from ccimport import compat
from cumm.common import TensorView
class OMPLib(pccm.Class):
def __init__(self):
super().__init__()
......
......@@ -13,7 +13,7 @@
# limitations under the License.
import pccm
from cumm.common import TensorView
from cumm.common import TensorView, GemmDTypes
from cumm.constants import CUMM_CPU_ONLY_BUILD
from spconv.csrc.sparse.cpu_core import OMPLib
from typing import List
......@@ -24,7 +24,7 @@ class GatherCPU(pccm.Class):
super().__init__()
if CUMM_CPU_ONLY_BUILD:
self.add_dependency(OMPLib)
self.add_dependency(TensorView)
self.add_dependency(TensorView, GemmDTypes)
self.add_include("tensorview/parallel/all.h")
@pccm.static_function
......@@ -39,7 +39,7 @@ class GatherCPU(pccm.Class):
auto nhot = inds.dim(0);
int channel = in.dim(1);
tv::dispatch<float, double>(out.dtype(), [&](auto I){{
tv::dispatch<float, double, tv::bfloat16_t, tv::half_t>(out.dtype(), [&](auto I){{
auto indices_data = inds.data_ptr<const int>();
using T = TV_DECLTYPE(I);
T *buffer_data = out.data_ptr<T>();
......@@ -65,7 +65,7 @@ class GatherCPU(pccm.Class):
// tv::check_shape(inds, {{in.dim(0)}});
auto nhot = inds.dim(0);
int channel = in.dim(1);
tv::dispatch<float, double>(out.dtype(), [&](auto I){{
tv::dispatch<float, double, tv::bfloat16_t, tv::half_t>(out.dtype(), [&](auto I){{
using T = TV_DECLTYPE(I);
auto indices_data = inds.data_ptr<const int>();
const T *buffer_data = in.data_ptr<const T>();
......
......@@ -18,7 +18,7 @@ from cumm.gemm.core.metaarray import MetaArray, seq
from cumm import dtypes
import pccm
from cumm.gemm.layout import TensorGeneric, to_stride
from cumm.common import TensorView, TensorViewHashKernel, TensorViewKernel, ThrustLib, GemmBasic
from cumm.common import TensorView, GemmDTypes, TensorViewKernel, ThrustLib, GemmBasic
from cumm.gemm import codeops
from typing import List
from cumm.conv.params import ConvProblem
......@@ -353,7 +353,7 @@ class IndiceMaxPool(pccm.Class):
class IndiceMaxPoolCPU(pccm.Class):
def __init__(self):
super().__init__()
self.add_dependency(TensorView)
self.add_dependency(TensorView, GemmDTypes)
if CUMM_CPU_ONLY_BUILD:
self.add_dependency(OMPLib)
self.add_include("tensorview/parallel/all.h")
......@@ -370,7 +370,7 @@ class IndiceMaxPoolCPU(pccm.Class):
code.raw(f"""
int nhot = out_inds.dim(0);
int num_features = in.dim(1);
tv::dispatch<float, double>(out.dtype(), [&](auto I){{
tv::dispatch<float, double, tv::half_t, tv::bfloat16_t>(out.dtype(), [&](auto I){{
using T = TV_DECLTYPE(I);
auto out_features = out.data_ptr<T>();
auto in_features = in.data_ptr<const T>();
......@@ -410,7 +410,7 @@ class IndiceMaxPoolCPU(pccm.Class):
code.raw(f"""
int nhot = out_inds.dim(0);
int num_features = in.dim(1);
tv::dispatch<float, double>(out.dtype(), [&](auto I){{
tv::dispatch<float, double, tv::half_t, tv::bfloat16_t>(out.dtype(), [&](auto I){{
using T = TV_DECLTYPE(I);
auto out_features = out.data_ptr<const T>();
auto in_features = in.data_ptr<const T>();
......
......@@ -23,15 +23,17 @@ from torch.nn import init
from torch.nn.parameter import Parameter
from spconv import pytorch as spconv
from spconv import SPCONV_VERSION_NUMBERS
from spconv.core import ConvAlgo
from spconv.pytorch import functional as Fsp
from spconv.pytorch import ops
from spconv.cppconstants import CPU_ONLY_BUILD
from spconv.pytorch.core import IndiceData, SparseConvTensor, ImplicitGemmIndiceData
from spconv.pytorch.modules import SparseModule
from spconv.constants import FILTER_HWIO
from spconv.constants import SAVED_WEIGHT_LAYOUT, ALL_WEIGHT_IS_KRSC
from spconv.utils import nullcontext
FILTER_HWIO = False
def _calculate_fan_in_and_fan_out_hwio(tensor, algo: ConvAlgo):
dimensions = tensor.ndimension()
......@@ -132,7 +134,7 @@ class SparseConvolution(SparseModule):
assert algo == ConvAlgo.Native, "cpu only build only support native algorithm"
self.algo = algo
# self.algo = ConvAlgo.Native
if self.algo == ConvAlgo.Native:
if self.algo == ConvAlgo.Native and not ALL_WEIGHT_IS_KRSC:
if FILTER_HWIO:
# RSCK
self.weight = Parameter(
......@@ -152,6 +154,37 @@ class SparseConvolution(SparseModule):
self.register_parameter('bias', None)
self.reset_parameters()
self._register_load_state_dict_pre_hook(self._load_weight_different_layout)
def _load_weight_different_layout(
self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
if not SAVED_WEIGHT_LAYOUT:
return
key = prefix + "weight"
assert key in state_dict
ndim = self.ndim
if SAVED_WEIGHT_LAYOUT == "RSKC":
state_dict[key] = state_dict[key].permute(ndim, *range(ndim), ndim + 1).contiguous()
elif SAVED_WEIGHT_LAYOUT == "RSCK":
state_dict[key] = state_dict[key].permute(ndim + 1, *range(ndim), ndim).contiguous()
if ALL_WEIGHT_IS_KRSC or self.algo != ConvAlgo.Native:
# in spconv 2.2, we only support KRSC layout.
if SAVED_WEIGHT_LAYOUT == "RSKC":
state_dict[key] = state_dict[key].permute(ndim, *range(ndim), ndim + 1).contiguous()
elif SAVED_WEIGHT_LAYOUT == "RSCK":
state_dict[key] = state_dict[key].permute(ndim + 1, *range(ndim), ndim).contiguous()
else:
if self.algo == ConvAlgo.Native:
# to RSCK
if SAVED_WEIGHT_LAYOUT == "RSKC":
state_dict[key] = state_dict[key].permute(*range(ndim), ndim + 1, ndim).contiguous()
elif SAVED_WEIGHT_LAYOUT == "KRSC":
state_dict[key] = state_dict[key].permute(*range(1, ndim + 1), 0, ndim + 1).contiguous()
def extra_repr(self):
s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}'
', stride={stride}')
......
......@@ -31,7 +31,7 @@ _TORCH_DTYPE_TO_TV = {
def torch_tensor_to_tv(ten: torch.Tensor,
dtype: Optional[int] = None,
shape: Optional[List[int]] = None):
assert ten.is_contiguous(), "must be contiguous tensor"
# assert ten.is_contiguous(), "must be contiguous tensor"
ptr = ten.data_ptr()
device = ten.device
if device.type == "cpu":
......@@ -44,7 +44,7 @@ def torch_tensor_to_tv(ten: torch.Tensor,
shape = list(ten.shape)
if dtype is None:
dtype = _TORCH_DTYPE_TO_TV[ten.dtype]
return tv.from_blob(ptr, shape, dtype, tv_device)
return tv.from_blob(ptr, shape, list(ten.stride()), dtype, tv_device)
def get_current_stream():
......
......@@ -36,7 +36,7 @@ else:
GEMM = None
CONV = None
import time
from spconv.constants import FILTER_HWIO
from spconv.constants import FILTER_HWIO, ALL_WEIGHT_IS_KRSC
from cumm.gemm import codeops
from spconv.tools import CUDAKernelTimer
......@@ -606,21 +606,40 @@ def indice_conv(features: torch.Tensor,
if features.dtype == torch.int8 or features.dtype == torch.qint8:
raise NotImplementedError("work in progress")
if not ALL_WEIGHT_IS_KRSC:
kv_dim = 0
is_KC_not_CK = not FILTER_HWIO
if FILTER_HWIO:
out_channel = filters.shape[-1]
filter_shape_per_kv = [filters.shape[-2], out_channel]
else:
out_channel = filters.shape[-2]
filter_shape_per_kv = [out_channel, filters.shape[-1]]
filters = filters.reshape(-1, *filters.shape[-2:])
kv = filters.shape[0]
else:
kv_dim = 1
out_channel = filters.shape[0]
filters = filters.reshape(out_channel, -1, filters.shape[-1])
is_KC_not_CK = True
kv = filters.shape[1]
filter_shape_per_kv = [out_channel, filters.shape[-1]]
kv_center = kv // 2
if subm:
# out_features = torch.zeros((num_activate_out, out_channel),
# dtype=features.dtype,
# device=features.device)
if FILTER_HWIO:
if not ALL_WEIGHT_IS_KRSC:
if not is_KC_not_CK:
out_features = torch.mm(features, filters[kv_center])
else:
out_features = torch.mm(features, filters[kv_center].T)
else:
out_features = torch.mm(features, filters[:, kv_center].T)
else:
out_features = torch.zeros((num_activate_out, out_channel),
dtype=features.dtype,
......@@ -640,7 +659,6 @@ def indice_conv(features: torch.Tensor,
pair_in = indice_pairs_tv[int(inverse)]
pair_out = indice_pairs_tv[int(not inverse)]
filters_tv = torch_tensor_to_tv(filters)
if not features.is_cuda:
# perform gather-mm-scatter_add for cpu data
assert not filters.is_cuda
......@@ -662,7 +680,8 @@ def indice_conv(features: torch.Tensor,
inp_indices = pair_in[i].slice_first_axis(0, nhot)
out_indices = pair_out[i].slice_first_axis(0, nhot)
SpconvOps.gather_cpu(inp_buffer_tv, a, inp_indices)
filters_cur = filters[i] if FILTER_HWIO else filters[i].T
filters_i = filters.select(kv_dim, i)
filters_cur = filters_i if not is_KC_not_CK else filters_i.T
torch.mm(inp_buffer[:nhot], filters_cur, out=out_buffer[:nhot])
SpconvOps.scatter_add_cpu(c, out_buffer_tv, out_indices)
......@@ -689,10 +708,10 @@ def indice_conv(features: torch.Tensor,
filters_tv.dtype,
c.dtype,
a.shape,
filters.shape[-2:],
filter_shape_per_kv,
c.shape,
False,
False if FILTER_HWIO else True,
is_KC_not_CK,
False,
arch=arch,
shuffle_type=ShuffleStrideType.ShuffleAC,
......@@ -708,13 +727,14 @@ def indice_conv(features: torch.Tensor,
inp_indices = torch_tensor_to_tv(inp_indices_th)
out_indices = torch_tensor_to_tv(out_indices_th)
filter_tv = torch_tensor_to_tv(filters)[profile_idx]
filter_tv = torch_tensor_to_tv(filters).select(kv_dim, profile_idx)
tuned_res, min_time = GEMM.tune_and_cache(
a,
filter_tv,
c,
False,
False if FILTER_HWIO else True,
is_KC_not_CK,
False,
arch=arch,
shuffle_type=ShuffleStrideType.ShuffleAC,
......@@ -736,7 +756,7 @@ def indice_conv(features: torch.Tensor,
continue
inp_indices = pair_in[i].slice_first_axis(0, nhot)
out_indices = pair_out[i].slice_first_axis(0, nhot)
b = filters_tv[i]
b = filters_tv.select(kv_dim, i)
# inp @ filter.T, NC @ KC
beta = 1.0 if inited else 0.0
algo_desp = GEMM.run_with_tuned_result(
......@@ -745,7 +765,7 @@ def indice_conv(features: torch.Tensor,
b,
c,
False,
False if FILTER_HWIO else True,
is_KC_not_CK,
False,
arch=arch,
stream=stream,
......@@ -783,11 +803,27 @@ def indice_conv_backward(features: torch.Tensor,
timer: CUDAKernelTimer = CUDAKernelTimer(False)):
# print(out_bp.mean(), out_bp.max(), out_bp.min())
num_activate_out = out_bp.shape[0]
out_channel = out_bp.shape[-1]
filters_shape = filters.shape
if not ALL_WEIGHT_IS_KRSC:
kv_dim = 0
is_KC_not_CK = not FILTER_HWIO
if FILTER_HWIO:
out_channel = filters.shape[-1]
filter_shape_per_kv = [filters.shape[-2], out_channel]
else:
out_channel = filters.shape[-2]
filter_shape_per_kv = [out_channel, filters.shape[-1]]
filters = filters.reshape(-1, *filters.shape[-2:])
kv = filters.shape[0]
else:
kv_dim = 1
out_channel = filters.shape[0]
filters = filters.reshape(out_channel, -1, filters.shape[-1])
is_KC_not_CK = True
kv = filters.shape[1]
filter_shape_per_kv = [out_channel, filters.shape[-1]]
kv_center = kv // 2
if not out_bp.is_contiguous():
out_bp = out_bp.contiguous()
......@@ -797,20 +833,24 @@ def indice_conv_backward(features: torch.Tensor,
if subm:
dfilters = torch.zeros_like(filters)
if FILTER_HWIO:
if not ALL_WEIGHT_IS_KRSC:
if not is_KC_not_CK:
torch.mm(features.T, out_bp, out=dfilters[kv_center])
# TODO can we use torch mm for f16 backward weight?
din = torch.mm(out_bp, filters[kv_center].T)
else:
torch.mm(out_bp.T, features, out=dfilters[kv_center])
# TODO can we use torch mm for f16 backward weight?
din = torch.mm(out_bp, filters[kv_center])
else:
# KN @ NC
torch.mm(out_bp.T, features, out=dfilters[:, kv_center])
# NK @ KC
din = torch.mm(out_bp, filters[:, kv_center])
else:
dfilters = torch.zeros_like(filters)
din = torch.zeros_like(features)
if kv == 1 and subm:
return (din, dfilters.reshape(filters_shape))
inited: bool = subm
indice_pairs_tv = torch_tensor_to_tv(indice_pairs)
# torch slice (a_th[x]) is very slow, so we need to use tv.Tensor earlier.
......@@ -854,12 +894,18 @@ def indice_conv_backward(features: torch.Tensor,
out_indices = pair_out[i].slice_first_axis(0, nhot)
SpconvOps.gather_cpu(inp_buffer_tv, features_tv, inp_indices)
SpconvOps.gather_cpu(out_buffer_tv, out_bp_tv, out_indices)
filters_T_cur = filters[i].T if FILTER_HWIO else filters[i]
dfilters_cur = dfilters[i] if FILTER_HWIO else dfilters[i].T
torch.mm(inp_buffer[:nhot].T, out_buffer[:nhot], out=dfilters_cur)
torch.mm(out_buffer[:nhot], filters_T_cur, out=inp_buffer[:nhot])
filters_i = filters.select(kv_dim, i)
dfilters_i = dfilters.select(kv_dim, i)
filters_KC = filters_i if is_KC_not_CK else filters_i.T
if is_KC_not_CK:
# KN @ NC
torch.mm(out_buffer[:nhot].T, inp_buffer[:nhot], out=dfilters_i)
else:
# CN @ NK
torch.mm(inp_buffer[:nhot].T, out_buffer[:nhot], out=dfilters_i)
# NK @ KC
torch.mm(out_buffer[:nhot], filters_KC, out=inp_buffer[:nhot])
SpconvOps.scatter_add_cpu(din_tv, inp_buffer_tv, inp_indices)
return (din, dfilters.reshape(filters_shape))
......@@ -883,10 +929,10 @@ def indice_conv_backward(features: torch.Tensor,
filters_tv.dtype,
din_tv.dtype,
out_bp_tv.shape,
filters.shape[-2:],
filter_shape_per_kv,
din_tv.shape,
False,
True if FILTER_HWIO else False,
not is_KC_not_CK,
False,
arch=arch,
shuffle_type=ShuffleStrideType.ShuffleAC,
......@@ -896,13 +942,13 @@ def indice_conv_backward(features: torch.Tensor,
if tuned_res_dgrad is None:
inp_indices = pair_in[profile_idx].slice_first_axis(0, nhot_profile)
out_indices = pair_out[profile_idx].slice_first_axis(0, nhot_profile)
filter_tv = filters_tv[profile_idx]
filter_tv = filters_tv.select(kv_dim, profile_idx)
tuned_res_dgrad, min_time = GEMM.tune_and_cache(
out_bp_tv,
filter_tv,
din_tv,
False,
True if FILTER_HWIO else False,
not is_KC_not_CK,
False,
arch=arch,
shuffle_type=ShuffleStrideType.ShuffleAC,
......@@ -912,7 +958,7 @@ def indice_conv_backward(features: torch.Tensor,
beta=0.0,
hint=AlgoHint.BackwardInput.value,
stream=stream)
if not FILTER_HWIO:
if is_KC_not_CK:
a_wgrad = out_bp_tv
b_wgrad = features_tv
else:
......@@ -924,7 +970,7 @@ def indice_conv_backward(features: torch.Tensor,
filters_tv.dtype,
a_wgrad.shape,
b_wgrad.shape,
filters.shape[-2:],
filter_shape_per_kv,
True,
False,
False,
......@@ -937,8 +983,8 @@ def indice_conv_backward(features: torch.Tensor,
if tuned_res_wgrad is None:
inp_indices = pair_in[profile_idx].slice_first_axis(0, nhot_profile)
out_indices = pair_out[profile_idx].slice_first_axis(0, nhot_profile)
dfilter_tv = dfilters_tv[profile_idx]
if not FILTER_HWIO:
dfilter_tv = dfilters_tv.select(kv_dim, profile_idx)
if is_KC_not_CK:
a_inds_wgrad = out_indices
b_inds_wgrad = inp_indices
else:
......@@ -961,7 +1007,7 @@ def indice_conv_backward(features: torch.Tensor,
stream=stream)
# print(tuned_res_wgrad.algo_desp, tuned_res_wgrad.splitk, min_time)
# get workspace size for wgrad
if not FILTER_HWIO:
if is_KC_not_CK:
a_shape = [maxnhot, out_bp_tv.dim(1)]
b_shape = [maxnhot, features_tv.dim(1)]
else:
......@@ -1003,13 +1049,13 @@ def indice_conv_backward(features: torch.Tensor,
inp_indices = pair_in[i].slice_first_axis(0, nhot)
out_indices = pair_out[i].slice_first_axis(0, nhot)
# out.T @ inp, NK @ NC
# print(features_tv.shape, out_bp_tv.shape)
filter_i_tv = filters_tv.select(kv_dim, i)
GEMM.run_with_tuned_result(tuned_res_dgrad,
out_bp_tv,
filters_tv[i],
filter_i_tv,
din_tv,
False,
True if FILTER_HWIO else False,
not is_KC_not_CK,
False,
arch=arch,
stream=stream,
......@@ -1033,7 +1079,7 @@ def indice_conv_backward(features: torch.Tensor,
GEMM.run_with_tuned_result(tuned_res_wgrad,
a,
b,
dfilters_tv[i],
dfilters_tv.select(kv_dim, i),
True,
False,
False,
......
......@@ -168,8 +168,8 @@ class Net(nn.Module):
# nn.ReLU(),
# spconv.SparseInverseConv3d(256, 128, 2, indice_key="m5", bias=False, algo=algo),
# # nn.BatchNorm1d(128),
# # nn.ReLU(),
# # # nn.BatchNorm1d(128),
# # # nn.ReLU(),
# spconv.SparseInverseConv3d(128, 64, 2, indice_key="m4", bias=False, algo=algo),
)
......@@ -312,7 +312,8 @@ def main():
# MaskImpGemm: 51.0ms
# MaskSplitImpGemm: 41.1ms
# algo = None
net = Net(spatial_shape, algo).to(device).eval().to(dtype).train()
net = Net(spatial_shape, algo).to(device).eval().to(dtype)# .train()
# net.load_state_dict(net.state_dict())
spconv.assign_name_for_sparse_modules(net)
print(coors_th.shape)
out = net(voxels_th, coors_th, 1)
......@@ -323,25 +324,25 @@ def main():
print(out.spatial_shape, out.features.mean(), out.features.max(),
out.features.min())
times = []
with torch.no_grad():
for i in range(20):
print("------------")
torch.cuda.synchronize()
t = time.time()
out_nograd = net(voxels_th, coors_th, 1, True)
timer = out_nograd._timer
res = timer.collect_by_name("forward", timer.get_all_pair_time())
res2 = timer.collect_by_name("forward0", timer.get_all_pair_time())
# times = []
# with torch.no_grad():
# for i in range(20):
# print("------------")
# torch.cuda.synchronize()
# t = time.time()
# out_nograd = net(voxels_th, coors_th, 1, False)
# timer = out_nograd._timer
# # res = timer.collect_by_name("forward", timer.get_all_pair_time())
# # res2 = timer.collect_by_name("forward0", timer.get_all_pair_time())
print(sum(res.values()) + sum(res2.values()))
# print(timer.get_all_pair_time())
# # print(sum(res.values()) + sum(res2.values()))
# # print(timer.get_all_pair_time())
# print(sum(timer.get_all_pair_time().values()))
torch.cuda.synchronize()
# sort_bench()
times.append(time.time() - t)
print("spconv time", np.mean(times[10:]))
# # print(sum(timer.get_all_pair_time().values()))
# torch.cuda.synchronize()
# # sort_bench()
# times.append(time.time() - t)
# print("spconv time", np.mean(times[10:]))
# times = []
# for i in range(10):
......
......@@ -23,7 +23,7 @@ from spconv.core import ConvAlgo
import spconv.pytorch as spconv
from spconv.test_utils import TestCase, generate_sparse_data, params_grid
from spconv.constants import FILTER_HWIO
from spconv.constants import ALL_WEIGHT_IS_KRSC, FILTER_HWIO
# import sparseconvnet as scn
# we must disable tf32 to increase reference precision.
......@@ -368,14 +368,14 @@ class TestSpConv(TestCase):
ConvAlgo.Native, ConvAlgo.MaskImplicitGemm,
ConvAlgo.MaskSplitImplicitGemm
]
algos = [ConvAlgo.MaskSplitImplicitGemm]
# algos = [ConvAlgo.Native]
for dev, shape, bs, IC, OC, k, s, p, d, al in params_grid(
devices, shapes, batchsizes, in_channels, out_channels, ksizes,
strides, paddings, dilations, algos):
if all([s > 1, d > 1]):
continue # don't support this.
print(k, s, p, d)
# print(dev, shape, bs, IC, OC, k, s, p, d)
device = torch.device(dev)
num_points = [1000] * bs
dtype = torch.float32
......@@ -405,7 +405,7 @@ class TestSpConv(TestCase):
features_dense_t = torch.from_numpy(features_dense).to(device).to(
dtype)
features_dense_t.requires_grad = True
if net.algo == ConvAlgo.Native:
if net.algo == ConvAlgo.Native and not ALL_WEIGHT_IS_KRSC:
if FILTER_HWIO:
filters = np.random.uniform(-1, 1,
size=[k, k, k, IC,
......@@ -451,7 +451,7 @@ class TestSpConv(TestCase):
for layer, layer_ref in zip(net.net, net_ref.net):
dw = layer.weight.grad.detach().cpu().numpy()
dw_ref = layer_ref.weight.grad.detach().cpu().numpy()
if net.algo == ConvAlgo.Native:
if net.algo == ConvAlgo.Native and not ALL_WEIGHT_IS_KRSC:
if FILTER_HWIO:
dw = dw.transpose(4, 3, 0, 1, 2)
else:
......@@ -829,4 +829,4 @@ if __name__ == '__main__':
# main(algo=spconv.ConvAlgo.SparseConvNet, dtype=torch.float32)
# TestCase().assertAllClose(out_my, out_ref)
# unittest.main()
TestSpConv().testSpMaxPool3d()
TestSpConv().testSpConv3d()
# Copyright 2021 Yan Yan
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Compare results between different algo:
CPU: gather-mm-scatter
Native: Fused gather-mm-scatter
ImplicitGemm
"""
......@@ -28,12 +28,12 @@ if (($CUDA_VERSION_FULL -eq "10.2") -or ($CUDA_VERSION_FULL -eq "11.0") -or ($CU
)
} elseif ($CUDA_VERSION_FULL -eq "11.3"){
$CUDA_PACKAGES_IN = @(
"cuda_nvcc";
"nvcc";
"visual_studio_integration";
"cuda_nvrtc";
"cuda_cudart";
"cuda_thrust";
"libcurand";
"nvrtc_dev";
"cudart";
"thrust";
"curand_dev";
)
} else {
# after cuda 11.4
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment